from typing import Tuple, Any import torch from torch.autograd.function import _ContextMethodMixin class EmulatedAutogradContext(_ContextMethodMixin): """ A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd, such as running several parallel backwards or transferring backward to a separate device. This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere. """ @property def saved_tensors(self): return tuple(self.to_save) def run_isolated_forward(func: torch.autograd.Function, *args) -> Tuple[EmulatedAutogradContext, Any]: """ run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that can be used to run backward through the same graph (performed manually by the user). """ ctx = EmulatedAutogradContext() # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables args = tuple(x.detach().requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x for x in args) with torch.no_grad(): return ctx, func.forward(ctx, *args) def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs): """ run backward pass for :func: in an isolated graph that was previously created through run_isolated_forward """ with torch.no_grad(): return func.backward(ctx, *grad_outputs)