浏览代码

only return grad w.r.t. inputs

justheuristic 5 年之前
父节点
当前提交
c5ee3d6041
共有 1 个文件被更改,包括 3 次插入1 次删除
  1. 3 1
      tesseract/client/moe.py

+ 3 - 1
tesseract/client/moe.py

@@ -239,4 +239,6 @@ class _RemoteMoECall(torch.autograd.Function):
 
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
-        return run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
+        backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
+        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
+        return grad_inputs