handler.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. import contextlib
  2. from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
  3. import torch
  4. from hivemind import (
  5. DHT,
  6. MSGPackSerializer,
  7. P2PContext,
  8. TensorDescriptor,
  9. deserialize_tensor_stream,
  10. deserialize_torch_tensor,
  11. nested_flatten,
  12. serialize_torch_tensor,
  13. )
  14. from hivemind.moe.server.connection_handler import ConnectionHandler
  15. from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
  16. from hivemind.proto import runtime_pb2
  17. from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
  18. from hivemind.utils.streaming import split_for_streaming
  19. from src.data_structures import CHAIN_DELIMITER, ModuleUID
  20. from src.server.backend import TransformerBackend
  21. from src.server.task_pool import PrioritizedTaskPool
  22. from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
  23. from src.utils.misc import DUMMY, is_dummy
  24. class TransformerConnectionHandler(ConnectionHandler):
  25. """Handles three request types: forward, backward and forward-incremental (inference)"""
  26. module_backends: Dict[ModuleUID, TransformerBackend]
  27. def __init__(
  28. self,
  29. dht: DHT,
  30. module_backends: Dict[str, TransformerBackend],
  31. inference_max_length: int,
  32. task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
  33. ):
  34. super().__init__(dht, module_backends)
  35. for module_backend in self.module_backends.values():
  36. assert isinstance(module_backend, TransformerBackend)
  37. self.inference_max_length = inference_max_length
  38. self._prioritizer = task_prioritizer
  39. async def _gather_inputs(
  40. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  41. ) -> Tuple[str, List[torch.Tensor], Dict]:
  42. expert_uid, metadata = None, None
  43. def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
  44. nonlocal expert_uid, metadata
  45. if expert_uid is None:
  46. expert_uid = req.uid
  47. elif expert_uid != req.uid:
  48. raise ValueError("Expert uids differ in one request")
  49. if metadata is None:
  50. metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
  51. return req.tensors
  52. tensors_stream = amap_in_executor(_unpack, requests)
  53. inputs = await deserialize_tensor_stream(tensors_stream)
  54. return expert_uid, inputs, metadata
  55. async def rpc_inference(
  56. self,
  57. requests: AsyncIterator[runtime_pb2.ExpertRequest],
  58. context: P2PContext,
  59. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  60. """Compute a single step of inference using attention cache; update attention cache accordingly."""
  61. try:
  62. print("OPENED RPC_INFERENCE")
  63. request = await anext(requests)
  64. requested_uids = self._check_uids(request.uid)
  65. metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
  66. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  67. max_length = metadata.get("max_length")
  68. points = metadata.get("points", 0)
  69. if not requested_uids:
  70. raise ValueError("User must specify at least one block for inference, but got none")
  71. assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
  72. assert isinstance(
  73. points, (float, int)
  74. ), f"rpc_inference should have number of points as a number or None, got {points}"
  75. if not 0 <= max_length <= self.inference_max_length:
  76. raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
  77. point_per_piece = points / max_length if max_length > 0 else 0.0
  78. batch_size = request.tensors[0].size[0] if request.tensors else 1
  79. cache_metadata = torch.tensor(
  80. [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
  81. ) # [cache_handle, prefix_length]
  82. prefix_length = 0
  83. async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
  84. assert len(cache_handles) == len(requested_backends)
  85. while request.tensors: # iterate while user is willing to supply tensors
  86. hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  87. # Cast inputs to backend dtype
  88. hidden_states = hidden_states.to(requested_backends[0].dtype)
  89. assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
  90. # parse deep prompts (optional argument)
  91. if prompts is None or is_dummy(prompts) or is_dummy(prompts):
  92. prompts = [DUMMY] * len(requested_backends)
  93. else:
  94. prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
  95. if not (len(requested_backends) == len(prompts)):
  96. raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
  97. length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
  98. if prefix_length + length_increment > max_length:
  99. raise ValueError(
  100. f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
  101. f" exceeds pre-allocated maximum {max_length}"
  102. )
  103. # run request tensors through all requested modules, update caches
  104. for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
  105. if not is_dummy(prompt):
  106. hidden_states[:, : prompt.shape[1]] += prompt
  107. cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
  108. assert isinstance(
  109. hidden_states, torch.Tensor
  110. ), f"hidden states must be tensor, got {type(hidden_states)}"
  111. assert (
  112. hidden_states.ndim == 3
  113. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  114. assert isinstance(
  115. backend.inference_pool, PrioritizedTaskPool
  116. ), "petals support only prioritized pools"
  117. priority = self._prioritizer.prioritize(
  118. cache_metadata,
  119. hidden_states,
  120. hypo_ids,
  121. points=point_per_piece / len(requested_backends),
  122. backend=backend,
  123. type="inference",
  124. )
  125. (hidden_states,) = await backend.inference_pool.submit_task(
  126. cache_metadata, hidden_states, hypo_ids, priority=priority
  127. )
  128. # serialize and send last layer outputs
  129. yield runtime_pb2.ExpertResponse(
  130. tensors=[
  131. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  132. for result, proto in zip(
  133. (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
  134. )
  135. ]
  136. )
  137. # prepare for next step
  138. prefix_length += hidden_states.shape[1]
  139. request = await (anext(requests))
  140. finally:
  141. print("CLOSED RPC_INFERENCE")
  142. async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  143. # Parse request and prepare backends
  144. flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  145. requested_uids = self._check_uids(request.uid)
  146. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  147. metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
  148. points = metadata.get("points", 0)
  149. assert isinstance(
  150. points, (float, int)
  151. ), f"rpc_forward should have number of points as number or None, got {points}"
  152. hidden_states = await _rpc_forward(
  153. *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
  154. )
  155. assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
  156. # Serialize output and respond to client
  157. return runtime_pb2.ExpertResponse(
  158. tensors=[
  159. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  160. for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
  161. ]
  162. )
  163. async def rpc_forward_stream(
  164. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  165. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  166. # Parse requests and prepare backends
  167. uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
  168. requested_uids = self._check_uids(uid_str)
  169. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  170. points = metadata.get("points", 0)
  171. assert isinstance(
  172. points, (float, int)
  173. ), f"rpc_forward_stream should have number of points as number or None, got {points}"
  174. hidden_states = await _rpc_forward(
  175. *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
  176. )
  177. assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
  178. # Serialize the overall output
  179. serialized_output = [
  180. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  181. for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
  182. ]
  183. # Split the serialized_output for streaming and respond to client
  184. output_split = [
  185. part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  186. ]
  187. async for part in as_aiter(*output_split):
  188. yield runtime_pb2.ExpertResponse(tensors=[part])
  189. async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
  190. # Parse requests and prepare backends
  191. flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  192. requested_uids = self._check_uids(request.uid)
  193. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  194. metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
  195. points = metadata.get("points", 0)
  196. assert isinstance(
  197. points, (float, int)
  198. ), f"rpc_backward should have number of points as number or None, got {points}"
  199. grads = await _rpc_backward(
  200. *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
  201. )
  202. # Modify grad_inputs_schema to support grad_prompts
  203. assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
  204. grad_inputs_schema_with_prompts = (
  205. requested_backends[0].args_schema * len(grads),
  206. requested_backends[0].kwargs_schema,
  207. ) # TODO generalize
  208. # Serialize the overall grad_input and respond
  209. return runtime_pb2.ExpertResponse(
  210. tensors=[
  211. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  212. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  213. ]
  214. )
  215. async def rpc_backward_stream(
  216. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  217. ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
  218. uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
  219. requested_uids = self._check_uids(uids_header)
  220. requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
  221. points = metadata.get("points", 0)
  222. assert isinstance(
  223. points, (float, int)
  224. ), f"rpc_backward_stream should have number of points as number or None, got {points}"
  225. grads = await _rpc_backward(
  226. *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
  227. )
  228. # Modify grad_inputs_schema to support grad_prompts
  229. assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
  230. grad_inputs_schema_with_prompts = (
  231. requested_backends[0].args_schema * len(grads),
  232. requested_backends[0].kwargs_schema,
  233. ) # TODO generalize
  234. # Serialize the overall grad_inputs
  235. serialized_grad_inputs = [
  236. serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
  237. for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
  238. ]
  239. # Split the serialized_grad_inputs for streaming and respond
  240. output_split = [
  241. part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
  242. ]
  243. async for part in as_aiter(*output_split):
  244. yield runtime_pb2.ExpertResponse(tensors=[part])
  245. def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
  246. """Check that the first request to rpc_inference is valid"""
  247. uids = (uids or "").split(CHAIN_DELIMITER)
  248. if not uids:
  249. raise RuntimeError("User did not provide any uids")
  250. for uid in uids:
  251. if uid not in self.module_backends:
  252. raise RuntimeError(f"Remote peer does not serve {uid}")
  253. return tuple(uids)
  254. @contextlib.asynccontextmanager
  255. async def _allocate_caches(
  256. self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
  257. ) -> Sequence[int]:
  258. """Allocate memory caches for each transformer block, return cache handles"""
  259. async with contextlib.AsyncExitStack() as stack:
  260. handles = []
  261. for backend in backends:
  262. num_heads = backend.module.self_attention.num_heads
  263. head_dim = backend.module.self_attention.head_dim
  264. cache_descriptor = TensorDescriptor(
  265. size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
  266. )
  267. # [key_or_value, batch_size, max_length, num_heads, head_dim]
  268. handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
  269. yield handles
  270. async def _rpc_forward(
  271. *flat_tensors: torch.Tensor,
  272. requested_backends: Sequence[TransformerBackend],
  273. prioritizer: TaskPrioritizerBase,
  274. points: int = 0,
  275. ) -> torch.Tensor:
  276. """
  277. Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
  278. :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
  279. :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
  280. :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
  281. :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
  282. """
  283. hidden_states, prompts = flat_tensors
  284. dtype = requested_backends[0].dtype
  285. # check parse input tensors and cast dtypes
  286. hidden_states = hidden_states.to(dtype)
  287. assert hidden_states.ndim == 3
  288. if prompts is None or is_dummy(prompts):
  289. prompts = [DUMMY] * len(requested_backends)
  290. else:
  291. prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
  292. # Run a chain of requested backends
  293. for backend, prompt in zip(requested_backends, prompts):
  294. if not is_dummy(prompt):
  295. hidden_states[:, : prompt.shape[1]] += prompt
  296. assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
  297. priority = prioritizer.prioritize(
  298. hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
  299. )
  300. (hidden_states,) = await backend.forward_pool.submit_task(
  301. hidden_states,
  302. priority=priority,
  303. )
  304. assert isinstance(hidden_states, torch.Tensor)
  305. assert (
  306. hidden_states.ndim == 3
  307. ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
  308. # Serialize the overall output
  309. return hidden_states
  310. async def _rpc_backward(
  311. *flat_tensors: torch.Tensor,
  312. requested_backends: Sequence[TransformerBackend],
  313. prioritizer: TaskPrioritizerBase,
  314. points: int = 0,
  315. ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
  316. inputs, grad_outputs, prompts = flat_tensors
  317. # Cast inputs & grad outputs to backend dtype
  318. inputs = inputs.to(requested_backends[0].dtype)
  319. grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
  320. if prompts is None or is_dummy(prompts):
  321. prompts = [DUMMY] * len(requested_backends)
  322. else:
  323. prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
  324. # Run a forward chain to collect intermediate inputs
  325. # Note that we do not forward for the last module since we do not need its output
  326. inter_inputs = []
  327. for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
  328. assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
  329. if not is_dummy(prompt):
  330. inputs[:, : prompt.shape[1]] += prompt
  331. inter_inputs.append(inputs)
  332. assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
  333. priority = prioritizer.prioritize(
  334. inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
  335. )
  336. (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
  337. assert isinstance(inputs, torch.Tensor)
  338. if not is_dummy(prompts[-1]):
  339. inputs[:, : prompts[-1].shape[1]] += prompts[-1]
  340. inter_inputs.append(inputs)
  341. assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
  342. grad_prompts_reversed = []
  343. # Run a chain of requested backends
  344. for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
  345. assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
  346. priority = prioritizer.prioritize(
  347. inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
  348. )
  349. (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
  350. assert isinstance(grad_outputs, torch.Tensor)
  351. if not is_dummy(prompt):
  352. grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
  353. grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
  354. return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape