Browse Source

unified prefix scheme

justheuristic 5 years ago
parent
commit
f9798a474a
1 changed files with 3 additions and 0 deletions
  1. 3 0
      tests/test_moe.py

+ 3 - 0
tests/test_moe.py

@@ -39,3 +39,6 @@ def test_remote_module_call():
     assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
     assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"
 
+
+if __name__ == '__main__':
+    test_remote_module_call()