|
@@ -197,7 +197,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
|
|
|
stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
|
|
|
|
|
|
- flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
|
|
|
+ flat_average_outputs = tuple((alive_expert_probs @ stacked_out.flatten(1)).view(*stacked_out.shape[1:])
|
|
|
for stacked_out in stacked_alive_outputs)
|
|
|
|
|
|
# 3. save individual outputs for backward pass
|
|
@@ -222,7 +222,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
|
|
|
survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
|
|
|
weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
|
|
|
- flat_grad_inputs = tuple(dot_along_first_axis(weight_ratios, stacked_grad_inp)
|
|
|
+ flat_grad_inputs = tuple((weight_ratios @ stacked_grad_inp.flatten(1)).view(stacked_grad_inp.shape[1:])
|
|
|
for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
|
|
|
|
|
|
# compute grad w.r.t. logits
|
|
@@ -247,7 +247,3 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
|
|
|
grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
|
|
|
return grad_inputs
|
|
|
-
|
|
|
-
|
|
|
-def dot_along_first_axis(x, y):
|
|
|
- return (x.view(-1, *[1] * (y.ndim - 1)) * y).sum(0)
|