remote_block.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from __future__ import annotations
  2. import asyncio
  3. from functools import partial
  4. from typing import List, Optional, Union, Sequence, AsyncIterator, Dict, Any
  5. import torch
  6. from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
  7. from hivemind.moe.expert_uid import ExpertUID, ExpertInfo as RemoteModuleInfo
  8. from hivemind.p2p import P2P, PeerID, StubBase
  9. from hivemind.proto import runtime_pb2
  10. from hivemind.dht import DHT, DHTNode, DHTValue
  11. from hivemind.utils import MPFuture, DHTExpiration, get_dht_time, as_aiter, anext, nested_flatten
  12. from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
  13. from src.server.handler import TransformerConnectionHandler
  14. class RemoteTransformerBlock(RemoteExpert):
  15. """A class that interacts with a remote module on a specific server for forward/backward or inference"""
  16. @property
  17. def stub(self) -> StubBase:
  18. return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
  19. def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
  20. """Initialize a new inference session with the specified remote server"""
  21. return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
  22. class RemoteTransformerBlockInferenceSession:
  23. """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
  24. def __init__(self, uid: ExpertUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
  25. self.uid, self.info = uid, info
  26. # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
  27. # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
  28. self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
  29. self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
  30. self.closed = False
  31. @classmethod
  32. async def _create(
  33. cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
  34. ) -> RemoteTransformerBlockInferenceSession:
  35. """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
  36. inputs_queue = asyncio.Queue()
  37. outputs_stream = await remote_module.stub.rpc_inference(
  38. cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
  39. )
  40. return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
  41. @staticmethod
  42. async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
  43. while True:
  44. next_input_message = await asyncio.wait_for(queue.get(), timeout)
  45. yield next_input_message
  46. if not next_input_message.uid and not next_input_message.tensors:
  47. break # this message means "done sending"
  48. def step(self, new_hidden_states: torch.Tensor):
  49. """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
  50. if self.closed:
  51. raise Exception("Session is closed, cannot perform step")
  52. # serialize inputs and put them into the queue
  53. inputs = (new_hidden_states,)
  54. outputs_serialized = RemoteExpertWorker.run_coroutine(self._step(
  55. runtime_pb2.ExpertRequest(uid=self.uid, tensors=[
  56. serialize_torch_tensor(tensor, proto.compression)
  57. for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
  58. ])
  59. ))
  60. outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
  61. assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
  62. return outputs[0]
  63. async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
  64. """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
  65. await self._inputs_queue.put(inputs_serialized)
  66. return await anext(self._outputs_stream)
  67. def close(self):
  68. """Finish a given inference session, close the underlying connection"""
  69. if self._outputs_stream is None:
  70. return # already closed
  71. RemoteExpertWorker.run_coroutine(self._aclose_stream())
  72. self._outputs_stream = self._inputs_queue = None
  73. self.closed = True
  74. async def _aclose_stream(self):
  75. """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
  76. if self._outputs_stream is None:
  77. return # already closed
  78. await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
  79. try:
  80. await anext(self._outputs_stream)
  81. except StopAsyncIteration:
  82. pass
  83. def __del__(self):
  84. self.close()
  85. def __enter__(self):
  86. assert not self.closed
  87. return self
  88. def __exit__(self, *exc_details):
  89. self.close()
  90. def get_remote_module(
  91. dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
  92. ) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
  93. """
  94. :param uids: find experts with these ids from across the DHT
  95. :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
  96. :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
  97. :returns: a list of [RemoteTransformerBlock if found else None]
  98. """
  99. assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
  100. infos = dht.run_coroutine(
  101. partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time),
  102. return_future)
  103. if return_future:
  104. async def _unpack(infos_future: MPFuture, dht: DHT):
  105. p2p = await dht.replicate_p2p()
  106. return _create_remote_modules_from_infos(await infos_future, p2p)
  107. return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
  108. p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
  109. return _create_remote_modules_from_infos(infos, p2p)
  110. async def _get_remote_module_infos(
  111. dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
  112. ) -> List[Optional[RemoteModuleInfo]]:
  113. if expiration_time is None:
  114. expiration_time = get_dht_time()
  115. num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
  116. found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
  117. experts: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
  118. for i, uid in enumerate(uids):
  119. server_peer_id = found[uid]
  120. if server_peer_id is not None and isinstance(server_peer_id.value, str):
  121. experts[i] = RemoteModuleInfo(uid, PeerID.from_base58(server_peer_id.value))
  122. return experts
  123. def _create_remote_modules_from_infos(infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
  124. ) -> List[Optional[RemoteTransformerBlock]]:
  125. experts: List[Optional[RemoteTransformerBlock]] = []
  126. for info in infos:
  127. if info is not None:
  128. experts.append(RemoteTransformerBlock(info, p2p))
  129. else:
  130. experts.append(None)
  131. return experts