浏览代码

Add graceful shutdown to DHT and Averager (#301)

- DHT.shutdown and DecentralizedAvearger.shutdown will no longer use terminate process the hard way
- made internal daemon logic of DHT and DecentralizedAverager more similar
- added a basic test that ensures shutdown actually shuts processes down
- added several minor clarifications

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 年之前
父节点
当前提交
e9956b84f6

+ 5 - 1
CONTRIBUTING.md

@@ -102,6 +102,10 @@ a new test to make sure it's not reintroduced by future changes.
 To run tests, you need to install hivemind in development mode with additional dependencies: `pip install -e .[dev]`.
 To run tests, you need to install hivemind in development mode with additional dependencies: `pip install -e .[dev]`.
 You can run all tests with `pytest tests/` or choose a specific subset, e.g., `pytest tests/test_dht.py`.
 You can run all tests with `pytest tests/` or choose a specific subset, e.g., `pytest tests/test_dht.py`.
 
 
+When investigating test behavior, please note that pytest automatically wraps all hivemind tests with fixtures defined
+in a global configuration file [`tests/conftest.py`](./tests/conftest.py), some of which will run automatically. 
+For more informantion, refer to the [pytest documentation on fixtures](https://docs.pytest.org/en/6.2.x/fixture.html).
+
 ## Building documentation
 ## Building documentation
 
 
 Any function exposed to a user must have a docstring compatible
 Any function exposed to a user must have a docstring compatible
@@ -140,4 +144,4 @@ This guide was inspired by several influential Python open source projects liste
 
 
 * [PyTorch](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md)
 * [PyTorch](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md)
 * [Scikit-learn](https://scikit-learn.org/dev/developers/contributing.html)
 * [Scikit-learn](https://scikit-learn.org/dev/developers/contributing.html)
-* [transformers](https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md)
+* [transformers](https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md)

+ 41 - 23
hivemind/client/averaging/__init__.py

@@ -75,6 +75,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
           local tensors for averaging
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
       with averager.allow_state_sharing = True / False
       with averager.allow_state_sharing = True / False
+    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
 
 
     Example:
     Example:
 
 
@@ -90,6 +91,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     """
     """
     _matchmaking: Matchmaking
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
     _pending_group_assembled: asyncio.Event
+    _server: grpc.aio.Server
     serializer = MSGPackSerializer
     serializer = MSGPackSerializer
 
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
@@ -100,7 +102,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
                  auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
-                 channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
+                 channel_options: Optional[Sequence[Tuple[str, Any]]] = None,
+                 shutdown_timeout: float = 5, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
             "throughput must be a non-negative float32"
             "throughput must be a non-negative float32"
@@ -130,7 +133,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             tensor.share_memory_()
             tensor.share_memory_()
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
-        self._throughput = throughput
+        self.shutdown_timeout = shutdown_timeout
+        self.throughput = throughput
 
 
         self.matchmaking_kwargs = dict(
         self.matchmaking_kwargs = dict(
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
@@ -140,7 +144,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
 
-        self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
+        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
 
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
@@ -154,7 +158,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
         background_fetcher = threading.Thread(
             daemon=True, target=_background_thread_fetch_current_state,
             daemon=True, target=_background_thread_fetch_current_state,
-            args=[self.serializer, self.pipe, weakref.WeakMethod(self.get_current_state)])
+            args=[self.serializer, self._outer_pipe, weakref.WeakMethod(self.get_current_state)])
         background_fetcher.start()
         background_fetcher.start()
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
@@ -205,12 +209,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 grpc.aio.init_grpc_aio()
                 grpc.aio.init_grpc_aio()
 
 
                 if self.listen:
                 if self.listen:
-                    server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
-                    found_port = server.add_insecure_port(self.listen_on)
+                    self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
+                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
+                    found_port = self._server.add_insecure_port(self.listen_on)
                     assert found_port != 0, f"Failed to listen to {self.listen_on}"
                     assert found_port != 0, f"Failed to listen to {self.listen_on}"
                     self._port.value = found_port
                     self._port.value = found_port
-                    await server.start()
+                    await self._server.start()
                 else:
                 else:
                     logger.debug(f"The averager is running in client mode.")
                     logger.debug(f"The averager is running in client mode.")
 
 
@@ -224,8 +228,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 self.ready.set()
                 self.ready.set()
 
 
                 while True:
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                    asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                    if method == '_shutdown':
+                        await task
+                        break
 
 
             loop.run_until_complete(_run())
             loop.run_until_complete(_run())
 
 
@@ -240,15 +247,26 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     def shutdown(self) -> None:
     def shutdown(self) -> None:
         """ Shut down the averager process """
         """ Shut down the averager process """
-        # TODO notify peers before terminating
-        if self._parent_pid != os.getpid() or self.is_alive():
-            self._pipe.send(('_SHUTDOWN', None))
-            self.terminate()
+        if self.is_alive():
+            self._outer_pipe.send(('_shutdown', [None], {}))  # shut down the daemon process
+            self._inner_pipe.send(('_SHUTDOWN', None))  # shut down background thread in master
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
+                self.terminate()
         else:
         else:
-            logger.warning("DHT shutdown has no effect: the process is not alive")
+            logger.exception("Averager shutdown has no effect: the process is already not alive")
+
+    async def _shutdown(self, timeout: Optional[float] = None) -> None:
+        remaining_tasks = set()
+        for group in self._running_groups.values():
+            remaining_tasks.update(group.finalize(cancel=True))
+        if self.listen:
+            remaining_tasks.add(self._server.stop(timeout))
+        await asyncio.gather(*remaining_tasks)
 
 
     def __del__(self):
     def __del__(self):
-        if self._parent_pid != os.getpid() or self.is_alive():
+        if self._parent_pid == os.getpid() and self.is_alive():
             self.shutdown()
             self.shutdown()
 
 
     def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = None,
     def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = None,
@@ -274,8 +292,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
-        self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
-                                          allow_retries=allow_retries, timeout=timeout)))
+        self._outer_pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
+                                                 allow_retries=allow_retries, timeout=timeout)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
     async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
     async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
@@ -286,7 +304,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             while not future.done():
             while not future.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary]) 
