123456789101112131415161718192021222324252627282930313233343536 |
- from typing import Tuple, Any
- import torch
- from torch.autograd.function import _ContextMethodMixin
- class EmulatedAutogradContext(_ContextMethodMixin):
- """
- A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd,
- such as running several parallel backwards or transferring backward to a separate device.
- This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
- """
- @property
- def saved_tensors(self):
- return tuple(self.to_save)
- def run_isolated_forward(func: torch.autograd.Function, *args, **kwargs) -> Tuple[EmulatedAutogradContext, Any]:
- """
- run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
- can be used to run backward through the same graph (manually by the user).
- """
- ctx = EmulatedAutogradContext()
- # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
- args = tuple(x.detach().requires_grad_(x.requires_grad) for x in args)
- kwargs = {k: x.detach().requires_grad_(x.requires_grad) for k, x in kwargs.items()}
- with torch.no_grad():
- return ctx, func.forward(ctx, *args, **kwargs)
- def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):
- """
- run backward pass for :func: in an isolated graph that was previously created through run_isolated_forward
- """
- with torch.no_grad():
- return func.backward(ctx, *grad_outputs)
|