|
@@ -228,13 +228,11 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
|
|
|
|
|
|
# compute grad w.r.t. logits
|
|
|
- print('A')
|
|
|
grad_wrt_probs = sum(tuple(
|
|
|
torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
|
|
|
dim=tuple(range(1, stacked_avive_out.ndim)))
|
|
|
for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
|
|
|
))
|
|
|
- print('B')
|
|
|
softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
|
|
|
grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
|
|
|
|