handler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import contextlib
  2. from typing import AsyncIterator, Dict, Sequence
  3. import torch
  4. from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
  5. from hivemind.moe.server.connection_handler import ConnectionHandler
  6. from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
  7. from hivemind.proto import runtime_pb2
  8. from hivemind.utils import as_aiter
  9. from hivemind.utils.asyncio import anext
  10. from hivemind.utils.streaming import split_for_streaming
  11. from src.data_structures import CHAIN_DELIMITER, ModuleUID
  12. from src.server.backend import MAX_LENGTH, TransformerBackend
  13. from src.utils.misc import DUMMY, is_dummy
  14. class TransformerConnectionHandler(ConnectionHandler):
  15. """Handles three request types: forward, backward and forward-incremental (inference)"""
  16. module_backends: Dict[ModuleUID, TransformerBackend]
  17. def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
  18. super().__init__(dht, module_backends)
  19. for module_backend in self.module_backends.values():
  20. assert isinstance(module_backend, TransformerBackend)
  21. async def rpc_inference(
  22. self,
  23. requests: AsyncIterator[runtime_pb2.ExpertRequest],
  24. context: P2PContext,
  25. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  26. """Compute a single step of inference using attention cache; update attention cache accordingly."""
  27. try:
  28. print("OPENED RPC_INFERENCE")
  29. request = await anext(requests)
  30. requested_uids = self._check_header(request)
  31. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  32. batch_size = request.tensors[0].size[0] if request.tensors else 1
  33. cache_metadata = torch.tensor(
  34. [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
  35. ) # [cache_handle, prefix_length]
  36. prefix_length = 0
  37. async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
  38. assert len(cache_handles) == len(requested_backends)
  39. while request.tensors: # iterate while user is willing to supply tensors
  40. hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  41. # Cast inputs to backend dtype
  42. hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
  43. # run request tensors through all requested modules, update caches
  44. for backend, cache_handle in zip(requested_backends, cache_handles):
  45. cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
  46. assert (
  47. len(hidden_states) == 1 and hidden_states[0].ndim == 3
  48. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  49. hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
  50. assert isinstance(hidden_states, (list, tuple))
  51. assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
  52. # serialize and send last layer outputs
  53. yield runtime_pb2.ExpertResponse(
  54. tensors=[
  55. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  56. for result, proto in zip(
  57. hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
  58. )
  59. ]
  60. )
  61. # prepare for next step
  62. prefix_length += hidden_states[0].shape[1]
  63. request = await (anext(requests))
  64. finally:
  65. print("CLOSED RPC_INFERENCE")
  66. async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  67. # Parse request and prepare backends
  68. inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  69. requested_uids = self._check_header(request)
  70. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  71. hidden_states = await _rpc_forward(inputs, requested_backends)
  72. # Serialize the overall output and respond
  73. assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
  74. return runtime_pb2.ExpertResponse(
  75. tensors=[
  76. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  77. for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
  78. ]
  79. )
  80. async def rpc_forward_stream(
  81. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  82. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  83. # Parse requests and prepare backends
  84. uids_header, inputs = await self._gather_inputs(requests, context)
  85. requested_uids = self._check_header_str(uids_header)
  86. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  87. hidden_states = await _rpc_forward(inputs, requested_backends)
  88. # Serialize the overall output
  89. assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
  90. serialized_output = [
  91. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  92. for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
  93. ]
  94. # Split the serialized_output for streaming and respond
  95. output_split = [
  96. part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  97. ]
  98. async for part in as_aiter(*output_split):
  99. yield runtime_pb2.ExpertResponse(tensors=[part])
  100. async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  101. # Parse requests and prepare backends
  102. inputs, prompts, grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  103. requested_uids = self._check_header(request)
  104. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  105. grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
  106. # Modify grad_inputs_schema to support grad_prompts
  107. assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
  108. grad_inputs_schema_with_prompts = (
  109. requested_backends[0].args_schema * len(grads),
  110. requested_backends[0].kwargs_schema,
  111. )
  112. # Serialize the overall grad_input and respond
  113. return runtime_pb2.ExpertResponse(
  114. tensors=[
  115. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  116. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  117. ]
  118. )
  119. async def rpc_backward_stream(
  120. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  121. ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
  122. uids_header, (inputs, prompts, grad_outputs) = await self._gather_inputs(requests, context)
  123. requested_uids = self._check_header_str(uids_header)
  124. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  125. grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
  126. # Modify grad_inputs_schema to support grad_prompts
  127. assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
  128. grad_inputs_schema_with_prompts = (
  129. requested_backends[0].args_schema * len(grads),
  130. requested_backends[0].kwargs_schema,
  131. )
  132. # Serialize the overall grad_inputs
  133. serialized_grad_inputs = [
  134. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  135. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  136. ]
  137. # Split the serialized_grad_inputs for streaming and respond
  138. output_split = [
  139. part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  140. ]
  141. async for part in as_aiter(*output_split):
  142. yield runtime_pb2.ExpertResponse(tensors=[part])
  143. def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
  144. """Check that the first request to rpc_inference is valid"""
  145. uids = (request.uid or "").split(CHAIN_DELIMITER)
  146. if not uids:
  147. raise RuntimeError("User did not provide any uids")
  148. for uid in uids:
  149. if uid not in self.module_backends:
  150. raise RuntimeError(f"Remote peer does not serve {uid}")
  151. return tuple(uids)
  152. def _check_header_str(self, header) -> Sequence[ModuleUID]:
  153. """Check that the first request to rpc_inference is valid"""
  154. uids = (header or "").split(CHAIN_DELIMITER)
  155. if not uids:
  156. raise RuntimeError("User did not provide any uids")
  157. for uid in uids:
  158. if uid not in self.module_backends:
  159. raise RuntimeError(f"Remote peer does not serve {uid}")
  160. return tuple(uids)
  161. @contextlib.asynccontextmanager
  162. async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
  163. """Allocate memory caches for each transformer block, return cache handles"""
  164. async with contextlib.AsyncExitStack() as stack:
  165. handles = []
  166. for backend in backends:
  167. num_heads = backend.module.self_attention.num_heads
  168. head_dim = backend.module.self_attention.head_dim
  169. cache_descriptor = TensorDescriptor(
  170. size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
  171. )
  172. # [key_or_value, batch_size, max_length, num_heads, head_dim]
  173. handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
  174. yield handles
  175. async def _rpc_forward(inputs, requested_backends):
  176. # Cast inputs to backend dtype
  177. inputs = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
  178. assert len(inputs) == 2 and inputs[0].ndim == 3
  179. hidden_states, prompts = inputs
  180. if is_dummy(prompts):
  181. prompts = [DUMMY] * len(requested_backends)
  182. else:
  183. pre_seq_len = prompts.shape[2]
  184. # Run a chain of requested backends
  185. for backend, prompt in zip(requested_backends, prompts):
  186. if not is_dummy(prompt):
  187. hidden_states[:, :pre_seq_len] += prompt
  188. (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
  189. assert isinstance(hidden_states, torch.Tensor)
  190. assert (
  191. hidden_states.ndim == 3
  192. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  193. # Serialize the overall output
  194. return [hidden_states]
  195. async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
  196. # Cast inputs & grad outputs to backend dtype
  197. inputs = inputs.to(requested_backends[0].dtype)
  198. prompts = prompts.to(requested_backends[0].dtype)
  199. grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
  200. if is_dummy(prompts):
  201. prompts = [DUMMY] * len(requested_backends)
  202. else:
  203. pre_seq_len = prompts.shape[2]
  204. prompts = [p.squeeze(0) for p in prompts.split(1)]
  205. # Run a forward chain to collect intermediate inputs
  206. # Note that we do not forward for the last module since we do not need its output
  207. inter_inputs = [inputs]
  208. for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
  209. assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
  210. if not is_dummy(prompt):
  211. inputs[:, :pre_seq_len] += prompt
  212. (inputs,) = await backend.forward_pool.submit_task(inputs)
  213. assert isinstance(inputs, torch.Tensor)
  214. inter_inputs.append(inputs)
  215. grad_prompts = []
  216. # Run a chain of requested backends
  217. for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
  218. if not is_dummy(prompt):
  219. inp[:, :pre_seq_len] += prompt
  220. (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
  221. assert isinstance(grad_outputs, torch.Tensor)
  222. if not is_dummy(prompt):
  223. grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
  224. else:
  225. grad_prompts.append(DUMMY)
  226. is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
  227. grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
  228. grads = [grad_outputs, grad_prompts]
  229. return grads