|
@@ -46,15 +46,18 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
|
|
|
assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
|
|
|
|
|
|
|
|
|
+def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
|
|
|
+ serialized_tensor = serialize_torch_tensor(tensor, compression)
|
|
|
+ chunks = list(split_for_streaming(serialized_tensor, chunk_size))
|
|
|
+ assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
|
|
|
+ restored = combine_from_streaming(chunks)
|
|
|
+ result = deserialize_torch_tensor(restored)
|
|
|
+ assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
|
|
|
+ assert result.dtype == tensor.dtype
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.forked
|
|
|
def test_serialize_tensor():
|
|
|
- def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
|
|
|
- serialized_tensor = serialize_torch_tensor(tensor, compression)
|
|
|
- chunks = list(split_for_streaming(serialized_tensor, chunk_size))
|
|
|
- assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
|
|
|
- restored = combine_from_streaming(chunks)
|
|
|
- assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
|
|
|
-
|
|
|
tensor = torch.randn(512, 12288)
|
|
|
for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
|
|
|
_check(tensor, CompressionType.NONE, chunk_size=chunk_size)
|
|
@@ -65,6 +68,13 @@ def test_serialize_tensor():
|
|
|
_check(torch.tensor(1.0), CompressionType.FLOAT16)
|
|
|
|
|
|
|
|
|
+@pytest.mark.forked
|
|
|
+def test_serialize_bfloat16():
|
|
|
+ tensor = torch.randn(4096, 16, dtype=torch.bfloat16)
|
|
|
+ _check(tensor, CompressionType.NONE)
|
|
|
+ _check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.forked
|
|
|
def test_allreduce_compression():
|
|
|
"""this test ensures that compression works correctly when multiple tensors have different compression types"""
|