Jelajahi Sumber

do not .detach non-tensor parameters

justheuristic 5 tahun lalu
induk
melakukan
49e4459ec8
1 mengubah file dengan 4 tambahan dan 5 penghapusan
  1. 4 5
      tesseract/utils/autograd.py

+ 4 - 5
tesseract/utils/autograd.py

@@ -15,17 +15,16 @@ class EmulatedAutogradContext(_ContextMethodMixin):
         return tuple(self.to_save)
 
 
-def run_isolated_forward(func: torch.autograd.Function, *args, **kwargs) -> Tuple[EmulatedAutogradContext, Any]:
+def run_isolated_forward(func: torch.autograd.Function, *args) -> Tuple[EmulatedAutogradContext, Any]:
     """
     run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
-    can be used to run backward through the same graph (manually by the user).
+    can be used to run backward through the same graph (performed manually by the user).
     """
     ctx = EmulatedAutogradContext()
     # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
-    args = tuple(x.detach().requires_grad_(x.requires_grad) for x in args)
-    kwargs = {k: x.detach().requires_grad_(x.requires_grad) for k, x in kwargs.items()}
+    args = tuple(x.detach().requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x for x in args)
     with torch.no_grad():
-        return ctx, func.forward(ctx, *args, **kwargs)
+        return ctx, func.forward(ctx, *args)
 
 
 def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):