|
@@ -197,9 +197,9 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
# 2. compute softmax weights for alive experts and average outputs
|
|
|
alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
|
|
|
alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
|
|
|
- print(f'{alive_expert_probs=}')
|
|
|
|
|
|
stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
|
|
|
+ print(f'>> {[outs[0].norm() for outs in alive_outputs]}')
|
|
|
flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
|
|
|
for stacked_out in stacked_alive_outputs)
|
|
|
|