Răsfoiți Sursa

Remove libp2p handlers when ConnectionHandler, DHT, and DecentralizedAverager are shut down (#501)

Alexander Borzunov 3 ani în urmă
părinte
comite
3267fc7ab5

+ 5 - 4
hivemind/averaging/averager.py

@@ -24,8 +24,7 @@ from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
+from hivemind.p2p import P2P, P2PContext, P2PDaemonError, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
@@ -350,6 +349,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             logger.exception("Averager shutdown has no effect: the process is already not alive")
             logger.exception("Averager shutdown has no effect: the process is already not alive")
 
 
     async def _shutdown(self, timeout: Optional[float]) -> None:
     async def _shutdown(self, timeout: Optional[float]) -> None:
+        if not self.client_mode:
+            await self.remove_p2p_handlers(self._p2p, namespace=self.prefix)
+
         remaining_tasks = set()
         remaining_tasks = set()
         for group in self._running_groups.values():
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
             remaining_tasks.update(group.finalize(cancel=True))
@@ -469,8 +471,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.CancelledError,
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
                     asyncio.InvalidStateError,
                     P2PHandlerError,
                     P2PHandlerError,
-                    DispatchFailure,
-                    ControlFailure,
+                    P2PDaemonError,
                 ) as e:
                 ) as e:
                     if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
                     if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
                         if not step.cancelled():
                         if not step.cancelled():

+ 2 - 3
hivemind/averaging/matchmaking.py

@@ -13,8 +13,7 @@ from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
+from hivemind.p2p import P2P, P2PContext, P2PDaemonError, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import DHTExpiration, TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils import DHTExpiration, TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils.asyncio import anext, cancel_and_wait
 from hivemind.utils.asyncio import anext, cancel_and_wait
@@ -239,7 +238,7 @@ class Matchmaking:
         except asyncio.TimeoutError:
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
             return None
-        except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
+        except (P2PDaemonError, P2PHandlerError, StopAsyncIteration) as e:
             logger.debug(f"{self} - failed to request potential leader {leader}:", exc_info=True)
             logger.debug(f"{self} - failed to request potential leader {leader}:", exc_info=True)
             return None
             return None
 
 

+ 1 - 0
hivemind/dht/node.py

@@ -271,6 +271,7 @@ class DHTNode:
     async def shutdown(self):
     async def shutdown(self):
         """Process existing requests, close all connections and stop the server"""
         """Process existing requests, close all connections and stop the server"""
         self.is_alive = False
         self.is_alive = False
+        await self.protocol.shutdown()
         if self._should_shutdown_p2p:
         if self._should_shutdown_p2p:
             await self.p2p.shutdown()
             await self.p2p.shutdown()
 
 

+ 5 - 1
hivemind/dht/protocol.py

@@ -70,7 +70,7 @@ class DHTProtocol(ServicerBase):
         self.record_validator = record_validator
         self.record_validator = record_validator
         self.authorizer = authorizer
         self.authorizer = authorizer
 
 
-        if not client_mode:
+        if not self.client_mode:
             await self.add_p2p_handlers(self.p2p, AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer))
             await self.add_p2p_handlers(self.p2p, AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer))
 
 
             self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes())
             self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes())
@@ -79,6 +79,10 @@ class DHTProtocol(ServicerBase):
             self.node_info = dht_pb2.NodeInfo()
             self.node_info = dht_pb2.NodeInfo()
         return self
         return self
 
 
+    async def shutdown(self) -> None:
+        if not self.client_mode:
+            await self.remove_p2p_handlers(self.p2p)
+
     def __init__(self, *, _initialized_with_create=False):
     def __init__(self, *, _initialized_with_create=False):
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
         assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
         assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"

+ 47 - 8
hivemind/moe/server/connection_handler.py

@@ -28,36 +28,75 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     :param module_backends: a dict [UID -> ModuleBackend] with all active experts
     :param module_backends: a dict [UID -> ModuleBackend] with all active experts
     """
     """
 
 
-    def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend]):
+    def __init__(
+        self,
+        dht: DHT,
+        module_backends: Dict[str, ModuleBackend],
+        *,
+        balanced: bool = True,
+        shutdown_timeout: float = 3,
+        start: bool = False,
+    ):
         super().__init__()
         super().__init__()
         self.dht, self.module_backends = dht, module_backends
         self.dht, self.module_backends = dht, module_backends
+        self.balanced, self.shutdown_timeout = balanced, shutdown_timeout
         self._p2p: Optional[P2P] = None
         self._p2p: Optional[P2P] = None
 
 
+        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False)
         self.ready = MPFuture()
         self.ready = MPFuture()
 
 
+        if start:
+            self.run_in_background(await_ready=True)
+
     def run(self):
     def run(self):
         torch.set_num_threads(1)
         torch.set_num_threads(1)
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
+        stop = asyncio.Event()
+        loop.add_reader(self._inner_pipe.fileno(), stop.set)
 
 
         async def _run():
         async def _run():
             try:
             try:
                 self._p2p = await self.dht.replicate_p2p()
                 self._p2p = await self.dht.replicate_p2p()
-                await self.add_p2p_handlers(self._p2p, balanced=True)
-
-                # wait forever
-                await asyncio.Future()
-
+                await self.add_p2p_handlers(self._p2p, balanced=self.balanced)
+                self.ready.set_result(None)
             except Exception as e:
             except Exception as e:
+                logger.error("ConnectionHandler failed to start:", exc_info=True)
                 self.ready.set_exception(e)
                 self.ready.set_exception(e)
