|
@@ -49,7 +49,7 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
|
|
def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
|
|
def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
|
|
serialized_tensor = serialize_torch_tensor(tensor, compression)
|
|
serialized_tensor = serialize_torch_tensor(tensor, compression)
|
|
chunks = list(split_for_streaming(serialized_tensor, chunk_size))
|
|
chunks = list(split_for_streaming(serialized_tensor, chunk_size))
|
|
- assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
|
|
|
|
|
|
+ assert len(chunks) == max((len(serialized_tensor.buffer) - 1) // chunk_size + 1, 1)
|
|
restored = combine_from_streaming(chunks)
|
|
restored = combine_from_streaming(chunks)
|
|
result = deserialize_torch_tensor(restored)
|
|
result = deserialize_torch_tensor(restored)
|
|
assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
|
|
assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
|
|
@@ -69,10 +69,11 @@ def test_serialize_tensor():
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
|
|
@pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
|
|
|
|
+@pytest.mark.parametrize("tensor_size", [(4096, 16), (0, 0)])
|
|
@pytest.mark.forked
|
|
@pytest.mark.forked
|
|
-def test_serialize_bfloat16(use_legacy_bfloat16: bool):
|
|
|
|
|
|
+def test_serialize_bfloat16(use_legacy_bfloat16: bool, tensor_size: tuple):
|
|
hivemind.compression.base.USE_LEGACY_BFLOAT16 = use_legacy_bfloat16
|
|
hivemind.compression.base.USE_LEGACY_BFLOAT16 = use_legacy_bfloat16
|
|
- tensor = torch.randn(4096, 16, dtype=torch.bfloat16)
|
|
|
|
|
|
+ tensor = torch.randn(tensor_size, dtype=torch.bfloat16)
|
|
_check(tensor, CompressionType.NONE)
|
|
_check(tensor, CompressionType.NONE)
|
|
_check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)
|
|
_check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)
|
|
|
|
|