瀏覽代碼

added torch1.7 support, switch to grpc 1.33, grpc bump, improved tests & logging, (#116)

* specify force_reinstall to override the circleci cache

* switch from warnings.warn to logger.warning

* check for compression even if compression_type is None

* bump version

* require newer grpcio

* switch to asyncio testing

* use pytest.mark.forked/pytest.mark.asyncio where applicable

* detach before serializing

* set requires_grad manually

* switch from grpc.experimental.aio to grpc.aio
justheuristic 4 年之前
父節點
當前提交
2bd481c73f

+ 4 - 3
.circleci/config.yml

@@ -11,12 +11,13 @@ jobs:
     steps:
       - checkout
       - python/load-cache
-      - run: pip uninstall -y pytest codecov  # temporary override for broken cache
-      - run: pip install codecov pytest tqdm scikit-learn
+      - run: pip uninstall -y pytest codecov
+      # note: uninstall is required because otherwise circleci cache will lose track of pytest/codecov executables
+      - run: pip install pytest pytest-forked pytest-asyncio codecov tqdm scikit-learn
       - python/install-deps
       - python/save-cache
       - run:
-          command: pip install -e .
+          command: pip install --force-reinstall -e .
           name: setup
       - run:
           command: pytest ./tests

+ 2 - 1
docs/user/contributing.md

@@ -28,9 +28,10 @@ cd hivemind
 python setup.py develop
 ``` 
 
-To run tests, you will also need to `pip install pytest codecov tqdm scikit-learn`.
+To run tests, you will also need to `pip install pytest pytest-forked pytest-asyncio codecov tqdm scikit-learn`.
 You can run all tests with `pytest ./tests` or choose a specific set, e.g. `pytest ./tests/test_dht.py`.
 
+
 To build docs locally,
 1. `pip install sphinx sphinx_rtd_theme recommonmark`
 2. make sure you ran setup.py (see above)

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.9'
+__version__ = '0.8.10'

+ 0 - 1
hivemind/client/expert.py

@@ -3,7 +3,6 @@ from functools import lru_cache
 from typing import Tuple, Optional, Any, Dict
 
 import grpc
-import grpc.experimental.aio
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable

+ 2 - 3
hivemind/dht/__init__.py

@@ -17,7 +17,6 @@ import ctypes
 import heapq
 import multiprocessing as mp
 import re
-import warnings
 from collections import deque
 from concurrent.futures import ThreadPoolExecutor
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
@@ -163,11 +162,11 @@ class DHT(mp.Process):
             raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
 
     def shutdown(self) -> None:
-        """ Shuts down the dht process """
+        """ Shut down a running dht process """
         if self.is_alive():
             self.terminate()
         else:
-            warnings.warn("DHT shutdown has no effect: dht process is already not alive")
+            logger.warning("DHT shutdown has no effect: dht process is already not alive")
 
     @property
     def port(self) -> Optional[int]:

+ 3 - 4
hivemind/dht/node.py

@@ -6,7 +6,6 @@ from collections import defaultdict
 from dataclasses import dataclass, field
 from functools import partial
 from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
-from warnings import warn
 
 from sortedcontainers import SortedList
 
@@ -36,7 +35,7 @@ class DHTNode:
 
     * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
     * store - send several (key, value, expiration_time) pairs to the same peer (like Kademlia STORE, but in bulk)
-    * find - request one or several keys, get values & expiration (if peer finds it locally) and :bucket_size: of
+    * find - request one or several keys, get values and expiration (if peer finds it locally) and :bucket_size: of
         nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
         This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
 
@@ -146,7 +145,7 @@ class DHTNode:
                 finished_pings |= finished_in_time
 
             if not finished_pings:
-                warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
+                logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
@@ -189,7 +188,7 @@ class DHTNode:
         num_workers = num_workers if num_workers is not None else self.num_workers
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
         if k_nearest > beam_size:
-            warn("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
+            logger.warning("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
         if node_to_endpoint is None:
             node_to_endpoint: Dict[DHTID, Endpoint] = dict()
             for query in queries:

+ 10 - 12
hivemind/dht/protocol.py

@@ -3,10 +3,8 @@ from __future__ import annotations
 
 import asyncio
 from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
-from warnings import warn
 
 import grpc
-import grpc.experimental.aio
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue, ValueWithExpiration
@@ -19,7 +17,7 @@ logger = get_logger(__name__)
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
-    channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
+    channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.aio.Server
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     # fmt:on
 
@@ -51,8 +49,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
 
         if listen:  # set up server to process incoming rpc requests
-            grpc.experimental.aio.init_grpc_aio()
-            self.server = grpc.experimental.aio.server(**kwargs)
+            grpc.aio.init_grpc_aio()
+            self.server = grpc.aio.server(**kwargs)
             dht_grpc.add_DHTServicer_to_server(self, self.server)
 
             found_port = self.server.add_insecure_port(listen_on)
@@ -64,8 +62,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
             # note: use empty node_info so peers wont add you to their routing tables
             self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
             if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
-                warn(f"DHTProtocol has no server (due to listen=False), listen_on"
-                     f"and kwargs have no effect (unused kwargs: {kwargs})")
+                logger.warning(f"DHTProtocol has no server (due to listen=False), listen_on"
+                               f"and kwargs have no effect (unused kwargs: {kwargs})")
         return self
 
     def __init__(self, *, _initialized_with_create=False):
@@ -78,11 +76,11 @@ class DHTProtocol(dht_grpc.DHTServicer):
         if self.server:
             await self.server.stop(timeout)
         else:
-            warn("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
+            logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
 
     def _get(self, peer: Endpoint) -> dht_grpc.DHTStub:
         """ get a DHTStub that sends requests to a given peer """
-        channel = grpc.experimental.aio.insecure_channel(peer, options=self.channel_options)
+        channel = grpc.aio.insecure_channel(peer, options=self.channel_options)
         return dht_grpc.DHTStub(channel)
 
     async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
@@ -96,7 +94,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         try:
             async with self.rpc_semaphore:
                 peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
-        except grpc.experimental.aio.AioRpcError as error:
+        except grpc.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             peer_info = None
         responded = bool(peer_info and peer_info.node_id)
@@ -162,7 +160,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
             return response.store_ok
-        except grpc.experimental.aio.AioRpcError as error:
+        except grpc.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             return None
@@ -226,7 +224,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                     logger.error(f"Unknown result type: {result.type}")
 
             return output
-        except grpc.experimental.aio.AioRpcError as error:
+        except grpc.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 

+ 3 - 3
hivemind/server/connection_handler.py

@@ -4,7 +4,7 @@ import os
 import pickle
 from typing import Dict
 
-import grpc.experimental.aio
+import grpc
 import torch
 import uvloop
 
@@ -35,9 +35,9 @@ class ConnectionHandler(mp.Process):
         loop = asyncio.new_event_loop()
 
         async def _run():
-            grpc.experimental.aio.init_grpc_aio()
+            grpc.aio.init_grpc_aio()
             logger.debug(f'Starting, pid {os.getpid()}')
-            server = grpc.experimental.aio.server(options=[
+            server = grpc.aio.server(options=[
                 ('grpc.so_reuseport', 1),
                 ('grpc.max_send_message_length', -1),
                 ('grpc.max_receive_message_length', -1)

+ 5 - 6
hivemind/server/task_pool.py

@@ -161,10 +161,8 @@ class TaskPool(TaskPoolBase):
 
             logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
             # find or create shared arrays for current batch size
-            batch_inputs = [
-                torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
-                for i in range(len(batch_tasks[0].args))
-            ]
+            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
+            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
 
             logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
             self.batch_sender.send((batch_index, batch_inputs))
@@ -187,7 +185,7 @@ class TaskPool(TaskPoolBase):
             # split batch into partitions for individual tasks
             batch_tasks = pending_batches.pop(batch_index)
             task_sizes = [self.get_task_size(task) for task in batch_tasks]
-            outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
+            outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
             logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
 
             # dispatch results to futures
@@ -209,7 +207,8 @@ class TaskPool(TaskPoolBase):
 
     def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
         """ send results for a processed batch, previously loaded through load_batch_to_runtime """
-        batch_outputs = [tensor.to(device='cpu').share_memory_() for tensor in batch_outputs]
+        batch_outputs = [tensor.to(device='cpu').share_memory_().detach().requires_grad_(tensor.requires_grad)
+                         for tensor in batch_outputs]
         self.outputs_sender.send((batch_index, batch_outputs))
 
     def get_task_size(self, task: Task) -> int:

+ 3 - 1
hivemind/utils/grpc.py

@@ -46,7 +46,7 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             size=tensor.shape,
             dtype='clamped_float32',
             requires_grad=tensor.requires_grad)
-    else:
+    elif compression_type == CompressionType.NONE:
         array = tensor.numpy()
         proto = runtime_pb2.Tensor(
             compression=compression_type,
@@ -54,6 +54,8 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             size=array.shape,
             dtype=array.dtype.name,
             requires_grad=tensor.requires_grad)
+    else:
+        raise ValueError(f"Unknown compression type: {compression_type}")
 
     return proto
 

+ 2 - 2
requirements.txt

@@ -5,6 +5,6 @@ prefetch_generator>=1.0.1
 msgpack>=0.5.6
 sortedcontainers
 uvloop>=0.14.0
-grpcio>=1.31
-grpcio-tools>=1.30.0
+grpcio>=1.33.2
+grpcio-tools>=1.33.2
 configargparse>=1.2.3

+ 4 - 3
tests/benchmark_dht.py

@@ -1,13 +1,14 @@
 import argparse
 import random
 import time
-from warnings import warn
 
 from tqdm import trange
 
 import hivemind
 from hivemind.utils.threading import increase_file_limit
 
+logger = hivemind.get_logger(__file__)
+
 
 def random_endpoint() -> hivemind.Endpoint:
     return f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}." \
@@ -53,7 +54,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     time.sleep(wait_before_read)
 
     if time.perf_counter() - benchmark_started > expiration:
-        warn("Warning: all keys expired before benchmark started getting them. Consider increasing expiration_time")
+        logger.warning("All keys expired before benchmark started getting them. Consider increasing expiration_time")
 
     successful_gets = total_get_time = 0
 
@@ -68,7 +69,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
                 successful_gets += 1
 
     if time.perf_counter() - benchmark_started > expiration:
-        warn("Warning: keys expired midway during get requests. If that is not desired, increase expiration_time param")
+        logger.warning("keys expired midway during get requests. If that isn't desired, increase expiration_time param")
 
     print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
     print(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")

+ 12 - 15
tests/test_dht_experts.py

@@ -8,6 +8,7 @@ import hivemind
 from hivemind import LOCALHOST, UidEndpoint
 
 
+@pytest.mark.forked
 def test_store_get_experts():
     peers = [hivemind.DHT(start=True)]
     for i in range(10):
@@ -36,6 +37,7 @@ def test_store_get_experts():
         peer.shutdown()
 
 
+@pytest.mark.forked
 def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
                      grid_dims=(32, 32, 32)):
     dht = []
@@ -68,6 +70,7 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
         assert all(len(experts) == beam_size for experts in batch_experts)
 
 
+@pytest.mark.forked
 def test_dht_single_node():
     node = hivemind.DHT(start=True, expiration=999)
 
@@ -132,8 +135,9 @@ def test_uid_patterns():
         assert not hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
 
 
-def test_negative_caching():
-    test_success = mp.Event()
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_negative_caching():
     peers = []
     for i in range(10):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
@@ -148,16 +152,9 @@ def test_negative_caching():
     # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
     assert len(neg_caching_peer.get_initial_beam(prefix='ffn.', scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
 
-    async def _tester():
-        node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
-        fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
-        for i in range(6):
-            assert fetched[i] is not None, f"node should have cached ffn.{i}."
-        for i in range(6, len(fetched)):
-            assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
-        test_success.set()
-
-    proc = mp.Process(target=lambda: asyncio.run(_tester()))
-    proc.start()
-    proc.join()
-    assert test_success.is_set()
+    node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
+    fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
+    for i in range(6):
+        assert fetched[i] is not None, f"node should have cached ffn.{i}."
+    for i in range(6, len(fetched)):
+        assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."

+ 304 - 349
tests/test_dht_node.py

@@ -4,12 +4,13 @@ import random
 import heapq
 from typing import Optional
 import numpy as np
+import pytest
 
 import hivemind
 from typing import List, Dict
 
 from hivemind import get_dht_time
-from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
+from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.storage import DictionaryDHTValue
 
@@ -29,6 +30,11 @@ def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
 
 
+# note: we run grpc-related tests in a separate process to re-initialize all global states from scratch
+# this helps us avoid undesirable side-effects (e.g. segfaults) when running multiple tests in sequence
+
+
+@pytest.mark.forked
 def test_dht_protocol():
     # create the first peer
     peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
@@ -41,121 +47,100 @@ def test_dht_protocol():
                             kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
     peer2_proc.start(), peer2_started.wait()
 
-    test_success = mp.Event()
-
-    def _tester():
-        # note: we run everything in a separate process to re-initialize all global states from scratch
-        # this helps us avoid undesirable side-effects when running multiple tests in sequence
-
-        loop = asyncio.get_event_loop()
-        for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
-            protocol = loop.run_until_complete(DHTProtocol.create(
-                DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
-            print(f"Self id={protocol.node_id}", flush=True)
-
-            assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
-
-            key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
-            store_ok = loop.run_until_complete(protocol.call_store(
-                f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-            )
-            assert all(store_ok), "DHT rejected a trivial store"
-
-            # peer 1 must know about peer 2
-            (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
-            recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-            (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-            assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
-                f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
-
-            assert recv_value == value and recv_expiration == expiration, \
-                f"call_find_value expected {value} (expires by {expiration}) " \
-                f"but got {recv_value} (expires by {recv_expiration})"
-
-            # peer 2 must know about peer 1, but not have a *random* nonexistent value
-            dummy_key = DHTID.generate()
-            empty_item, nodes_found_2 = loop.run_until_complete(
-                protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
-            assert empty_item is None, "Non-existent keys shouldn't have values"
-            (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
-            assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
-                f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
-
-            # cause a non-response by querying a nonexistent peer
-            dummy_port = hivemind.find_open_port()
-            assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
-
-            # store/get a dictionary with sub-keys
-            nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
-            value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
-            assert loop.run_until_complete(protocol.call_store(
-                f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
-                expiration_time=[expiration], subkeys=[subkey1])
-            )
-            assert loop.run_until_complete(protocol.call_store(
-                f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
-                expiration_time=[expiration + 5], subkeys=[subkey2])
-            )
-            (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
-            assert isinstance(recv_dict, DictionaryDHTValue)
-            assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
-            assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
-            assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
-
-            if listen:
-                loop.run_until_complete(protocol.shutdown())
-            print("DHTProtocol test finished successfully!")
-            test_success.set()
-
-    tester = mp.Process(target=_tester, daemon=True)
-    tester.start()
-    tester.join()
-    assert test_success.is_set()
+    loop = asyncio.get_event_loop()
+    for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
+        protocol = loop.run_until_complete(DHTProtocol.create(
+            DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
+        print(f"Self id={protocol.node_id}", flush=True)
+
+        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
+
+        key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+        store_ok = loop.run_until_complete(protocol.call_store(
+            f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        )
+        assert all(store_ok), "DHT rejected a trivial store"
+
+        # peer 1 must know about peer 2
+        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
+        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+        (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
+        assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
+            f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
+
+        assert recv_value == value and recv_expiration == expiration, \
+            f"call_find_value expected {value} (expires by {expiration}) " \
+            f"but got {recv_value} (expires by {recv_expiration})"
+
+        # peer 2 must know about peer 1, but not have a *random* nonexistent value
+        dummy_key = DHTID.generate()
+        empty_item, nodes_found_2 = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
+        assert empty_item is None, "Non-existent keys shouldn't have values"
+        (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
+        assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
+            f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
+
+        # cause a non-response by querying a nonexistent peer
+        dummy_port = hivemind.find_open_port()
+        assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
+
+        # store/get a dictionary with sub-keys
+        nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
+        value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
+        assert loop.run_until_complete(protocol.call_store(
+            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
+            expiration_time=[expiration], subkeys=[subkey1])
+        )
+        assert loop.run_until_complete(protocol.call_store(
+            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
+            expiration_time=[expiration + 5], subkeys=[subkey2])
+        )
+        (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
+            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
+        assert isinstance(recv_dict, DictionaryDHTValue)
+        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
+        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
+        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
+
+        if listen:
+            loop.run_until_complete(protocol.shutdown())
+        print("DHTProtocol test finished successfully!")
+
     peer1_proc.terminate()
     peer2_proc.terminate()
 
 
+@pytest.mark.forked
 def test_empty_table():
     """ Test RPC methods with empty routing table """
     peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
     peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
     peer_proc.start(), peer_started.wait()
-    test_success = mp.Event()
 
-    def _tester():
-        # note: we run everything in a separate process to re-initialize all global states from scratch
-        # this helps us avoid undesirable side-effects when running multiple tests in sequence
-
-        loop = asyncio.get_event_loop()
-        protocol = loop.run_until_complete(DHTProtocol.create(
-            DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
-
-        key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
-
-        empty_item, nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
-        assert empty_item is None and len(nodes_found) == 0
-        assert all(loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )), "peer rejected store"
-
-        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
-        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-        assert len(nodes_found) == 0
-        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-            f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
-
-        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
-        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
-        test_success.set()
-
-    tester = mp.Process(target=_tester, daemon=True)
-    tester.start()
-    tester.join()
-    assert test_success.is_set()
+    loop = asyncio.get_event_loop()
+    protocol = loop.run_until_complete(DHTProtocol.create(
+        DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
+
+    key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+
+    empty_item, nodes_found = loop.run_until_complete(
+        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+    assert empty_item is None and len(nodes_found) == 0
+    assert all(loop.run_until_complete(protocol.call_store(
+        f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+    )), "peer rejected store"
+
+    (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
+        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+    assert len(nodes_found) == 0
+    assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
+        f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+
+    assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
+    assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
     peer_proc.terminate()
 
 
@@ -170,6 +155,7 @@ def run_node(node_id, peers, status_pipe: mp.Pipe):
         loop.run_forever()
 
 
+@pytest.mark.forked
 def test_dht_node():
     # create dht with 50 nodes + your 51-st node
     dht: Dict[Endpoint, DHTID] = {}
@@ -185,254 +171,223 @@ def test_dht_node():
         processes.append(proc)
         dht[f"{LOCALHOST}:{port}"] = node_id
 
-    test_success = mp.Event()
-
-    def _tester():
-        # note: we run everything in a separate process to re-initialize all global states from scratch
-        # this helps us avoid undesirable side-effects when running multiple tests in sequence
-        loop = asyncio.get_event_loop()
-        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
-                                                    cache_refresh_before_expiry=False))
-
-        # test 1: find self
-        nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
-        assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
-
-        # test 2: find others
-        for i in range(10):
-            ref_endpoint, query_id = random.choice(list(dht.items()))
-            nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
-            assert len(nearest) == 1
-            found_node_id, found_endpoint = next(iter(nearest.items()))
-            assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
-
-        # test 3: find neighbors to random nodes
-        accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
-        jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
-        all_node_ids = list(dht.values())
-
-        for i in range(100):
-            query_id = DHTID.generate()
-            k_nearest = random.randint(1, 20)
-            exclude_self = random.random() > 0.5
-            nearest = loop.run_until_complete(
-                me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
-            nearest_nodes = list(nearest)  # keys from ordered dict
-
-            assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
-            assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
-            assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
-
-            ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
-            if exclude_self and me.node_id in ref_nearest:
-                ref_nearest.remove(me.node_id)
-            if len(ref_nearest) > k_nearest:
-                ref_nearest.pop()
-
-            accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
-            accuracy_denominator += 1
-
-            jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
-            jaccard_denominator += k_nearest
-
-        accuracy = accuracy_numerator / accuracy_denominator
-        print("Top-1 accuracy:", accuracy)  # should be 98-100%
-        jaccard_index = jaccard_numerator / jaccard_denominator
-        print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
-        assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-        assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-
-        # test 4: find all nodes
-        dummy = DHTID.generate()
-        nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
-        assert len(nearest) == len(dht) + 1
-        assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
-
-        # test 5: node without peers
-        detached_node = loop.run_until_complete(DHTNode.create())
-        nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
-        assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
-        nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
-        assert len(nearest) == 0
-
-        # test 6 store and get value
-        true_time = get_dht_time() + 1200
-        assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
-        that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
-                                                          cache_refresh_before_expiry=False, cache_locally=False))
-
-        for node in [me, that_guy]:
-            val, expiration_time = loop.run_until_complete(node.get("mykey"))
-            assert val == ["Value", 10], "Wrong value"
-            assert expiration_time == true_time, f"Wrong time"
-
-        assert loop.run_until_complete(detached_node.get("mykey")) is None
-
-        # test 7: bulk store and bulk get
-        keys = 'foo', 'bar', 'baz', 'zzz'
-        values = 3, 2, 'batman', [1, 2, 3]
-        store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
-        assert all(store_ok.values()), "failed to store one or more keys"
-        response = loop.run_until_complete(me.get_many(keys[::-1]))
-        for key, value in zip(keys, values):
-            assert key in response and response[key][0] == value
-
-        # test 8: store dictionaries as values (with sub-keys)
-        upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
-        now = get_dht_time()
-        assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
-        assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
-        for node in [that_guy, me]:
-            value, time = loop.run_until_complete(node.get(upper_key))
-            assert isinstance(value, dict) and time == now + 20
-            assert value[subkey1] == (123, now + 10)
-            assert value[subkey2] == (456, now + 20)
-            assert len(value) == 2
-
-        assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
-        assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
-        assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
-        loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
-
-        for node in [that_guy, me]:
-            value, time = loop.run_until_complete(node.get(upper_key))
-            assert isinstance(value, dict) and time == now + 50, (value, time)
-            assert value[subkey1] == (123, now + 10)
-            assert value[subkey2] == (567, now + 30)
-            assert value[subkey3] == (890, now + 50)
-            assert len(value) == 3
-
-        test_success.set()
-
-    tester = mp.Process(target=_tester, daemon=True)
-    tester.start()
-    tester.join()
-    assert test_success.is_set()
+    loop = asyncio.get_event_loop()
+    me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
+                                                cache_refresh_before_expiry=False))
+
+    # test 1: find self
+    nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
+    assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
+
+    # test 2: find others
+    for i in range(10):
+        ref_endpoint, query_id = random.choice(list(dht.items()))
+        nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
+        assert len(nearest) == 1
+        found_node_id, found_endpoint = next(iter(nearest.items()))
+        assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
+
+    # test 3: find neighbors to random nodes
+    accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
+    jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
+    all_node_ids = list(dht.values())
+
+    for i in range(100):
+        query_id = DHTID.generate()
+        k_nearest = random.randint(1, 20)
+        exclude_self = random.random() > 0.5
+        nearest = loop.run_until_complete(
+            me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
+        nearest_nodes = list(nearest)  # keys from ordered dict
+
+        assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
+        assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
+        assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
+
+        ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
+        if exclude_self and me.node_id in ref_nearest:
+            ref_nearest.remove(me.node_id)
+        if len(ref_nearest) > k_nearest:
+            ref_nearest.pop()
+
+        accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
+        accuracy_denominator += 1
+
+        jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
+        jaccard_denominator += k_nearest
+
+    accuracy = accuracy_numerator / accuracy_denominator
+    print("Top-1 accuracy:", accuracy)  # should be 98-100%
+    jaccard_index = jaccard_numerator / jaccard_denominator
+    print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
+    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+    assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+
+    # test 4: find all nodes
+    dummy = DHTID.generate()
+    nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
+    assert len(nearest) == len(dht) + 1
+    assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
+
+    # test 5: node without peers
+    detached_node = loop.run_until_complete(DHTNode.create())
+    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
+    assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
+    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
+    assert len(nearest) == 0
+
+    # test 6 store and get value
+    true_time = get_dht_time() + 1200
+    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
+    that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
+                                                      cache_refresh_before_expiry=False, cache_locally=False))
+
+    for node in [me, that_guy]:
+        val, expiration_time = loop.run_until_complete(node.get("mykey"))
+        assert val == ["Value", 10], "Wrong value"
+        assert expiration_time == true_time, f"Wrong time"
+
+    assert loop.run_until_complete(detached_node.get("mykey")) is None
+
+    # test 7: bulk store and bulk get
+    keys = 'foo', 'bar', 'baz', 'zzz'
+    values = 3, 2, 'batman', [1, 2, 3]
+    store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
+    assert all(store_ok.values()), "failed to store one or more keys"
+    response = loop.run_until_complete(me.get_many(keys[::-1]))
+    for key, value in zip(keys, values):
+        assert key in response and response[key][0] == value
+
+    # test 8: store dictionaries as values (with sub-keys)
+    upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
+    now = get_dht_time()
+    assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
+    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
+    for node in [that_guy, me]:
+        value, time = loop.run_until_complete(node.get(upper_key))
+        assert isinstance(value, dict) and time == now + 20
+        assert value[subkey1] == (123, now + 10)
+        assert value[subkey2] == (456, now + 20)
+        assert len(value) == 2
+
+    assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
+    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
+    assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
+    loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
+
+    for node in [that_guy, me]:
+        value, time = loop.run_until_complete(node.get(upper_key))
+        assert isinstance(value, dict) and time == now + 50, (value, time)
+        assert value[subkey1] == (123, now + 10)
+        assert value[subkey2] == (567, now + 30)
+        assert value[subkey3] == (890, now + 50)
+        assert len(value) == 3
+
     for proc in processes:
         proc.terminate()
 
 
-def test_dhtnode_replicas():
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_replicas():
     dht_size = 20
     initial_peers = 3
     num_replicas = random.randint(1, 20)
-    test_success = mp.Event()
-
-    async def _tester():
-        peers = []
-        for i in range(dht_size):
-            neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
-            peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
-
-        you = random.choice(peers)
-        assert await you.store('key1', 'foo', get_dht_time() + 999)
-
-        actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
-        assert num_replicas == actual_key1_replicas
-
-        assert await you.store('key2', 'bar', get_dht_time() + 999)
-        total_size = sum(len(peer.protocol.storage) for peer in peers)
-        actual_key2_replicas = total_size - actual_key1_replicas
-        assert num_replicas == actual_key2_replicas
-
-        assert await you.store('key2', 'baz', get_dht_time() + 1000)
-        assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
-        test_success.set()
-
-    proc = mp.Process(target=lambda: asyncio.run(_tester()))
-    proc.start()
-    proc.join()
-    assert test_success.is_set()
-
-
-def test_dhtnode_caching(T=0.05):
-    test_success = mp.Event()
-
-    async def _tester():
-        node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
-        node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
-                                              cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
-        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
-        await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
-        await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
-        await node1.get_many(['k', 'k2', 'k3', 'k4'])
-        assert len(node1.protocol.cache) == 3
-        assert len(node1.cache_refresh_queue) == 0
-
-        await node1.get_many(['k', 'k2', 'k3', 'k4'])
-        assert len(node1.cache_refresh_queue) == 3
-
-        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
-        await asyncio.sleep(4 * T)
-        await node1.get('k')
-        await asyncio.sleep(1 * T)
-
-        assert len(node1.protocol.cache) == 3
-        assert len(node1.cache_refresh_queue) == 2
-        await asyncio.sleep(3 * T)
-
-        assert len(node1.cache_refresh_queue) == 1
-
-        await asyncio.sleep(5 * T)
-        assert len(node1.cache_refresh_queue) == 0
-        await asyncio.sleep(5 * T)
-        assert len(node1.cache_refresh_queue) == 0
-
-        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
-        await node1.get('k')
-        await asyncio.sleep(1 * T)
-        assert len(node1.cache_refresh_queue) == 0
-        await node1.get('k')
-        await asyncio.sleep(1 * T)
-        assert len(node1.cache_refresh_queue) == 1
-
-        await asyncio.sleep(5 * T)
-        assert len(node1.cache_refresh_queue) == 0
-
-        await asyncio.gather(node1.shutdown(), node2.shutdown())
-        test_success.set()
-
-    proc = mp.Process(target=lambda: asyncio.run(_tester()))
-    proc.start()
-    proc.join()
-    assert test_success.is_set()
-
-
-def test_dhtnode_reuse_get():
-    test_success = mp.Event()
-
-    async def _tester():
-        peers = []
-        for i in range(10):
-            neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-            peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
-
-        await asyncio.gather(
-            random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
-            random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
-        )
-
-        you = random.choice(peers)
-
-        futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
-        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
-        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
-
-        futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
-        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
-
-        await asyncio.gather(*futures1.values(), *futures2.values())
-        futures3 = await you.get_many(['k3'], return_futures=True)
-        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
-        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
-        assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
-
-        assert (await futures1['k1'])[0] == 123
-        assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
-        assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
-        test_success.set()
 
-    proc = mp.Process(target=lambda: asyncio.run(_tester()))
-    proc.start()
-    proc.join()
-    assert test_success.is_set()
+    peers = []
+    for i in range(dht_size):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
+        peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
+
+    you = random.choice(peers)
+    assert await you.store('key1', 'foo', get_dht_time() + 999)
+
+    actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
+    assert num_replicas == actual_key1_replicas
+
+    assert await you.store('key2', 'bar', get_dht_time() + 999)
+    total_size = sum(len(peer.protocol.storage) for peer in peers)
+    actual_key2_replicas = total_size - actual_key1_replicas
+    assert num_replicas == actual_key2_replicas
+
+    assert await you.store('key2', 'baz', get_dht_time() + 1000)
+    assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_caching(T=0.05):
+    node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
+    node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
+                                          cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
+    await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
+    await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
+    await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
+    await node1.get_many(['k', 'k2', 'k3', 'k4'])
+    assert len(node1.protocol.cache) == 3
+    assert len(node1.cache_refresh_queue) == 0
+
+    await node1.get_many(['k', 'k2', 'k3', 'k4'])
+    assert len(node1.cache_refresh_queue) == 3
+
+    await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
+    await asyncio.sleep(4 * T)
+    await node1.get('k')
+    await asyncio.sleep(1 * T)
+
+    assert len(node1.protocol.cache) == 3
+    assert len(node1.cache_refresh_queue) == 2
+    await asyncio.sleep(3 * T)
+
+    assert len(node1.cache_refresh_queue) == 1
+
+    await asyncio.sleep(5 * T)
+    assert len(node1.cache_refresh_queue) == 0
+    await asyncio.sleep(5 * T)
+    assert len(node1.cache_refresh_queue) == 0
+
+    await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
+    await node1.get('k')
+    await asyncio.sleep(1 * T)
+    assert len(node1.cache_refresh_queue) == 0
+    await node1.get('k')
+    await asyncio.sleep(1 * T)
+    assert len(node1.cache_refresh_queue) == 1
+
+    await asyncio.sleep(5 * T)
+    assert len(node1.cache_refresh_queue) == 0
+
+    await asyncio.gather(node1.shutdown(), node2.shutdown())
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_reuse_get():
+    peers = []
+    for i in range(10):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
+
+    await asyncio.gather(
+        random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
+        random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
+    )
+
+    you = random.choice(peers)
+
+    futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
+    assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
+    assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
+
+    futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
+    assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
+
+    await asyncio.gather(*futures1.values(), *futures2.values())
+    futures3 = await you.get_many(['k3'], return_futures=True)
+    assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
+    assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
+    assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
+
+    assert (await futures1['k1'])[0] == 123
+    assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
+    assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None

+ 6 - 0
tests/test_moe.py

@@ -7,6 +7,7 @@ from hivemind.client.expert import DUMMY
 from hivemind import background_server
 
 
+@pytest.mark.forked
 def test_moe():
     all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
                        for _ in range(20)]
@@ -22,6 +23,7 @@ def test_moe():
             out.sum().backward()
 
 
+@pytest.mark.forked
 def test_call_many():
     k_min = 1
     timeout_after_k_min = None
@@ -71,6 +73,7 @@ def test_call_many():
         assert torch.allclose(our_grad, reference_grad, rtol, atol)
 
 
+@pytest.mark.forked
 def test_remote_module_call():
     with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=1024,
                            optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
@@ -93,6 +96,7 @@ def test_remote_module_call():
             fake_expert(dummy_x)
 
 
+@pytest.mark.forked
 def test_beam_search_correctness():
     all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
     dht = hivemind.DHT(start=True, expiration=999)
@@ -119,6 +123,7 @@ def test_beam_search_correctness():
         assert np.allclose(true_best_scores, our_best_scores)
 
 
+@pytest.mark.forked
 def test_determinism():
     rtol = 0
     atol = 1e-5
@@ -140,6 +145,7 @@ def test_determinism():
     assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
 
 
+@pytest.mark.forked
 def test_compute_expert_scores():
     try:
         dht = hivemind.DHT(start=True)

+ 0 - 4
tests/test_routing.py

@@ -66,10 +66,6 @@ def test_routing_table_basic():
         assert bucket_index == found_bucket_index
 
 
-
-
-
-
 def test_routing_table_parameters():
     for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
         (20,          5,      45,           65),

+ 2 - 0
tests/test_training.py

@@ -1,6 +1,7 @@
 from functools import partial
 from typing import Optional
 
+import pytest
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -9,6 +10,7 @@ from sklearn.datasets import load_digits
 from hivemind import RemoteExpert, background_server
 
 
+@pytest.mark.forked
 def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
     dataset = load_digits()
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])

+ 38 - 40
tests/test_util_modules.py

@@ -78,46 +78,44 @@ def test_mpfuture_status():
         assert future.set_running_or_notify_cancel() is False
 
 
-def test_await_mpfuture():
-    async def _run():
-        # await result
-        f1, f2 = hivemind.MPFuture.make_pair()
-
-        async def wait_and_assign():
-            assert f2.set_running_or_notify_cancel() is True
-            await asyncio.sleep(0.1)
-            f2.set_result((123, 'ololo'))
-
-        asyncio.create_task(wait_and_assign())
-        for future in [f1, f2]:
-            res = await future
-            assert res == (123, 'ololo')
-
-        # await cancel
-        f1, f2 = hivemind.MPFuture.make_pair()
-
-        async def wait_and_cancel():
-            await asyncio.sleep(0.1)
-            f1.cancel()
-
-        asyncio.create_task(wait_and_cancel())
-        for future in [f1, f2]:
-            with pytest.raises(CancelledError):
-                await future
-
-        # await exception
-        f1, f2 = hivemind.MPFuture.make_pair()
-
-        async def wait_and_raise():
-            await asyncio.sleep(0.1)
-            f1.set_exception(SystemError())
-
-        asyncio.create_task(wait_and_raise())
-        for future in [f1, f2]:
-            with pytest.raises(SystemError):
-                await future
-
-    asyncio.new_event_loop().run_until_complete(_run())
+@pytest.mark.asyncio
+async def test_await_mpfuture():
+    # await result
+    f1, f2 = hivemind.MPFuture.make_pair()
+
+    async def wait_and_assign():
+        assert f2.set_running_or_notify_cancel() is True
+        await asyncio.sleep(0.1)
+        f2.set_result((123, 'ololo'))
+
+    asyncio.create_task(wait_and_assign())
+    for future in [f1, f2]:
+        res = await future
+        assert res == (123, 'ololo')
+
+    # await cancel
+    f1, f2 = hivemind.MPFuture.make_pair()
+
+    async def wait_and_cancel():
+        await asyncio.sleep(0.1)
+        f1.cancel()
+
+    asyncio.create_task(wait_and_cancel())
+    for future in [f1, f2]:
+        with pytest.raises(CancelledError):
+            await future
+
+    # await exception
+    f1, f2 = hivemind.MPFuture.make_pair()
+
+    async def wait_and_raise():
+        await asyncio.sleep(0.1)
+        f1.set_exception(SystemError())
+
+    asyncio.create_task(wait_and_raise())
+    for future in [f1, f2]:
+        with pytest.raises(SystemError):
+            await future
 
 
 def test_vector_compression(size=(128, 128, 64), alpha=5e-08):