فهرست منبع

added torch1.7 support, switch to grpc 1.33, grpc bump, improved tests & logging, (#116)

* specify force_reinstall to override the circleci cache

* switch from warnings.warn to logger.warning

* check for compression even if compression_type is None

* bump version

* require newer grpcio

* switch to asyncio testing

* use pytest.mark.forked/pytest.mark.asyncio where applicable

* detach before serializing

* set requires_grad manually

* switch from grpc.experimental.aio to grpc.aio
justheuristic 4 سال پیش
والد
کامیت
2bd481c73f

+ 4 - 3
.circleci/config.yml

@@ -11,12 +11,13 @@ jobs:
     steps:
     steps:
       - checkout
       - checkout
       - python/load-cache
       - python/load-cache
-      - run: pip uninstall -y pytest codecov  # temporary override for broken cache
-      - run: pip install codecov pytest tqdm scikit-learn
+      - run: pip uninstall -y pytest codecov
+      # note: uninstall is required because otherwise circleci cache will lose track of pytest/codecov executables
+      - run: pip install pytest pytest-forked pytest-asyncio codecov tqdm scikit-learn
       - python/install-deps
       - python/install-deps
       - python/save-cache
       - python/save-cache
       - run:
       - run:
-          command: pip install -e .
+          command: pip install --force-reinstall -e .
           name: setup
           name: setup
       - run:
       - run:
           command: pytest ./tests
           command: pytest ./tests

+ 2 - 1
docs/user/contributing.md

@@ -28,9 +28,10 @@ cd hivemind
 python setup.py develop
 python setup.py develop
 ``` 
 ``` 
 
 
-To run tests, you will also need to `pip install pytest codecov tqdm scikit-learn`.
+To run tests, you will also need to `pip install pytest pytest-forked pytest-asyncio codecov tqdm scikit-learn`.
 You can run all tests with `pytest ./tests` or choose a specific set, e.g. `pytest ./tests/test_dht.py`.
 You can run all tests with `pytest ./tests` or choose a specific set, e.g. `pytest ./tests/test_dht.py`.
 
 
+
 To build docs locally,
 To build docs locally,
 1. `pip install sphinx sphinx_rtd_theme recommonmark`
 1. `pip install sphinx sphinx_rtd_theme recommonmark`
 2. make sure you ran setup.py (see above)
 2. make sure you ran setup.py (see above)

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.9'
+__version__ = '0.8.10'

+ 0 - 1
hivemind/client/expert.py

@@ -3,7 +3,6 @@ from functools import lru_cache
 from typing import Tuple, Optional, Any, Dict
 from typing import Tuple, Optional, Any, Dict
 
 
 import grpc
 import grpc
-import grpc.experimental.aio
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable

+ 2 - 3
hivemind/dht/__init__.py

@@ -17,7 +17,6 @@ import ctypes
 import heapq
 import heapq
 import multiprocessing as mp
 import multiprocessing as mp
 import re
 import re
-import warnings
 from collections import deque
 from collections import deque
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
@@ -163,11 +162,11 @@ class DHT(mp.Process):
             raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
             raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
 
 
     def shutdown(self) -> None:
     def shutdown(self) -> None:
-        """ Shuts down the dht process """
+        """ Shut down a running dht process """
         if self.is_alive():
         if self.is_alive():
             self.terminate()
             self.terminate()
         else:
         else:
-            warnings.warn("DHT shutdown has no effect: dht process is already not alive")
+            logger.warning("DHT shutdown has no effect: dht process is already not alive")
 
 
     @property
     @property
     def port(self) -> Optional[int]:
     def port(self) -> Optional[int]:

+ 3 - 4
hivemind/dht/node.py

@@ -6,7 +6,6 @@ from collections import defaultdict
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from functools import partial
 from functools import partial
 from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
 from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
-from warnings import warn
 
 
 from sortedcontainers import SortedList
 from sortedcontainers import SortedList
 
 
@@ -36,7 +35,7 @@ class DHTNode:
 
 
     * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
     * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
     * store - send several (key, value, expiration_time) pairs to the same peer (like Kademlia STORE, but in bulk)
     * store - send several (key, value, expiration_time) pairs to the same peer (like Kademlia STORE, but in bulk)
-    * find - request one or several keys, get values & expiration (if peer finds it locally) and :bucket_size: of
+    * find - request one or several keys, get values and expiration (if peer finds it locally) and :bucket_size: of
         nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
         nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
         This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
         This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
 
 
@@ -146,7 +145,7 @@ class DHTNode:
                 finished_pings |= finished_in_time
                 finished_pings |= finished_in_time
 
 
             if not finished_pings:
             if not finished_pings:
-                warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
+                logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
 
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
@@ -189,7 +188,7 @@ class DHTNode:
         num_workers = num_workers if num_workers is not None else self.num_workers
         num_workers = num_workers if num_workers is not None else self.num_workers
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
         if k_nearest > beam_size:
         if k_nearest > beam_size:
-            warn("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
+            logger.warning("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
         if node_to_endpoint is None:
         if node_to_endpoint is None:
             node_to_endpoint: Dict[DHTID, Endpoint] = dict()
             node_to_endpoint: Dict[DHTID, Endpoint] = dict()
             for query in queries:
             for query in queries:

+ 10 - 12
hivemind/dht/protocol.py

@@ -3,10 +3,8 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
-from warnings import warn
 
 
 import grpc
 import grpc
-import grpc.experimental.aio
 
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue, ValueWithExpiration
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue, ValueWithExpiration
@@ -19,7 +17,7 @@ logger = get_logger(__name__)
 class DHTProtocol(dht_grpc.DHTServicer):
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
-    channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
+    channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.aio.Server
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     # fmt:on
     # fmt:on
 
 
@@ -51,8 +49,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
 
 
         if listen:  # set up server to process incoming rpc requests
         if listen:  # set up server to process incoming rpc requests
-            grpc.experimental.aio.init_grpc_aio()
-            self.server = grpc.experimental.aio.server(**kwargs)
+            grpc.aio.init_grpc_aio()
+            self.server = grpc.aio.server(**kwargs)
             dht_grpc.add_DHTServicer_to_server(self, self.server)
             dht_grpc.add_DHTServicer_to_server(self, self.server)
 
 
             found_port = self.server.add_insecure_port(listen_on)
             found_port = self.server.add_insecure_port(listen_on)
@@ -64,8 +62,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
             # note: use empty node_info so peers wont add you to their routing tables
             # note: use empty node_info so peers wont add you to their routing tables
             self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
             self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
             if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
             if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
-                warn(f"DHTProtocol has no server (due to listen=False), listen_on"
-                     f"and kwargs have no effect (unused kwargs: {kwargs})")
+                logger.warning(f"DHTProtocol has no server (due to listen=False), listen_on"
+                               f"and kwargs have no effect (unused kwargs: {kwargs})")
         return self
         return self
 
 
     def __init__(self, *, _initialized_with_create=False):
     def __init__(self, *, _initialized_with_create=False):
@@ -78,11 +76,11 @@ class DHTProtocol(dht_grpc.DHTServicer):
         if self.server:
         if self.server:
             await self.server.stop(timeout)
             await self.server.stop(timeout)
         else:
         else:
-            warn("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
+            logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
 
 
     def _get(self, peer: Endpoint) -> dht_grpc.DHTStub:
     def _get(self, peer: Endpoint) -> dht_grpc.DHTStub:
         """ get a DHTStub that sends requests to a given peer """
         """ get a DHTStub that sends requests to a given peer """
-        channel = grpc.experimental.aio.insecure_channel(peer, options=self.channel_options)
+        channel = grpc.aio.insecure_channel(peer, options=self.channel_options)
         return dht_grpc.DHTStub(channel)
         return dht_grpc.DHTStub(channel)
 
 
     async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
     async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
@@ -96,7 +94,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         try:
         try:
             async with self.rpc_semaphore:
             async with self.rpc_semaphore:
                 peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
                 peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
-        except grpc.experimental.aio.AioRpcError as error:
+        except grpc.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             peer_info = None
             peer_info = None
         responded = bool(peer_info and peer_info.node_id)
         responded = bool(peer_info and peer_info.node_id)
@@ -162,7 +160,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
             return response.store_ok
             return response.store_ok
-        except grpc.experimental.aio.AioRpcError as error:
+        except grpc.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             return None
             return None
@@ -226,7 +224,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                     logger.error(f"Unknown result type: {result.type}")
                     logger.error(f"Unknown result type: {result.type}")
 
 
             return output
             return output
-        except grpc.experimental.aio.AioRpcError as error:
+        except grpc.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
 

+ 3 - 3
hivemind/server/connection_handler.py

@@ -4,7 +4,7 @@ import os
 import pickle
 import pickle
 from typing import Dict
 from typing import Dict
 
 
-import grpc.experimental.aio
+import grpc
 import torch
 import torch
 import uvloop
 import uvloop
 
 
@@ -35,9 +35,9 @@ class ConnectionHandler(mp.Process):
         loop = asyncio.new_event_loop()
         loop = asyncio.new_event_loop()
 
 
         async def _run():
         async def _run():
-            grpc.experimental.aio.init_grpc_aio()
+            grpc.aio.init_grpc_aio()
             logger.debug(f'Starting, pid {os.getpid()}')
             logger.debug(f'Starting, pid {os.getpid()}')
-            server = grpc.experimental.aio.server(options=[
+            server = grpc.aio.server(options=[
                 ('grpc.so_reuseport', 1),
                 ('grpc.so_reuseport', 1),
                 ('grpc.max_send_message_length', -1),
                 ('grpc.max_send_message_length', -1),
                 ('grpc.max_receive_message_length', -1)
                 ('grpc.max_receive_message_length', -1)

+ 5 - 6
hivemind/server/task_pool.py

@@ -161,10 +161,8 @@ class TaskPool(TaskPoolBase):
 
 
             logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
             logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
             # find or create shared arrays for current batch size
             # find or create shared arrays for current batch size
-            batch_inputs = [
-                torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
-                for i in range(len(batch_tasks[0].args))
-            ]
+            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
+            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
 
 
             logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
             logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
             self.batch_sender.send((batch_index, batch_inputs))
             self.batch_sender.send((batch_index, batch_inputs))
@@ -187,7 +185,7 @@ class TaskPool(TaskPoolBase):
             # split batch into partitions for individual tasks
             # split batch into partitions for individual tasks
             batch_tasks = pending_batches.pop(batch_index)
             batch_tasks = pending_batches.pop(batch_index)
             task_sizes = [self.get_task_size(task) for task in batch_tasks]
             task_sizes = [self.get_task_size(task) for task in batch_tasks]
-            outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
+            outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
             logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
             logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
 
 
             # dispatch results to futures
             # dispatch results to futures
@@ -209,7 +207,8 @@ class TaskPool(TaskPoolBase):
 
 
     def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
     def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
         """ send results for a processed batch, previously loaded through load_batch_to_runtime """
         """ send results for a processed batch, previously loaded through load_batch_to_runtime """
-        batch_outputs = [tensor.to(device='cpu').share_memory_() for tensor in batch_outputs]
+        batch_outputs = [tensor.to(device='cpu').share_memory_().detach().requires_grad_(tensor.requires_grad)
+                         for tensor in batch_outputs]
         self.outputs_sender.send((batch_index, batch_outputs))
         self.outputs_sender.send((batch_index, batch_outputs))
 
 
     def get_task_size(self, task: Task) -> int:
     def get_task_size(self, task: Task) -> int:

+ 3 - 1
hivemind/utils/grpc.py

@@ -46,7 +46,7 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             size=tensor.shape,
             size=tensor.shape,
             dtype='clamped_float32',
             dtype='clamped_float32',
             requires_grad=tensor.requires_grad)
             requires_grad=tensor.requires_grad)
-    else:
+    elif compression_type == CompressionType.NONE:
         array = tensor.numpy()
         array = tensor.numpy()
         proto = runtime_pb2.Tensor(
         proto = runtime_pb2.Tensor(
             compression=compression_type,
             compression=compression_type,
@@ -54,6 +54,8 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             size=array.shape,
             size=array.shape,
             dtype=array.dtype.name,
             dtype=array.dtype.name,
             requires_grad=tensor.requires_grad)
             requires_grad=tensor.requires_grad)
+    else:
+        raise ValueError(f"Unknown compression type: {compression_type}")
 
 
     return proto
     return proto
 
 

+ 2 - 2
requirements.txt

@@ -5,6 +5,6 @@ prefetch_generator>=1.0.1
 msgpack>=0.5.6
 msgpack>=0.5.6
 sortedcontainers
 sortedcontainers
 uvloop>=0.14.0
 uvloop>=0.14.0
-grpcio>=1.31
-grpcio-tools>=1.30.0
+grpcio>=1.33.2
+grpcio-tools>=1.33.2
 configargparse>=1.2.3
 configargparse>=1.2.3

+ 4 - 3
tests/benchmark_dht.py

@@ -1,13 +1,14 @@
 import argparse
 import argparse
 import random
 import random
 import time
 import time
-from warnings import warn
 
 
 from tqdm import trange
 from tqdm import trange
 
 
 import hivemind
 import hivemind
 from hivemind.utils.threading import increase_file_limit
 from hivemind.utils.threading import increase_file_limit
 
 
+logger = hivemind.get_logger(__file__)
+
 
 
 def random_endpoint() -> hivemind.Endpoint:
 def random_endpoint() -> hivemind.Endpoint:
     return f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}." \
     return f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}." \
@@ -53,7 +54,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     time.sleep(wait_before_read)
     time.sleep(wait_before_read)
 
 
     if time.perf_counter() - benchmark_started > expiration:
     if time.perf_counter() - benchmark_started > expiration:
-        warn("Warning: all keys expired before benchmark started getting them. Consider increasing expiration_time")
+        logger.warning("All keys expired before benchmark started getting them. Consider increasing expiration_time")
 
 
     successful_gets = total_get_time = 0
     successful_gets = total_get_time = 0
 
 
@@ -68,7 +69,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
                 successful_gets += 1
                 successful_gets += 1
 
 
     if time.perf_counter() - benchmark_started > expiration:
     if time.perf_counter() - benchmark_started > expiration:
-        warn("Warning: keys expired midway during get requests. If that is not desired, increase expiration_time param")
+        logger.warning("keys expired midway during get requests. If that isn't desired, increase expiration_time param")
 
 
     print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
     print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
     print(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
     print(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")

+ 12 - 15
tests/test_dht_experts.py

@@ -8,6 +8,7 @@ import hivemind
 from hivemind import LOCALHOST, UidEndpoint
 from hivemind import LOCALHOST, UidEndpoint
 
 
 
 
+@pytest.mark.forked
 def test_store_get_experts():
 def test_store_get_experts():
     peers = [hivemind.DHT(start=True)]
     peers = [hivemind.DHT(start=True)]
     for i in range(10):
     for i in range(10):
@@ -36,6 +37,7 @@ def test_store_get_experts():
         peer.shutdown()
         peer.shutdown()
 
 
 
 
+@pytest.mark.forked
 def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
 def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
                      grid_dims=(32, 32, 32)):
                      grid_dims=(32, 32, 32)):
     dht = []
     dht = []
@@ -68,6 +70,7 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
         assert all(len(experts) == beam_size for experts in batch_experts)
         assert all(len(experts) == beam_size for experts in batch_experts)
 
 
 
 
+@pytest.mark.forked
 def test_dht_single_node():
 def test_dht_single_node():
     node = hivemind.DHT(start=True, expiration=999)
     node = hivemind.DHT(start=True, expiration=999)
 
 
@@ -132,8 +135,9 @@ def test_uid_patterns():
         assert not hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
         assert not hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
 
 
 
 
-def test_negative_caching():
-    test_success = mp.Event()
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_negative_caching():
     peers = []
     peers = []
     for i in range(10):
     for i in range(10):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
@@ -148,16 +152,9 @@ def test_negative_caching():
     # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
     # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
     assert len(neg_caching_peer.get_initial_beam(prefix='ffn.', scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
     assert len(neg_caching_peer.get_initial_beam(prefix='ffn.', scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
 
 
-    async def _tester():
-        node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
-        fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
-        for i in range(6):
-            assert fetched[i] is not None, f"node should have cached ffn.{i}."
-        for i in range(6, len(fetched)):
-            assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
-        test_success.set()
-
-    proc = mp.Process(target=lambda: asyncio.run(_tester()))
-    proc.start()
-    proc.join()
-    assert test_success.is_set()
+    node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
+    fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
+    for i in range(6):
+        assert fetched[i] is not None, f"node should have cached ffn.{i}."
+    for i in range(6, len(fetched)):
+        assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."

+ 304 - 349
tests/test_dht_node.py

@@ -4,12 +4,13 @@ import random
 import heapq
 import heapq
 from typing import Optional
 from typing import Optional
 import numpy as np
 import numpy as np
+import pytest
 
 
 import hivemind
 import hivemind
 from typing import List, Dict
 from typing import List, Dict
 
 
 from hivemind import get_dht_time
 from hivemind import get_dht_time
-from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
+from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.storage import DictionaryDHTValue
 
 
@@ -29,6 +30,11 @@ def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
 
 
 
 
+# note: we run grpc-related tests in a separate process to re-initialize all global states from scratch
+# this helps us avoid undesirable side-effects (e.g. segfaults) when running multiple tests in sequence
+
+
+@pytest.mark.forked
 def test_dht_protocol():
 def test_dht_protocol():
     # create the first peer
     # create the first peer
     peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
     peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
@@ -41,121 +47,100 @@ def test_dht_protocol():
                             kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
                             kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
     peer2_proc.start(), peer2_started.wait()
     peer2_proc.start(), peer2_started.wait()
 
 
-    test_success = mp.Event()
-
-    def _tester():
-        # note: we run everything in a separate process to re-initialize all global states from scratch
-        # this helps us avoid undesirable side-effects when running multiple tests in sequence
-
-        loop = asyncio.get_event_loop()
-        for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
-            protocol = loop.run_until_complete(DHTProtocol.create(
-                DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
-            print(f"Self id={protocol.node_id}", flush=True)
-
-            assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
-
-            key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
-            store_ok = loop.run_until_complete(protocol.call_store(
-                f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-            )
-            assert all(store_ok), "DHT rejected a trivial store"
-
-            # peer 1 must know about peer 2
-            (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
-            recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-            (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-            assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
-                f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
-
-            assert recv_value == value and recv_expiration == expiration, \
-                f"call_find_value expected {value} (expires by {expiration}) " \
-                f"but got {recv_value} (expires by {recv_expiration})"
-
-            # peer 2 must know about peer 1, but not have a *random* nonexistent value
-            dummy_key = DHTID.generate()
-            empty_item, nodes_found_2 = loop.run_until_complete(
-                protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
-            assert empty_item is None, "Non-existent keys shouldn't have values"
-            (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
-            assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
-                f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
-
-            # cause a non-response by querying a nonexistent peer
-            dummy_port = hivemind.find_open_port()
-            assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
-
-            # store/get a dictionary with sub-keys
-            nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
-            value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
-            assert loop.run_until_complete(protocol.call_store(
-                f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
-                expiration_time=[expiration], subkeys=[subkey1])
-            )
-            assert loop.run_until_complete(protocol.call_store(
-                f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
-                expiration_time=[expiration + 5], subkeys=[subkey2])
-            )
-            (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
-            assert isinstance(recv_dict, DictionaryDHTValue)
-            assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
-            assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
-            assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
-
-            if listen:
-                loop.run_until_complete(protocol.shutdown())
-            print("DHTProtocol test finished successfully!")
-            test_success.set()
-
-    tester = mp.Process(target=_tester, daemon=True)
-    tester.start()
-    tester.join()
-    assert test_success.is_set()
+    loop = asyncio.get_event_loop()
+    for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
+        protocol = loop.run_until_complete(DHTProtocol.create(
+            DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
+        print(f"Self id={protocol.node_id}", flush=True)
+
+        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
+
+        key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+        store_ok = loop.run_until_complete(protocol.call_store(
+            f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        )
+        assert all(store_ok), "DHT rejected a trivial store"
+
+        # peer 1 must know about peer 2
+        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
+        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+        (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
+        assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
+            f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
+
+        assert recv_value == value and recv_expiration == expiration, \
+            f"call_find_value expected {value} (expires by {expiration}) " \
+            f"but got {recv_value} (expires by {recv_expiration})"
+
+        # peer 2 must know about peer 1, but not have a *random* nonexistent value
+        dummy_key = DHTID.generate()
+        empty_item, nodes_found_2 = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
+        assert empty_item is None, "Non-existent keys shouldn't have values"
+        (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
+        assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
+            f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
+
+        # cause a non-response by querying a nonexistent peer
+        dummy_port = hivemind.find_open_port()
+        assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
+
+        # store/get a dictionary with sub-keys
+        nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
+        value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
+        assert loop.run_until_complete(protocol.call_store(
+            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
+            expiration_time=[expiration], subkeys=[subkey1])
+        )
+        assert loop.run_until_complete(protocol.call_store(
+            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
+            expiration_time=[expiration + 5], subkeys=[subkey2])
+        )
+        (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
+        assert isinstance(recv_dict, DictionaryDHTValue)
+        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
+        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
+        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
+
+        if listen:
+            loop.run_until_complete(protocol.shutdown())
+        print("DHTProtocol test finished successfully!")
+
     peer1_proc.terminate()
     peer1_proc.terminate()
     peer2_proc.terminate()
     peer2_proc.terminate()
 
 
 
 
+@pytest.mark.forked
 def test_empty_table():
 def test_empty_table():
     """ Test RPC methods with empty routing table """
     """ Test RPC methods with empty routing table """
     peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
     peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
     peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
     peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
     peer_proc.start(), peer_started.wait()
     peer_proc.start(), peer_started.wait()
-    test_success = mp.Event()
 
 
-    def _tester():
-        # note: we run everything in a separate process to re-initialize all global states from scratch
-        # this helps us avoid undesirable side-effects when running multiple tests in sequence
-
-        loop = asyncio.get_event_loop()
-        protocol = loop.run_until_complete(DHTProtocol.create(
-            DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
-
-        key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
-
-        empty_item, nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
-        assert empty_item is None and len(nodes_found) == 0
-        assert all(loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )), "peer rejected store"
-
-        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
-        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-        assert len(nodes_found) == 0
-        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-            f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
-
-        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
-        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
-        test_success.set()
-
-    tester = mp.Process(target=_tester, daemon=True)
-    tester.start()
-    tester.join()
-    assert test_success.is_set()
+    loop = asyncio.get_event_loop()
+    protocol = loop.run_until_complete(DHTProtocol.create(
+        DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
+
+    key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+
+    empty_item, nodes_found = loop.run_until_complete(
+        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+    assert empty_item is None and len(nodes_found) == 0
+    assert all(loop.run_until_complete(protocol.call_store(
+        f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+    )), "peer rejected store"
+
+    (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
+        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+    assert len(nodes_found) == 0
+    assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
+        f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+
+    assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
+    assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
     peer_proc.terminate()
     peer_proc.terminate()
 
 
 
 
@@ -170,6 +155,7 @@ def run_node(node_id, peers, status_pipe: mp.Pipe):
         loop.run_forever()
         loop.run_forever()
 
 
 
 
+@pytest.mark.forked
 def test_dht_node():
 def test_dht_node():
     # create dht with 50 nodes + your 51-st node
     # create dht with 50 nodes + your 51-st node
     dht: Dict[Endpoint, DHTID] = {}
     dht: Dict[Endpoint, DHTID] = {}
@@ -185,254 +171,223 @@ def test_dht_node():
         processes.append(proc)
         processes.append(proc)
         dht[f"{LOCALHOST}:{port}"] = node_id
         dht[f"{LOCALHOST}:{port}"] = node_id
 
 
-    test_success = mp.Event()
-
-    def _tester():
-        # note: we run everything in a separate process to re-initialize all global states from scratch
-        # this helps us avoid undesirable side-effects when running multiple tests in sequence
-        loop = asyncio.get_event_loop()
-        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
-                                                    cache_refresh_before_expiry=False))
-
-        # test 1: find self
-        nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
-        assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
-
-        # test 2: find others
-        for i in range(10):
-            ref_endpoint, query_id = random.choice(list(dht.items()))
-            nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
-            assert len(nearest) == 1
-            found_node_id, found_endpoint = next(iter(nearest.items()))
-            assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
-
-        # test 3: find neighbors to random nodes
-        accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
-        jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
-        all_node_ids = list(dht.values())
-
-        for i in range(100):
-            query_id = DHTID.generate()
-            k_nearest = random.randint(1, 20)
-            exclude_self = random.random() > 0.5
-            nearest = loop.run_until_complete(
-                me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
-            nearest_nodes = list(nearest)  # keys from ordered dict
-
-            assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
-            assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
-            assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
-
-            ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
-            if exclude_self and me.node_id in ref_nearest:
-                ref_nearest.remove(me.node_id)
-            if len(ref_nearest) > k_nearest:
-                ref_nearest.pop()
-
-            accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
-            accuracy_denominator += 1
-
-            jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
-            jaccard_denominator += k_nearest
-
-        accuracy = accuracy_numerator / accuracy_denominator
-        print("Top-1 accuracy:", accuracy)  # should be 98-100%
-        jaccard_index = jaccard_numerator / jaccard_denominator
-        print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
-        assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-        assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-
-        # test 4: find all nodes
-        dummy = DHTID.generate()
-        nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
-        assert len(nearest) == len(dht) + 1
-        assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
-
-        # test 5: node without peers
-        detached_node = loop.run_until_complete(DHTNode.create())
-        nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
-        assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
-        nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
-        assert len(nearest) == 0
-
-        # test 6 store and get value
-        true_time = get_dht_time() + 1200
-        assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
-        that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
-                                                          cache_refresh_before_expiry=False, cache_locally=False))
-
-        for node in [me, that_guy]:
-            val, expiration_time = loop.run_until_complete(node.get("mykey"))
-            assert val == ["Value", 10], "Wrong value"
-            assert expiration_time == true_time, f"Wrong time"
-
-        assert loop.run_until_complete(detached_node.get("mykey")) is None
-
-        # test 7: bulk store and bulk get
-        keys = 'foo', 'bar', 'baz', 'zzz'
-        values = 3, 2, 'batman', [1, 2, 3]
-        store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
-        assert all(store_ok.values()), "failed to store one or more keys"
-        response = loop.run_until_complete(me.get_many(keys[::-1]))
-        for key, value in zip(keys, values):
-            assert key in response and response[key][0] == value
-
-        # test 8: store dictionaries as values (with sub-keys)
-        upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
-        now = get_dht_time()
-        assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
-        assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
-        for node in [that_guy, me]:
-            value, time = loop.run_until_complete(node.get(upper_key))
-            assert isinstance(value, dict) and time == now + 20
-            assert value[subkey1] == (123, now + 10)
-            assert value[subkey2] == (456, now + 20)
-            assert len(value) == 2
-
-        assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
-        assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
-        assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
-        loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
-
-        for node in [that_guy, me]:
-            value, time = loop.run_until_complete(node.get(upper_key))
-            assert isinstance(value, dict) and time == now + 50, (value, time)
-            assert value[subkey1] == (123, now + 10)
-            assert value[subkey2] == (567, now + 30)
-            assert value[subkey3] == (890, now + 50)
-            assert len(value) == 3
-
-        test_success.set()
-
-    tester = mp.Process(target=_tester, daemon=True)
-    tester.start()
-    tester.join()
-    assert test_success.is_set()
+    loop = asyncio.get_event_loop()
+    me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
+                                                cache_refresh_before_expiry=False))
+
+    # test 1: find self
+    nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
+    assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
+
+    # test 2: find others
+    for i in range(10):
+        ref_endpoint, query_id = random.choice(list(dht.items()))
+        nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
+        assert len(nearest) == 1
+        found_node_id, found_endpoint = next(iter(nearest.items()))
+        assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
+
+    # test 3: find neighbors to random nodes
+    accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
+    jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
+    all_node_ids = list(dht.values())
+
+    for i in range(100):
+        query_id = DHTID.generate()
+        k_nearest = random.randint(1, 20)
+        exclude_self = random.random() > 0.5
+        nearest = loop.run_until_complete(
+            me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
+        nearest_nodes = list(nearest)  # keys from ordered dict
+
+        assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
+        assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
+        assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
+
+        ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
+        if exclude_self and me.node_id in ref_nearest:
+            ref_nearest.remove(me.node_id)
+        if len(ref_nearest) > k_nearest:
+            ref_nearest.pop()
+
+        accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
+        accuracy_denominator += 1
+
+        jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
+        jaccard_denominator += k_nearest
+
+    accuracy = accuracy_numerator / accuracy_denominator
+    print("Top-1 accuracy:", accuracy)  # should be 98-100%
+    jaccard_index = jaccard_numerator / jaccard_denominator
+    print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
+    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+    assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+
+    # test 4: find all nodes
+    dummy = DHTID.generate()
+    nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
+    assert len(nearest) == len(dht) + 1
+    assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
+
+    # test 5: node without peers
+    detached_node = loop.run_until_complete(DHTNode.create())
+    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
+    assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
+    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
+    assert len(nearest) == 0
+
+    # test 6 store and get value
+    true_time = get_dht_time() + 1200
+    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
+    that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
+                                                      cache_refresh_before_expiry=False, cache_locally=False))
+
+    for node in [me, that_guy]:
+        val, expiration_time = loop.run_until_complete(node.get("mykey"))
+        assert val == ["Value", 10], "Wrong value"
+        assert expiration_time == true_time, f"Wrong time"
+
+    assert loop.run_until_complete(detached_node.get("mykey")) is None
+
+    # test 7: bulk store and bulk get
+    keys = 'foo', 'bar', 'baz', 'zzz'
+    values = 3, 2, 'batman', [1, 2, 3]
+    store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
+    assert all(store_ok.values()), "failed to store one or more keys"
+    response = loop.run_until_complete(me.get_many(keys[::-1]))
+    for key, value in zip(keys, values):
+        assert key in response and response[key][0] == value
+
+    # test 8: store dictionaries as values (with sub-keys)
+    upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
+    now = get_dht_time()
+    assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
+    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
+    for node in [that_guy, me]:
+        value, time = loop.run_until_complete(node.get(upper_key))
+        assert isinstance(value, dict) and time == now + 20
+        assert value[subkey1] == (123, now + 10)
+        assert value[subkey2] == (456, now + 20)
+        assert len(value) == 2
+
+    assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
+    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
+    assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
+    loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
+
+    for node in [that_guy, me]:
+        value, time = loop.run_until_complete(node.get(upper_key))
+        assert isinstance(value, dict) and time == now + 50, (value, time)
+        assert value[subkey1] == (123, now + 10)
+        assert value[subkey2] == (567, now + 30)
+        assert value[subkey3] == (890, now + 50)
+        assert len(value) == 3
+
     for proc in processes:
     for proc in processes:
         proc.terminate()
         proc.terminate()
 
 
 
 
-def test_dhtnode_replicas():
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_replicas():
     dht_size = 20
     dht_size = 20
     initial_peers = 3
     initial_peers = 3
     num_replicas = random.randint(1, 20)
     num_replicas = random.randint(1, 20)
-    test_success = mp.Event()
-
-    async def _tester():
-        peers = []
-        for i in range(dht_size):
-            neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
-            peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
-
-        you = random.choice(peers)
-        assert await you.store('key1', 'foo', get_dht_time() + 999)
-
-        actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
-        assert num_replicas == actual_key1_replicas
-
-        assert await you.store('key2', 'bar', get_dht_time() + 999)
-        total_size = sum(len(peer.protocol.storage) for peer in peers)
-        actual_key2_replicas = total_size - actual_key1_replicas
-        assert num_replicas == actual_key2_replicas
-
-        assert await you.store('key2', 'baz', get_dht_time() + 1000)
-        assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
-        test_success.set()
-
-    proc = mp.Process(target=lambda: asyncio.run(_tester()))
-    proc.start()
-    proc.join()
-    assert test_success.is_set()
-
-
-def test_dhtnode_caching(T=0.05):
-    test_success = mp.Event()
-
-    async def _tester():
-        node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
-        node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
-                                              cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
-        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
-        await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
-        await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
-        await node1.get_many(['k', 'k2', 'k3', 'k4'])
-        assert len(node1.protocol.cache) == 3
-        assert len(node1.cache_refresh_queue) == 0
-
-        await node1.get_many(['k', 'k2', 'k3', 'k4'])
-        assert len(node1.cache_refresh_queue) == 3
-
-        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
-        await asyncio.sleep(4 * T)
-        await node1.get('k')
-        await asyncio.sleep(1 * T)
-
-        assert len(node1.protocol.cache) == 3
-        assert len(node1.cache_refresh_queue) == 2
-        await asyncio.sleep(3 * T)
-
-        assert len(node1.cache_refresh_queue) == 1
-
-        await asyncio.sleep(5 * T)
-        assert len(node1.cache_refresh_queue) == 0
-        await asyncio.sleep(5 * T)
-        assert len(node1.cache_refresh_queue) == 0
-
-        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
-        await node1.get('k')
-        await asyncio.sleep(1 * T)
-        assert len(node1.cache_refresh_queue) == 0
-        await node1.get('k')
-        await asyncio.sleep(1 * T)
-        assert len(node1.cache_refresh_queue) == 1
-
-        await asyncio.sleep(5 * T)
-        assert len(node1.cache_refresh_queue) == 0
-
-        await asyncio.gather(node1.shutdown(), node2.shutdown())
-        test_success.set()
-
-    proc = mp.Process(target=lambda: asyncio.run(_tester()))
-    proc.start()
-    proc.join()
-    assert test_success.is_set()
-
-
-def test_dhtnode_reuse_get():
-    test_success = mp.Event()
-
-    async def _tester():
-        peers = []
-        for i in range(10):
-            neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-            peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
-
-        await asyncio.gather(
-            random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
-            random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
-        )
-
-        you = random.choice(peers)
-
-        futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
-        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
-        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
-
-        futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
-        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
-
-        await asyncio.gather(*futures1.values(), *futures2.values())
-        futures3 = await you.get_many(['k3'], return_futures=True)
-        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
-        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
-        assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
-
-        assert (await futures1['k1'])[0] == 123
-        assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
-        assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
-        test_success.set()
 
 
-    proc = mp.Process(target=lambda: asyncio.run(_tester()))
-    proc.start()
-    proc.join()
-    assert test_success.is_set()
+    peers = []
+    for i in range(dht_size):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
+        peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
+
+    you = random.choice(peers)
+    assert await you.store('key1', 'foo', get_dht_time() + 999)
+
+    actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
+    assert num_replicas == actual_key1_replicas
+
+    assert await you.store('key2', 'bar', get_dht_time() + 999)
+    total_size = sum(len(peer.protocol.storage) for peer in peers)
+    actual_key2_replicas = total_size - actual_key1_replicas
+    assert num_replicas == actual_key2_replicas
+
+    assert await you.store('key2', 'baz', get_dht_time() + 1000)
+    assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_caching(T=0.05):
+    node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
+    node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
+                                          cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
+    await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
+    await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
+    await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
+    await node1.get_many(['k', 'k2', 'k3', 'k4'])
+    assert len(node1.protocol.cache) == 3
+    assert len(node1.cache_refresh_queue) == 0
+
+    await node1.get_many(['k', 'k2', 'k3', 'k4'])
+    assert len(node1.cache_refresh_queue) == 3
+
+    await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
+    await asyncio.sleep(4 * T)
+    await node1.get('k')
+    await asyncio.sleep(1 * T)
+
+    assert len(node1.protocol.cache) == 3
+    assert len(node1.cache_refresh_queue) == 2
+    await asyncio.sleep(3 * T)
+
+    assert len(node1.cache_refresh_queue) == 1
+
+    await asyncio.sleep(5 * T)
+    assert len(node1.cache_refresh_queue) == 0
+    await asyncio.sleep(5 * T)
+    assert len(node1.cache_refresh_queue) == 0
+
+    await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
+    await node1.get('k')
+    await asyncio.sleep(1 * T)
+    assert len(node1.cache_refresh_queue) == 0
+    await node1.get('k')
+    await asyncio.sleep(1 * T)
+    assert len(node1.cache_refresh_queue) == 1
+
+    await asyncio.sleep(5 * T)
+    assert len(node1.cache_refresh_queue) == 0
+
+    await asyncio.gather(node1.shutdown(), node2.shutdown())
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_reuse_get():
+    peers = []
+    for i in range(10):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
+
+    await asyncio.gather(
+        random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
+        random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
+    )
+
+    you = random.choice(peers)
+
+    futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
+    assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
+    assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
+
+    futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
+    assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
+
+    await asyncio.gather(*futures1.values(), *futures2.values())
+    futures3 = await you.get_many(['k3'], return_futures=True)
+    assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
+    assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
+    assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
+
+    assert (await futures1['k1'])[0] == 123
+    assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
+    assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None

+ 6 - 0
tests/test_moe.py

@@ -7,6 +7,7 @@ from hivemind.client.expert import DUMMY
 from hivemind import background_server
 from hivemind import background_server
 
 
 
 
+@pytest.mark.forked
 def test_moe():
 def test_moe():
     all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
     all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
                        for _ in range(20)]
                        for _ in range(20)]
@@ -22,6 +23,7 @@ def test_moe():
             out.sum().backward()
             out.sum().backward()
 
 
 
 
+@pytest.mark.forked
 def test_call_many():
 def test_call_many():
     k_min = 1
     k_min = 1
     timeout_after_k_min = None
     timeout_after_k_min = None
@@ -71,6 +73,7 @@ def test_call_many():
         assert torch.allclose(our_grad, reference_grad, rtol, atol)
         assert torch.allclose(our_grad, reference_grad, rtol, atol)
 
 
 
 
+@pytest.mark.forked
 def test_remote_module_call():
 def test_remote_module_call():
     with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=1024,
     with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=1024,
                            optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
                            optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
@@ -93,6 +96,7 @@ def test_remote_module_call():
             fake_expert(dummy_x)
             fake_expert(dummy_x)
 
 
 
 
+@pytest.mark.forked
 def test_beam_search_correctness():
 def test_beam_search_correctness():
     all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
     all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
     dht = hivemind.DHT(start=True, expiration=999)
     dht = hivemind.DHT(start=True, expiration=999)
@@ -119,6 +123,7 @@ def test_beam_search_correctness():
         assert np.allclose(true_best_scores, our_best_scores)
         assert np.allclose(true_best_scores, our_best_scores)
 
 
 
 
+@pytest.mark.forked
 def test_determinism():
 def test_determinism():
     rtol = 0
     rtol = 0
     atol = 1e-5
     atol = 1e-5
@@ -140,6 +145,7 @@ def test_determinism():
     assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
     assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
 
 
 
 
+@pytest.mark.forked
 def test_compute_expert_scores():
 def test_compute_expert_scores():
     try:
     try:
         dht = hivemind.DHT(start=True)
         dht = hivemind.DHT(start=True)

+ 0 - 4
tests/test_routing.py

@@ -66,10 +66,6 @@ def test_routing_table_basic():
         assert bucket_index == found_bucket_index
         assert bucket_index == found_bucket_index
 
 
 
 
-
-
-
-
 def test_routing_table_parameters():
 def test_routing_table_parameters():
     for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
     for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
         (20,          5,      45,           65),
         (20,          5,      45,           65),

+ 2 - 0
tests/test_training.py

@@ -1,6 +1,7 @@
 from functools import partial
 from functools import partial
 from typing import Optional
 from typing import Optional
 
 
+import pytest
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
@@ -9,6 +10,7 @@ from sklearn.datasets import load_digits
 from hivemind import RemoteExpert, background_server
 from hivemind import RemoteExpert, background_server
 
 
 
 
+@pytest.mark.forked
 def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
 def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
     dataset = load_digits()
     dataset = load_digits()
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])

+ 38 - 40
tests/test_util_modules.py

@@ -78,46 +78,44 @@ def test_mpfuture_status():
         assert future.set_running_or_notify_cancel() is False
         assert future.set_running_or_notify_cancel() is False
 
 
 
 
-def test_await_mpfuture():
-    async def _run():
-        # await result
-        f1, f2 = hivemind.MPFuture.make_pair()
-
-        async def wait_and_assign():
-            assert f2.set_running_or_notify_cancel() is True
-            await asyncio.sleep(0.1)
-            f2.set_result((123, 'ololo'))
-
-        asyncio.create_task(wait_and_assign())
-        for future in [f1, f2]:
-            res = await future
-            assert res == (123, 'ololo')
-
-        # await cancel
-        f1, f2 = hivemind.MPFuture.make_pair()
-
-        async def wait_and_cancel():
-            await asyncio.sleep(0.1)
-            f1.cancel()
-
-        asyncio.create_task(wait_and_cancel())
-        for future in [f1, f2]:
-            with pytest.raises(CancelledError):
-                await future
-
-        # await exception
-        f1, f2 = hivemind.MPFuture.make_pair()
-
-        async def wait_and_raise():
-            await asyncio.sleep(0.1)
-            f1.set_exception(SystemError())
-
-        asyncio.create_task(wait_and_raise())
-        for future in [f1, f2]:
-            with pytest.raises(SystemError):
-                await future
-
-    asyncio.new_event_loop().run_until_complete(_run())
+@pytest.mark.asyncio
+async def test_await_mpfuture():
+    # await result
+    f1, f2 = hivemind.MPFuture.make_pair()
+
+    async def wait_and_assign():
+        assert f2.set_running_or_notify_cancel() is True
+        await asyncio.sleep(0.1)
+        f2.set_result((123, 'ololo'))
+
+    asyncio.create_task(wait_and_assign())
+    for future in [f1, f2]:
+        res = await future
+        assert res == (123, 'ololo')
+
+    # await cancel
+    f1, f2 = hivemind.MPFuture.make_pair()
+
+    async def wait_and_cancel():
+        await asyncio.sleep(0.1)
+        f1.cancel()
+
+    asyncio.create_task(wait_and_cancel())
+    for future in [f1, f2]:
+        with pytest.raises(CancelledError):
+            await future
+
+    # await exception
+    f1, f2 = hivemind.MPFuture.make_pair()
+
+    async def wait_and_raise():
+        await asyncio.sleep(0.1)
+        f1.set_exception(SystemError())
+
+    asyncio.create_task(wait_and_raise())
+    for future in [f1, f2]:
+        with pytest.raises(SystemError):
+            await future
 
 
 
 
 def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
 def test_vector_compression(size=(128, 128, 64), alpha=5e-08):