inference_session.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. from typing import AsyncIterator, List, Optional
  5. import torch
  6. from hivemind import (
  7. P2P,
  8. anext,
  9. deserialize_torch_tensor,
  10. get_logger,
  11. nested_flatten,
  12. serialize_torch_tensor,
  13. use_hivemind_log_handler,
  14. )
  15. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  16. from hivemind.p2p import StubBase
  17. from hivemind.proto import runtime_pb2
  18. from src.client.sequence_manager import RemoteSequenceManager
  19. from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
  20. from src.server.handler import TransformerConnectionHandler
  21. use_hivemind_log_handler("in_root_logger")
  22. logger = get_logger(__file__)
  23. class RemoteTransformerBlockInferenceSession:
  24. """
  25. An interface to a single multi-step *inference* session for a specific remote module on a specific server
  26. :note: this inference session is *not* fault-tolerant out of the box
  27. """
  28. def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
  29. self.uid, self.rpc_info = uid, rpc_info
  30. # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
  31. # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
  32. self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
  33. self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
  34. self.stepped = False
  35. self.closed = False
  36. @classmethod
  37. async def _create(
  38. cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
  39. ) -> RemoteTransformerBlockInferenceSession:
  40. """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
  41. inputs_queue = asyncio.Queue()
  42. outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
  43. return cls(uid, rpc_info, inputs_queue, outputs_stream)
  44. @staticmethod
  45. async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
  46. while True:
  47. next_input_message = await asyncio.wait_for(queue.get(), timeout)
  48. yield next_input_message
  49. if not next_input_message.uid and not next_input_message.tensors:
  50. break # this message means "done sending"
  51. def step(self, new_hidden_states: torch.Tensor):
  52. """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
  53. if self.closed:
  54. raise Exception("Session is closed, cannot perform step")
  55. # serialize inputs and put them into the queue
  56. inputs = (new_hidden_states,)
  57. outputs_serialized = RemoteExpertWorker.run_coroutine(
  58. self._step(
  59. runtime_pb2.ExpertRequest(
  60. uid=self.uid,
  61. tensors=[
  62. serialize_torch_tensor(tensor, proto.compression)
  63. for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
  64. ],
  65. )
  66. )
  67. )
  68. outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
  69. assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
  70. return outputs[0]
  71. async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
  72. """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
  73. await self._inputs_queue.put(inputs_serialized)
  74. self.stepped = True
  75. return await anext(self._outputs_stream)
  76. def close(self):
  77. """Finish a given inference session, close the underlying connection"""
  78. if self._outputs_stream is None:
  79. return # already closed
  80. RemoteExpertWorker.run_coroutine(self._aclose_stream())
  81. self._outputs_stream = self._inputs_queue = None
  82. self.closed = True
  83. async def _aclose_stream(self):
  84. """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
  85. if self._outputs_stream is None:
  86. return # already closed
  87. if self.stepped:
  88. await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
  89. try:
  90. await anext(self._outputs_stream)
  91. except StopAsyncIteration:
  92. pass
  93. def __del__(self):
  94. self.close()
  95. def __enter__(self):
  96. assert not self.closed
  97. return self
  98. def __exit__(self, *exc_details):
  99. self.close()
  100. class RemoteSequentialInferenceSession:
  101. """
  102. An interface to a multi-step *inference* session for a sequence of remote transformer blocks
  103. """
  104. def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None):
  105. self.sequence_manager = sequence_manager
  106. self.p2p = p2p
  107. self.closed = False
  108. self.chosen_spans: List[RemoteSpanInfo] = []
  109. self.stack = contextlib.ExitStack()
  110. self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
  111. self.timeout = timeout
  112. def __enter__(self):
  113. assert not self.closed and not self.chosen_spans
  114. self.stack.__enter__()
  115. # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
  116. self.chosen_spans.extend(self.sequence_manager.make_sequence())
  117. for chosen_span in self.chosen_spans:
  118. stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
  119. span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
  120. inference_session = RemoteExpertWorker.run_coroutine(
  121. RemoteTransformerBlockInferenceSession._create(
  122. stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
  123. )
  124. )
  125. self.inference_sessions.append(inference_session)
  126. self.stack.enter_context(inference_session)
  127. return self
  128. def step(self, inputs: torch.Tensor):
  129. assert not self.closed
  130. if torch.is_grad_enabled():
  131. logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
  132. for session in self.inference_sessions:
  133. outputs = session.step(inputs)
  134. assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
  135. inputs = outputs
  136. return inputs
  137. def close(self, *exc_details):
  138. """Finish a given inference session, close the underlying connection"""
  139. if not self.closed:
  140. self.stack.__exit__(*exc_details or (None, None, None))
  141. self.inference_sessions.clear()
  142. self.closed = True
  143. def __exit__(self, *exc_details):
  144. self.close(*exc_details)
  145. def __del__(self):
  146. self.close()