|
@@ -239,4 +239,6 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
|
|
|
@staticmethod
|
|
|
def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
|
|
|
- return run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
|
|
|
+ backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
|
|
|
+ grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
|
|
|
+ return grad_inputs
|