justheuristic 5 lat temu
rodzic
commit
c8a8c5cbf7
1 zmienionych plików z 1 dodań i 2 usunięć
  1. 1 2
      tests/test_moe.py

+ 1 - 2
tests/test_moe.py

@@ -29,14 +29,13 @@ def test_remote_module_call():
         probs = torch.softmax(logits, 0)
         outs = [expert(xx) for expert in experts[:3]]
         manual_output = sum(p * x for p, x in zip(probs, outs))
-        print(f'ref {[manual_output.min(), manual_output.max(), manual_output.norm()]}')
         grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
         grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
         grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
 
-    assert torch.allclose(moe_output, manual_output), "_RemoteMoECall returned incorrect output"
     assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun), "Experts are non-deterministic. This test is only " \
                                                                  "valid for deterministic experts"
+    assert torch.allclose(moe_output, manual_output, rtol=1e-3, atol=1e-6), "_RemoteMoECall returned incorrect output"
     assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
     assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"