justheuristic 5 gadi atpakaļ
vecāks
revīzija
5cddd77c93
1 mainītis faili ar 1 papildinājumiem un 1 dzēšanām
  1. 1 1
      tesseract/client/moe.py

+ 1 - 1
tesseract/client/moe.py

@@ -203,7 +203,7 @@ class _RemoteMoECall(torch.autograd.Function):
                                      for stacked_out in stacked_alive_outputs)
         flat_average_outputs_ = tuple(map(
             lambda *tensors: sum(x * weight for x, weight in zip(tensors, alive_expert_probs)), *alive_outputs))
-        assert torch.allcloce(flat_average_outputs_[0], flat_average_outputs[0])
+        assert torch.allclose(flat_average_outputs_[0], flat_average_outputs[0])
         assert False
         # 3. save individual outputs for backward pass
         ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)