autograd.py 4.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """
  2. Temporary autograd extensions to enable inter-op parallelism during backward pass
  3. Note: we should get rid of this module if https://github.com/pytorch/pytorch/pull/33157 reaches a pytorch release
  4. """
  5. from itertools import chain
  6. from typing import Tuple, Any
  7. from concurrent.futures import Future
  8. import numpy as np
  9. import torch
  10. import torch.autograd.function
  11. from .threading import run_in_background
  12. class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
  13. """
  14. A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd,
  15. such as running several parallel backwards or transferring backward to a separate device.
  16. This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
  17. """
  18. @property
  19. def saved_tensors(self):
  20. return tuple(self.to_save)
  21. def run_isolated_forward(func: torch.autograd.Function, *args) -> Tuple[EmulatedAutogradContext, Any]:
  22. """
  23. run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
  24. can be used to run backward through the same graph (performed manually by the user).
  25. """
  26. ctx = EmulatedAutogradContext()
  27. # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
  28. args = tuple(x.detach().requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x for x in args)
  29. with torch.no_grad():
  30. return ctx, func.forward(ctx, *args)
  31. def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):
  32. """
  33. run backward pass for :func: in an isolated graph that was previously created through run_isolated_forward
  34. """
  35. with torch.no_grad():
  36. return func.backward(ctx, *grad_outputs)
  37. def map_with_parallel_backward(
  38. func: torch.autograd.Function, *args_per_call: Tuple[torch.Tensor, ...]) -> Tuple[Tuple[torch.Tensor, ...]]:
  39. """
  40. Apply an autograd function to several sets of inputs with two extra guarantees:
  41. (1) both forward and backward pass happens concurrently for each set of inputs
  42. (2) any operation dependent on any individual function will wait for all functions to finish
  43. :param func: torch autograd function to be called several times in parallel
  44. :param args_per_call: a sequence of tuples of arguments, each tuple corresponds to one function call
  45. :returns: a tuple of outputs from each func call
  46. Note: this function currently requires that all :func: calls succeed (i.e. do not raise an exception).
  47. """
  48. arg_counts = list(map(len, args_per_call))
  49. assert len(set(arg_counts)) == 1, "All input sets must have the same number of arguments"
  50. output_strides_ph = Future()
  51. flat_outputs: Tuple[torch.Tensor, ...] = _ParallelApplyFunction.apply(
  52. func, len(args_per_call), arg_counts[0], output_strides_ph, *chain(*args_per_call))
  53. output_strides = output_strides_ph.result()
  54. return tuple(flat_outputs[output_strides[i]: output_strides[i + 1]] for i in range(len(output_strides) - 1))
  55. class _ParallelApplyFunction(torch.autograd.Function):
  56. """
  57. A special torch autograd function that runs another function several times in parallel.
  58. Please do not call this function directly. Use apply_with_parallel_backward instead.
  59. Unlike default pytorch behavior, the backward pass for each function will also happen in parallel.
  60. """
  61. @staticmethod
  62. def forward(ctx, func: torch.autograd.Function, num_calls: int, num_args_per_call: int,
  63. output_strides_ph: Future, *args_flat) -> Tuple[torch.Tensor, ...]:
  64. assert num_calls * num_args_per_call == len(args_flat)
  65. args_per_call = [args_flat[i * num_args_per_call: (i + 1) * num_args_per_call] for i in range(num_calls)]
  66. futures = [run_in_background(run_isolated_forward, func, *args) for args in args_per_call]
  67. contexts, outputs = zip(*[future.result() for future in futures])
  68. output_strides = np.cumsum([0] + list(map(len, outputs)))
  69. ctx._inner_func = func
  70. ctx._call_contexts = contexts
  71. ctx._output_strides = output_strides
  72. output_strides_ph.set_result(output_strides)
  73. return tuple(chain(*outputs))
  74. @staticmethod
  75. def backward(ctx, *grad_outputs_flat: torch.Tensor):
  76. func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
  77. grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]] for i in range(len(contexts))]
  78. futures = [run_in_background(run_isolated_backward, func, context, *grads)
  79. for context, grads in zip(contexts, grad_outputs_per_call)]
  80. flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())
  81. return None, None, None, None, *flat_grads_wrt_input