connection_handler.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import asyncio
  2. import multiprocessing as mp
  3. import os
  4. import pickle
  5. from typing import Dict
  6. import grpc
  7. import torch
  8. import uvloop
  9. from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
  10. from hivemind.server.expert_backend import ExpertBackend
  11. from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, nested_flatten
  12. from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
  13. from hivemind.utils.asyncio import switch_to_uvloop
  14. logger = get_logger(__name__)
  15. class ConnectionHandler(mp.context.ForkProcess):
  16. """
  17. A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
  18. :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port.
  19. :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
  20. :param experts: a dict [UID -> ExpertBackend] with all active experts
  21. """
  22. def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
  23. super().__init__()
  24. self.listen_on, self.experts = listen_on, experts
  25. self.ready = mp.Event()
  26. def run(self):
  27. torch.set_num_threads(1)
  28. loop = switch_to_uvloop()
  29. async def _run():
  30. grpc.aio.init_grpc_aio()
  31. logger.debug(f'Starting, pid {os.getpid()}')
  32. server = grpc.aio.server(options=GRPC_KEEPALIVE_OPTIONS + (
  33. ('grpc.so_reuseport', 1),
  34. ('grpc.max_send_message_length', -1),
  35. ('grpc.max_receive_message_length', -1)
  36. ))
  37. runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
  38. found_port = server.add_insecure_port(self.listen_on)
  39. assert found_port != 0, f"Failed to listen to {self.listen_on}"
  40. await server.start()
  41. self.ready.set()
  42. await server.wait_for_termination()
  43. logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
  44. loop.run_until_complete(_run())
  45. async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
  46. return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
  47. async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
  48. inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  49. future = self.experts[request.uid].forward_pool.submit_task(*inputs)
  50. serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
  51. in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))]
  52. return runtime_pb2.ExpertResponse(tensors=serialized_response)
  53. async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
  54. inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  55. future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
  56. serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
  57. in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))]
  58. return runtime_pb2.ExpertResponse(tensors=serialized_response)