|
@@ -246,7 +246,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
|
|
|
grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
|
|
|
|
|
|
- return grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs
|
|
|
+ return (grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs)
|
|
|
|
|
|
@staticmethod
|
|
|
def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
|