Prechádzať zdrojové kódy

Increase tolerances in test_tp_block (#196)

deflapify tests
justheuristic 2 rokov pred
rodič
commit
c2cb6d19ae
1 zmenil súbory, kde vykonal 2 pridanie a 2 odobranie
  1. 2 2
      tests/test_tensor_parallel.py

+ 2 - 2
tests/test_tensor_parallel.py

@@ -40,7 +40,7 @@ def test_tp_block(devices, custom_config):
     y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
     y_ours.backward(grad_proj)
 
-    assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6)
-    assert torch.allclose(y_ours, y_ref, atol=1e-6)
+    assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
+    assert torch.allclose(y_ours, y_ref, atol=1e-5)
     assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
     assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)