|
@@ -56,9 +56,9 @@ def test_remote_sequential():
|
|
|
(approx_outputs * grad_proj).sum().backward()
|
|
|
|
|
|
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
|
|
|
- assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-4), "compression was not used"
|
|
|
+ assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
|
|
|
assert abs(approx_outputs - full_outputs).mean() < 0.01
|
|
|
- assert abs(test_inputs.grad - full_grad).mean() < 0.1
|
|
|
+ assert abs(test_inputs.grad - full_grad).mean() < 0.3
|
|
|
|
|
|
|
|
|
class DummyCustomSequenceManager(RemoteSequenceManager):
|