@@ -229,7 +229,7 @@ class _RemoteMoECall(torch.autograd.Function):
*survived_grad_inputs))
grad_logits = None # TODO
- return grad_logits, None, None, None, None, None, None, *flat_grad_inputs
+ return grad_logits, None, None, None, None, None, None, None, *flat_grad_inputs
@staticmethod
def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):