connection_handler.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import asyncio
  2. import multiprocessing as mp
  3. from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Union
  4. import torch
  5. from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
  6. from hivemind.dht import DHT
  7. from hivemind.moe.server.module_backend import ModuleBackend
  8. from hivemind.moe.server.task_pool import TaskPool
  9. from hivemind.p2p import P2PContext, ServicerBase
  10. from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P
  11. from hivemind.proto import runtime_pb2
  12. from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
  13. from hivemind.utils.asyncio import amap_in_executor, switch_to_uvloop
  14. from hivemind.utils.streaming import split_for_streaming
  15. from hivemind.utils.tensor_descr import BatchTensorDescriptor
  16. logger = get_logger(__name__)
  17. class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
  18. """
  19. A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
  20. :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
  21. :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
  22. :param experts: a dict [UID -> ModuleBackend] with all active experts
  23. """
  24. def __init__(self, dht: DHT, experts: Dict[str, ModuleBackend]):
  25. super().__init__()
  26. self.dht, self.experts = dht, experts
  27. self._p2p: Optional[P2P] = None
  28. self.ready = MPFuture()
  29. def run(self):
  30. torch.set_num_threads(1)
  31. loop = switch_to_uvloop()
  32. async def _run():
  33. try:
  34. self._p2p = await self.dht.replicate_p2p()
  35. await self.add_p2p_handlers(self._p2p, balanced=True)
  36. # wait forever
  37. await asyncio.Future()
  38. except Exception as e:
  39. self.ready.set_exception(e)
  40. return
  41. self.ready.set_result(None)
  42. try:
  43. loop.run_until_complete(_run())
  44. except KeyboardInterrupt:
  45. logger.debug("Caught KeyboardInterrupt, shutting down")
  46. async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
  47. return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
  48. async def _gather_inputs(
  49. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  50. ) -> Tuple[str, List[torch.Tensor]]:
  51. expert_uid = None
  52. def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
  53. nonlocal expert_uid
  54. if expert_uid is None:
  55. expert_uid = req.uid
  56. elif expert_uid != req.uid:
  57. raise ValueError("Expert uids differ in one request")
  58. return req.tensors
  59. tensors_stream = amap_in_executor(_unpack, requests)
  60. inputs = await deserialize_tensor_stream(tensors_stream)
  61. return expert_uid, inputs
  62. async def _process_inputs(
  63. self,
  64. inputs: List[torch.Tensor],
  65. pool: TaskPool,
  66. schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]],
  67. ) -> List[runtime_pb2.Tensor]:
  68. return [
  69. serialize_torch_tensor(result, proto.compression, allow_inplace=True)
  70. for result, proto in zip(await pool.submit_task(*inputs), nested_flatten(schema))
  71. ]
  72. async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  73. inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  74. expert = self.experts[request.uid]
  75. return runtime_pb2.ExpertResponse(
  76. tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
  77. )
  78. async def rpc_forward_stream(
  79. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  80. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  81. uid, inputs = await self._gather_inputs(requests, context)
  82. expert = self.experts[uid]
  83. output_split = [
  84. part
  85. for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
  86. for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  87. ]
  88. async for part in as_aiter(*output_split):
  89. yield runtime_pb2.ExpertResponse(tensors=[part])
  90. async def rpc_backward(
  91. self, request: runtime_pb2.ExpertRequest, context: P2PContext
  92. ) -> runtime_pb2.ExpertResponse:
  93. inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  94. expert = self.experts[request.uid]
  95. return runtime_pb2.ExpertResponse(
  96. tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
  97. )
  98. async def rpc_backward_stream(
  99. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  100. ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
  101. uid, inputs_and_grads = await self._gather_inputs(requests, context)
  102. expert = self.experts[uid]
  103. output_split = [
  104. part
  105. for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
  106. for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  107. ]
  108. async for part in as_aiter(*output_split):
  109. yield runtime_pb2.ExpertResponse(tensors=[part])