Sfoglia il codice sorgente

Add support for quantization with bitsandbytes (#490)

* Add support for quantization with bitsandbytes

* Extend the compression benchmark

* Add a test for blockwise compression

* Add a note to README about bitsandbytes

* Install bitsandbytes in tests as well

* Verify outputs consistently in test_moe.py
(to make the test less flaky)

* Pass device="cpu" in test_background_server_identity_path
This ensures that the server can actually launch in a GPU-enabled environment: otherwise initializing the CUDA context in a parent process prevents it

* Filter bitsandbytes warnings

(cherry picked from commit 131f82c97ea67510d552bb7a68138ad27cbfa5d4)
Max Ryabinin 2 anni fa
parent
commit
11a02608a3

+ 3 - 0
.github/workflows/run-benchmarks.yml

@@ -26,6 +26,9 @@ jobs:
           python -m pip install --upgrade pip
           python -m pip install --upgrade pip
           pip install -r requirements.txt
           pip install -r requirements.txt
           pip install -r requirements-dev.txt
           pip install -r requirements-dev.txt
+      - name: Build bitsandbytes
+        run: |
+          pip install bitsandbytes==0.32.3
       - name: Build hivemind
       - name: Build hivemind
         run: |
         run: |
           pip install .
           pip install .

+ 6 - 0
.github/workflows/run-tests.yml

@@ -29,6 +29,9 @@ jobs:
           python -m pip install --upgrade pip
           python -m pip install --upgrade pip
           pip install -r requirements.txt
           pip install -r requirements.txt
           pip install -r requirements-dev.txt
           pip install -r requirements-dev.txt
+      - name: Build bitsandbytes
+        run: |
+          pip install bitsandbytes==0.32.3
       - name: Build hivemind
       - name: Build hivemind
         run: |
         run: |
           pip install .
           pip install .
@@ -88,6 +91,9 @@ jobs:
           python -m pip install --upgrade pip
           python -m pip install --upgrade pip
           pip install -r requirements.txt
           pip install -r requirements.txt
           pip install -r requirements-dev.txt
           pip install -r requirements-dev.txt
+      - name: Build bitsandbytes
+        run: |
+          pip install bitsandbytes==0.32.3
       - name: Build hivemind
       - name: Build hivemind
         run: |
         run: |
           pip install -e . --no-use-pep517
           pip install -e . --no-use-pep517

+ 4 - 0
README.md

@@ -53,6 +53,10 @@ If your versions of Python and PyTorch match the requirements, you can install h
 pip install hivemind
 pip install hivemind
 ```
 ```
 
 
+Also, if you want to use blockwise 8-bit compression from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) 
+during data transfer, you can install it with `pip install hivemind[bitsandbytes]`. 
+After that, you can use the `BlockwiseQuantization` class in [hivemind.compression](./hivemind/compression)
+
 ### From source
 ### From source
 
 
 To install hivemind from source, simply run the following:
 To install hivemind from source, simply run the following:

+ 20 - 9
benchmarks/benchmark_tensor_compression.py

@@ -11,26 +11,37 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
+def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> [float, float, int]:
     t = time.time()
     t = time.time()
-    deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
-    return time.time() - t
+    serialized = serialize_torch_tensor(tensor, compression_type)
+    result = deserialize_torch_tensor(serialized)
+    return time.time() - t, (tensor - result).square().mean(), serialized.ByteSize()
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
-    parser.add_argument("--size", type=int, default=10000000, required=False)
+    parser.add_argument("--size", type=int, default=10_000_000, required=False)
     parser.add_argument("--seed", type=int, default=7348, required=False)
     parser.add_argument("--seed", type=int, default=7348, required=False)
     parser.add_argument("--num_iters", type=int, default=30, required=False)
     parser.add_argument("--num_iters", type=int, default=30, required=False)
 
 
     args = parser.parse_args()
     args = parser.parse_args()
 
 
     torch.manual_seed(args.seed)
     torch.manual_seed(args.seed)
-    X = torch.randn(args.size)
+    X = torch.randn(args.size, dtype=torch.float32)
 
 
     for name, compression_type in CompressionType.items():
     for name, compression_type in CompressionType.items():
-        tm = 0
+        total_time = 0
+        compression_error = 0
+        total_size = 0
         for i in range(args.num_iters):
         for i in range(args.num_iters):
