|
@@ -224,7 +224,6 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
|
|
|
survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
|
|
|
weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
|
|
|
- print('>>>', survived_grad_inputs)
|
|
|
flat_grad_inputs = tuple(dot_along_first_axis(weight_ratios, stacked_grad_inp)
|
|
|
for stacked_grad_inp in map(torch.stack, survived_grad_inputs))
|
|
|
|
|
@@ -247,6 +246,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
|
|
|
@staticmethod
|
|
|
def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
|
|
|
+ print('!!!', [g.shape for g in grad_outputs])
|
|
|
backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
|
|
|
grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
|
|
|
return grad_inputs
|