Bladeren bron

only return grad w.r.t. inputs

justheuristic 5 jaren geleden
bovenliggende
commit
c5ee3d6041
1 gewijzigde bestanden met toevoegingen van 3 en 1 verwijderingen
  1. 3 1
      tesseract/client/moe.py

+ 3 - 1
tesseract/client/moe.py

@@ -239,4 +239,6 @@ class _RemoteMoECall(torch.autograd.Function):
 
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
-        return run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
+        backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
+        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
+        return grad_inputs