Denis Mazur 4 anni fa
parent
commit
f5a106a2e7
2 ha cambiato i file con 28 aggiunte e 33 eliminazioni
  1. 2 2
      hivemind/moe/server/__init__.py
  2. 26 31
      hivemind/moe/server/connection_handler.py

+ 2 - 2
hivemind/moe/server/__init__.py

@@ -71,7 +71,7 @@ class Server(threading.Thread):
             listen_on = replace_port(listen_on, new_port=get_free_port())
         self.listen_on, self.port = listen_on, get_port(listen_on)
 
-        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
+        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(1)]
         if checkpoint_dir is not None:
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
         else:
@@ -253,7 +253,7 @@ class Server(threading.Thread):
         for process in self.conn_handlers:
             if not process.is_alive():
                 process.start()
-            process.ready.wait()
+            process.ready.result()
 
         try:
             self.runtime.run()

+ 26 - 31
hivemind/moe/server/connection_handler.py

@@ -1,22 +1,22 @@
+import asyncio
 import multiprocessing as mp
-import os
 import pickle
 from typing import Dict
 
-import grpc
 import torch
 
+from hivemind.p2p import P2PContext, ServicerBase
+from hivemind.dht import DHT
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, get_logger, nested_flatten
+from hivemind.proto import runtime_pb2
+from hivemind.utils import get_logger, nested_flatten, MPFuture
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 logger = get_logger(__name__)
 
 
-class ConnectionHandler(mp.context.ForkProcess):
+class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     """
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
 
@@ -25,45 +25,40 @@ class ConnectionHandler(mp.context.ForkProcess):
     :param experts: a dict [UID -> ExpertBackend] with all active experts
     """
 
-    def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
+    def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]):
         super().__init__()
-        self.listen_on, self.experts = listen_on, experts
-        self.ready = mp.Event()
+        self.dht, self.experts = dht, experts
+
+        self.ready = MPFuture()
 
     def run(self):
         torch.set_num_threads(1)
         loop = switch_to_uvloop()
 
         async def _run():
-            grpc.aio.init_grpc_aio()
-            logger.debug(f"Starting, pid {os.getpid()}")
-            server = grpc.aio.server(
-                options=GRPC_KEEPALIVE_OPTIONS
-                + (
-                    ("grpc.so_reuseport", 1),
-                    ("grpc.max_send_message_length", -1),
-                    ("grpc.max_receive_message_length", -1),
-                )
-            )
-            runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
-
-            found_port = server.add_insecure_port(self.listen_on)
-            assert found_port != 0, f"Failed to listen to {self.listen_on}"
-
-            await server.start()
-            self.ready.set()
-            await server.wait_for_termination()
-            logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
+            try:
+                self._p2p = await self.dht.replicate_p2p()
+                await self.add_p2p_handlers(self._p2p)
+
+                logger.info(f"Connection handler started. Info: {await self._p2p._client.identify()}")
+
+                # TODO: await p2p death
+                await asyncio.Future()
+
+            except Exception as e:
+                self.ready.set_exception(e)
+                return
+        self.ready.set_result(None)
 
         try:
             loop.run_until_complete(_run())
         except KeyboardInterrupt:
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
-    async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
+    async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
 
-    async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
+    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
         serialized_response = [
@@ -73,7 +68,7 @@ class ConnectionHandler(mp.context.ForkProcess):
 
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
-    async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
+    async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
         serialized_response = [