|
@@ -11,6 +11,8 @@ def test_remote_module_call():
|
|
|
backward_k_min = 1
|
|
|
timeout_total = None
|
|
|
backward_timeout = None
|
|
|
+ rtol = 1e-3
|
|
|
+ atol = 1e-5
|
|
|
|
|
|
xx = torch.randn(32, 1024, requires_grad=True)
|
|
|
logits = torch.randn(3, requires_grad=True)
|
|
@@ -33,12 +35,11 @@ def test_remote_module_call():
|
|
|
grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
|
|
|
grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
|
|
|
|
|
|
- assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun), "Experts are non-deterministic. This test is only " \
|
|
|
- "valid for deterministic experts"
|
|
|
- print('DIFF', (moe_output - manual_output))
|
|
|
- assert torch.allclose(moe_output, manual_output, rtol=1e-3, atol=1e-6), "_RemoteMoECall returned incorrect output"
|
|
|
- assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
|
|
|
- assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"
|
|
|
+ assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun, rtol, atol), "Experts are non-deterministic. The test" \
|
|
|
+ " is only valid for deterministic experts"
|
|
|
+ assert torch.allclose(moe_output, manual_output, rtol, atol), "_RemoteMoECall returned incorrect output"
|
|
|
+ assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol, atol), "incorrect gradient w.r.t. input"
|
|
|
+ assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|