Pārlūkot izejas kodu

Add support for quantization with bitsandbytes (#490)

* Add support for quantization with bitsandbytes

* Extend the compression benchmark

* Add a test for blockwise compression

* Add a note to README about bitsandbytes

* Install bitsandbytes in tests as well

* Verify outputs consistently in test_moe.py
(to make the test less flaky)

* Pass device="cpu" in test_background_server_identity_path
This ensures that the server can actually launch in a GPU-enabled environment: otherwise initializing the CUDA context in a parent process prevents it

* Filter bitsandbytes warnings
Max Ryabinin 2 gadi atpakaļ
vecāks
revīzija
131f82c97e

+ 3 - 0
.github/workflows/run-benchmarks.yml

@@ -26,6 +26,9 @@ jobs:
           python -m pip install --upgrade pip
           pip install -r requirements.txt
           pip install -r requirements-dev.txt
+      - name: Build bitsandbytes
+        run: |
+          pip install bitsandbytes==0.32.3
       - name: Build hivemind
         run: |
           pip install .

+ 6 - 0
.github/workflows/run-tests.yml

@@ -29,6 +29,9 @@ jobs:
           python -m pip install --upgrade pip
           pip install -r requirements.txt
           pip install -r requirements-dev.txt
+      - name: Build bitsandbytes
+        run: |
+          pip install bitsandbytes==0.32.3
       - name: Build hivemind
         run: |
           pip install .
@@ -88,6 +91,9 @@ jobs:
           python -m pip install --upgrade pip
           pip install -r requirements.txt
           pip install -r requirements-dev.txt
+      - name: Build bitsandbytes
+        run: |
+          pip install bitsandbytes==0.32.3
       - name: Build hivemind
         run: |
           pip install -e . --no-use-pep517

+ 4 - 0
README.md

@@ -53,6 +53,10 @@ If your versions of Python and PyTorch match the requirements, you can install h
 pip install hivemind
 ```
 
+Also, if you want to use blockwise 8-bit compression from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) 
+during data transfer, you can install it with `pip install hivemind[bitsandbytes]`. 
+After that, you can use the `BlockwiseQuantization` class in [hivemind.compression](./hivemind/compression)
+
 ### From source
 
 To install hivemind from source, simply run the following:

+ 20 - 9
benchmarks/benchmark_tensor_compression.py

@@ -11,26 +11,37 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
-def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
+def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> [float, float, int]:
     t = time.time()
-    deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
-    return time.time() - t
+    serialized = serialize_torch_tensor(tensor, compression_type)
+    result = deserialize_torch_tensor(serialized)
+    return time.time() - t, (tensor - result).square().mean(), serialized.ByteSize()
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--size", type=int, default=10000000, required=False)
+    parser.add_argument("--size", type=int, default=10_000_000, required=False)
     parser.add_argument("--seed", type=int, default=7348, required=False)
     parser.add_argument("--num_iters", type=int, default=30, required=False)
 
     args = parser.parse_args()
 
     torch.manual_seed(args.seed)
-    X = torch.randn(args.size)
+    X = torch.randn(args.size, dtype=torch.float32)
 
     for name, compression_type in CompressionType.items():
-        tm = 0
+        total_time = 0
+        compression_error = 0
+        total_size = 0
         for i in range(args.num_iters):
-            tm += benchmark_compression(X, compression_type)
-        tm /= args.num_iters
-        logger.info(f"Compression type: {name}, time: {tm}")
+            iter_time, iter_distortion, size = benchmark_compression(X, compression_type)
+            total_time += iter_time
+            compression_error += iter_distortion
+            total_size += size
+        total_time /= args.num_iters
+        compression_error /= args.num_iters
+        total_size /= args.num_iters
+        logger.info(
+            f"Compression type: {name}, time: {total_time:.5f}, compression error: {compression_error:.5f}, "
+            f"size: {int(total_size):d}"
+        )

+ 1 - 1
hivemind/compression/__init__.py

@@ -5,7 +5,7 @@ Compression strategies that reduce the network communication in .averaging, .opt
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
-from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.compression.serialization import (
     deserialize_tensor_stream,
     deserialize_torch_tensor,

+ 63 - 0
hivemind/compression/quantization.py

@@ -1,5 +1,7 @@
+import importlib.util
 import math
 import os
+import warnings
 from abc import ABC, abstractmethod
 from concurrent.futures import ThreadPoolExecutor
 from typing import Tuple
@@ -7,6 +9,10 @@ from typing import Tuple
 import numpy as np
 import torch
 
+if importlib.util.find_spec("bitsandbytes") is not None:
+    warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
+    from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
+
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.proto import runtime_pb2
 
@@ -112,3 +118,60 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
     for job in jobs:
         job.result()
     return np.quantile(partition_quantiles, quantiles)
+
+
+BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly. 
+Please install it with `pip install bitsandbytes` 
+or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
+
+
+class BlockwiseQuantization(Quantization):
+    compression_type = runtime_pb2.BLOCKWISE_8BIT
+    codebook_dtype, indices_dtype = np.float32, np.uint8
+
+    def quantize(
+        self, tensor: torch.Tensor, allow_inplace: bool = False
+    ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
+        try:
+            quantized, (absmax, codebook) = quantize_blockwise(tensor)
+        except NameError:
+            raise ImportError(BNB_MISSING_MESSAGE)
+        return quantized.numpy(), (absmax.numpy(), codebook.numpy())
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+
+        serialized_data = (
+            np.int64(len(absmax)).tobytes(),
+            np.int64(len(codebook)).tobytes(),
+            absmax.tobytes(),
+            codebook.tobytes(),
+            quantized.tobytes(),
+        )
+
+        return runtime_pb2.Tensor(
+            buffer=b"".join(serialized_data),
+            size=tensor.shape,
+            requires_grad=tensor.requires_grad,
+            dtype=tensor.numpy().dtype.name,
+            compression=self.compression_type,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
+        codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
+        absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
+        codebook = np.frombuffer(
+            serialized_tensor.buffer, offset=16 + absmax.nbytes, count=codebook_size, dtype=self.codebook_dtype
+        )
+        quantized = np.frombuffer(
+            serialized_tensor.buffer, offset=16 + absmax.nbytes + codebook.nbytes, dtype=self.indices_dtype
+        )
+
+        absmax = torch.as_tensor(absmax)
+        codebook = torch.as_tensor(codebook)
+        quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
+        try:
+            return dequantize_blockwise(quantized, (absmax, codebook))
+        except NameError:
+            raise ImportError(BNB_MISSING_MESSAGE)

+ 7 - 6
hivemind/compression/serialization.py

@@ -6,21 +6,22 @@ import torch
 
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
-from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.proto import runtime_pb2
 from hivemind.utils.streaming import combine_from_streaming
 
-BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+_BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
     NONE=NoCompression(),
     FLOAT16=Float16Compression(),
     MEANSTD_16BIT=ScaledFloat16Compression(),
     QUANTILE_8BIT=Quantile8BitQuantization(),
     UNIFORM_8BIT=Uniform8BitQuantization(),
+    BLOCKWISE_8BIT=BlockwiseQuantization(),
 )
 
 for key in runtime_pb2.CompressionType.keys():
-    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
-    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+    assert key in _BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
+    actual_compression_type = _BASE_COMPRESSION_TYPES[key].compression_type
     assert (
         runtime_pb2.CompressionType.Name(actual_compression_type) == key
     ), f"Compression strategy for {key} has inconsistent type"
@@ -35,14 +36,14 @@ def serialize_torch_tensor(
 ) -> runtime_pb2.Tensor:
     """Serialize a given tensor into a protobuf message using the specified compression strategy"""
     assert tensor.device == torch.device("cpu")
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+    compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
     info = info or CompressionInfo.from_tensor(tensor, **kwargs)
     return compression.compress(tensor, info, allow_inplace)
 
 
 def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
     """Restore a pytorch tensor from a protobuf message"""
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+    compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
     return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
 
 

+ 1 - 0
hivemind/proto/runtime.proto

@@ -26,6 +26,7 @@ enum CompressionType{
   FLOAT16 = 2;
   QUANTILE_8BIT = 3;
   UNIFORM_8BIT = 4;
+  BLOCKWISE_8BIT = 5;
 }
 
 message Tensor {

+ 3 - 2
setup.py

@@ -23,7 +23,6 @@ EXECUTABLES = {
     "p2pd": "1252a2a2095040cef8e317f5801df8b8c93559711783a2496a0aff2f3e177e39",
 }
 
-
 here = os.path.abspath(os.path.dirname(__file__))
 
 
@@ -140,7 +139,9 @@ with open("requirements-dev.txt") as dev_requirements_file:
 with open("requirements-docs.txt") as docs_requirements_file:
     extras["docs"] = list(map(str, parse_requirements(docs_requirements_file)))
 
-extras["all"] = extras["dev"] + extras["docs"]
+extras["bitsandbytes"] = ["bitsandbytes==0.32.3"]
+
+extras["all"] = extras["dev"] + extras["docs"] + extras["bitsandbytes"]
 
 setup(
     name="hivemind",

+ 1 - 1
tests/test_cli_scripts.py

@@ -35,7 +35,7 @@ def test_dht_connection_successful():
         dht_client_proc.stderr.readline()
     first_report_msg = dht_client_proc.stderr.readline()
 
-    assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
+    assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg
 
     # ensure we get the output of dht_proc after the start of dht_client_proc
     sleep(dht_refresh_period)

+ 2 - 0
tests/test_compression.py

@@ -38,6 +38,8 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     assert error.square().mean() < beta
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     assert error.square().mean() < beta
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.BLOCKWISE_8BIT)) - X
+    assert error.square().mean() < beta
 
     zeros = torch.zeros(5, 5)
     for compression_type in CompressionType.values():

+ 1 - 1
tests/test_moe.py

@@ -162,7 +162,7 @@ def test_remote_module_call(hidden_dim=16):
 
         # check that the server is still alive after processing a malformed request
         out3_yet_again = real_expert(dummy_x[1:])
-        assert torch.allclose(out3_yet_again, out3[1:])
+        assert torch.allclose(out3_yet_again, out3[1:], atol=1e-5, rtol=0)
 
 
 @pytest.mark.forked

+ 15 - 7
tests/test_start_server.py

@@ -1,5 +1,6 @@
 import os
 import re
+from functools import partial
 from subprocess import PIPE, Popen
 from tempfile import TemporaryDirectory
 
@@ -10,10 +11,11 @@ def test_background_server_identity_path():
     with TemporaryDirectory() as tempdir:
         id_path = os.path.join(tempdir, "id")
 
-        with background_server(num_experts=1, identity_path=id_path) as server_info_1, background_server(
-            num_experts=1, identity_path=id_path
-        ) as server_info_2, background_server(num_experts=1, identity_path=None) as server_info_3:
+        server_runner = partial(background_server, num_experts=1, device="cpu", hidden_dim=1)
 
+        with server_runner(identity_path=id_path) as server_info_1, server_runner(
+            identity_path=id_path
+        ) as server_info_2, server_runner(identity_path=None) as server_info_3:
             assert server_info_1.peer_id == server_info_2.peer_id
             assert server_info_1.peer_id != server_info_3.peer_id
             assert server_info_3.peer_id == server_info_3.peer_id
@@ -33,9 +35,11 @@ def test_cli_run_server_identity_path():
         )
 
         # Skip line "Generating new identity (libp2p private key) in {path to file}"
+        server_1_proc.stderr.readline()
         line = server_1_proc.stderr.readline()
-        line = server_1_proc.stderr.readline()
-        addrs_1 = set(re.search(pattern, line).group(1).split(", "))
+        addrs_pattern_result = re.search(pattern, line)
+        assert addrs_pattern_result is not None, line
+        addrs_1 = set(addrs_pattern_result.group(1).split(", "))
         ids_1 = set(a.split("/")[-1] for a in addrs_1)
 
         assert len(ids_1) == 1
@@ -48,7 +52,9 @@ def test_cli_run_server_identity_path():
         )
 
         line = server_2_proc.stderr.readline()
-        addrs_2 = set(re.search(pattern, line).group(1).split(", "))
+        addrs_pattern_result = re.search(pattern, line)
+        assert addrs_pattern_result is not None, line
+        addrs_2 = set(addrs_pattern_result.group(1).split(", "))
         ids_2 = set(a.split("/")[-1] for a in addrs_2)
 
         assert len(ids_2) == 1
@@ -61,7 +67,9 @@ def test_cli_run_server_identity_path():
         )
 
         line = server_3_proc.stderr.readline()
-        addrs_3 = set(re.search(pattern, line).group(1).split(", "))
+        addrs_pattern_result = re.search(pattern, line)
+        assert addrs_pattern_result is not None, line
+        addrs_3 = set(addrs_pattern_result.group(1).split(", "))
         ids_3 = set(a.split("/")[-1] for a in addrs_3)
 
         assert len(ids_3) == 1