handler.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
  2. import contextlib
  3. from typing import AsyncIterator, Dict, Sequence
  4. import torch
  5. from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
  6. from hivemind.moe.server.connection_handler import ConnectionHandler
  7. from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
  8. from hivemind.proto import runtime_pb2
  9. from hivemind.utils import as_aiter
  10. from hivemind.utils.asyncio import anext
  11. from hivemind.utils.streaming import split_for_streaming
  12. from src.data_structures import CHAIN_DELIMITER, ModuleUID
  13. from src.server.backend import MAX_LENGTH, TransformerBackend
  14. class TransformerConnectionHandler(ConnectionHandler):
  15. """Handles three request types: forward, backward and forward-incremental (inference)"""
  16. module_backends: Dict[ModuleUID, TransformerBackend]
  17. def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
  18. super().__init__(dht, module_backends)
  19. for module_backend in self.module_backends.values():
  20. assert isinstance(module_backend, TransformerBackend)
  21. async def rpc_inference(
  22. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  23. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  24. """Compute a single step of inference using attention cache; update attention cache accordingly."""
  25. try:
  26. print("OPENED RPC_INFERENCE")
  27. request = await anext(requests)
  28. requested_uids = self._check_header(request)
  29. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  30. cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, prefix_length]
  31. prefix_length = 0
  32. async with self._allocate_caches(requested_backends) as cache_handles:
  33. assert len(cache_handles) == len(requested_backends)
  34. while request.tensors: # iterate while user is willing to supply tensors
  35. hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  36. # run request tensors through all requested modules, update caches
  37. for backend, cache_handle in zip(requested_backends, cache_handles):
  38. cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
  39. assert (
  40. len(hidden_states) == 1 and hidden_states[0].ndim == 3
  41. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  42. hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
  43. assert isinstance(hidden_states, (list, tuple))
  44. assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
  45. # serialize and send last layer outputs
  46. yield runtime_pb2.ExpertResponse(
  47. tensors=[
  48. serialize_torch_tensor(result, proto.compression, allow_inplace=True)
  49. for result, proto in zip(
  50. hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
  51. )
  52. ]
  53. )
  54. # prepare for next step
  55. prefix_length += hidden_states[0].shape[1]
  56. request = await (anext(requests))
  57. finally:
  58. print("CLOSED RPC_INFERENCE")
  59. async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  60. # Parse request and prepare backends
  61. hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  62. requested_uids = self._check_header(request)
  63. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  64. # Run a chain of requested backends
  65. for backend in requested_backends:
  66. assert isinstance(hidden_states, (list, tuple))
  67. assert (
  68. len(hidden_states) == 1 and hidden_states[0].ndim == 3
  69. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  70. hidden_states = await backend.forward_pool.submit_task(*hidden_states)
  71. # Serialize the overall output and respond
  72. assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
  73. return runtime_pb2.ExpertResponse(
  74. tensors=[
  75. serialize_torch_tensor(result, proto.compression, allow_inplace=True)
  76. for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
  77. ]
  78. )
  79. async def rpc_forward_stream(
  80. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  81. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  82. # Parse requests and prepare backends
  83. uids_header, hidden_states = await self._gather_inputs(requests, context)
  84. requested_uids = self._check_header_str(uids_header)
  85. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  86. # Run a chain of requested backends
  87. for backend in requested_backends:
  88. assert isinstance(hidden_states, (list, tuple))
  89. assert (
  90. len(hidden_states) == 1 and hidden_states[0].ndim == 3
  91. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  92. hidden_states = await backend.forward_pool.submit_task(*hidden_states)
  93. # Serialize the overall output
  94. assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
  95. serialized_output = [
  96. serialize_torch_tensor(result, proto.compression, allow_inplace=True)
  97. for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
  98. ]
  99. # Split the serialized_output for streaming and respond
  100. output_split = [
  101. part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  102. ]
  103. async for part in as_aiter(*output_split):
  104. yield runtime_pb2.ExpertResponse(tensors=[part])
  105. async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  106. # Parse requests and prepare backends
  107. inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  108. requested_uids = self._check_header(request)
  109. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  110. # Run a forward chain to collect intermediate inputs
  111. # Note that we do not forward for the last module since we do not need its output
  112. inter_inputs = [inputs]
  113. for backend in requested_backends[:-1]:
  114. assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
  115. inputs = await backend.forward_pool.submit_task(inputs)
  116. assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
  117. inputs = inputs[0]
  118. inter_inputs.append(inputs)
  119. # Run a chain of requested backends
  120. for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
  121. inputs_and_grads = [inp, grads]
  122. grads = await backend.backward_pool.submit_task(*inputs_and_grads)
  123. assert isinstance(grads, (list, tuple)) and len(grads) == 1
  124. grads = grads[0]
  125. # Serialize the overall grad_input and respond
  126. return runtime_pb2.ExpertResponse(
  127. tensors=[
  128. serialize_torch_tensor(result, proto.compression, allow_inplace=True)
  129. for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
  130. ]
  131. )
  132. async def rpc_backward_stream(
  133. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  134. ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
  135. uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
  136. inputs, grads = inputs_and_grads
  137. requested_uids = self._check_header_str(uids_header)
  138. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  139. # Run a forward chain to collect intermediate inputs
  140. # Note that we do not forward for the last module since we do not need its outputs
  141. inter_inputs = [inputs]
  142. for backend in requested_backends[:-1]:
  143. assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
  144. inputs = await backend.forward_pool.submit_task(inputs)
  145. assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
  146. inputs = inputs[0]
  147. inter_inputs.append(inputs)
  148. # Run a backward chain for requested backends
  149. for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
  150. inputs_and_grads = [inp, grads]
  151. grads = await backend.backward_pool.submit_task(*inputs_and_grads)
  152. assert isinstance(grads, (list, tuple)) and len(grads) == 1
  153. grads = grads[0]
  154. # Serialize the overall grad_inputs
  155. serialized_grad_inputs = [
  156. serialize_torch_tensor(result, proto.compression, allow_inplace=True)
  157. for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
  158. ]
  159. # Split the serialized_grad_inputs for streaming and respond
  160. output_split = [
  161. part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  162. ]
  163. async for part in as_aiter(*output_split):
  164. yield runtime_pb2.ExpertResponse(tensors=[part])
  165. def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
  166. """Check that the first request to rpc_inference is valid"""
  167. uids = (request.uid or "").split(CHAIN_DELIMITER)
  168. if not uids:
  169. raise RuntimeError("User did not provide any uids")
  170. for uid in uids:
  171. if uid not in self.module_backends:
  172. raise RuntimeError(f"Remote peer does not serve {uid}")
  173. return tuple(uids)
  174. def _check_header_str(self, header) -> Sequence[ModuleUID]:
  175. """Check that the first request to rpc_inference is valid"""
  176. uids = (header or "").split(CHAIN_DELIMITER)
  177. if not uids:
  178. raise RuntimeError("User did not provide any uids")
  179. for uid in uids:
  180. if uid not in self.module_backends:
  181. raise RuntimeError(f"Remote peer does not serve {uid}")
  182. return tuple(uids)
  183. @contextlib.asynccontextmanager
  184. async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
  185. """Allocate memory caches for each transformer block, return cache handles"""
  186. async with contextlib.AsyncExitStack() as stack:
  187. handles = []
  188. for backend in backends:
  189. num_heads = backend.module.self_attention.num_heads
  190. head_dim = backend.module.self_attention.head_dim
  191. cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
  192. # [key_or_value, batch_size, max_length, num_heads, head_dim]
  193. handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
  194. yield handles