connection_handler.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import multiprocessing as mp
  2. import os
  3. import pickle
  4. from typing import Dict
  5. import grpc
  6. import torch
  7. from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
  8. from hivemind.moe.server.expert_backend import ExpertBackend
  9. from hivemind.utils import get_logger, Endpoint, nested_flatten
  10. from hivemind.utils.asyncio import switch_to_uvloop
  11. from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
  12. from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
  13. logger = get_logger(__name__)
  14. class ConnectionHandler(mp.context.ForkProcess):
  15. """
  16. A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
  17. :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port.
  18. :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
  19. :param experts: a dict [UID -> ExpertBackend] with all active experts
  20. """
  21. def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
  22. super().__init__()
  23. self.listen_on, self.experts = listen_on, experts
  24. self.ready = mp.Event()
  25. def run(self):
  26. torch.set_num_threads(1)
  27. loop = switch_to_uvloop()
  28. async def _run():
  29. grpc.aio.init_grpc_aio()
  30. logger.debug(f'Starting, pid {os.getpid()}')
  31. server = grpc.aio.server(options=GRPC_KEEPALIVE_OPTIONS + (
  32. ('grpc.so_reuseport', 1),
  33. ('grpc.max_send_message_length', -1),
  34. ('grpc.max_receive_message_length', -1)
  35. ))
  36. runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
  37. found_port = server.add_insecure_port(self.listen_on)
  38. assert found_port != 0, f"Failed to listen to {self.listen_on}"
  39. await server.start()
  40. self.ready.set()
  41. await server.wait_for_termination()
  42. logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
  43. try:
  44. loop.run_until_complete(_run())
  45. except KeyboardInterrupt:
  46. logger.debug('Caught KeyboardInterrupt, shutting down')
  47. async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
  48. return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
  49. async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
  50. inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  51. future = self.experts[request.uid].forward_pool.submit_task(*inputs)
  52. serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
  53. in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))]
  54. return runtime_pb2.ExpertResponse(tensors=serialized_response)
  55. async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
  56. inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  57. future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
  58. serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
  59. in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))]
  60. return runtime_pb2.ExpertResponse(tensors=serialized_response)