Kaynağa Gözat

only return grad w.r.t. inputs

justheuristic 5 yıl önce
ebeveyn
işleme
c5ee3d6041
1 değiştirilmiş dosya ile 3 ekleme ve 1 silme
  1. 3 1
      tesseract/client/moe.py

+ 3 - 1
tesseract/client/moe.py

@@ -239,4 +239,6 @@ class _RemoteMoECall(torch.autograd.Function):
 
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
-        return run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
+        backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
+        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
+        return grad_inputs