handler.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import contextlib
  2. from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
  3. import torch
  4. from hivemind import (
  5. DHT,
  6. MSGPackSerializer,
  7. P2PContext,
  8. TensorDescriptor,
  9. deserialize_torch_tensor,
  10. nested_flatten,
  11. serialize_torch_tensor,
  12. )
  13. from hivemind.moe.server.connection_handler import ConnectionHandler
  14. from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
  15. from hivemind.proto import runtime_pb2
  16. from hivemind.utils import as_aiter
  17. from hivemind.utils.asyncio import anext
  18. from hivemind.utils.streaming import split_for_streaming
  19. from src.data_structures import CHAIN_DELIMITER, ModuleUID
  20. from src.server.backend import MAX_LENGTH, TransformerBackend
  21. from src.utils.misc import DUMMY, is_dummy
  22. class TransformerConnectionHandler(ConnectionHandler):
  23. """Handles three request types: forward, backward and forward-incremental (inference)"""
  24. module_backends: Dict[ModuleUID, TransformerBackend]
  25. def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
  26. super().__init__(dht, module_backends)
  27. for module_backend in self.module_backends.values():
  28. assert isinstance(module_backend, TransformerBackend)
  29. async def rpc_inference(
  30. self,
  31. requests: AsyncIterator[runtime_pb2.ExpertRequest],
  32. context: P2PContext,
  33. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  34. """Compute a single step of inference using attention cache; update attention cache accordingly."""
  35. try:
  36. print("OPENED RPC_INFERENCE")
  37. request = await anext(requests)
  38. requested_uids = self._check_uids(request.uid)
  39. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  40. batch_size = request.tensors[0].size[0] if request.tensors else 1
  41. cache_metadata = torch.tensor(
  42. [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
  43. ) # [cache_handle, prefix_length]
  44. prefix_length = 0
  45. async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
  46. assert len(cache_handles) == len(requested_backends)
  47. while request.tensors: # iterate while user is willing to supply tensors
  48. hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  49. # Cast inputs to backend dtype
  50. hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
  51. # run request tensors through all requested modules, update caches
  52. for backend, cache_handle in zip(requested_backends, cache_handles):
  53. cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
  54. assert (
  55. len(hidden_states) == 1 and hidden_states[0].ndim == 3
  56. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  57. hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
  58. assert isinstance(hidden_states, (list, tuple))
  59. assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
  60. # serialize and send last layer outputs
  61. yield runtime_pb2.ExpertResponse(
  62. tensors=[
  63. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  64. for result, proto in zip(
  65. hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
  66. )
  67. ]
  68. )
  69. # prepare for next step
  70. prefix_length += hidden_states[0].shape[1]
  71. request = await (anext(requests))
  72. finally:
  73. print("CLOSED RPC_INFERENCE")
  74. async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  75. # Parse request and prepare backends
  76. flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  77. requested_uids = self._check_uids(request.uid)
  78. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  79. hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
  80. assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
  81. # Serialize output and respond to client
  82. return runtime_pb2.ExpertResponse(
  83. tensors=[
  84. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  85. for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
  86. ]
  87. )
  88. async def rpc_forward_stream(
  89. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  90. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  91. # Parse requests and prepare backends
  92. uid_str, flat_inputs = await self._gather_inputs(requests, context)
  93. requested_uids = self._check_uids(uid_str)
  94. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  95. hidden_states = await _rpc_forward(flat_inputs, requested_backends)
  96. assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
  97. # Serialize the overall output
  98. serialized_output = [
  99. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  100. for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
  101. ]
  102. # Split the serialized_output for streaming and respond to client
  103. output_split = [
  104. part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  105. ]
  106. async for part in as_aiter(*output_split):
  107. yield runtime_pb2.ExpertResponse(tensors=[part])
  108. async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  109. # Parse requests and prepare backends
  110. flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  111. requested_uids = self._check_uids(request.uid)
  112. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  113. grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
  114. # Modify grad_inputs_schema to support grad_prompts
  115. assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
  116. grad_inputs_schema_with_prompts = (
  117. requested_backends[0].args_schema * len(grads),
  118. requested_backends[0].kwargs_schema,
  119. ) # TODO generalize
  120. # Serialize the overall grad_input and respond
  121. return runtime_pb2.ExpertResponse(
  122. tensors=[
  123. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  124. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  125. ]
  126. )
  127. async def rpc_backward_stream(
  128. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  129. ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
  130. uids_header, flat_tensors = await self._gather_inputs(requests, context)
  131. requested_uids = self._check_uids(uids_header)
  132. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  133. grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
  134. # Modify grad_inputs_schema to support grad_prompts
  135. assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
  136. grad_inputs_schema_with_prompts = (
  137. requested_backends[0].args_schema * len(grads),
  138. requested_backends[0].kwargs_schema,
  139. ) # TODO generalize
  140. # Serialize the overall grad_inputs
  141. serialized_grad_inputs = [
  142. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  143. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  144. ]
  145. # Split the serialized_grad_inputs for streaming and respond
  146. output_split = [
  147. part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  148. ]
  149. async for part in as_aiter(*output_split):
  150. yield runtime_pb2.ExpertResponse(tensors=[part])
  151. def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
  152. """Check that the first request to rpc_inference is valid"""
  153. uids = (uids or "").split(CHAIN_DELIMITER)
  154. if not uids:
  155. raise RuntimeError("User did not provide any uids")
  156. for uid in uids:
  157. if uid not in self.module_backends:
  158. raise RuntimeError(f"Remote peer does not serve {uid}")
  159. return tuple(uids)
  160. @contextlib.asynccontextmanager
  161. async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
  162. """Allocate memory caches for each transformer block, return cache handles"""
  163. async with contextlib.AsyncExitStack() as stack:
  164. handles = []
  165. for backend in backends:
  166. num_heads = backend.module.self_attention.num_heads
  167. head_dim = backend.module.self_attention.head_dim
  168. cache_descriptor = TensorDescriptor(
  169. size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
  170. )
  171. # [key_or_value, batch_size, max_length, num_heads, head_dim]
  172. handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
  173. yield handles
  174. async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
  175. """
  176. Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
  177. :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
  178. :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
  179. :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
  180. :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
  181. """
  182. hidden_states, *prompts = flat_tensors
  183. dtype = requested_backends[0].dtype
  184. # check parse input tensors and cast dtypes
  185. hidden_states = hidden_states.to(dtype)
  186. assert hidden_states.ndim == 3
  187. if not prompts or is_dummy(prompts[0]):
  188. prompts = [DUMMY] * len(requested_backends)
  189. pre_seq_len = 0
  190. else:
  191. prompts = [prompts[0].to(requested_backends[0].dtype)]
  192. prompts = [p.squeeze(0) for p in prompts[0].split(1)]
  193. pre_seq_len = prompts[0].shape[-2]
  194. # Run a chain of requested backends
  195. for backend, prompt in zip(requested_backends, prompts):
  196. if not is_dummy(prompt):
  197. hidden_states[:, :pre_seq_len] += prompt
  198. (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
  199. assert isinstance(hidden_states, torch.Tensor)
  200. assert (
  201. hidden_states.ndim == 3
  202. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  203. # Serialize the overall output
  204. return hidden_states
  205. async def _rpc_backward(
  206. *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
  207. ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
  208. inputs, grad_outputs, *prompts = flat_tensors
  209. # Cast inputs & grad outputs to backend dtype
  210. inputs = inputs.to(requested_backends[0].dtype)
  211. grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
  212. if not prompts or is_dummy(prompts[0]):
  213. prompts = [DUMMY] * len(requested_backends)
  214. pre_seq_len = 0
  215. else:
  216. prompts = [prompts[0].to(requested_backends[0].dtype)]
  217. prompts = [p.squeeze(0) for p in prompts[0].split(1)]
  218. pre_seq_len = prompts[0].shape[-2]
  219. # Run a forward chain to collect intermediate inputs
  220. # Note that we do not forward for the last module since we do not need its output
  221. inter_inputs = []
  222. for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
  223. assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
  224. if not is_dummy(prompt):
  225. inputs[:, :pre_seq_len] += prompt
  226. inter_inputs.append(inputs)
  227. (inputs,) = await backend.forward_pool.submit_task(inputs)
  228. assert isinstance(inputs, torch.Tensor)
  229. if not is_dummy(prompts[-1]):
  230. inputs[:, :pre_seq_len] += prompts[-1]
  231. inter_inputs.append(inputs)
  232. assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
  233. grad_prompts_reversed = []
  234. # Run a chain of requested backends
  235. for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
  236. (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
  237. assert isinstance(grad_outputs, torch.Tensor)
  238. if not is_dummy(prompt):
  239. grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
  240. grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
  241. return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape