Bladeren bron

add balanced rpc handlers for connection handler

Pavel Samygin 3 jaren geleden
bovenliggende
commit
5eab59bafb

+ 1 - 1
benchmarks/benchmark_throughput_p2p.py

@@ -134,7 +134,7 @@ def benchmark_throughput(
         server = hivemind.moe.Server(
             dht=server_dht,
             expert_backends=experts,
-            num_connection_handlers=1,  # TODO: support greater number
+            num_connection_handlers=num_handlers,
             device=device,
         )
         server.start()

+ 1 - 1
hivemind/moe/server/connection_handler.py

@@ -41,7 +41,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         async def _run():
             try:
                 self._p2p = await self.dht.replicate_p2p()
-                await self.add_p2p_handlers(self._p2p)
+                await self.add_p2p_handlers(self._p2p, balanced=True)
 
                 await asyncio.Future()
 

+ 1 - 1
hivemind/moe/server/server.py

@@ -67,7 +67,7 @@ class Server(threading.Thread):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
 
-        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(1)]
+        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
         else:

+ 9 - 6
hivemind/p2p/p2p_daemon.py

@@ -315,6 +315,7 @@ class P2P:
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
+        balanced: bool = False,
     ) -> None:
         """
         :param max_prefetch: Maximum number of items to prefetch from the request stream.
@@ -379,7 +380,7 @@ class P2P:
                 finally:
                     processing_task.cancel()
 
-        await self.add_binary_stream_handler(name, _handle_stream)
+        await self.add_binary_stream_handler(name, _handle_stream, balanced=balanced)
 
     async def _iterate_protobuf_stream_handler(
         self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
@@ -421,6 +422,7 @@ class P2P:
         *,
         stream_input: bool = False,
         stream_output: bool = False,
+        balanced: bool = False,
     ) -> None:
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
@@ -430,7 +432,7 @@ class P2P:
         """
 
         if not stream_input and not stream_output:
-            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type, balanced=balanced)
             return
 
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
@@ -443,13 +445,14 @@ class P2P:
             else:
                 yield await output
 
-        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
+        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type, balanced=balanced)
 
     async def _add_protobuf_unary_handler(
         self,
         handle_name: str,
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         input_protobuf_type: Type[Message],
+        balanced: bool = False,
     ) -> None:
         """
         Register a request-response (unary) handler. Unary requests and responses
@@ -471,7 +474,7 @@ class P2P:
             response = await handler(input_serialized, context)
             return response.SerializeToString()
 
-        await self._client.add_unary_handler(handle_name, _unary_handler)
+        await self._client.add_unary_handler(handle_name, _unary_handler, balanced=balanced)
 
     async def call_protobuf_handler(
         self,
@@ -515,10 +518,10 @@ class P2P:
 
         self._listen_task = asyncio.create_task(listen())
 
-    async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
+    async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler, balanced: bool = False) -> None:
         if self._listen_task is None:
             self._start_listening()
-        await self._client.stream_handler(name, handler)
+        await self._client.stream_handler(name, handler, balanced)
 
     async def call_binary_stream_handler(
         self, peer_id: PeerID, handler_name: str

+ 4 - 4
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -246,10 +246,10 @@ class ControlClient:
         self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
         self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
         call_id = uuid4()
 
-        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced)
         req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
 
         if self.unary_handlers.get(proto):
@@ -358,11 +358,11 @@ class ControlClient:
 
         return stream_info, reader, writer
 
-    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
         reader, writer = await self.daemon_connector.open_connection()
 
         listen_path_maddr_bytes = self.listen_maddr.to_bytes()
-        stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto])
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto], balanced=balanced)
         req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
         await write_pbmsg(writer, req)
 

+ 4 - 4
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -61,8 +61,8 @@ class Client:
         async with self.control.listen():
             yield self
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
-        await self.control.add_unary_handler(proto, handler)
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
+        await self.control.add_unary_handler(proto, handler, balanced=balanced)
 
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         return await self.control.call_unary_handler(peer_id, proto, data)
@@ -105,11 +105,11 @@ class Client:
         """
         return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
 
-    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
         """
         Register a stream handler
         :param proto: protocols that handler serves
         :param handler_cb: handler callback
         :return:
         """
-        await self.control.stream_handler(proto=proto, handler_cb=handler_cb)
+        await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced)

+ 2 - 1
hivemind/p2p/servicer.py

@@ -104,7 +104,7 @@ class ServicerBase:
         caller.__name__ = handler.method_name
         return caller
 
-    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None) -> None:
+    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None, balanced: bool = False) -> None:
         self._collect_rpc_handlers()
 
         servicer = self if wrapper is None else wrapper
@@ -116,6 +116,7 @@ class ServicerBase:
                     handler.request_type,
                     stream_input=handler.stream_input,
                     stream_output=handler.stream_output,
+                    balanced=balanced
                 )
                 for handler in self._rpc_handlers
             ]

+ 3 - 1
hivemind/proto/p2pd.proto

@@ -15,7 +15,7 @@ message Request {
     DHT                      = 4;
     LIST_PEERS               = 5;
     CONNMANAGER              = 6;
-    DISCONNECT               = 7;      
+    DISCONNECT               = 7;
     PUBSUB                   = 8;
 
     PERSISTENT_CONN_UPGRADE  = 9;
@@ -90,6 +90,7 @@ message StreamOpenRequest {
 message StreamHandlerRequest {
   required bytes addr = 1;
   repeated string proto = 2;
+  required bool balanced = 3;
 }
 
 message ErrorResponse {
@@ -201,6 +202,7 @@ message CallUnaryResponse {
 
 message AddUnaryHandlerRequest {
   required string proto = 1;
+  required bool balanced = 2;
 }
 
 message DaemonError {