|
@@ -1,22 +1,22 @@
|
|
|
+import asyncio
|
|
|
import multiprocessing as mp
|
|
|
-import os
|
|
|
import pickle
|
|
|
from typing import Dict
|
|
|
|
|
|
-import grpc
|
|
|
import torch
|
|
|
|
|
|
+from hivemind.p2p import P2PContext, ServicerBase
|
|
|
+from hivemind.dht import DHT
|
|
|
from hivemind.moe.server.expert_backend import ExpertBackend
|
|
|
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
|
|
|
-from hivemind.utils import Endpoint, get_logger, nested_flatten
|
|
|
+from hivemind.proto import runtime_pb2
|
|
|
+from hivemind.utils import get_logger, nested_flatten, MPFuture
|
|
|
from hivemind.utils.asyncio import switch_to_uvloop
|
|
|
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
-class ConnectionHandler(mp.context.ForkProcess):
|
|
|
+class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
|
|
|
"""
|
|
|
A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
|
|
|
|
|
@@ -25,45 +25,40 @@ class ConnectionHandler(mp.context.ForkProcess):
|
|
|
:param experts: a dict [UID -> ExpertBackend] with all active experts
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
|
|
|
+ def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]):
|
|
|
super().__init__()
|
|
|
- self.listen_on, self.experts = listen_on, experts
|
|
|
- self.ready = mp.Event()
|
|
|
+ self.dht, self.experts = dht, experts
|
|
|
+
|
|
|
+ self.ready = MPFuture()
|
|
|
|
|
|
def run(self):
|
|
|
torch.set_num_threads(1)
|
|
|
loop = switch_to_uvloop()
|
|
|
|
|
|
async def _run():
|
|
|
- grpc.aio.init_grpc_aio()
|
|
|
- logger.debug(f"Starting, pid {os.getpid()}")
|
|
|
- server = grpc.aio.server(
|
|
|
- options=GRPC_KEEPALIVE_OPTIONS
|
|
|
- + (
|
|
|
- ("grpc.so_reuseport", 1),
|
|
|
- ("grpc.max_send_message_length", -1),
|
|
|
- ("grpc.max_receive_message_length", -1),
|
|
|
- )
|
|
|
- )
|
|
|
- runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
|
|
|
-
|
|
|
- found_port = server.add_insecure_port(self.listen_on)
|
|
|
- assert found_port != 0, f"Failed to listen to {self.listen_on}"
|
|
|
-
|
|
|
- await server.start()
|
|
|
- self.ready.set()
|
|
|
- await server.wait_for_termination()
|
|
|
- logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
|
|
|
+ try:
|
|
|
+ self._p2p = await self.dht.replicate_p2p()
|
|
|
+ await self.add_p2p_handlers(self._p2p)
|
|
|
+
|
|
|
+ logger.info(f"Connection handler started. Info: {await self._p2p._client.identify()}")
|
|
|
+
|
|
|
+ # TODO: await p2p death
|
|
|
+ await asyncio.Future()
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.ready.set_exception(e)
|
|
|
+ return
|
|
|
+ self.ready.set_result(None)
|
|
|
|
|
|
try:
|
|
|
loop.run_until_complete(_run())
|
|
|
except KeyboardInterrupt:
|
|
|
logger.debug("Caught KeyboardInterrupt, shutting down")
|
|
|
|
|
|
- async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
|
|
|
+ async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
|
|
|
return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
|
|
|
|
|
|
- async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
|
|
|
+ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
future = self.experts[request.uid].forward_pool.submit_task(*inputs)
|
|
|
serialized_response = [
|
|
@@ -73,7 +68,7 @@ class ConnectionHandler(mp.context.ForkProcess):
|
|
|
|
|
|
return runtime_pb2.ExpertResponse(tensors=serialized_response)
|
|
|
|
|
|
- async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
|
|
|
+ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
|
|
|
serialized_response = [
|