Selaa lähdekoodia

Add graceful shutdown to DHT and Averager (#301)

- DHT.shutdown and DecentralizedAvearger.shutdown will no longer use terminate process the hard way
- made internal daemon logic of DHT and DecentralizedAverager more similar
- added a basic test that ensures shutdown actually shuts processes down
- added several minor clarifications

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 vuotta sitten
vanhempi
commit
e9956b84f6

+ 5 - 1
CONTRIBUTING.md

@@ -102,6 +102,10 @@ a new test to make sure it's not reintroduced by future changes.
 To run tests, you need to install hivemind in development mode with additional dependencies: `pip install -e .[dev]`.
 You can run all tests with `pytest tests/` or choose a specific subset, e.g., `pytest tests/test_dht.py`.
 
+When investigating test behavior, please note that pytest automatically wraps all hivemind tests with fixtures defined
+in a global configuration file [`tests/conftest.py`](./tests/conftest.py), some of which will run automatically. 
+For more informantion, refer to the [pytest documentation on fixtures](https://docs.pytest.org/en/6.2.x/fixture.html).
+
 ## Building documentation
 
 Any function exposed to a user must have a docstring compatible
@@ -140,4 +144,4 @@ This guide was inspired by several influential Python open source projects liste
 
 * [PyTorch](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md)
 * [Scikit-learn](https://scikit-learn.org/dev/developers/contributing.html)
-* [transformers](https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md)
+* [transformers](https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md)

+ 41 - 23
hivemind/client/averaging/__init__.py

@@ -75,6 +75,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
       with averager.allow_state_sharing = True / False
+    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
 
     Example:
 
@@ -90,6 +91,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     """
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
+    _server: grpc.aio.Server
     serializer = MSGPackSerializer
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
@@ -100,7 +102,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
-                 channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
+                 channel_options: Optional[Sequence[Tuple[str, Any]]] = None,
+                 shutdown_timeout: float = 5, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
             "throughput must be a non-negative float32"
@@ -130,7 +133,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             tensor.share_memory_()
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
-        self._throughput = throughput
+        self.shutdown_timeout = shutdown_timeout
+        self.throughput = throughput
 
         self.matchmaking_kwargs = dict(
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
@@ -140,7 +144,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
-        self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
+        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
@@ -154,7 +158,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
             daemon=True, target=_background_thread_fetch_current_state,
-            args=[self.serializer, self.pipe, weakref.WeakMethod(self.get_current_state)])
+            args=[self.serializer, self._outer_pipe, weakref.WeakMethod(self.get_current_state)])
         background_fetcher.start()
         if start:
             self.run_in_background(await_ready=True)
@@ -205,12 +209,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 grpc.aio.init_grpc_aio()
 
                 if self.listen:
-                    server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
-                    found_port = server.add_insecure_port(self.listen_on)
+                    self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
+                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
+                    found_port = self._server.add_insecure_port(self.listen_on)
                     assert found_port != 0, f"Failed to listen to {self.listen_on}"
                     self._port.value = found_port
-                    await server.start()
+                    await self._server.start()
                 else:
                     logger.debug(f"The averager is running in client mode.")
 
@@ -224,8 +228,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 self.ready.set()
 
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                    asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                    if method == '_shutdown':
+                        await task
+                        break
 
             loop.run_until_complete(_run())
 
@@ -240,15 +247,26 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     def shutdown(self) -> None:
         """ Shut down the averager process """
-        # TODO notify peers before terminating
-        if self._parent_pid != os.getpid() or self.is_alive():
-            self._pipe.send(('_SHUTDOWN', None))
-            self.terminate()
+        if self.is_alive():
+            self._outer_pipe.send(('_shutdown', [None], {}))  # shut down the daemon process
+            self._inner_pipe.send(('_SHUTDOWN', None))  # shut down background thread in master
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
+                self.terminate()
         else:
-            logger.warning("DHT shutdown has no effect: the process is not alive")
+            logger.exception("Averager shutdown has no effect: the process is already not alive")
+
+    async def _shutdown(self, timeout: Optional[float] = None) -> None:
+        remaining_tasks = set()
+        for group in self._running_groups.values():
+            remaining_tasks.update(group.finalize(cancel=True))
+        if self.listen:
+            remaining_tasks.add(self._server.stop(timeout))
+        await asyncio.gather(*remaining_tasks)
 
     def __del__(self):
-        if self._parent_pid != os.getpid() or self.is_alive():
+        if self._parent_pid == os.getpid() and self.is_alive():
             self.shutdown()
 
     def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = None,
@@ -274,8 +292,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         future, _future = MPFuture.make_pair()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
-        self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
-                                          allow_retries=allow_retries, timeout=timeout)))
+        self._outer_pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
+                                                 allow_retries=allow_retries, timeout=timeout)))
         return future.result() if wait else future
 
     async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
@@ -286,7 +304,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             while not future.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary]) 
+                    data_for_gather = self.serializer.dumps([weight, self.throughput, self.mode.value, gather_binary])
                     group_info = await self._matchmaking.look_for_group(timeout=timeout,
                                                                         data_for_gather=data_for_gather)
                     if group_info is None:
@@ -446,7 +464,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _get_current_state_from_host_process(self):
         """ Executed in the averager process inside rpc_download_state """
         future, _future = MPFuture.make_pair()
-        self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
+        self._inner_pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
         return await future
 
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
@@ -460,7 +478,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
+        self._outer_pipe.send(('_load_state_from_peers', [], dict(future=_future)))
         return future.result() if wait else future
 
     async def _load_state_from_peers(self, future: MPFuture):
@@ -520,7 +538,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :returns: averager's current group key bits (without prefix)
         """
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_group_bits', [], dict(future=_future)))
+        self._outer_pipe.send(('_get_group_bits', [], dict(future=_future)))
         return future.result() if wait else future
 
     async def _get_group_bits(self, future: MPFuture):
@@ -533,7 +551,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """
         future, _future = MPFuture.make_pair()
         assert all(bit in '01' for bit in group_bits)
-        self.pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
+        self._outer_pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
         return future.result() if wait else future
 
     async def _set_group_bits(self, group_bits: str, future: MPFuture):

+ 15 - 7
hivemind/client/averaging/allreduce.py

@@ -112,11 +112,6 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             self.finalize(exception=e)
             for task in pending_tasks:
                 task.cancel()
-            code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
-            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
-                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
-                    asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
             raise
 
     async def _communicate_with_peer(self, peer_endpoint: Endpoint):
@@ -210,7 +205,20 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
         await stream.done_writing()
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
+        """ finish or terminate AllReduceRunner, propagate any errors / cancellations to peers. """
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
+        pending_tasks = set()
+        if cancel or exception:
+            # propagate error to peers
+            if cancel or isinstance(exception, asyncio.CancelledError):
+                code = averaging_pb2.CANCELLED
+            else:
+                code = averaging_pb2.INTERNAL_ERROR
+            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
+            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
+                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
+                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_endpoint, code)))
+
         if not self._future.done():
             if cancel:
                 logger.debug(f"{self} - cancelled")
@@ -223,7 +231,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 self._future.set_result(None)
             self.tensor_part_container.finalize()
             self.tensor_part_reducer.finalize()
-            return True
+            return pending_tasks
         else:
             logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
-            return False
+            return pending_tasks

+ 45 - 29
hivemind/dht/__init__.py

@@ -16,12 +16,11 @@ from __future__ import annotations
 import asyncio
 import ctypes
 import multiprocessing as mp
+import os
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
-from typing import Iterable, List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
+from typing import Iterable, Optional, Sequence, Union, Callable, Awaitable, TypeVar
 
-import hivemind
-from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import DHTValue, DHTKey, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
@@ -46,19 +45,22 @@ class DHT(mp.Process):
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
         (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
+    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
     :param kwargs: any other params will be forwarded to DHTNode upon creation
     """
+    _node: DHTNode
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 record_validators: Iterable[RecordValidatorBase] = (), **kwargs):
+                 record_validators: Iterable[RecordValidatorBase] = (), shutdown_timeout: float = 3, **kwargs):
         super().__init__()
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
         self._record_validator = CompositeValidator(record_validators)
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
-        self._pipe, self.pipe = mp.Pipe(duplex=True)
+        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
+        self.shutdown_timeout = shutdown_timeout
         self.ready = mp.Event()
         self.daemon = daemon
         if start:
@@ -70,17 +72,20 @@ class DHT(mp.Process):
 
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
             async def _run():
-                node = await DHTNode.create(
+                self._node = await DHTNode.create(
                     initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
                     num_workers=self.max_workers or 1, record_validator=self._record_validator,
                     **self.kwargs)
-                if node.port is not None:
-                    self._port.value = node.port
+                if self._node.port is not None:
+                    self._port.value = self._node.port
                 self.ready.set()
 
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                    asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                    if method == '_shutdown':
+                        await task
+                        break
 
             coro = _run()
             loop.run_until_complete(coro)
@@ -97,10 +102,17 @@ class DHT(mp.Process):
     def shutdown(self) -> None:
         """ Shut down a running dht process """
         if self.is_alive():
-            self.terminate()
+            self._outer_pipe.send(('_shutdown', [], {}))
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
+                self.terminate()
         else:
             logger.warning("DHT shutdown has no effect: dht process is already not alive")
 
+    async def _shutdown(self):
+        await self._node.shutdown()
+
     @property
     def port(self) -> Optional[int]:
         return self._port.value if self._port.value != 0 else None
@@ -116,12 +128,12 @@ class DHT(mp.Process):
         :returns: (value, expiration time); if value was not found, returns None
         """
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
+        self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
         return future if return_future else future.result()
 
-    async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
+    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
         try:
-            result = await node.get(key, latest=latest, **kwargs)
+            result = await self._node.get(key, latest=latest, **kwargs)
             if not future.done():
                 future.set_result(result)
         except BaseException as e:
@@ -142,14 +154,14 @@ class DHT(mp.Process):
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
-                                           future=_future, **kwargs)))
+        self._outer_pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
+                                                  future=_future, **kwargs)))
         return future if return_future else future.result()
 
-    async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
+    async def _store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
                      subkey: Optional[Subkey], future: MPFuture, **kwargs):
         try:
-            result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs)
+            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
             if not future.done():
                 future.set_result(result)
         except BaseException as e:
@@ -173,12 +185,12 @@ class DHT(mp.Process):
         :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
         """
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
+        self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
         return future if return_future else future.result()
 
-    async def _run_coroutine(self, node: DHTNode, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
+    async def _run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
                              future: MPFuture[ReturnType]):
-        main_task = asyncio.create_task(coro(self, node))
+        main_task = asyncio.create_task(coro(self, self._node))
         cancel_task = asyncio.create_task(await_cancelled(future))
         try:
             await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
@@ -215,21 +227,21 @@ class DHT(mp.Process):
         assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
         assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
+        self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
         return future.result()
 
-    async def _get_visible_address(self, node: DHTNode, num_peers: Optional[int], peers: Sequence[Endpoint],
+    async def _get_visible_address(self, num_peers: Optional[int], peers: Sequence[Endpoint],
                                    future: Optional[MPFuture]):
-        if not peers and (num_peers or not node.protocol.node_info.endpoint):
+        if not peers and (num_peers or not self._node.protocol.node_info.endpoint):
             # if we can't resolve the endpoint locally, ask one random peer
-            peers_and_endpoints = node.protocol.routing_table.get_nearest_neighbors(
-                DHTID.generate(), num_peers or 1, exclude=node.node_id)
+            peers_and_endpoints = self._node.protocol.routing_table.get_nearest_neighbors(
+                DHTID.generate(), num_peers or 1, exclude=self._node.node_id)
             peers = tuple(endpoint for node_id, endpoint in peers_and_endpoints)
 
         chosen_address = None
         if peers:
             possible_endpoints: Sequence[Optional[Endpoint]] = await asyncio.gather(*(
-                node.protocol.get_outgoing_request_endpoint(peer) for peer in peers))
+                self._node.protocol.get_outgoing_request_endpoint(peer) for peer in peers))
 
             for endpoint in possible_endpoints:
                 if endpoint is None:
@@ -244,8 +256,8 @@ class DHT(mp.Process):
             if chosen_address is None:
                 logger.warning(f"None of the selected peers responded with an address ({peers})")
 
-        if node.protocol.node_info.endpoint:
-            address = strip_port(node.protocol.node_info.endpoint)
+        if self._node.protocol.node_info.endpoint:
+            address = strip_port(self._node.protocol.node_info.endpoint)
             if chosen_address is not None and address != chosen_address:
                 logger.warning(f"Node was manually given endpoint {address} , but other peers report {chosen_address}")
             chosen_address = chosen_address or address
@@ -255,3 +267,7 @@ class DHT(mp.Process):
         else:
             future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
                                             f" Please ensure the node is connected or specify peers=... manually."))
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 2 - 1
requirements-dev.txt

