Преглед изворни кода

Fix bfloat16 serialization for tensors with zero elements (#560)

Follow-up to #553.

(cherry picked from commit 3164928dbb8350553792ad1662dfa8c635ce951d)
Alexander Borzunov пре 2 година
родитељ
комит
8d8520d0d9
2 измењених фајлова са 6 додато и 4 уклоњено
  1. 2 1
      hivemind/compression/base.py
  2. 4 3
      tests/test_compression.py

+ 2 - 1
hivemind/compression/base.py

@@ -105,7 +105,8 @@ class NoCompression(CompressionBase):
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
         shape = torch.Size(serialized_tensor.size)
         if serialized_tensor.dtype == "bfloat16":
-            if len(serialized_tensor.buffer) // shape.numel() == 4:  # legacy mode: convert to fp32
+            numel = shape.numel()
+            if numel > 0 and len(serialized_tensor.buffer) // numel == 4:  # legacy mode: convert to fp32
                 array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
                 tensor = torch.as_tensor(array, dtype=torch.bfloat16)
             else:  # efficient mode: send bfloat16 data directly

+ 4 - 3
tests/test_compression.py

@@ -49,7 +49,7 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
 def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
     serialized_tensor = serialize_torch_tensor(tensor, compression)
     chunks = list(split_for_streaming(serialized_tensor, chunk_size))
-    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+    assert len(chunks) == max((len(serialized_tensor.buffer) - 1) // chunk_size + 1, 1)
     restored = combine_from_streaming(chunks)
     result = deserialize_torch_tensor(restored)
     assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
@@ -69,10 +69,11 @@ def test_serialize_tensor():
 
 
 @pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
+@pytest.mark.parametrize("tensor_size", [(4096, 16), (0, 0)])
 @pytest.mark.forked
-def test_serialize_bfloat16(use_legacy_bfloat16: bool):
+def test_serialize_bfloat16(use_legacy_bfloat16: bool, tensor_size: tuple):
     hivemind.compression.base.USE_LEGACY_BFLOAT16 = use_legacy_bfloat16
-    tensor = torch.randn(4096, 16, dtype=torch.bfloat16)
+    tensor = torch.randn(tensor_size, dtype=torch.bfloat16)
     _check(tensor, CompressionType.NONE)
     _check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)