connection_handler.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import asyncio
  2. import multiprocessing as mp
  3. import pickle
  4. from typing import AsyncIterator, Dict
  5. import torch
  6. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  7. from hivemind.dht import DHT
  8. from hivemind.moe.server.expert_backend import ExpertBackend
  9. from hivemind.p2p import P2PContext, ServicerBase
  10. from hivemind.proto import runtime_pb2
  11. from hivemind.utils import MPFuture, as_aiter, get_logger, nested_flatten
  12. from hivemind.utils.asyncio import switch_to_uvloop
  13. logger = get_logger(__name__)
  14. class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
  15. """
  16. A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
  17. :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port.
  18. :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
  19. :param experts: a dict [UID -> ExpertBackend] with all active experts
  20. """
  21. def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]):
  22. super().__init__()
  23. self.dht, self.experts = dht, experts
  24. self.ready = MPFuture()
  25. def run(self):
  26. torch.set_num_threads(1)
  27. loop = switch_to_uvloop()
  28. async def _run():
  29. try:
  30. self._p2p = await self.dht.replicate_p2p()
  31. await self.add_p2p_handlers(self._p2p)
  32. await asyncio.Future()
  33. except Exception as e:
  34. self.ready.set_exception(e)
  35. return
  36. self.ready.set_result(None)
  37. try:
  38. loop.run_until_complete(_run())
  39. except KeyboardInterrupt:
  40. logger.debug("Caught KeyboardInterrupt, shutting down")
  41. async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
  42. return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
  43. async def rpc_forward(
  44. self, request: runtime_pb2.ExpertRequest, context: P2PContext
  45. ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
  46. inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  47. future = self.experts[request.uid].forward_pool.submit_task(*inputs)
  48. serialized_response = [
  49. serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
  50. for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
  51. ]
  52. yield runtime_pb2.ExpertResponse(tensors=serialized_response)
  53. async def rpc_backward(
  54. self, request: runtime_pb2.ExpertRequest, context: P2PContext
  55. ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
  56. inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  57. future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
  58. serialized_response = [
  59. serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
  60. for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
  61. ]
  62. yield runtime_pb2.ExpertResponse(tensors=serialized_response)