|
@@ -225,7 +225,8 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
survived_expert_probas = torch.softmax(expert_logits[survived_ix], dim=0)
|
|
|
|
|
|
flat_grad_inputs = tuple(map(
|
|
|
- lambda *tensors: sum(x * weight for x, weight in zip(tensors, survived_expert_probas)),
|
|
|
+ lambda *tensors: sum(x * (weight / old_weight) for x, weight, old_weight
|
|
|
+ in zip(tensors, survived_expert_probas, alive_expert_probas[survived_backward])),
|
|
|
*survived_grad_inputs))
|
|
|
|
|
|
grad_logits = None # TODO
|