inference_session.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. from typing import AsyncIterator, List, Optional
  5. import torch
  6. from hivemind import (
  7. P2P,
  8. MSGPackSerializer,
  9. anext,
  10. deserialize_torch_tensor,
  11. get_logger,
  12. nested_flatten,
  13. serialize_torch_tensor,
  14. use_hivemind_log_handler,
  15. )
  16. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  17. from hivemind.p2p import StubBase
  18. from hivemind.proto import runtime_pb2
  19. from src.client.sequence_manager import RemoteSequenceManager
  20. from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
  21. from src.server.handler import TransformerConnectionHandler
  22. from src.utils.misc import DUMMY, is_dummy
  23. use_hivemind_log_handler("in_root_logger")
  24. logger = get_logger(__file__)
  25. class RemoteTransformerBlockInferenceSession:
  26. """
  27. An interface to a single multi-step *inference* session for a specific remote module on a specific server
  28. :note: this inference session is *not* fault-tolerant out of the box
  29. """
  30. def __init__(
  31. self,
  32. uid: ModuleUID,
  33. rpc_info: RPCInfo,
  34. inputs_queue: asyncio.Queue,
  35. outputs_aiter: AsyncIterator,
  36. *,
  37. max_length: int,
  38. points: int = 0,
  39. ):
  40. self.uid, self.rpc_info = uid, rpc_info
  41. self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
  42. # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
  43. # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
  44. self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
  45. self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
  46. self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
  47. self.stepped = False
  48. self.closed = False
  49. @classmethod
  50. async def _create(
  51. cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
  52. ) -> RemoteTransformerBlockInferenceSession:
  53. """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
  54. inputs_queue = asyncio.Queue()
  55. outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
  56. return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
  57. @staticmethod
  58. async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
  59. while True:
  60. next_input_message = await asyncio.wait_for(queue.get(), timeout)
  61. yield next_input_message
  62. if not next_input_message.uid and not next_input_message.tensors:
  63. break # this message means "done sending"
  64. def step(
  65. self,
  66. new_hidden_states: torch.Tensor,
  67. prompts: Optional[torch.Tensor] = None,
  68. hypo_ids: Optional[torch.Tensor] = None,
  69. ):
  70. """
  71. Inference step: send a chunk of input tesors and receive a chunk of outputs
  72. :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
  73. if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
  74. """
  75. if self.closed:
  76. raise Exception("Session is closed, cannot perform step")
  77. if prompts is None or is_dummy(prompts):
  78. prompts = DUMMY
  79. else:
  80. assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
  81. assert prompts.shape[0] == self.num_blocks
  82. assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
  83. assert prompts.shape[2] <= new_hidden_states.shape[1]
  84. assert prompts.shape[3] == new_hidden_states.shape[2]
  85. if hypo_ids is None or is_dummy(hypo_ids):
  86. hypo_ids = DUMMY
  87. else:
  88. assert len(hypo_ids) == len(new_hidden_states)
  89. assert hypo_ids.dtype == torch.int64
  90. # serialize inputs and put them into the queue
  91. inputs = (new_hidden_states, prompts, hypo_ids)
  92. outputs_serialized = RemoteExpertWorker.run_coroutine(
  93. self._step(
  94. runtime_pb2.ExpertRequest(
  95. uid=self.uid,
  96. tensors=[
  97. serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
  98. for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
  99. ],
  100. metadata=self._serialized_metadata if not self.stepped else None,
  101. )
  102. )
  103. )
  104. outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
  105. assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
  106. return outputs[0]
  107. async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
  108. """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
  109. await self._inputs_queue.put(inputs_serialized)
  110. self.stepped = True
  111. return await anext(self._outputs_stream)
  112. def close(self):
  113. """Finish a given inference session, close the underlying connection"""
  114. if self._outputs_stream is None:
  115. return # already closed
  116. RemoteExpertWorker.run_coroutine(self._aclose_stream())
  117. self._outputs_stream = self._inputs_queue = None
  118. self.closed = True
  119. async def _aclose_stream(self):
  120. """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
  121. if self._outputs_stream is None:
  122. return # already closed
  123. if self.stepped:
  124. await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
  125. try:
  126. await anext(self._outputs_stream)
  127. except StopAsyncIteration:
  128. pass
  129. def __del__(self):
  130. self.close()
  131. def __enter__(self):
  132. assert not self.closed
  133. return self
  134. def __exit__(self, *exc_details):
  135. self.close()
  136. class RemoteSequentialInferenceSession:
  137. """
  138. An interface to a multi-step *inference* session for a sequence of remote transformer blocks
  139. """
  140. def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
  141. self.sequence_manager = sequence_manager
  142. self.p2p = p2p
  143. self.closed = False
  144. self.chosen_spans: List[RemoteSpanInfo] = []
  145. self.stack = contextlib.ExitStack()
  146. self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
  147. self.metadata = metadata
  148. self.timeout = timeout
  149. def __enter__(self):
  150. assert not self.closed and not self.chosen_spans
  151. self.stack.__enter__()
  152. # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
  153. self.chosen_spans.extend(self.sequence_manager.make_sequence())
  154. for chosen_span in self.chosen_spans:
  155. stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
  156. span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
  157. inference_session = RemoteExpertWorker.run_coroutine(
  158. RemoteTransformerBlockInferenceSession._create(
  159. stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
  160. )
  161. )
  162. self.inference_sessions.append(inference_session)
  163. self.stack.enter_context(inference_session)
  164. return self
  165. def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
  166. assert not self.closed
  167. if torch.is_grad_enabled():
  168. logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
  169. if prompts is None or is_dummy(prompts):
  170. prompts = DUMMY
  171. else:
  172. assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
  173. for session in self.inference_sessions:
  174. outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
  175. assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
  176. inputs = outputs
  177. return inputs
  178. def close(self, *exc_details):
  179. """Finish a given inference session, close the underlying connection"""
  180. if not self.closed:
  181. self.stack.__exit__(*exc_details or (None, None, None))
  182. self.inference_sessions.clear()
  183. self.closed = True
  184. def __exit__(self, *exc_details):
  185. self.close(*exc_details)
  186. def __del__(self):
  187. self.close()