sequential_autograd.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import asyncio
  2. import logging
  3. from typing import List, Optional, Sequence, Tuple
  4. import torch
  5. from hivemind import serialize_torch_tensor
  6. from hivemind.moe.client.expert import expert_backward, expert_forward
  7. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  8. from hivemind.p2p import StubBase
  9. from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
  10. from src.client.sequence_manager import RemoteSequenceManager
  11. from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
  12. from src.server.handler import TransformerConnectionHandler
  13. from src.utils.misc import DUMMY, is_dummy
  14. MAX_TOKENS_IN_BATCH = 1024
  15. async def run_expert_forward(
  16. uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
  17. ) -> Tuple[torch.Tensor, ...]:
  18. """
  19. Serializes input tensors and calls "expert_forward".
  20. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
  21. but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
  22. """
  23. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  24. # detach to avoid pickling the computation graph
  25. assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
  26. kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
  27. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  28. forward_inputs = (inputs, kwargs)
  29. # Modify forward_schema to support prompts
  30. args_schema, kwargs_schema = rpc_info["forward_schema"]
  31. # TODO: rm this assert when support arbitrary number of input tensors
  32. assert len(args_schema) == 1 and len(inputs) == 2
  33. forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
  34. if not nested_compare(forward_inputs, forward_schema_with_prompts):
  35. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  36. forward_inputs = nested_flatten(forward_inputs)
  37. inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
  38. # Asynchronous serialization
  39. loop = asyncio.get_running_loop()
  40. serialized_tensors = await asyncio.gather(
  41. *(
  42. loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
  43. for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
  44. )
  45. )
  46. deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
  47. flat_outputs = tuple(deserialized_outputs)
  48. return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
  49. async def run_expert_backward(
  50. uid: ModuleUID,
  51. stub: StubBase,
  52. rpc_info: RPCInfo,
  53. inputs: torch.Tensor,
  54. grad_outputs: List[torch.Tensor],
  55. *extra_tensors: torch.Tensor,
  56. ) -> Sequence[torch.Tensor]:
  57. """
  58. Serializes grad outputs and calls "expert_backward".
  59. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
  60. but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
  61. """
  62. grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
  63. inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
  64. # Modify forward_schema to support prompts
  65. args_schema, kwargs_schema = rpc_info["forward_schema"]
  66. assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
  67. # TODO generalize this
  68. prompts_schema = next(iter(args_schema))
  69. backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
  70. # Asynchronous serialization
  71. loop = asyncio.get_running_loop()
  72. serialized_tensors = await asyncio.gather(
  73. *(
  74. loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
  75. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
  76. )
  77. )
  78. deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
  79. return deserialized_grad_inputs
  80. async def sequential_forward(
  81. inputs: torch.Tensor,
  82. prompts: torch.Tensor,
  83. sequence_manager: RemoteSequenceManager,
  84. start_index: int = 0,
  85. end_index: Optional[int] = None,
  86. ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
  87. """
  88. Constructs a routing path from <start_index> to <end_index>.
  89. Performs chained forward for each subsequence of blocks on the path.
  90. If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
  91. """
  92. assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
  93. end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
  94. assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
  95. assert is_dummy(prompts) or len(prompts) == len(
  96. sequence_manager.block_uids
  97. ) # should be n_layers - 1 but add extra prompts for convenience
  98. sequences = sequence_manager.make_sequence(start_index, end_index)
  99. intermediate_inputs = []
  100. done_sequences = []
  101. while len(sequences) > 0:
  102. while True:
  103. try:
  104. span = sequences.pop(0)
  105. span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
  106. stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
  107. inputs_and_prompts = [inputs, prompts[span.start : span.end]]
  108. (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
  109. assert isinstance(outputs, torch.Tensor)
  110. assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
  111. # Save intermediate inputs and subsequences if the forward is already done for them
  112. intermediate_inputs.append(inputs)
  113. done_sequences.append(span)
  114. inputs = outputs
  115. break
  116. except Exception as e:
  117. logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
  118. backup_sequences = sequence_manager.make_sequence(span.start)
  119. assert backup_sequences[0].start == span.start
  120. sequences = backup_sequences
  121. return outputs, intermediate_inputs, done_sequences
  122. async def sequential_backward(
  123. grad_outputs: Sequence[torch.Tensor],
  124. intermediate_inputs: List[torch.Tensor],
  125. prompts: torch.Tensor,
  126. forward_sequences: List[RemoteSpanInfo],
  127. sequence_manager: RemoteSequenceManager,
  128. ) -> Sequence[torch.Tensor]:
  129. """
  130. Performs chained backward for each forward subsequence.
  131. If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
  132. """
  133. assert len(intermediate_inputs) == len(forward_sequences)
  134. grad_prompts_reversed = []
  135. while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
  136. while True:
  137. inputs = intermediate_inputs.pop(-1)
  138. span = forward_sequences.pop(-1)
  139. span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
  140. try:
  141. stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
  142. grad_outputs, *span_grad_prompts = await run_expert_backward(
  143. span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
  144. )
  145. grad_outputs = [grad_outputs]
  146. grad_prompts_reversed.extend(span_grad_prompts)
  147. break
  148. except Exception as e:
  149. logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
  150. _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
  151. inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
  152. )
  153. assert len(intermediate_inputs) == len(forward_sequences)
  154. assert backup_forward_sequences[0].start == span.start
  155. assert backup_forward_sequences[-1].end == span.end
  156. forward_sequences.extend(backup_forward_sequences)
  157. intermediate_inputs.extend(backup_intermediate_inputs)
  158. # For now, we do not support mixed dummy and grad prompts
  159. # Concat in num_layer dimension
  160. grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
  161. return grad_outputs, grad_prompts
  162. async def _gather_forward(input_batches, prompt_batches, sequence_manager):
  163. """Wrapper for asyncio.gather to perform parallel sequential forwards"""
  164. return await asyncio.gather(
  165. *[
  166. sequential_forward(input_batch, prompt_batch, sequence_manager)
  167. for input_batch, prompt_batch in zip(input_batches, prompt_batches)
  168. ]
  169. )
  170. async def _gather_backward(
  171. grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
  172. ):
  173. """Wrapper for asyncio.gather to perform parallel sequential backwards"""
  174. return await asyncio.gather(
  175. *[
  176. sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
  177. for grad_output, input_batch, prompt_batch, spans in zip(
  178. grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
  179. )
  180. ]
  181. )
  182. class _RemoteSequentialAutogradFunction(torch.autograd.Function):
  183. """
  184. PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
  185. This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
  186. """
  187. @staticmethod
  188. def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
  189. batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
  190. input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
  191. if is_dummy(prompts):
  192. prompt_batches = [DUMMY] * len(input_batches)
  193. else:
  194. prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
  195. sequence_manager.rpc_info # lazy init
  196. outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
  197. assert len(outputs) == len(input_batches)
  198. output_batches = [output[0] for output in outputs]
  199. intemediate_input_batches = [output[1] for output in outputs]
  200. sequences_for_batches = [output[2] for output in outputs]
  201. ctx.prompt_batches = prompt_batches
  202. ctx.sequence_manager = sequence_manager
  203. ctx.intemediate_input_batches = intemediate_input_batches
  204. ctx.sequences_for_batches = sequences_for_batches
  205. return torch.cat(output_batches, dim=0)
  206. @staticmethod
  207. def backward(ctx, grad_outputs: torch.Tensor):
  208. intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
  209. forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
  210. ctx.sequence_manager.rpc_info # lazy init
  211. batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
  212. grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
  213. assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
  214. outputs = RemoteExpertWorker.run_coroutine(
  215. _gather_backward(
  216. grad_output_batches,
  217. intermediate_input_batches,
  218. ctx.prompt_batches,
  219. forward_sequences,
  220. ctx.sequence_manager,
  221. )
  222. )
  223. grad_input_batches = [output[0][0] for output in outputs]
  224. grad_prompt_batches = [output[1] for output in outputs]
  225. grad_inputs = torch.cat(grad_input_batches, dim=0)
  226. dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
  227. grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
  228. return (grad_inputs, grad_prompts, None)