-            tm += benchmark_compression(X, compression_type)
-        tm /= args.num_iters
-        logger.info(f"Compression type: {name}, time: {tm}")
+            iter_time, iter_distortion, size = benchmark_compression(X, compression_type)
+            total_time += iter_time
+            compression_error += iter_distortion
+            total_size += size
+        total_time /= args.num_iters
+        compression_error /= args.num_iters
+        total_size /= args.num_iters
+        logger.info(
+            f"Compression type: {name}, time: {total_time:.5f}, compression error: {compression_error:.5f}, "
+            f"size: {int(total_size):d}"
+        )

+ 1 - 1
hivemind/compression/__init__.py

@@ -5,7 +5,7 @@ Compression strategies that reduce the network communication in .averaging, .opt
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
-from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.compression.serialization import (
 from hivemind.compression.serialization import (
     deserialize_tensor_stream,
     deserialize_tensor_stream,
     deserialize_torch_tensor,
     deserialize_torch_tensor,

+ 63 - 0
hivemind/compression/quantization.py

@@ -1,5 +1,7 @@
+import importlib.util
 import math
 import math
 import os
 import os
+import warnings
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from typing import Tuple
 from typing import Tuple
@@ -7,6 +9,10 @@ from typing import Tuple
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
+if importlib.util.find_spec("bitsandbytes") is not None:
+    warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
+    from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
+
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 
 
@@ -112,3 +118,60 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
     for job in jobs:
     for job in jobs:
         job.result()
         job.result()
     return np.quantile(partition_quantiles, quantiles)
     return np.quantile(partition_quantiles, quantiles)
+
+
+BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly. 
+Please install it with `pip install bitsandbytes` 
+or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
+
+
+class BlockwiseQuantization(Quantization):
+    compression_type = runtime_pb2.BLOCKWISE_8BIT
+    codebook_dtype, indices_dtype = np.float32, np.uint8
+
+    def quantize(
+        self, tensor: torch.Tensor, allow_inplace: bool = False
+    ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
+        try:
+            quantized, (absmax, codebook) = quantize_blockwise(tensor)
+        except NameError:
+            raise ImportError(BNB_MISSING_MESSAGE)
+        return quantized.numpy(), (absmax.numpy(), codebook.numpy())
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+
+        serialized_data = (
+            np.int64(len(absmax)).tobytes(),
+            np.int64(len(codebook)).tobytes(),
+            absmax.tobytes(),
+            codebook.tobytes(),
+            quantized.tobytes(),
+        )
+
+        return runtime_pb2.Tensor(
+            buffer=b"".join(serialized_data),
+            size=tensor.shape,
+            requires_grad=tensor.requires_grad,
+            dtype=tensor.numpy().dtype.name,
+            compression=self.compression_type,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
+        codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
+        absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
+        codebook = np.frombuffer(
+            serialized_tensor.buffer, offset=16 + absmax.nbytes, count=codebook_size, dtype=self.codebook_dtype
+        )
+        quantized = np.frombuffer(
+            serialized_tensor.buffer, offset=16 + absmax.nbytes + codebook.nbytes, dtype=self.indices_dtype
+        )
+
+        absmax = torch.as_tensor(absmax)
+        codebook = torch.as_tensor(codebook)
+        quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
+        try:
+            return dequantize_blockwise(quantized, (absmax, codebook))
+        except NameError:
+            raise ImportError(BNB_MISSING_MESSAGE)

+ 7 - 6
hivemind/compression/serialization.py

@@ -6,21 +6,22 @@ import torch
 
 
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
-from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.streaming import combine_from_streaming
 from hivemind.utils.streaming import combine_from_streaming
 
 
-BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+_BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
     NONE=NoCompression(),
     NONE=NoCompression(),
     FLOAT16=Float16Compression(),
     FLOAT16=Float16Compression(),
     MEANSTD_16BIT=ScaledFloat16Compression(),
     MEANSTD_16BIT=ScaledFloat16Compression(),
     QUANTILE_8BIT=Quantile8BitQuantization(),
     QUANTILE_8BIT=Quantile8BitQuantization(),
     UNIFORM_8BIT=Uniform8BitQuantization(),
     UNIFORM_8BIT=Uniform8BitQuantization(),
+    BLOCKWISE_8BIT=BlockwiseQuantization(),
 )
 )
 
 
 for key in runtime_pb2.CompressionType.keys():
 for key in runtime_pb2.CompressionType.keys():
-    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
-    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+    assert key in _BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
+    actual_compression_type = _BASE_COMPRESSION_TYPES[key].compression_type
     assert (
     assert (
         runtime_pb2.CompressionType.Name(actual_compression_type) == key
         runtime_pb2.CompressionType.Name(actual_compression_type) == key
     ), f"Compression strategy for {key} has inconsistent type"
     ), f"Compression strategy for {key} has inconsistent type"
@@ -35,14 +36,14 @@ def serialize_torch_tensor(
 ) -> runtime_pb2.Tensor:
 ) -> runtime_pb2.Tensor:
     """Serialize a given tensor into a protobuf message using the specified compression strategy"""
     """Serialize a given tensor into a protobuf message using the specified compression strategy"""
     assert tensor.device == torch.device("cpu")
     assert tensor.device == torch.device("cpu")
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+    compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
     info = info or CompressionInfo.from_tensor(tensor, **kwargs)
     info = info or CompressionInfo.from_tensor(tensor, **kwargs)
     return compression.compress(tensor, info, allow_inplace)
     return compression.compress(tensor, info, allow_inplace)
 
 
 
 
 def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
 def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
     """Restore a pytorch tensor from a protobuf message"""
     """Restore a pytorch tensor from a protobuf message"""
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+    compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
     return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
     return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
 
 
 
 

+ 1 - 0
hivemind/proto/runtime.proto

@@ -26,6 +26,7 @@ enum CompressionType{
   FLOAT16 = 2;
   FLOAT16 = 2;
   QUANTILE_8BIT = 3;
   QUANTILE_8BIT = 3;
   UNIFORM_8BIT = 4;
   UNIFORM_8BIT = 4;
+  BLOCKWISE_8BIT = 5;
 }
 }
 
 
 message Tensor {
 message Tensor {

+ 3 - 2
setup.py

@@ -23,7 +23,6 @@ EXECUTABLES = {
     "p2pd": "1252a2a2095040cef8e317f5801df8b8c93559711783a2496a0aff2f3e177e39",
     "p2pd": "1252a2a2095040cef8e317f5801df8b8c93559711783a2496a0aff2f3e177e39",
 }
 }
 
 
-
 here = os.path.abspath(os.path.dirname(__file__))
 here = os.path.abspath(os.path.dirname(__file__))
 
 
 
 
@@ -140,7 +139,9 @@ with open("requirements-dev.txt") as dev_requirements_file:
 with open("requirements-docs.txt") as docs_requirements_file:
 with open("requirements-docs.txt") as docs_requirements_file:
     extras["docs"] = list(map(str, parse_requirements(docs_requirements_file)))
     extras["docs"] = list(map(str, parse_requirements(docs_requirements_file)))
 
 
-extras["all"] = extras["dev"] + extras["docs"]
+extras["bitsandbytes"] = ["bitsandbytes==0.32.3"]
+
+extras["all"] = extras["dev"] + extras["docs"] + extras["bitsandbytes"]
 
 
 setup(
 setup(
     name="hivemind",
     name="hivemind",

+ 1 - 1
tests/test_cli_scripts.py

@@ -35,7 +35,7 @@ def test_dht_connection_successful():
         dht_client_proc.stderr.readline()
         dht_client_proc.stderr.readline()
     first_report_msg = dht_client_proc.stderr.readline()
     first_report_msg = dht_client_proc.stderr.readline()
 
 
-    assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
+    assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg
 
 
     # ensure we get the output of dht_proc after the start of dht_client_proc
     # ensure we get the output of dht_proc after the start of dht_client_proc
     sleep(dht_refresh_period)
     sleep(dht_refresh_period)

+ 2 - 0
tests/test_compression.py

@@ -38,6 +38,8 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     assert error.square().mean() < beta
     assert error.square().mean() < beta
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     assert error.square().mean() < beta
     assert error.square().mean() < beta
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.BLOCKWISE_8BIT)) - X
+    assert error.square().mean() < beta
 
 
     zeros = torch.zeros(5, 5)
     zeros = torch.zeros(5, 5)
     for compression_type in CompressionType.values():
     for compression_type in CompressionType.values():

+ 1 - 1
tests/test_moe.py

@@ -162,7 +162,7 @@ def test_remote_module_call(hidden_dim=16):
 
 
         # check that the server is still alive after processing a malformed request
         # check that the server is still alive after processing a malformed request
         out3_yet_again = real_expert(dummy_x[1:])
         out3_yet_again = real_expert(dummy_x[1:])
-        assert torch.allclose(out3_yet_again, out3[1:])
+        assert torch.allclose(out3_yet_again, out3[1:], atol=1e-5, rtol=0)
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 15 - 7
tests/test_start_server.py

@@ -1,5 +1,6 @@
 import os
 import os
 import re
 import re
+from functools import partial
 from subprocess import PIPE, Popen
 from subprocess import PIPE, Popen
 from tempfile import TemporaryDirectory
 from tempfile import TemporaryDirectory
 
 
@@ -10,10 +11,11 @@ def test_background_server_identity_path():
     with TemporaryDirectory() as tempdir:
     with TemporaryDirectory() as tempdir:
         id_path = os.path.join(tempdir, "id")
         id_path = os.path.join(tempdir, "id")
 
 
-        with background_server(num_experts=1, identity_path=id_path) as server_info_1, background_server(
-            num_experts=1, identity_path=id_path
-        ) as server_info_2, background_server(num_experts=1, identity_path=None) as server_info_3:
+        server_runner = partial(background_server, num_experts=1, device="cpu", hidden_dim=1)
 
 
+        with server_runner(identity_path=id_path) as server_info_1, server_runner(
+            identity_path=id_path
+        ) as server_info_2, server_runner(identity_path=None) as server_info_3:
             assert server_info_1.peer_id == server_info_2.peer_id
             assert server_info_1.peer_id == server_info_2.peer_id
             assert server_info_1.peer_id != server_info_3.peer_id
             assert server_info_1.peer_id != server_info_3.peer_id
             assert server_info_3.peer_id == server_info_3.peer_id
             assert server_info_3.peer_id == server_info_3.peer_id
@@ -33,9 +35,11 @@ def test_cli_run_server_identity_path():
         )
         )
 
 
         # Skip line "Generating new identity (libp2p private key) in {path to file}"
         # Skip line "Generating new identity (libp2p private key) in {path to file}"
+        server_1_proc.stderr.readline()
         line = server_1_proc.stderr.readline()
         line = server_1_proc.stderr.readline()
-        line = server_1_proc.stderr.readline()
-        addrs_1 = set(re.search(pattern, line).group(1).split(", "))
+        addrs_pattern_result = re.search(pattern, line)
+        assert addrs_pattern_result is not None, line
+        addrs_1 = set(addrs_pattern_result.group(1).split(", "))
         ids_1 = set(a.split("/")[-1] for a in addrs_1)
         ids_1 = set(a.split("/")[-1] for a in addrs_1)
 
 
         assert len(ids_1) == 1
         assert len(ids_1) == 1
@@ -48,7 +52,9 @@ def test_cli_run_server_identity_path():
         )
         )
 
 
         line = server_2_proc.stderr.readline()
         line = server_2_proc.stderr.readline()
-        addrs_2 = set(re.search(pattern, line).group(1).split(", "))
+        addrs_pattern_result = re.search(pattern, line)
+        assert addrs_pattern_result is not None, line
+        addrs_2 = set(addrs_pattern_result.group(1).split(", "))
         ids_2 = set(a.split("/")[-1] for a in addrs_2)
         ids_2 = set(a.split("/")[-1] for a in addrs_2)
 
 
         assert len(ids_2) == 1
         assert len(ids_2) == 1
@@ -61,7 +67,9 @@ def test_cli_run_server_identity_path():
         )
         )
 
 
         line = server_3_proc.stderr.readline()
         line = server_3_proc.stderr.readline()
-        addrs_3 = set(re.search(pattern, line).group(1).split(", "))
+        addrs_pattern_result = re.search(pattern, line)
+        assert addrs_pattern_result is not None, line
+        addrs_3 = set(addrs_pattern_result.group(1).split(", "))
         ids_3 = set(a.split("/")[-1] for a in addrs_3)
         ids_3 = set(a.split("/")[-1] for a in addrs_3)
 
 
         assert len(ids_3) == 1
         assert len(ids_3) == 1