|
@@ -15,17 +15,16 @@ class EmulatedAutogradContext(_ContextMethodMixin):
|
|
|
return tuple(self.to_save)
|
|
|
|
|
|
|
|
|
-def run_isolated_forward(func: torch.autograd.Function, *args, **kwargs) -> Tuple[EmulatedAutogradContext, Any]:
|
|
|
+def run_isolated_forward(func: torch.autograd.Function, *args) -> Tuple[EmulatedAutogradContext, Any]:
|
|
|
"""
|
|
|
run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
|
|
|
- can be used to run backward through the same graph (manually by the user).
|
|
|
+ can be used to run backward through the same graph (performed manually by the user).
|
|
|
"""
|
|
|
ctx = EmulatedAutogradContext()
|
|
|
# create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
|
|
|
- args = tuple(x.detach().requires_grad_(x.requires_grad) for x in args)
|
|
|
- kwargs = {k: x.detach().requires_grad_(x.requires_grad) for k, x in kwargs.items()}
|
|
|
+ args = tuple(x.detach().requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x for x in args)
|
|
|
with torch.no_grad():
|
|
|
- return ctx, func.forward(ctx, *args, **kwargs)
|
|
|
+ return ctx, func.forward(ctx, *args)
|
|
|
|
|
|
|
|
|
def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):
|