瀏覽代碼

unified prefix scheme

justheuristic 5 年之前
父節點
當前提交
f9798a474a
共有 1 個文件被更改,包括 3 次插入0 次删除
  1. 3 0
      tests/test_moe.py

+ 3 - 0
tests/test_moe.py

@@ -39,3 +39,6 @@ def test_remote_module_call():
     assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
     assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"
 
+
+if __name__ == '__main__':
+    test_remote_module_call()