|
@@ -203,8 +203,6 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
|
|
|
for stacked_out in stacked_alive_outputs)
|
|
|
|
|
|
- print(f'ours {[flat_average_outputs[0].min(), flat_average_outputs[0].max(), flat_average_outputs[0].norm()]}')
|
|
|
-
|
|
|
# 3. save individual outputs for backward pass
|
|
|
ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
|
|
|
ctx._alive_contexts = alive_contexts
|