connection_handler.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import multiprocessing as mp
  2. import os
  3. import pickle
  4. from typing import Dict
  5. import grpc
  6. import torch
  7. from hivemind.moe.server.expert_backend import ExpertBackend
  8. from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
  9. from hivemind.utils import Endpoint, get_logger, nested_flatten
  10. from hivemind.utils.asyncio import switch_to_uvloop
  11. from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
  12. from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
  13. logger = get_logger(__name__)
  14. class ConnectionHandler(mp.context.ForkProcess):
  15. """
  16. A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
  17. :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port.
  18. :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
  19. :param experts: a dict [UID -> ExpertBackend] with all active experts
  20. """
  21. def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
  22. super().__init__()
  23. self.listen_on, self.experts = listen_on, experts
  24. self.ready = mp.Event()
  25. def run(self):
  26. torch.set_num_threads(1)
  27. loop = switch_to_uvloop()
  28. async def _run():
  29. grpc.aio.init_grpc_aio()
  30. logger.debug(f"Starting, pid {os.getpid()}")
  31. server = grpc.aio.server(
  32. options=GRPC_KEEPALIVE_OPTIONS
  33. + (
  34. ("grpc.so_reuseport", 1),
  35. ("grpc.max_send_message_length", -1),
  36. ("grpc.max_receive_message_length", -1),
  37. )
  38. )
  39. runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
  40. found_port = server.add_insecure_port(self.listen_on)
  41. assert found_port != 0, f"Failed to listen to {self.listen_on}"
  42. await server.start()
  43. self.ready.set()
  44. await server.wait_for_termination()
  45. logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
  46. try:
  47. loop.run_until_complete(_run())
  48. except KeyboardInterrupt:
  49. logger.debug("Caught KeyboardInterrupt, shutting down")
  50. async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
  51. return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
  52. async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
  53. inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  54. future = self.experts[request.uid].forward_pool.submit_task(*inputs)
  55. serialized_response = [
  56. serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
  57. for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
  58. ]
  59. return runtime_pb2.ExpertResponse(tensors=serialized_response)
  60. async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
  61. inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  62. future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
  63. serialized_response = [
  64. serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
  65. for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
  66. ]
  67. return runtime_pb2.ExpertResponse(tensors=serialized_response)