|
@@ -203,7 +203,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
for stacked_out in stacked_alive_outputs)
|
|
|
flat_average_outputs_ = tuple(map(
|
|
|
lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
|
|
|
- assert torch.allcloce(flat_average_outputs_[0], flat_average_outputs[0])
|
|
|
+ assert torch.allclose(flat_average_outputs_[0], flat_average_outputs[0])
|
|
|
assert False
|
|
|
# 3. save individual outputs for backward pass
|
|
|
ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
|