inference_session.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. from __future__ import annotations
  2. import asyncio
  3. import itertools
  4. import logging
  5. import time
  6. from typing import AsyncIterator, List, Optional
  7. import torch
  8. from hivemind import (
  9. P2P,
  10. MSGPackSerializer,
  11. anext,
  12. deserialize_torch_tensor,
  13. get_logger,
  14. nested_flatten,
  15. serialize_torch_tensor,
  16. )
  17. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  18. from hivemind.p2p import StubBase
  19. from hivemind.proto import runtime_pb2
  20. from src.client.sequence_manager import RemoteSequenceManager
  21. from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
  22. from src.server.handler import TransformerConnectionHandler
  23. from src.utils.misc import DUMMY, is_dummy
  24. logger = get_logger(__file__)
  25. class _ServerInferenceSession:
  26. """
  27. An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
  28. :note: This class is *not* fault-tolerant out of the box.
  29. """
  30. def __init__(
  31. self,
  32. uid: ModuleUID,
  33. rpc_info: RPCInfo,
  34. inputs_queue: asyncio.Queue,
  35. outputs_aiter: AsyncIterator,
  36. *,
  37. timeout: float,
  38. max_length: int,
  39. points: int = 0,
  40. ):
  41. self.uid, self.rpc_info = uid, rpc_info
  42. self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
  43. self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
  44. self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
  45. self.timeout = timeout
  46. self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
  47. self.stepped = False
  48. self.closed = False
  49. @classmethod
  50. async def create(
  51. cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
  52. ) -> _ServerInferenceSession:
  53. """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
  54. inputs_queue = asyncio.Queue()
  55. outputs_stream = await asyncio.wait_for(
  56. stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
  57. timeout,
  58. )
  59. return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
  60. @staticmethod
  61. async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
  62. while True:
  63. next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
  64. yield next_input_message
  65. if not next_input_message.uid and not next_input_message.tensors:
  66. break # this message means "done sending"
  67. def step(
  68. self,
  69. new_hidden_states: torch.Tensor,
  70. prompts: Optional[torch.Tensor] = None,
  71. hypo_ids: Optional[torch.Tensor] = None,
  72. ) -> torch.Tensor:
  73. """
  74. Inference step: send a chunk of input tesors and receive a chunk of outputs
  75. :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
  76. if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
  77. """
  78. if self.closed:
  79. raise Exception("Session is closed, cannot perform step")
  80. if prompts is None or is_dummy(prompts):
  81. prompts = DUMMY
  82. else:
  83. assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
  84. assert prompts.shape[0] == self.num_blocks
  85. assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
  86. assert prompts.shape[2] <= new_hidden_states.shape[1]
  87. assert prompts.shape[3] == new_hidden_states.shape[2]
  88. if hypo_ids is None or is_dummy(hypo_ids):
  89. hypo_ids = DUMMY
  90. else:
  91. assert len(hypo_ids) == len(new_hidden_states)
  92. assert hypo_ids.dtype == torch.int64
  93. # serialize inputs and put them into the queue
  94. inputs = (new_hidden_states, prompts, hypo_ids)
  95. outputs_serialized = RemoteExpertWorker.run_coroutine(
  96. self._step(
  97. runtime_pb2.ExpertRequest(
  98. uid=self.uid,
  99. tensors=[
  100. serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
  101. for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
  102. ],
  103. metadata=self._serialized_metadata if not self.stepped else None,
  104. )
  105. )
  106. )
  107. outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
  108. assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
  109. return outputs[0]
  110. async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
  111. """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
  112. await self._inputs_queue.put(inputs_serialized)
  113. self.stepped = True
  114. return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
  115. def close(self):
  116. """Finish a given inference session, close the underlying connection"""
  117. if self._outputs_stream is None:
  118. return # already closed
  119. RemoteExpertWorker.run_coroutine(self._aclose_stream())
  120. self._outputs_stream = self._inputs_queue = None
  121. self.closed = True
  122. async def _aclose_stream(self):
  123. """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
  124. if self._outputs_stream is None:
  125. return # already closed
  126. if self.stepped:
  127. await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
  128. try:
  129. await anext(self._outputs_stream)
  130. except StopAsyncIteration:
  131. pass
  132. def __del__(self):
  133. self.close()
  134. def __enter__(self):
  135. assert not self.closed
  136. return self
  137. def __exit__(self, *exc_details):
  138. self.close()
  139. class InferenceSession:
  140. """
  141. An interface to a multi-step *inference* session for a sequence of remote transformer blocks
  142. """
  143. def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
  144. self._sequence_manager = sequence_manager
  145. self._p2p = p2p
  146. self._closed = False
  147. self._chosen_spans = []
  148. self._server_sessions = []
  149. self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
  150. self._position = 0
  151. self._max_length = max_length
  152. self._metadata = metadata
  153. def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
  154. server_sessions = []
  155. try:
  156. for span in chosen_spans:
  157. stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
  158. span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
  159. session = RemoteExpertWorker.run_coroutine(
  160. _ServerInferenceSession.create(
  161. stub,
  162. span_uids,
  163. rpc_info=self._sequence_manager.rpc_info,
  164. timeout=self._sequence_manager.timeout,
  165. max_length=self._max_length,
  166. **self._metadata,
  167. )
  168. )
  169. server_sessions.append(session)
  170. session.__enter__()
  171. return server_sessions
  172. except:
  173. self._exit_server_sessions(server_sessions)
  174. raise
  175. def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
  176. for session in reversed(server_sessions):
  177. try:
  178. session.__exit__(None, None, None)
  179. except Exception:
  180. logger.debug("Caught exception while closing connection to server:", exc_info=True)
  181. def __enter__(self) -> "InferenceSession":
  182. assert not self._closed and not self._chosen_spans
  183. return self
  184. def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
  185. assert not self._closed
  186. if torch.is_grad_enabled():
  187. logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
  188. n_blocks = len(self._sequence_manager)
  189. if prompts is None or is_dummy(prompts):
  190. prompts = DUMMY
  191. else:
  192. assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
  193. inputs_device = inputs.device
  194. inputs_dtype = inputs.dtype
  195. inputs = inputs.cpu()
  196. prompts = prompts.cpu()
  197. n_input_tokens = inputs.shape[1]
  198. if self._position + n_input_tokens > self._max_length:
  199. raise ValueError(
  200. f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
  201. )
  202. server_idx = 0
  203. block_idx = 0
  204. recovery_until = -1 # Recovery mode is disabled until a failure happens
  205. while block_idx < n_blocks:
  206. for attempt_no in itertools.count():
  207. logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
  208. try:
  209. if attempt_no >= 1:
  210. self._sequence_manager.update_()
  211. if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
  212. # If there is a failed server session, this code closes it
  213. self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
  214. n_prev_spans = len(self._chosen_spans)
  215. update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
  216. if attempt_no >= 1 and update_end > recovery_until:
  217. logger.info(
  218. f"Due to a server failure, remote attention caches "
  219. f"from block {block_idx} to {update_end} will be regenerated"
  220. )
  221. recovery_until = max(recovery_until, update_end)
  222. updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
  223. # make_sequence() could return a longer sequence
  224. updated_spans[-1].end = min(updated_spans[-1].end, update_end)
  225. updated_sessions = self._enter_server_sessions(updated_spans)
  226. logger.debug(
  227. f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
  228. )
  229. # If there is a failed span, this code replaces it, otherwise it just adds new ones
  230. self._chosen_spans[server_idx : server_idx + 1] = updated_spans
  231. self._server_sessions[server_idx : server_idx + 1] = updated_sessions
  232. recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
  233. self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
  234. len(updated_spans) - 1
  235. )
  236. assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
  237. f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
  238. f"{len(self._server_inputs)} inputs"
  239. )
  240. session = self._server_sessions[server_idx]
  241. span = self._chosen_spans[server_idx]
  242. if self._server_inputs[server_idx] is None:
  243. self._server_inputs[server_idx] = inputs
  244. elif self._server_inputs[server_idx].shape[1] == self._position:
  245. self._server_inputs[server_idx] = torch.cat(
  246. [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
  247. )
  248. assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
  249. f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
  250. f"position={self._position} n_input_tokens={n_input_tokens}"
  251. )
  252. if not session.stepped:
  253. inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
  254. else:
  255. inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
  256. outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
  257. assert (
  258. inputs.shape == outputs.shape
  259. ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
  260. inputs = outputs
  261. server_idx += 1
  262. block_idx = span.end
  263. break
  264. except Exception as e:
  265. delay = self._sequence_manager.get_retry_delay(attempt_no)
  266. logger.warning(
  267. f"Caught exception when running inference from block {block_idx} "
  268. f"(retry in {delay:.0f} sec): {repr(e)}"
  269. )
  270. traceback_level = logging.DEBUG if str(e) else logging.WARNING
  271. logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
  272. time.sleep(delay)
  273. self._position += n_input_tokens
  274. outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
  275. return outputs
  276. def close(self, *exc_details):
  277. """Finish a given inference session, close the underlying connection"""
  278. if not self._closed:
  279. self._server_inputs.clear()
  280. self._exit_server_sessions(self._server_sessions)
  281. self._server_sessions.clear()
  282. self._chosen_spans.clear()
  283. self._closed = True
  284. def __exit__(self, *exc_details):
  285. self.close(*exc_details)
  286. def __del__(self):
  287. self.close()