|
@@ -201,7 +201,6 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
stacked_alive_outputs = tuple(map(torch.stack, alive_outputs))
|
|
|
flat_average_outputs = tuple(dot_along_first_axis(alive_expert_probs, stacked_out)
|
|
|
for stacked_out in stacked_alive_outputs)
|
|
|
- print('!!!!', flat_average_outputs, flush=True)
|
|
|
|
|
|
# 3. save individual outputs for backward pass
|
|
|
ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
|
|
@@ -254,4 +253,4 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
def dot_along_first_axis(x, y):
|
|
|
- (x.view(-1, *[1] * (y.ndim - 1))).sum(0)
|
|
|
+ return (x.view(-1, *[1] * (y.ndim - 1))).sum(0)
|