|
@@ -224,7 +224,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
|
|
|
survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
|
|
|
weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
|
|
|
-
|
|
|
+ print('>>>', survived_grad_inputs)
|
|
|
flat_grad_inputs = tuple(dot_along_first_axis(weight_ratios, stacked_grad_inp)
|
|
|
for stacked_grad_inp in map(torch.stack, survived_grad_inputs))
|
|
|
|