Parcourir la source

only return grad w.r.t. inputs

justheuristic il y a 5 ans
Parent
commit
c5ee3d6041
1 fichiers modifiés avec 3 ajouts et 1 suppressions
  1. 3 1
      tesseract/client/moe.py

+ 3 - 1
tesseract/client/moe.py

@@ -239,4 +239,6 @@ class _RemoteMoECall(torch.autograd.Function):
 
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
-        return run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
+        backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
+        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
+        return grad_inputs