|
@@ -88,7 +88,7 @@ def test_remote_module_call():
|
|
|
out3 = real_expert(dummy_x)
|
|
|
assert out3.shape == (3, 1024)
|
|
|
out3_again = real_expert(dummy_x[1:])
|
|
|
- assert torch.allclose(out3_again, out3[1:], atol=1e-6, rtol=0)
|
|
|
+ assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
|
|
|
out3_again.norm().backward()
|
|
|
assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
|
|
|
|