handler.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import contextlib
  2. from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
  3. import torch
  4. from hivemind import (
  5. DHT,
  6. MSGPackSerializer,
  7. P2PContext,
  8. TensorDescriptor,
  9. deserialize_tensor_stream,
  10. deserialize_torch_tensor,
  11. nested_flatten,
  12. serialize_torch_tensor,
  13. )
  14. from hivemind.moe.server.connection_handler import ConnectionHandler
  15. from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
  16. from hivemind.proto import runtime_pb2
  17. from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
  18. from hivemind.utils.streaming import split_for_streaming
  19. from src.data_structures import CHAIN_DELIMITER, ModuleUID
  20. from src.server.backend import MAX_LENGTH, PrioritizedTaskPool, TransformerBackend
  21. from src.utils.misc import DUMMY, is_dummy
  22. class TransformerConnectionHandler(ConnectionHandler):
  23. """Handles three request types: forward, backward and forward-incremental (inference)"""
  24. module_backends: Dict[ModuleUID, TransformerBackend]
  25. def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
  26. super().__init__(dht, module_backends)
  27. for module_backend in self.module_backends.values():
  28. assert isinstance(module_backend, TransformerBackend)
  29. async def _gather_inputs(
  30. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  31. ) -> Tuple[str, List[torch.Tensor], Dict]:
  32. expert_uid, metadata = None, None
  33. def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
  34. nonlocal expert_uid
  35. if expert_uid is None:
  36. expert_uid = req.uid
  37. elif expert_uid != req.uid:
  38. raise ValueError("Expert uids differ in one request")
  39. if metadata is None:
  40. metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
  41. return req.tensors
  42. tensors_stream = amap_in_executor(_unpack, requests)
  43. inputs = await deserialize_tensor_stream(tensors_stream)
  44. return expert_uid, inputs, metadata
  45. async def rpc_inference(
  46. self,
  47. requests: AsyncIterator[runtime_pb2.ExpertRequest],
  48. context: P2PContext,
  49. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  50. """Compute a single step of inference using attention cache; update attention cache accordingly."""
  51. try:
  52. print("OPENED RPC_INFERENCE")
  53. request = await anext(requests)
  54. requested_uids = self._check_uids(request.uid)
  55. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  56. batch_size = request.tensors[0].size[0] if request.tensors else 1
  57. cache_metadata = torch.tensor(
  58. [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
  59. ) # [cache_handle, prefix_length]
  60. prefix_length = 0
  61. async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
  62. assert len(cache_handles) == len(requested_backends)
  63. while request.tensors: # iterate while user is willing to supply tensors
  64. hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  65. metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
  66. dust = metadata.get("__dust", 0.0)
  67. # Cast inputs to backend dtype
  68. hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
  69. # run request tensors through all requested modules, update caches
  70. for backend, cache_handle in zip(requested_backends, cache_handles):
  71. cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
  72. assert (
  73. len(hidden_states) == 1 and hidden_states[0].ndim == 3
  74. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  75. if isinstance(backend.inference_pool, PrioritizedTaskPool):
  76. hidden_states = await backend.inference_pool.submit_task(
  77. cache_metadata, *hidden_states, dust
  78. )
  79. else:
  80. hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
  81. assert isinstance(hidden_states, (list, tuple))
  82. assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
  83. # serialize and send last layer outputs
  84. yield runtime_pb2.ExpertResponse(
  85. tensors=[
  86. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  87. for result, proto in zip(
  88. hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
  89. )
  90. ]
  91. )
  92. # prepare for next step
  93. prefix_length += hidden_states[0].shape[1]
  94. request = await (anext(requests))
  95. finally:
  96. print("CLOSED RPC_INFERENCE")
  97. async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  98. # Parse request and prepare backends
  99. flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  100. requested_uids = self._check_uids(request.uid)
  101. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  102. metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
  103. dust = metadata.get("__dust", 0.0)
  104. hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, dust=dust)
  105. assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
  106. # Serialize output and respond to client
  107. return runtime_pb2.ExpertResponse(
  108. tensors=[
  109. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  110. for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
  111. ]
  112. )
  113. async def rpc_forward_stream(
  114. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  115. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  116. # Parse requests and prepare backends
  117. uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
  118. requested_uids = self._check_uids(uid_str)
  119. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  120. hidden_states = await _rpc_forward(
  121. *flat_inputs, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
  122. )
  123. assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
  124. # Serialize the overall output
  125. serialized_output = [
  126. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  127. for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
  128. ]
  129. # Split the serialized_output for streaming and respond to client
  130. output_split = [
  131. part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  132. ]
  133. async for part in as_aiter(*output_split):
  134. yield runtime_pb2.ExpertResponse(tensors=[part])
  135. async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  136. # Parse requests and prepare backends
  137. flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  138. requested_uids = self._check_uids(request.uid)
  139. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  140. metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
  141. dust = metadata.get("__dust", 0.0)
  142. grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, dust=dust)
  143. # Modify grad_inputs_schema to support grad_prompts
  144. assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
  145. grad_inputs_schema_with_prompts = (
  146. requested_backends[0].args_schema * len(grads),
  147. requested_backends[0].kwargs_schema,
  148. ) # TODO generalize
  149. # Serialize the overall grad_input and respond
  150. return runtime_pb2.ExpertResponse(
  151. tensors=[
  152. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  153. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  154. ]
  155. )
  156. async def rpc_backward_stream(
  157. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  158. ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
  159. uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
  160. requested_uids = self._check_uids(uids_header)
  161. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  162. grads = await _rpc_backward(
  163. *flat_tensors, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
  164. )
  165. # Modify grad_inputs_schema to support grad_prompts
  166. assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
  167. grad_inputs_schema_with_prompts = (
  168. requested_backends[0].args_schema * len(grads),
  169. requested_backends[0].kwargs_schema,
  170. ) # TODO generalize
  171. # Serialize the overall grad_inputs
  172. serialized_grad_inputs = [
  173. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  174. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  175. ]
  176. # Split the serialized_grad_inputs for streaming and respond
  177. output_split = [
  178. part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  179. ]
  180. async for part in as_aiter(*output_split):
  181. yield runtime_pb2.ExpertResponse(tensors=[part])
  182. def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
  183. """Check that the first request to rpc_inference is valid"""
  184. uids = (uids or "").split(CHAIN_DELIMITER)
  185. if not uids:
  186. raise RuntimeError("User did not provide any uids")
  187. for uid in uids:
  188. if uid not in self.module_backends:
  189. raise RuntimeError(f"Remote peer does not serve {uid}")
  190. return tuple(uids)
  191. @contextlib.asynccontextmanager
  192. async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
  193. """Allocate memory caches for each transformer block, return cache handles"""
  194. async with contextlib.AsyncExitStack() as stack:
  195. handles = []
  196. for backend in backends:
  197. num_heads = backend.module.self_attention.num_heads
  198. head_dim = backend.module.self_attention.head_dim
  199. cache_descriptor = TensorDescriptor(
  200. size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
  201. )
  202. # [key_or_value, batch_size, max_length, num_heads, head_dim]
  203. handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
  204. yield handles
  205. async def _rpc_forward(
  206. *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
  207. ) -> torch.Tensor:
  208. """
  209. Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
  210. :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
  211. :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
  212. :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
  213. :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
  214. """
  215. hidden_states, *prompts = flat_tensors
  216. dtype = requested_backends[0].dtype
  217. # check parse input tensors and cast dtypes
  218. hidden_states = hidden_states.to(dtype)
  219. assert hidden_states.ndim == 3
  220. if not prompts or is_dummy(prompts[0]):
  221. prompts = [DUMMY] * len(requested_backends)
  222. pre_seq_len = 0
  223. else:
  224. prompts = [prompts[0].to(requested_backends[0].dtype)]
  225. prompts = [p.squeeze(0) for p in prompts[0].split(1)]
  226. pre_seq_len = prompts[0].shape[-2]
  227. # Run a chain of requested backends
  228. for backend, prompt in zip(requested_backends, prompts):
  229. if not is_dummy(prompt):
  230. hidden_states[:, :pre_seq_len] += prompt
  231. if isinstance(backend.forward_pool, PrioritizedTaskPool):
  232. (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, dust)
  233. else:
  234. (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
  235. assert isinstance(hidden_states, torch.Tensor)
  236. assert (
  237. hidden_states.ndim == 3
  238. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  239. # Serialize the overall output
  240. return hidden_states
  241. async def _rpc_backward(
  242. *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
  243. ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
  244. inputs, grad_outputs, *prompts = flat_tensors
  245. # Cast inputs & grad outputs to backend dtype
  246. inputs = inputs.to(requested_backends[0].dtype)
  247. grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
  248. if not prompts or is_dummy(prompts[0]):
  249. prompts = [DUMMY] * len(requested_backends)
  250. pre_seq_len = 0
  251. else:
  252. prompts = [prompts[0].to(requested_backends[0].dtype)]
  253. prompts = [p.squeeze(0) for p in prompts[0].split(1)]
  254. pre_seq_len = prompts[0].shape[-2]
  255. # Run a forward chain to collect intermediate inputs
  256. # Note that we do not forward for the last module since we do not need its output
  257. inter_inputs = []
  258. for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
  259. assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
  260. if not is_dummy(prompt):
  261. inputs[:, :pre_seq_len] += prompt
  262. inter_inputs.append(inputs)
  263. if isinstance(backend.forward_pool, PrioritizedTaskPool):
  264. (inputs,) = await backend.forward_pool.submit_task(inputs, dust / 2.0)
  265. else:
  266. (inputs,) = await backend.forward_pool.submit_task(inputs)
  267. assert isinstance(inputs, torch.Tensor)
  268. if not is_dummy(prompts[-1]):
  269. inputs[:, :pre_seq_len] += prompts[-1]
  270. inter_inputs.append(inputs)
  271. assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
  272. grad_prompts_reversed = []
  273. # Run a chain of requested backends
  274. for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
  275. if isinstance(backend.backward_pool, PrioritizedTaskPool):
  276. (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, dust / 2.0)
  277. else:
  278. (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
  279. assert isinstance(grad_outputs, torch.Tensor)
  280. if not is_dummy(prompt):
  281. grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
  282. grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
  283. return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape