Kaynağa Gözat

Support torch.bfloat16 in hivemind.compression (#524)

This PR implements bfloat16 support for `CompressionType.NONE` and `CompressionType.BLOCKWISE_8BIT`.

This is important for the Petals client, see https://github.com/bigscience-workshop/petals/issues/79
Alexander Borzunov 2 yıl önce
ebeveyn
işleme
1e4af434f3

+ 15 - 6
hivemind/compression/base.py

@@ -80,18 +80,27 @@ class NoCompression(CompressionBase):
     compression_type = runtime_pb2.CompressionType.NONE
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
-        array = tensor.detach().numpy()
+        tensor = tensor.detach()
+        dtype_name = str(tensor.dtype).lstrip("torch.")
+        if tensor.dtype == torch.bfloat16:
+            tensor = tensor.to(torch.float32)
+
         return runtime_pb2.Tensor(
             compression=self.compression_type,
-            buffer=array.tobytes(),
-            size=array.shape,
-            dtype=array.dtype.name,
+            buffer=tensor.numpy().tobytes(),
+            size=tensor.shape,
+            dtype=dtype_name,
             requires_grad=tensor.requires_grad,
         )
 
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
-        return torch.as_tensor(array).reshape(tuple(serialized_tensor.size))
+        if serialized_tensor.dtype == "bfloat16":
+            array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
+            tensor = torch.as_tensor(array, dtype=torch.bfloat16)
+        else:
+            array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
+            tensor = torch.as_tensor(array)
+        return tensor.reshape(tuple(serialized_tensor.size))
 
     def estimate_compression_ratio(self, info: CompressionInfo) -> float:
         return 1.0

+ 12 - 5
hivemind/compression/quantization.py

@@ -120,8 +120,8 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
     return np.quantile(partition_quantiles, quantiles)
 
 
-BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly. 
-Please install it with `pip install bitsandbytes` 
+BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
+Please install it with `pip install bitsandbytes`
 or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
 
 
@@ -139,7 +139,12 @@ class BlockwiseQuantization(Quantization):
         return quantized.numpy(), (absmax.numpy(), codebook.numpy())
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
-        quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+        tensor = tensor.detach()
+        dtype_name = str(tensor.dtype).lstrip("torch.")
+        if tensor.dtype == torch.bfloat16:
+            tensor = tensor.to(torch.float32)
+
+        quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace)
 
         serialized_data = (
             np.int64(len(absmax)).tobytes(),
@@ -153,7 +158,7 @@ class BlockwiseQuantization(Quantization):
             buffer=b"".join(serialized_data),
             size=tensor.shape,
             requires_grad=tensor.requires_grad,
-            dtype=tensor.numpy().dtype.name,
+            dtype=dtype_name,
             compression=self.compression_type,
         )
 
@@ -172,6 +177,8 @@ class BlockwiseQuantization(Quantization):
         codebook = torch.as_tensor(codebook)
         quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
         try:
-            return dequantize_blockwise(quantized, (absmax, codebook))
+            result = dequantize_blockwise(quantized, (absmax, codebook))  # Always returns a float32 tensor
         except NameError:
             raise ImportError(BNB_MISSING_MESSAGE)
+        result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
+        return result

+ 17 - 7
tests/test_compression.py

@@ -46,15 +46,18 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
         assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
 
 
+def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
+    serialized_tensor = serialize_torch_tensor(tensor, compression)
+    chunks = list(split_for_streaming(serialized_tensor, chunk_size))
+    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+    restored = combine_from_streaming(chunks)
+    result = deserialize_torch_tensor(restored)
+    assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
+    assert result.dtype == tensor.dtype
+
+
 @pytest.mark.forked
 def test_serialize_tensor():
-    def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
-        serialized_tensor = serialize_torch_tensor(tensor, compression)
-        chunks = list(split_for_streaming(serialized_tensor, chunk_size))
-        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-        restored = combine_from_streaming(chunks)
-        assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
-
     tensor = torch.randn(512, 12288)
     for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
         _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
@@ -65,6 +68,13 @@ def test_serialize_tensor():
     _check(torch.tensor(1.0), CompressionType.FLOAT16)
 
 
+@pytest.mark.forked
+def test_serialize_bfloat16():
+    tensor = torch.randn(4096, 16, dtype=torch.bfloat16)
+    _check(tensor, CompressionType.NONE)
+    _check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)
+
+
 @pytest.mark.forked
 def test_allreduce_compression():
     """this test ensures that compression works correctly when multiple tensors have different compression types"""