Browse Source

debugprint

justheuristic 5 years ago
parent
commit
188ab37921
1 changed files with 1 additions and 0 deletions
  1. 1 0
      tests/test_moe.py

+ 1 - 0
tests/test_moe.py

@@ -35,6 +35,7 @@ def test_remote_module_call():
 
 
     assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun), "Experts are non-deterministic. This test is only " \
     assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun), "Experts are non-deterministic. This test is only " \
                                                                  "valid for deterministic experts"
                                                                  "valid for deterministic experts"
+    print('DIFF', (moe_output - manual_output))
     assert torch.allclose(moe_output, manual_output, rtol=1e-3, atol=1e-6), "_RemoteMoECall returned incorrect output"
     assert torch.allclose(moe_output, manual_output, rtol=1e-3, atol=1e-6), "_RemoteMoECall returned incorrect output"
     assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
     assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
     assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"
     assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"