-                return
 
 
-        self.ready.set_result(None)
+            try:
+                await stop.wait()
+            finally:
+                await self.remove_p2p_handlers(self._p2p)
 
 
         try:
         try:
             loop.run_until_complete(_run())
             loop.run_until_complete(_run())
         except KeyboardInterrupt:
         except KeyboardInterrupt:
             logger.debug("Caught KeyboardInterrupt, shutting down")
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
 
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
+        """
+        Starts ConnectionHandler in a background process. If :await_ready:, this method will wait until
+        it is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self.ready.result(timeout=timeout)
+
+    def shutdown(self):
+        if self.is_alive():
+            self._outer_pipe.send("_shutdown")
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning(
+                    "ConnectionHandler did not shut down within the grace period; terminating it the hard way"
+                )
+                self.terminate()
+        else:
+            logger.warning("ConnectionHandler shutdown had no effect, the process is already dead")
+
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         module_info = self.module_backends[request.uid].get_info()
         module_info = self.module_backends[request.uid].get_info()
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))

+ 2 - 2
hivemind/p2p/__init__.py

@@ -1,3 +1,3 @@
-from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
+from hivemind.p2p.p2p_daemon import P2P, P2PContext
+from hivemind.p2p.p2p_daemon_bindings import P2PDaemonError, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p.servicer import ServicerBase, StubBase
 from hivemind.p2p.servicer import ServicerBase, StubBase

+ 16 - 0
hivemind/p2p/p2p_daemon.py

@@ -475,6 +475,19 @@ class P2P:
 
 
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type, balanced=balanced)
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type, balanced=balanced)
 
 
+    async def remove_protobuf_handler(
+        self,
+        name: str,
+        *,
+        stream_input: bool = False,
+        stream_output: bool = False,
+    ) -> None:
+        if not stream_input and not stream_output:
+            await self._client.remove_unary_handler(name)
+            return
+
+        await self.remove_binary_stream_handler(name)
+
     async def _add_protobuf_unary_handler(
     async def _add_protobuf_unary_handler(
         self,
         self,
         handle_name: str,
         handle_name: str,
@@ -553,6 +566,9 @@ class P2P:
             self._start_listening()
             self._start_listening()
         await self._client.stream_handler(name, handler, balanced)
         await self._client.stream_handler(name, handler, balanced)
 
 
+    async def remove_binary_stream_handler(self, name: str) -> None:
+        await self._client.remove_stream_handler(name)
+
     async def call_binary_stream_handler(
     async def call_binary_stream_handler(
         self, peer_id: PeerID, handler_name: str
         self, peer_id: PeerID, handler_name: str
     ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
     ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:

+ 2 - 0
hivemind/p2p/p2p_daemon_bindings/__init__.py

@@ -0,0 +1,2 @@
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import P2PDaemonError, P2PHandlerError

+ 55 - 22
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -12,7 +12,14 @@ from uuid import UUID, uuid4
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
 
 
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
-from hivemind.p2p.p2p_daemon_bindings.utils import DispatchFailure, raise_if_failed, read_pbmsg_safe, write_pbmsg
+from hivemind.p2p.p2p_daemon_bindings.utils import (
+    DispatchFailure,
+    P2PDaemonError,
+    P2PHandlerError,
+    raise_if_failed,
+    read_pbmsg_safe,
+    write_pbmsg,
+)
 from hivemind.proto import p2pd_pb2 as p2pd_pb
 from hivemind.proto import p2pd_pb2 as p2pd_pb
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -249,20 +256,37 @@ class ControlClient:
         self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
         self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
         self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
         self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
 
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False) -> None:
+        if proto in self.unary_handlers:
+            raise P2PDaemonError(f"Handler for protocol {proto} already registered")
+        self.unary_handlers[proto] = handler
+
         call_id = uuid4()
         call_id = uuid4()
+        req = p2pd_pb.PersistentConnectionRequest(
+            callId=call_id.bytes,
+            addUnaryHandler=p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced),
+        )
 
 
-        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced)
-        req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+        self._pending_calls[call_id] = asyncio.Future()
+        await self._pending_messages.put(req)
+        await self._pending_calls[call_id]
 
 
-        if self.unary_handlers.get(proto):
-            raise P2PDaemonError(f"Handler for protocol {proto} already registered")
-        self.unary_handlers[proto] = handler
+    async def remove_unary_handler(self, proto: str) -> None:
+        if proto not in self.unary_handlers:
+            raise P2PDaemonError(f"Handler for protocol {proto} is not registered")
+
+        call_id = uuid4()
+        req = p2pd_pb.PersistentConnectionRequest(
+            callId=call_id.bytes,
+            removeUnaryHandler=p2pd_pb.RemoveUnaryHandlerRequest(proto=proto),
+        )
 
 
         self._pending_calls[call_id] = asyncio.Future()
         self._pending_calls[call_id] = asyncio.Future()
         await self._pending_messages.put(req)
         await self._pending_messages.put(req)
         await self._pending_calls[call_id]
         await self._pending_calls[call_id]
 
 
+        del self.unary_handlers[proto]
+
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         call_id = uuid4()
         call_id = uuid4()
         call_unary_req = p2pd_pb.CallUnaryRequest(
         call_unary_req = p2pd_pb.CallUnaryRequest(
@@ -362,13 +386,18 @@ class ControlClient:
         return stream_info, reader, writer
         return stream_info, reader, writer
 
 
     async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
     async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
+        self.handlers[proto] = handler_cb
+
         reader, writer = await self.daemon_connector.open_connection()
         reader, writer = await self.daemon_connector.open_connection()
 
 
-        listen_path_maddr_bytes = self.listen_maddr.to_bytes()
-        stream_handler_req = p2pd_pb.StreamHandlerRequest(
-            addr=listen_path_maddr_bytes, proto=[proto], balanced=balanced
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.STREAM_HANDLER,
+            streamHandler=p2pd_pb.StreamHandlerRequest(
+                addr=self.listen_maddr.to_bytes(),
+                proto=[proto],
+                balanced=balanced,
+            ),
         )
         )
-        req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
         await write_pbmsg(writer, req)
         await write_pbmsg(writer, req)
 
 
         resp = p2pd_pb.Response()  # type: ignore
         resp = p2pd_pb.Response()  # type: ignore
@@ -376,17 +405,21 @@ class ControlClient:
         writer.close()
         writer.close()
         raise_if_failed(resp)
         raise_if_failed(resp)
 
 
-        # if success, add the handler to the dict
-        self.handlers[proto] = handler_cb
-
+    async def remove_stream_handler(self, proto: str) -> None:
+        reader, writer = await self.daemon_connector.open_connection()
 
 
-class P2PHandlerError(Exception):
-    """
-    Raised if remote handled a request with an exception
-    """
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.REMOVE_STREAM_HANDLER,
+            removeStreamHandler=p2pd_pb.RemoveStreamHandlerRequest(
+                addr=self.listen_maddr.to_bytes(),
+                proto=[proto],
+            ),
+        )
+        await write_pbmsg(writer, req)
 
 
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
 
 
-class P2PDaemonError(Exception):
-    """
-    Raised if daemon failed to handle request
-    """
+        del self.handlers[proto]

+ 7 - 1
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -61,9 +61,12 @@ class Client:
         async with self.control.listen():
         async with self.control.listen():
             yield self
             yield self
 
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False) -> None:
         await self.control.add_unary_handler(proto, handler, balanced=balanced)
         await self.control.add_unary_handler(proto, handler, balanced=balanced)
 
 
+    async def remove_unary_handler(self, proto: str) -> None:
+        await self.control.remove_unary_handler(proto)
+
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         return await self.control.call_unary_handler(peer_id, proto, data)
         return await self.control.call_unary_handler(peer_id, proto, data)
 
 
@@ -114,3 +117,6 @@ class Client:
         :return:
         :return:
         """
         """
         await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced)
         await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced)
+
+    async def remove_stream_handler(self, proto: str) -> None:
+        await self.control.remove_stream_handler(proto=proto)

+ 14 - 2
hivemind/p2p/p2p_daemon_bindings/utils.py

@@ -13,11 +13,23 @@ from hivemind.proto import p2pd_pb2 as p2pd_pb
 DEFAULT_MAX_BITS: int = 64
 DEFAULT_MAX_BITS: int = 64
 
 
 
 
-class ControlFailure(Exception):
+class P2PHandlerError(Exception):
+    """
+    Raised if remote handled a request with an exception
+    """
+
+
+class P2PDaemonError(Exception):
+    """
+    Raised if daemon failed to handle request
+    """
+
+
+class ControlFailure(P2PDaemonError):
     pass
     pass
 
 
 
 
-class DispatchFailure(Exception):
+class DispatchFailure(P2PDaemonError):
     pass
     pass
 
 
 
 

+ 14 - 0
hivemind/p2p/servicer.py

@@ -124,6 +124,20 @@ class ServicerBase:
             ]
             ]
         )
         )
 
 
+    async def remove_p2p_handlers(self, p2p: P2P, *, namespace: Optional[str] = None) -> None:
+        self._collect_rpc_handlers()
+
+        await asyncio.gather(
+            *[
+                p2p.remove_protobuf_handler(
+                    self._get_handle_name(namespace, handler.method_name),
+                    stream_input=handler.stream_input,
+                    stream_output=handler.stream_output,
+                )
+                for handler in self._rpc_handlers
+            ]
+        )
+
     @classmethod
     @classmethod
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
         cls._collect_rpc_handlers()
         cls._collect_rpc_handlers()

+ 12 - 1
hivemind/proto/p2pd.proto

@@ -12,12 +12,12 @@ message Request {
     CONNECT                  = 1;
     CONNECT                  = 1;
     STREAM_OPEN              = 2;
     STREAM_OPEN              = 2;
     STREAM_HANDLER           = 3;
     STREAM_HANDLER           = 3;
+    REMOVE_STREAM_HANDLER    = 10;
     DHT                      = 4;
     DHT                      = 4;
     LIST_PEERS               = 5;
     LIST_PEERS               = 5;
     CONNMANAGER              = 6;
     CONNMANAGER              = 6;
     DISCONNECT               = 7;
     DISCONNECT               = 7;
     PUBSUB                   = 8;
     PUBSUB                   = 8;
-
     PERSISTENT_CONN_UPGRADE  = 9;
     PERSISTENT_CONN_UPGRADE  = 9;
   }
   }
 
 
@@ -26,6 +26,7 @@ message Request {
   optional ConnectRequest connect = 2;
   optional ConnectRequest connect = 2;
   optional StreamOpenRequest streamOpen = 3;
   optional StreamOpenRequest streamOpen = 3;
   optional StreamHandlerRequest streamHandler = 4;
   optional StreamHandlerRequest streamHandler = 4;
+  optional RemoveStreamHandlerRequest removeStreamHandler = 9;
   optional DHTRequest dht = 5;
   optional DHTRequest dht = 5;
   optional ConnManagerRequest connManager = 6;
   optional ConnManagerRequest connManager = 6;
   optional DisconnectRequest disconnect = 7;
   optional DisconnectRequest disconnect = 7;
@@ -52,6 +53,7 @@ message PersistentConnectionRequest {
 
 
   oneof message {
   oneof message {
     AddUnaryHandlerRequest addUnaryHandler = 2;
     AddUnaryHandlerRequest addUnaryHandler = 2;
+    RemoveUnaryHandlerRequest removeUnaryHandler = 6;
     CallUnaryRequest  callUnary = 3;
     CallUnaryRequest  callUnary = 3;
     CallUnaryResponse unaryResponse = 4;
     CallUnaryResponse unaryResponse = 4;
     Cancel cancel = 5;
     Cancel cancel = 5;
@@ -93,6 +95,11 @@ message StreamHandlerRequest {
   required bool balanced = 3;
   required bool balanced = 3;
 }
 }
 
 
+message RemoveStreamHandlerRequest {
+  required bytes addr = 1;
+  repeated string proto = 2;
+}
+
 message ErrorResponse {
 message ErrorResponse {
   required string msg = 1;
   required string msg = 1;
 }
 }
@@ -205,6 +212,10 @@ message AddUnaryHandlerRequest {
   required bool balanced = 2;
   required bool balanced = 2;
 }
 }
 
 
+message RemoveUnaryHandlerRequest {
+  required string proto = 1;
+}
+
 message DaemonError {
 message DaemonError {
   optional string message = 1;
   optional string message = 1;
 }
 }

+ 2 - 2
setup.py

@@ -13,14 +13,14 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 from setuptools.command.develop import develop
 
 
-P2PD_VERSION = "v0.3.10"
+P2PD_VERSION = "v0.3.11"
 
 
 P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
 P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
 
 
 # The value is sha256 of the binary from the release page
 # The value is sha256 of the binary from the release page
 EXECUTABLES = {
 EXECUTABLES = {
-    "p2pd": "a9728685fd020dd5f0292e64b82740ac1643bbe9f793ec6d0b765c7efc28bcec",
+    "p2pd": "1252a2a2095040cef8e317f5801df8b8c93559711783a2496a0aff2f3e177e39",
 }
 }
 
 
 
 

+ 31 - 31
tests/test_connection_handler.py

@@ -20,19 +20,25 @@ from hivemind.utils.streaming import split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_connection_handler_info():
-    handler = ConnectionHandler(
-        DHT(start=True),
-        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
-    )
-    handler.start()
+@pytest.fixture
+async def client_stub():
+    handler_dht = DHT(start=True)
+    module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)}
+    handler = ConnectionHandler(handler_dht, module_backends, start=True)
 
 
     client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
     client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
     client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
     client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
 
 
-    # info
+    yield client_stub
+
+    client_dht.shutdown()
+    handler.shutdown()
+    handler_dht.shutdown()
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_info(client_stub):
     response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
     response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
     assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
     assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
 
 
@@ -45,16 +51,7 @@ async def test_connection_handler_info():
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_connection_handler_forward():
-    handler = ConnectionHandler(
-        DHT(start=True),
-        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
-    )
-    handler.start()
-
-    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
-    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
-
+async def test_connection_handler_forward(client_stub):
     inputs = torch.randn(1, 2)
     inputs = torch.randn(1, 2)
     inputs_long = torch.randn(2**21, 2)
     inputs_long = torch.randn(2**21, 2)
 
 
@@ -106,16 +103,7 @@ async def test_connection_handler_forward():
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_connection_handler_backward():
-    handler = ConnectionHandler(
-        DHT(start=True),
-        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
-    )
-    handler.start()
-
-    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
-    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
-
+async def test_connection_handler_backward(client_stub):
     inputs = torch.randn(1, 2)
     inputs = torch.randn(1, 2)
     inputs_long = torch.randn(2**21, 2)
     inputs_long = torch.randn(2**21, 2)
 
 
@@ -165,8 +153,20 @@ async def test_connection_handler_backward():
     # check that handler did not crash after failed request
     # check that handler did not crash after failed request
     await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
     await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
 
 
-    handler.terminate()
-    handler.join()
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_shutdown():
+    # Here, all handlers will have the common hivemind.DHT and hivemind.P2P instances
+    handler_dht = DHT(start=True)
+    module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)}
+
+    for _ in range(3):
+        handler = ConnectionHandler(handler_dht, module_backends, balanced=False, start=True)
+        # The line above would raise an exception if the previous handlers were not removed from hivemind.P2P
+        handler.shutdown()
+
+    handler_dht.shutdown()
 
 
 
 
 class DummyPool(TaskPool):
 class DummyPool(TaskPool):

+ 82 - 42
tests/test_p2p_servicer.py

@@ -3,7 +3,7 @@ from typing import AsyncIterator
 
 
 import pytest
 import pytest
 
 
-from hivemind.p2p import P2P, P2PContext, ServicerBase
+from hivemind.p2p import P2P, P2PContext, P2PDaemonError, ServicerBase
 from hivemind.proto import test_pb2
 from hivemind.proto import test_pb2
 from hivemind.utils.asyncio import anext
 from hivemind.utils.asyncio import anext
 
 
@@ -17,35 +17,37 @@ async def server_client():
     await asyncio.gather(server.shutdown(), client.shutdown())
     await asyncio.gather(server.shutdown(), client.shutdown())
 
 
 
 
+class UnaryUnaryServicer(ServicerBase):
+    async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
+        return test_pb2.TestResponse(number=request.number**2)
+
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_unary_unary(server_client):
 async def test_unary_unary(server_client):
-    class ExampleServicer(ServicerBase):
-        async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
-            return test_pb2.TestResponse(number=request.number**2)
-
     server, client = server_client
     server, client = server_client
-    servicer = ExampleServicer()
+    servicer = UnaryUnaryServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.peer_id)
+    stub = UnaryUnaryServicer.get_stub(client, server.peer_id)
 
 
     assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
     assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
 
 
 
 
+class StreamUnaryServicer(ServicerBase):
+    async def rpc_sum(
+        self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
+    ) -> test_pb2.TestResponse:
+        result = 0
+        async for item in stream:
+            result += item.number
+        return test_pb2.TestResponse(number=result)
+
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_stream_unary(server_client):
 async def test_stream_unary(server_client):
-    class ExampleServicer(ServicerBase):
-        async def rpc_sum(
-            self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
-        ) -> test_pb2.TestResponse:
-            result = 0
-            async for item in stream:
-                result += item.number
-            return test_pb2.TestResponse(number=result)
-
     server, client = server_client
     server, client = server_client
-    servicer = ExampleServicer()
+    servicer = StreamUnaryServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.peer_id)
+    stub = StreamUnaryServicer.get_stub(client, server.peer_id)
 
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
         for i in range(10):
@@ -54,42 +56,40 @@ async def test_stream_unary(server_client):
     assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
     assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
 
 
 
 
+class UnaryStreamServicer(ServicerBase):
+    async def rpc_count(
+        self, request: test_pb2.TestRequest, _context: P2PContext
+    ) -> AsyncIterator[test_pb2.TestResponse]:
+        for i in range(request.number):
+            yield test_pb2.TestResponse(number=i)
+
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_unary_stream(server_client):
 async def test_unary_stream(server_client):
-    class ExampleServicer(ServicerBase):
-        async def rpc_count(
-            self, request: test_pb2.TestRequest, _context: P2PContext
-        ) -> AsyncIterator[test_pb2.TestResponse]:
-            for i in range(request.number):
-                yield test_pb2.TestResponse(number=i)
-
     server, client = server_client
     server, client = server_client
-    servicer = ExampleServicer()
+    servicer = UnaryStreamServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.peer_id)
+    stub = UnaryStreamServicer.get_stub(client, server.peer_id)
 
 
     stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
     stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
-    i = 0
-    async for item in stream:
-        assert item == test_pb2.TestResponse(number=i)
-        i += 1
-    assert i == 10
+    assert [item.number async for item in stream] == list(range(10))
+
+
+class StreamStreamServicer(ServicerBase):
+    async def rpc_powers(
+        self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
+    ) -> AsyncIterator[test_pb2.TestResponse]:
+        async for item in stream:
+            yield test_pb2.TestResponse(number=item.number**2)
+            yield test_pb2.TestResponse(number=item.number**3)
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_stream_stream(server_client):
 async def test_stream_stream(server_client):
-    class ExampleServicer(ServicerBase):
-        async def rpc_powers(
-            self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
-        ) -> AsyncIterator[test_pb2.TestResponse]:
-            async for item in stream:
-                yield test_pb2.TestResponse(number=item.number**2)
-                yield test_pb2.TestResponse(number=item.number**3)
-
     server, client = server_client
     server, client = server_client
-    servicer = ExampleServicer()
+    servicer = StreamStreamServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.peer_id)
+    stub = StreamStreamServicer.get_stub(client, server.peer_id)
 
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
         for i in range(10):
@@ -153,3 +153,43 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
 
 
     await asyncio.sleep(0.25)
     await asyncio.sleep(0.25)
     assert handler_cancelled
     assert handler_cancelled
+
+
+@pytest.mark.asyncio
+async def test_removing_unary_handlers(server_client):
+    server1, client = server_client
+    server2 = await P2P.replicate(server1.daemon_listen_maddr)
+    servicer = UnaryUnaryServicer()
+    stub = UnaryUnaryServicer.get_stub(client, server1.peer_id)
+
+    for server in [server1, server2, server1]:
+        await servicer.add_p2p_handlers(server)
+        assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
+
+        await servicer.remove_p2p_handlers(server)
+        with pytest.raises((P2PDaemonError, ConnectionError)):
+            await stub.rpc_square(test_pb2.TestRequest(number=10))
+
+    await asyncio.gather(server2.shutdown())
+
+
+@pytest.mark.asyncio
+async def test_removing_stream_handlers(server_client):
+    server1, client = server_client
+    server2 = await P2P.replicate(server1.daemon_listen_maddr)
+    servicer = UnaryStreamServicer()
+    stub = UnaryStreamServicer.get_stub(client, server1.peer_id)
+
+    for server in [server1, server2, server1]:
+        await servicer.add_p2p_handlers(server)
+        stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
+        assert [item.number async for item in stream] == list(range(10))
+
+        await servicer.remove_p2p_handlers(server)
+        with pytest.raises((P2PDaemonError, ConnectionError)):
+            stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
+            outputs = [item.number async for item in stream]
+            if not outputs:
+                raise P2PDaemonError("Daemon has reset the connection")
+
+    await asyncio.gather(server2.shutdown())