sequential_autograd.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. """
  2. A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
  3. """
  4. import asyncio
  5. import itertools
  6. from collections import deque
  7. from typing import List, Optional, Sequence, Tuple
  8. import torch
  9. from hivemind import MSGPackSerializer
  10. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  11. from hivemind.utils.logging import get_logger
  12. from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
  13. from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
  14. from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
  15. from petals.server.handler import TransformerConnectionHandler
  16. from petals.utils.misc import DUMMY, is_dummy
  17. logger = get_logger(__name__)
  18. MAX_TOKENS_IN_BATCH = 1024
  19. async def sequential_forward(
  20. inputs: torch.Tensor,
  21. prompts: torch.Tensor,
  22. sequence_manager: RemoteSequenceManager,
  23. start_index: int = 0,
  24. end_index: Optional[int] = None,
  25. ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
  26. """
  27. Constructs a routing path from <start_index> to <end_index>.
  28. Performs chained forward for each subsequence of blocks on the path.
  29. If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
  30. """
  31. assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
  32. inputs_device = inputs.device
  33. inputs_dtype = inputs.dtype
  34. inputs = inputs.cpu()
  35. prompts = prompts.cpu()
  36. end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
  37. assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
  38. assert is_dummy(prompts) or len(prompts) == len(
  39. sequence_manager.block_uids
  40. ) # should be n_layers - 1 but add extra prompts for convenience
  41. sequences = deque()
  42. intermediate_inputs = []
  43. done_sequences = []
  44. block_idx = start_index
  45. while block_idx < end_index:
  46. for attempt_no in itertools.count():
  47. logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
  48. span = None
  49. try:
  50. if not sequences or attempt_no >= 1:
  51. sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="max_throughput"))
  52. # make_sequence() could return a longer sequence
  53. sequences[-1].end = min(sequences[-1].end, end_index)
  54. logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")
  55. span = sequences.popleft()
  56. stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
  57. inputs_and_prompts = [inputs, prompts[span.start : span.end]]
  58. span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
  59. metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
  60. (outputs,) = await run_remote_forward(
  61. span_uids,
  62. stub,
  63. sequence_manager.rpc_info,
  64. *inputs_and_prompts,
  65. config=sequence_manager.config,
  66. metadata=MSGPackSerializer.dumps(metadata),
  67. )
  68. assert isinstance(outputs, torch.Tensor)
  69. assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
  70. # Save intermediate inputs and subsequences if the forward is already done for them
  71. intermediate_inputs.append(inputs)
  72. done_sequences.append(span)
  73. inputs = outputs
  74. block_idx = span.end
  75. sequence_manager.on_request_success(span.peer_id)
  76. break
  77. except Exception as e:
  78. sequence_manager.on_request_failure(span.peer_id if span is not None else None)
  79. if attempt_no + 1 == sequence_manager.config.max_retries:
  80. raise
  81. delay = sequence_manager.get_retry_delay(attempt_no)
  82. logger.warning(
  83. f"Caught exception when running forward via {span} (retry in {delay:.0f} sec): {repr(e)}"
  84. )
  85. maybe_log_traceback(e)
  86. await asyncio.sleep(delay)
  87. outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
  88. intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]
  89. return outputs, intermediate_inputs, done_sequences
  90. async def sequential_backward(
  91. grad_outputs: Sequence[torch.Tensor],
  92. intermediate_inputs: List[torch.Tensor],
  93. prompts: torch.Tensor,
  94. forward_sequences: List[RemoteSpanInfo],
  95. sequence_manager: RemoteSequenceManager,
  96. ) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
  97. """
  98. Performs chained backward for each forward subsequence.
  99. If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
  100. """
  101. assert len(intermediate_inputs) == len(forward_sequences)
  102. grad_outputs_device = grad_outputs[0].device if grad_outputs else None
  103. grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None
  104. prompts_device = prompts.device
  105. prompts_dtype = prompts.dtype
  106. grad_outputs = [tensor.cpu() for tensor in grad_outputs]
  107. intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
  108. prompts = prompts.cpu()
  109. grad_prompts_reversed = []
  110. while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
  111. inputs = intermediate_inputs.pop()
  112. span = forward_sequences.pop()
  113. for attempt_no in itertools.count():
  114. logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
  115. try:
  116. if attempt_no >= 1:
  117. _, backup_inputs, backup_sequences = await sequential_forward(
  118. inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
  119. )
  120. assert len(backup_inputs) == len(backup_sequences)
  121. assert backup_sequences[0].start == span.start
  122. assert backup_sequences[-1].end == span.end
  123. intermediate_inputs.extend(backup_inputs)
  124. forward_sequences.extend(backup_sequences)
  125. inputs = intermediate_inputs.pop()
  126. span = forward_sequences.pop()
  127. span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
  128. stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
  129. metadata = sequence_manager.get_request_metadata(
  130. "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
  131. )
  132. grad_outputs, *span_grad_prompts = await run_remote_backward(
  133. span_uids,
  134. stub,
  135. sequence_manager.rpc_info,
  136. inputs,
  137. grad_outputs,
  138. prompts[span.start : span.end],
  139. config=sequence_manager.config,
  140. metadata=MSGPackSerializer.dumps(metadata),
  141. )
  142. grad_outputs = [grad_outputs]
  143. grad_prompts_reversed.extend(span_grad_prompts)
  144. sequence_manager.on_request_success(span.peer_id)
  145. break
  146. except Exception as e:
  147. sequence_manager.on_request_failure(span.peer_id if span is not None else None)
  148. if attempt_no + 1 == sequence_manager.config.max_retries:
  149. raise
  150. delay = sequence_manager.get_retry_delay(attempt_no)
  151. logger.warning(
  152. f"Caught exception when running backward via {span} (retry in {delay:.0f} sec): {repr(e)}"
  153. )
  154. maybe_log_traceback(e)
  155. await asyncio.sleep(delay)
  156. # For now, we do not support mixed dummy and grad prompts
  157. # Concat in num_layer dimension
  158. grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
  159. if grad_outputs_dtype is not None:
  160. grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]
  161. if grad_prompts is not None:
  162. grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)
  163. return grad_outputs, grad_prompts
  164. async def _gather_forward(input_batches, prompt_batches, sequence_manager):
  165. """Wrapper for asyncio.gather to perform parallel sequential forwards"""
  166. return await asyncio.gather(
  167. *[
  168. sequential_forward(input_batch, prompt_batch, sequence_manager)
  169. for input_batch, prompt_batch in zip(input_batches, prompt_batches)
  170. ]
  171. )
  172. async def _gather_backward(
  173. grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
  174. ):
  175. """Wrapper for asyncio.gather to perform parallel sequential backwards"""
  176. return await asyncio.gather(
  177. *[
  178. sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
  179. for grad_output, input_batch, prompt_batch, spans in zip(
  180. grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
  181. )
  182. ]
  183. )
  184. class _RemoteSequentialAutogradFunction(torch.autograd.Function):
  185. """
  186. PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
  187. This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
  188. """
  189. @staticmethod
  190. def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
  191. batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
  192. input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
  193. if is_dummy(prompts):
  194. prompt_batches = [DUMMY] * len(input_batches)
  195. else:
  196. prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
  197. sequence_manager.rpc_info # lazy init
  198. outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
  199. assert len(outputs) == len(input_batches)
  200. output_batches = [output[0] for output in outputs]
  201. intemediate_input_batches = [output[1] for output in outputs]
  202. sequences_for_batches = [output[2] for output in outputs]
  203. ctx.prompt_batches = prompt_batches
  204. ctx.sequence_manager = sequence_manager
  205. ctx.intemediate_input_batches = intemediate_input_batches
  206. ctx.sequences_for_batches = sequences_for_batches
  207. return torch.cat(output_batches, dim=0)
  208. @staticmethod
  209. def backward(ctx, grad_outputs: torch.Tensor):
  210. intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
  211. forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
  212. ctx.sequence_manager.rpc_info # lazy init
  213. batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
  214. grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
  215. assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
  216. outputs = RemoteExpertWorker.run_coroutine(
  217. _gather_backward(
  218. grad_output_batches,
  219. intermediate_input_batches,
  220. ctx.prompt_batches,
  221. forward_sequences,
  222. ctx.sequence_manager,
  223. )
  224. )
  225. grad_input_batches = [output[0][0] for output in outputs]
  226. grad_prompt_batches = [output[1] for output in outputs]
  227. grad_inputs = torch.cat(grad_input_batches, dim=0)
  228. dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
  229. grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
  230. return (grad_inputs, grad_prompts, None)