+                    data_for_gather = self.serializer.dumps([weight, self.throughput, self.mode.value, gather_binary])
                     group_info = await self._matchmaking.look_for_group(timeout=timeout,
                     group_info = await self._matchmaking.look_for_group(timeout=timeout,
                                                                         data_for_gather=data_for_gather)
                                                                         data_for_gather=data_for_gather)
                     if group_info is None:
                     if group_info is None:
@@ -446,7 +464,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _get_current_state_from_host_process(self):
     async def _get_current_state_from_host_process(self):
         """ Executed in the averager process inside rpc_download_state """
         """ Executed in the averager process inside rpc_download_state """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
+        self._inner_pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
         return await future
         return await future
 
 
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
@@ -460,7 +478,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         The exact contents of both metadata and tensors are determined by get_current_state method
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
         """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
+        self._outer_pipe.send(('_load_state_from_peers', [], dict(future=_future)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
     async def _load_state_from_peers(self, future: MPFuture):
     async def _load_state_from_peers(self, future: MPFuture):
@@ -520,7 +538,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :returns: averager's current group key bits (without prefix)
         :returns: averager's current group key bits (without prefix)
         """
         """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_group_bits', [], dict(future=_future)))
+        self._outer_pipe.send(('_get_group_bits', [], dict(future=_future)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
     async def _get_group_bits(self, future: MPFuture):
     async def _get_group_bits(self, future: MPFuture):
@@ -533,7 +551,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """
         """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         assert all(bit in '01' for bit in group_bits)
         assert all(bit in '01' for bit in group_bits)
-        self.pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
+        self._outer_pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
     async def _set_group_bits(self, group_bits: str, future: MPFuture):
     async def _set_group_bits(self, group_bits: str, future: MPFuture):

+ 15 - 7
hivemind/client/averaging/allreduce.py

@@ -112,11 +112,6 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             self.finalize(exception=e)
             self.finalize(exception=e)
             for task in pending_tasks:
             for task in pending_tasks:
                 task.cancel()
                 task.cancel()
-            code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
-            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
-                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
-                    asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
             raise
             raise
 
 
     async def _communicate_with_peer(self, peer_endpoint: Endpoint):
     async def _communicate_with_peer(self, peer_endpoint: Endpoint):
@@ -210,7 +205,20 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
         await stream.done_writing()
         await stream.done_writing()
 
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
+        """ finish or terminate AllReduceRunner, propagate any errors / cancellations to peers. """
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
+        pending_tasks = set()
+        if cancel or exception:
+            # propagate error to peers
+            if cancel or isinstance(exception, asyncio.CancelledError):
+                code = averaging_pb2.CANCELLED
+            else:
+                code = averaging_pb2.INTERNAL_ERROR
+            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
+            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
+                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
+                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_endpoint, code)))
+
         if not self._future.done():
         if not self._future.done():
             if cancel:
             if cancel:
                 logger.debug(f"{self} - cancelled")
                 logger.debug(f"{self} - cancelled")
@@ -223,7 +231,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 self._future.set_result(None)
                 self._future.set_result(None)
             self.tensor_part_container.finalize()
             self.tensor_part_container.finalize()
             self.tensor_part_reducer.finalize()
             self.tensor_part_reducer.finalize()
-            return True
+            return pending_tasks
         else:
         else:
             logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
             logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
-            return False
+            return pending_tasks

+ 45 - 29
hivemind/dht/__init__.py

@@ -16,12 +16,11 @@ from __future__ import annotations
 import asyncio
 import asyncio
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
+import os
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
 from functools import partial
-from typing import Iterable, List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
+from typing import Iterable, Optional, Sequence, Union, Callable, Awaitable, TypeVar
 
 
-import hivemind
-from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import DHTValue, DHTKey, Subkey
 from hivemind.dht.routing import DHTValue, DHTKey, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
@@ -46,19 +45,22 @@ class DHT(mp.Process):
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
         (but no more than one per key)
         (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
+    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
     :param kwargs: any other params will be forwarded to DHTNode upon creation
     :param kwargs: any other params will be forwarded to DHTNode upon creation
     """
     """
+    _node: DHTNode
 
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 record_validators: Iterable[RecordValidatorBase] = (), **kwargs):
+                 record_validators: Iterable[RecordValidatorBase] = (), shutdown_timeout: float = 3, **kwargs):
         super().__init__()
         super().__init__()
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
         self._record_validator = CompositeValidator(record_validators)
         self._record_validator = CompositeValidator(record_validators)
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
-        self._pipe, self.pipe = mp.Pipe(duplex=True)
+        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
+        self.shutdown_timeout = shutdown_timeout
         self.ready = mp.Event()
         self.ready = mp.Event()
         self.daemon = daemon
         self.daemon = daemon
         if start:
         if start:
@@ -70,17 +72,20 @@ class DHT(mp.Process):
 
 
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
             async def _run():
             async def _run():
-                node = await DHTNode.create(
+                self._node = await DHTNode.create(
                     initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
                     initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
                     num_workers=self.max_workers or 1, record_validator=self._record_validator,
                     num_workers=self.max_workers or 1, record_validator=self._record_validator,
                     **self.kwargs)
                     **self.kwargs)
-                if node.port is not None:
-                    self._port.value = node.port
+                if self._node.port is not None:
+                    self._port.value = self._node.port
                 self.ready.set()
                 self.ready.set()
 
 
                 while True:
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
-                    asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
+                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                    if method == '_shutdown':
+                        await task
+                        break
 
 
             coro = _run()
             coro = _run()
             loop.run_until_complete(coro)
             loop.run_until_complete(coro)
@@ -97,10 +102,17 @@ class DHT(mp.Process):
     def shutdown(self) -> None:
     def shutdown(self) -> None:
         """ Shut down a running dht process """
         """ Shut down a running dht process """
         if self.is_alive():
         if self.is_alive():
-            self.terminate()
+            self._outer_pipe.send(('_shutdown', [], {}))
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
+                self.terminate()
         else:
         else:
             logger.warning("DHT shutdown has no effect: dht process is already not alive")
             logger.warning("DHT shutdown has no effect: dht process is already not alive")
 
 
+    async def _shutdown(self):
+        await self._node.shutdown()
+
     @property
     @property
     def port(self) -> Optional[int]:
     def port(self) -> Optional[int]:
         return self._port.value if self._port.value != 0 else None
         return self._port.value if self._port.value != 0 else None
@@ -116,12 +128,12 @@ class DHT(mp.Process):
         :returns: (value, expiration time); if value was not found, returns None
         :returns: (value, expiration time); if value was not found, returns None
         """
         """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
+        self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
         return future if return_future else future.result()
         return future if return_future else future.result()
 
 
-    async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
+    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
         try:
         try:
-            result = await node.get(key, latest=latest, **kwargs)
+            result = await self._node.get(key, latest=latest, **kwargs)
             if not future.done():
             if not future.done():
                 future.set_result(result)
                 future.set_result(result)
         except BaseException as e:
         except BaseException as e:
@@ -142,14 +154,14 @@ class DHT(mp.Process):
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
-                                           future=_future, **kwargs)))
+        self._outer_pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
+                                                  future=_future, **kwargs)))
         return future if return_future else future.result()
         return future if return_future else future.result()
 
 
-    async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
+    async def _store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
                      subkey: Optional[Subkey], future: MPFuture, **kwargs):
                      subkey: Optional[Subkey], future: MPFuture, **kwargs):
         try:
         try:
-            result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs)
+            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
             if not future.done():
             if not future.done():
                 future.set_result(result)
                 future.set_result(result)
         except BaseException as e:
         except BaseException as e:
@@ -173,12 +185,12 @@ class DHT(mp.Process):
         :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
         :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
         """
         """
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
+        self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
         return future if return_future else future.result()
         return future if return_future else future.result()
 
 
-    async def _run_coroutine(self, node: DHTNode, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
+    async def _run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
                              future: MPFuture[ReturnType]):
                              future: MPFuture[ReturnType]):
-        main_task = asyncio.create_task(coro(self, node))
+        main_task = asyncio.create_task(coro(self, self._node))
         cancel_task = asyncio.create_task(await_cancelled(future))
         cancel_task = asyncio.create_task(await_cancelled(future))
         try:
         try:
             await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
             await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
@@ -215,21 +227,21 @@ class DHT(mp.Process):
         assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
         assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
         assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
         assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
+        self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
         return future.result()
         return future.result()
 
 
-    async def _get_visible_address(self, node: DHTNode, num_peers: Optional[int], peers: Sequence[Endpoint],
+    async def _get_visible_address(self, num_peers: Optional[int], peers: Sequence[Endpoint],
                                    future: Optional[MPFuture]):
                                    future: Optional[MPFuture]):
-        if not peers and (num_peers or not node.protocol.node_info.endpoint):
+        if not peers and (num_peers or not self._node.protocol.node_info.endpoint):
             # if we can't resolve the endpoint locally, ask one random peer
             # if we can't resolve the endpoint locally, ask one random peer
-            peers_and_endpoints = node.protocol.routing_table.get_nearest_neighbors(
-                DHTID.generate(), num_peers or 1, exclude=node.node_id)
+            peers_and_endpoints = self._node.protocol.routing_table.get_nearest_neighbors(
+                DHTID.generate(), num_peers or 1, exclude=self._node.node_id)
             peers = tuple(endpoint for node_id, endpoint in peers_and_endpoints)
             peers = tuple(endpoint for node_id, endpoint in peers_and_endpoints)
 
 
         chosen_address = None
         chosen_address = None
         if peers:
         if peers:
             possible_endpoints: Sequence[Optional[Endpoint]] = await asyncio.gather(*(
             possible_endpoints: Sequence[Optional[Endpoint]] = await asyncio.gather(*(
-                node.protocol.get_outgoing_request_endpoint(peer) for peer in peers))
+                self._node.protocol.get_outgoing_request_endpoint(peer) for peer in peers))
 
 
             for endpoint in possible_endpoints:
             for endpoint in possible_endpoints:
                 if endpoint is None:
                 if endpoint is None:
@@ -244,8 +256,8 @@ class DHT(mp.Process):
             if chosen_address is None:
             if chosen_address is None:
                 logger.warning(f"None of the selected peers responded with an address ({peers})")
                 logger.warning(f"None of the selected peers responded with an address ({peers})")
 
 
-        if node.protocol.node_info.endpoint:
-            address = strip_port(node.protocol.node_info.endpoint)
+        if self._node.protocol.node_info.endpoint:
+            address = strip_port(self._node.protocol.node_info.endpoint)
             if chosen_address is not None and address != chosen_address:
             if chosen_address is not None and address != chosen_address:
                 logger.warning(f"Node was manually given endpoint {address} , but other peers report {chosen_address}")
                 logger.warning(f"Node was manually given endpoint {address} , but other peers report {chosen_address}")
             chosen_address = chosen_address or address
             chosen_address = chosen_address or address
@@ -255,3 +267,7 @@ class DHT(mp.Process):
         else:
         else:
             future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
             future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
                                             f" Please ensure the node is connected or specify peers=... manually."))
                                             f" Please ensure the node is connected or specify peers=... manually."))
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 2 - 1
requirements-dev.txt

@@ -4,4 +4,5 @@ pytest-asyncio
 pytest-cov
 pytest-cov
 codecov
 codecov
 tqdm
 tqdm
-scikit-learn
+scikit-learn
+psutil

+ 9 - 0
tests/conftest.py

@@ -0,0 +1,9 @@
+import pytest
+import psutil
+
+
+@pytest.fixture(autouse=True, scope='session')
+def cleanup_children():
+    yield
+    for child in psutil.Process().children(recursive=True):
+        child.terminate()

+ 1 - 0
tests/test_averaging.py

@@ -3,6 +3,7 @@ import random
 import numpy as np
 import numpy as np
 import torch
 import torch
 import pytest
 import pytest
+
 import hivemind
 import hivemind
 from hivemind.client.averaging.allreduce import AveragingMode
 from hivemind.client.averaging.allreduce import AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers

+ 1 - 0
tests/test_dht.py

@@ -8,6 +8,7 @@ import hivemind
 from hivemind import LOCALHOST, strip_port
 from hivemind import LOCALHOST, strip_port
 
 
 
 
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_get_store():
 def test_get_store():
     peers = []
     peers = []

+ 14 - 2
tests/test_dht_experts.py

@@ -1,5 +1,6 @@
 import asyncio
 import asyncio
 import random
 import random
+import time
 
 
 import numpy as np
 import numpy as np
 import pytest
 import pytest
@@ -36,8 +37,14 @@ def test_store_get_experts():
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert first_found.endpoint == f'that_host:{other_port}'
     assert first_found.endpoint == f'that_host:{other_port}'
 
 
-    for peer in peers:
-        peer.shutdown()
+    # test graceful shutdown
+    first_peer.shutdown()
+    other_peer.shutdown()
+    time.sleep(1.0)
+    remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
+    remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
+    assert all(hivemind.declare_experts(remaining_peer1, ['new_expert.1'], 'dummy'))
+    assert hivemind.get_experts(remaining_peer2, ['new_expert.1'])[0].endpoint == 'dummy'
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -156,3 +163,8 @@ async def test_negative_caching():
         assert fetched[i] is not None, f"node should have cached ffn.{i}."
         assert fetched[i] is not None, f"node should have cached ffn.{i}."
     for i in range(6, len(fetched)):
     for i in range(6, len(fetched)):
         assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
         assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
+
+    await node.shutdown()
+    neg_caching_peer.shutdown()
+    for peer in peers:
+        peer.shutdown()