autograd.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from typing import Tuple, Any
  2. import torch
  3. from torch.autograd.function import _ContextMethodMixin
  4. class EmulatedAutogradContext(_ContextMethodMixin):
  5. """
  6. A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd,
  7. such as running several parallel backwards or transferring backward to a separate device.
  8. This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
  9. """
  10. @property
  11. def saved_tensors(self):
  12. return tuple(self.to_save)
  13. def run_isolated_forward(func: torch.autograd.Function, *args, **kwargs) -> Tuple[EmulatedAutogradContext, Any]:
  14. """
  15. run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
  16. can be used to run backward through the same graph (manually by the user).
  17. """
  18. ctx = EmulatedAutogradContext()
  19. # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
  20. args = tuple(x.detach().requires_grad_(x.requires_grad) for x in args)
  21. kwargs = {k: x.detach().requires_grad_(x.requires_grad) for k, x in kwargs.items()}
  22. with torch.no_grad():
  23. return ctx, func.forward(ctx, *args, **kwargs)
  24. def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):
  25. """
  26. run backward pass for :func: in an isolated graph that was previously created through run_isolated_forward
  27. """
  28. with torch.no_grad():
  29. return func.backward(ctx, *grad_outputs)