|
@@ -199,8 +199,10 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
|
|
|
|
|
|
stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
|
|
|
- flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
|
|
|
- for stacked_out in stacked_alive_outputs)
|
|
|
+ # flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
|
|
|
+ # for stacked_out in stacked_alive_outputs)
|
|
|
+ flat_average_outputs = tuple(map(
|
|
|
+ lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
|
|
|
|
|
|
# 3. save individual outputs for backward pass
|
|
|
ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
|