justheuristic 5 лет назад
Родитель
Сommit
4aaf179e43
1 измененных файлов с 7 добавлено и 6 удалено
  1. 7 6
      tests/test_moe.py

+ 7 - 6
tests/test_moe.py

@@ -11,6 +11,8 @@ def test_remote_module_call():
     backward_k_min = 1
     timeout_total = None
     backward_timeout = None
+    rtol = 1e-3
+    atol = 1e-5
 
     xx = torch.randn(32, 1024, requires_grad=True)
     logits = torch.randn(3, requires_grad=True)
@@ -33,12 +35,11 @@ def test_remote_module_call():
         grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
         grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
 
-    assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun), "Experts are non-deterministic. This test is only " \
-                                                                 "valid for deterministic experts"
-    print('DIFF', (moe_output - manual_output))
-    assert torch.allclose(moe_output, manual_output, rtol=1e-3, atol=1e-6), "_RemoteMoECall returned incorrect output"
-    assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
-    assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"
+    assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun, rtol, atol), "Experts are non-deterministic. The test" \
+                                                                             " is only valid for deterministic experts"
+    assert torch.allclose(moe_output, manual_output, rtol, atol), "_RemoteMoECall returned incorrect output"
+    assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol, atol), "incorrect gradient w.r.t. input"
+    assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
 
 
 if __name__ == '__main__':