|
@@ -39,3 +39,6 @@ def test_remote_module_call():
|
|
|
assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
|
|
|
assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"
|
|
|
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ test_remote_module_call()
|