瀏覽代碼

Increase tolerances in test_tp_block (#196)

deflapify tests
justheuristic 2 年之前
父節點
當前提交
c2cb6d19ae
共有 1 個文件被更改,包括 2 次插入2 次删除
  1. 2 2
      tests/test_tensor_parallel.py

+ 2 - 2
tests/test_tensor_parallel.py

@@ -40,7 +40,7 @@ def test_tp_block(devices, custom_config):
     y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
     y_ours.backward(grad_proj)
 
-    assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6)
-    assert torch.allclose(y_ours, y_ref, atol=1e-6)
+    assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
+    assert torch.allclose(y_ours, y_ref, atol=1e-5)
     assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
     assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)