|
@@ -219,24 +219,22 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
|
|
|
for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
|
results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
|
|
|
- print('A')
|
|
|
backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
|
|
|
backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
|
|
|
backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
|
|
|
- print('B')
|
|
|
survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
|
|
|
weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
|
|
|
flat_grad_inputs = tuple(dot_along_first_axis(weight_ratios, stacked_grad_inp)
|
|
|
- for stacked_grad_inp in map(torch.stack, survived_grad_inputs))
|
|
|
-
|
|
|
- print('C')
|
|
|
+ for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
|
|
|
|
|
|
# compute grad w.r.t. logits
|
|
|
+ print('A')
|
|
|
grad_wrt_probs = sum(tuple(
|
|
|
torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
|
|
|
dim=tuple(range(1, stacked_avive_out.ndim)))
|
|
|
- for grad_out, stacked_avive_out in zip(grad_outputs_flat, zip(*stacked_alive_outputs))
|
|
|
+ for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
|
|
|
))
|
|
|
+ print('B')
|
|
|
softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
|
|
|
grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
|
|
|
|