@@ -4,4 +4,5 @@ pytest-asyncio
 pytest-cov
 codecov
 tqdm
-scikit-learn
+scikit-learn
+psutil

+ 9 - 0
tests/conftest.py

@@ -0,0 +1,9 @@
+import pytest
+import psutil
+
+
+@pytest.fixture(autouse=True, scope='session')
+def cleanup_children():
+    yield
+    for child in psutil.Process().children(recursive=True):
+        child.terminate()

+ 1 - 0
tests/test_averaging.py

@@ -3,6 +3,7 @@ import random
 import numpy as np
 import torch
 import pytest
+
 import hivemind
 from hivemind.client.averaging.allreduce import AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers

+ 1 - 0
tests/test_dht.py

@@ -8,6 +8,7 @@ import hivemind
 from hivemind import LOCALHOST, strip_port
 
 
+
 @pytest.mark.forked
 def test_get_store():
     peers = []

+ 14 - 2
tests/test_dht_experts.py

@@ -1,5 +1,6 @@
 import asyncio
 import random
+import time
 
 import numpy as np
 import pytest
@@ -36,8 +37,14 @@ def test_store_get_experts():
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert first_found.endpoint == f'that_host:{other_port}'
 
-    for peer in peers:
-        peer.shutdown()
+    # test graceful shutdown
+    first_peer.shutdown()
+    other_peer.shutdown()
+    time.sleep(1.0)
+    remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
+    remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
+    assert all(hivemind.declare_experts(remaining_peer1, ['new_expert.1'], 'dummy'))
+    assert hivemind.get_experts(remaining_peer2, ['new_expert.1'])[0].endpoint == 'dummy'
 
 
 @pytest.mark.forked
@@ -156,3 +163,8 @@ async def test_negative_caching():
         assert fetched[i] is not None, f"node should have cached ffn.{i}."
     for i in range(6, len(fetched)):
         assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
+
+    await node.shutdown()
+    neg_caching_peer.shutdown()
+    for peer in peers:
+        peer.shutdown()