handler.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import contextlib
  2. from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
  3. import torch
  4. from hivemind import (
  5. DHT,
  6. MSGPackSerializer,
  7. P2PContext,
  8. TensorDescriptor,
  9. deserialize_torch_tensor,
  10. nested_flatten,
  11. serialize_torch_tensor,
  12. )
  13. from hivemind.moe.server.connection_handler import ConnectionHandler
  14. from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
  15. from hivemind.proto import runtime_pb2
  16. from hivemind.utils import as_aiter
  17. from hivemind.utils.asyncio import anext
  18. from hivemind.utils.streaming import split_for_streaming
  19. from src.data_structures import CHAIN_DELIMITER, ModuleUID
  20. from src.server.backend import TransformerBackend
  21. from src.utils.misc import DUMMY, is_dummy
  22. class TransformerConnectionHandler(ConnectionHandler):
  23. """Handles three request types: forward, backward and forward-incremental (inference)"""
  24. module_backends: Dict[ModuleUID, TransformerBackend]
  25. def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
  26. super().__init__(dht, module_backends)
  27. for module_backend in self.module_backends.values():
  28. assert isinstance(module_backend, TransformerBackend)
  29. self.inference_max_length = inference_max_length
  30. async def rpc_inference(
  31. self,
  32. requests: AsyncIterator[runtime_pb2.ExpertRequest],
  33. context: P2PContext,
  34. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  35. """Compute a single step of inference using attention cache; update attention cache accordingly."""
  36. try:
  37. print("OPENED RPC_INFERENCE")
  38. request = await anext(requests)
  39. requested_uids = self._check_uids(request.uid)
  40. metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
  41. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  42. max_length = metadata.get("max_length")
  43. if not requested_uids:
  44. raise ValueError("User must specify at least one block for inference, but got none")
  45. assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
  46. if not 0 <= max_length <= self.inference_max_length:
  47. raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
  48. batch_size = request.tensors[0].size[0] if request.tensors else 1
  49. cache_metadata = torch.tensor(
  50. [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
  51. ) # [cache_handle, prefix_length]
  52. prefix_length = 0
  53. async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
  54. assert len(cache_handles) == len(requested_backends)
  55. while request.tensors: # iterate while user is willing to supply tensors
  56. hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  57. length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq)
  58. if prefix_length + length_increment > max_length:
  59. raise ValueError(
  60. f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
  61. f" exceeds pre-allocated maximum {max_length}"
  62. )
  63. # Cast inputs to backend dtype
  64. hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
  65. # run request tensors through all requested modules, update caches
  66. for backend, cache_handle in zip(requested_backends, cache_handles):
  67. cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
  68. assert (
  69. len(hidden_states) == 1 and hidden_states[0].ndim == 3
  70. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  71. hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
  72. assert isinstance(hidden_states, (list, tuple))
  73. assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
  74. # serialize and send last layer outputs
  75. yield runtime_pb2.ExpertResponse(
  76. tensors=[
  77. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  78. for result, proto in zip(
  79. hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
  80. )
  81. ]
  82. )
  83. # prepare for next step
  84. prefix_length += hidden_states[0].shape[1]
  85. request = await (anext(requests))
  86. finally:
  87. print("CLOSED RPC_INFERENCE")
  88. async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  89. # Parse request and prepare backends
  90. flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  91. requested_uids = self._check_uids(request.uid)
  92. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  93. hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
  94. assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
  95. # Serialize output and respond to client
  96. return runtime_pb2.ExpertResponse(
  97. tensors=[
  98. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  99. for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
  100. ]
  101. )
  102. async def rpc_forward_stream(
  103. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  104. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  105. # Parse requests and prepare backends
  106. uid_str, flat_inputs = await self._gather_inputs(requests, context)
  107. requested_uids = self._check_uids(uid_str)
  108. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  109. hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
  110. assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
  111. # Serialize the overall output
  112. serialized_output = [
  113. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  114. for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
  115. ]
  116. # Split the serialized_output for streaming and respond to client
  117. output_split = [
  118. part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  119. ]
  120. async for part in as_aiter(*output_split):
  121. yield runtime_pb2.ExpertResponse(tensors=[part])
  122. async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  123. # Parse requests and prepare backends
  124. flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  125. requested_uids = self._check_uids(request.uid)
  126. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  127. grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
  128. # Modify grad_inputs_schema to support grad_prompts
  129. assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
  130. grad_inputs_schema_with_prompts = (
  131. requested_backends[0].args_schema * len(grads),
  132. requested_backends[0].kwargs_schema,
  133. ) # TODO generalize
  134. # Serialize the overall grad_input and respond
  135. return runtime_pb2.ExpertResponse(
  136. tensors=[
  137. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  138. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  139. ]
  140. )
  141. async def rpc_backward_stream(
  142. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  143. ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
  144. uids_header, flat_tensors = await self._gather_inputs(requests, context)
  145. requested_uids = self._check_uids(uids_header)
  146. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  147. grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
  148. # Modify grad_inputs_schema to support grad_prompts
  149. assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
  150. grad_inputs_schema_with_prompts = (
  151. requested_backends[0].args_schema * len(grads),
  152. requested_backends[0].kwargs_schema,
  153. ) # TODO generalize
  154. # Serialize the overall grad_inputs
  155. serialized_grad_inputs = [
  156. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  157. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  158. ]
  159. # Split the serialized_grad_inputs for streaming and respond
  160. output_split = [
  161. part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  162. ]
  163. async for part in as_aiter(*output_split):
  164. yield runtime_pb2.ExpertResponse(tensors=[part])
  165. def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
  166. """Check that the first request to rpc_inference is valid"""
  167. uids = (uids or "").split(CHAIN_DELIMITER)
  168. if not uids:
  169. raise RuntimeError("User did not provide any uids")
  170. for uid in uids:
  171. if uid not in self.module_backends:
  172. raise RuntimeError(f"Remote peer does not serve {uid}")
  173. return tuple(uids)
  174. @contextlib.asynccontextmanager
  175. async def _allocate_caches(
  176. self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
  177. ) -> Sequence[int]:
  178. """Allocate memory caches for each transformer block, return cache handles"""
  179. async with contextlib.AsyncExitStack() as stack:
  180. handles = []
  181. for backend in backends:
  182. num_heads = backend.module.self_attention.num_heads
  183. head_dim = backend.module.self_attention.head_dim
  184. cache_descriptor = TensorDescriptor(
  185. size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
  186. )
  187. # [key_or_value, batch_size, max_length, num_heads, head_dim]
  188. handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
  189. yield handles
  190. async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
  191. """
  192. Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
  193. :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
  194. :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
  195. :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
  196. :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
  197. """
  198. hidden_states, *prompts = flat_tensors
  199. dtype = requested_backends[0].dtype
  200. # check parse input tensors and cast dtypes
  201. hidden_states = hidden_states.to(dtype)
  202. assert hidden_states.ndim == 3
  203. if not prompts or is_dummy(prompts[0]):
  204. prompts = [DUMMY] * len(requested_backends)
  205. pre_seq_len = 0
  206. else:
  207. prompts = [prompts[0].to(requested_backends[0].dtype)]
  208. prompts = [p.squeeze(0) for p in prompts[0].split(1)]
  209. pre_seq_len = prompts[0].shape[-2]
  210. # Run a chain of requested backends
  211. for backend, prompt in zip(requested_backends, prompts):
  212. if not is_dummy(prompt):
  213. hidden_states[:, :pre_seq_len] += prompt
  214. (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
  215. assert isinstance(hidden_states, torch.Tensor)
  216. assert (
  217. hidden_states.ndim == 3
  218. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  219. # Serialize the overall output
  220. return hidden_states
  221. async def _rpc_backward(
  222. *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
  223. ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
  224. inputs, grad_outputs, *prompts = flat_tensors
  225. # Cast inputs & grad outputs to backend dtype
  226. inputs = inputs.to(requested_backends[0].dtype)
  227. grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
  228. if not prompts or is_dummy(prompts[0]):
  229. prompts = [DUMMY] * len(requested_backends)
  230. pre_seq_len = 0
  231. else:
  232. prompts = [prompts[0].to(requested_backends[0].dtype)]
  233. prompts = [p.squeeze(0) for p in prompts[0].split(1)]
  234. pre_seq_len = prompts[0].shape[-2]
  235. # Run a forward chain to collect intermediate inputs
  236. # Note that we do not forward for the last module since we do not need its output
  237. inter_inputs = []
  238. for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
  239. assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
  240. if not is_dummy(prompt):
  241. inputs[:, :pre_seq_len] += prompt
  242. inter_inputs.append(inputs)
  243. (inputs,) = await backend.forward_pool.submit_task(inputs)
  244. assert isinstance(inputs, torch.Tensor)
  245. if not is_dummy(prompts[-1]):
  246. inputs[:, :pre_seq_len] += prompts[-1]
  247. inter_inputs.append(inputs)
  248. assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
  249. grad_prompts_reversed = []
  250. # Run a chain of requested backends
  251. for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
  252. (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
  253. assert isinstance(grad_outputs, torch.Tensor)
  254. if not is_dummy(prompt):
  255. grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
  256. grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
  257. return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape