autograd.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435
  1. from typing import Tuple, Any
  2. import torch
  3. from torch.autograd.function import _ContextMethodMixin
  4. class EmulatedAutogradContext(_ContextMethodMixin):
  5. """
  6. A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd,
  7. such as running several parallel backwards or transferring backward to a separate device.
  8. This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
  9. """
  10. @property
  11. def saved_tensors(self):
  12. return tuple(self.to_save)
  13. def run_isolated_forward(func: torch.autograd.Function, *args) -> Tuple[EmulatedAutogradContext, Any]:
  14. """
  15. run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
  16. can be used to run backward through the same graph (performed manually by the user).
  17. """
  18. ctx = EmulatedAutogradContext()
  19. # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
  20. args = tuple(x.detach().requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x for x in args)
  21. with torch.no_grad():
  22. return ctx, func.forward(ctx, *args)
  23. def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):
  24. """
  25. run backward pass for :func: in an isolated graph that was previously created through run_isolated_forward
  26. """
  27. with torch.no_grad():
  28. return func.backward(ctx, *grad_outputs)