sequential_autograd.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import asyncio
  2. import logging
  3. from typing import List, Optional, Sequence, Tuple
  4. import torch
  5. from hivemind import serialize_torch_tensor
  6. from hivemind.moe.client.expert import expert_backward, expert_forward
  7. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  8. from hivemind.p2p import StubBase
  9. from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
  10. from src.client.sequence_manager import RemoteSequenceManager
  11. from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
  12. from src.server.handler import TransformerConnectionHandler
  13. MAX_TOKENS_IN_BATCH = 1024
  14. async def run_expert_forward(
  15. uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
  16. ) -> Tuple[torch.Tensor, ...]:
  17. """
  18. Serializes input tensors and calls "expert_forward".
  19. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
  20. but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
  21. """
  22. # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
  23. # detach to avoid pickling the computation graph
  24. assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
  25. kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
  26. # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
  27. forward_inputs = (inputs, kwargs)
  28. if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
  29. raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
  30. forward_inputs = nested_flatten(forward_inputs)
  31. inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
  32. # TODO: figure out whether we should use run_in_executor here
  33. serialized_tensors = (
  34. serialize_torch_tensor(tensor, proto.compression)
  35. for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
  36. )
  37. deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
  38. flat_outputs = tuple(deserialized_outputs)
  39. return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
  40. async def run_expert_backward(
  41. uid: ModuleUID,
  42. stub: StubBase,
  43. rpc_info: RPCInfo,
  44. intemediate_inputs: List[torch.Tensor],
  45. grad_outputs: List[torch.Tensor],
  46. ) -> Sequence[torch.Tensor]:
  47. """
  48. Serializes grad outputs and calls "expert_backward".
  49. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
  50. but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
  51. """
  52. grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
  53. inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
  54. backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
  55. serialized_tensors = (
  56. serialize_torch_tensor(tensor, proto.compression)
  57. for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
  58. )
  59. deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
  60. return deserialized_grad_inputs
  61. async def sequential_forward(
  62. inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
  63. ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
  64. """
  65. Constructs a routing path from <start_index> to <end_index>.
  66. Performs chained forward for each subsequence of blocks on the path.
  67. If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
  68. """
  69. assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
  70. end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
  71. assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
  72. sequences = sequence_manager.make_sequence(start_index, end_index)
  73. intermediate_inputs = []
  74. done_sequences = []
  75. while len(sequences) > 0:
  76. while True:
  77. try:
  78. span = sequences.pop(0)
  79. span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
  80. stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
  81. (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
  82. assert isinstance(outputs, torch.Tensor)
  83. assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
  84. # Save intermediate inputs and subsequences if the forward is already done for them
  85. intermediate_inputs.append(inputs)
  86. done_sequences.append(span)
  87. inputs = outputs
  88. break
  89. except Exception as e:
  90. logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
  91. backup_sequences = sequence_manager.make_sequence(span.start)
  92. assert backup_sequences[0].start == span.start
  93. sequences = backup_sequences
  94. return outputs, intermediate_inputs, done_sequences
  95. async def sequential_backward(
  96. grad_outputs: Sequence[torch.Tensor],
  97. intermediate_inputs: Sequence[torch.Tensor],
  98. forward_sequences: Sequence[RemoteSpanInfo],
  99. sequence_manager: RemoteSequenceManager,
  100. ) -> Sequence[torch.Tensor]:
  101. """
  102. Performs chained backward for each forward subsequence.
  103. If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
  104. """
  105. assert len(intermediate_inputs) == len(forward_sequences)
  106. # TODO think about grads w.r.t. deep prompts
  107. while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
  108. while True:
  109. try:
  110. inputs = intermediate_inputs.pop(-1)
  111. span = forward_sequences.pop(-1)
  112. span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
  113. stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
  114. grad_outputs = await run_expert_backward(
  115. span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
  116. )
  117. break
  118. except Exception as e:
  119. logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
  120. _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
  121. inputs, sequence_manager, start_index=span.start, end_index=span.end
  122. )
  123. assert len(intermediate_inputs) == len(forward_sequences)
  124. assert backup_forward_sequences[0].start == span.start
  125. assert backup_forward_sequences[-1].end == span.end
  126. forward_sequences.extend(backup_forward_sequences)
  127. intermediate_inputs.extend(backup_intermediate_inputs)
  128. return grad_outputs
  129. async def _gather_forward(input_batches, sequence_manager):
  130. """Wrapper for asyncio.gather to perform parallel sequential forwards"""
  131. return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
  132. async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
  133. """Wrapper for asyncio.gather to perform parallel sequential backwards"""
  134. return await asyncio.gather(
  135. *[
  136. sequential_backward((grad_output,), input_batch, spans, sequence_manager)
  137. for grad_output, input_batch, spans in zip(
  138. grad_output_batches, intermediate_input_batches, forward_sequences
  139. )
  140. ]
  141. )
  142. class _RemoteSequentialAutogradFunction(torch.autograd.Function):
  143. """
  144. PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
  145. This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
  146. """
  147. @staticmethod
  148. def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
  149. batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
  150. input_batches: Sequence[torch.Tensor] = inputs.split(batch_size)
  151. sequence_manager.rpc_info # lazy init
  152. outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
  153. assert len(outputs) == len(input_batches)
  154. output_batches = [output[0] for output in outputs]
  155. intemediate_input_batches = [output[1] for output in outputs]
  156. sequences_for_batches = [output[2] for output in outputs]
  157. ctx.sequence_manager = sequence_manager
  158. ctx.intemediate_input_batches = intemediate_input_batches
  159. ctx.sequences_for_batches = sequences_for_batches
  160. return torch.cat(output_batches, dim=0)
  161. @staticmethod
  162. def backward(ctx, grad_outputs: torch.Tensor):
  163. intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
  164. forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
  165. ctx.sequence_manager.rpc_info # lazy init
  166. batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
  167. grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
  168. assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
  169. grad_input_batches = RemoteExpertWorker.run_coroutine(
  170. _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
  171. )
  172. grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
  173. grad_inputs = torch.cat(grad_inputs, dim=0)
  174. return (grad_inputs, None)