|
@@ -1,5 +1,7 @@
|
|
|
|
+import importlib.util
|
|
import math
|
|
import math
|
|
import os
|
|
import os
|
|
|
|
+import warnings
|
|
from abc import ABC, abstractmethod
|
|
from abc import ABC, abstractmethod
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Tuple
|
|
from typing import Tuple
|
|
@@ -7,6 +9,10 @@ from typing import Tuple
|
|
import numpy as np
|
|
import numpy as np
|
|
import torch
|
|
import torch
|
|
|
|
|
|
|
|
+if importlib.util.find_spec("bitsandbytes") is not None:
|
|
|
|
+ warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
|
|
|
|
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
|
|
|
|
+
|
|
from hivemind.compression.base import CompressionBase, CompressionInfo
|
|
from hivemind.compression.base import CompressionBase, CompressionInfo
|
|
from hivemind.proto import runtime_pb2
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
|
|
@@ -112,3 +118,60 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
|
|
for job in jobs:
|
|
for job in jobs:
|
|
job.result()
|
|
job.result()
|
|
return np.quantile(partition_quantiles, quantiles)
|
|
return np.quantile(partition_quantiles, quantiles)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
|
|
|
|
+Please install it with `pip install bitsandbytes`
|
|
|
|
+or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class BlockwiseQuantization(Quantization):
|
|
|
|
+ compression_type = runtime_pb2.BLOCKWISE_8BIT
|
|
|
|
+ codebook_dtype, indices_dtype = np.float32, np.uint8
|
|
|
|
+
|
|
|
|
+ def quantize(
|
|
|
|
+ self, tensor: torch.Tensor, allow_inplace: bool = False
|
|
|
|
+ ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
|
|
|
+ try:
|
|
|
|
+ quantized, (absmax, codebook) = quantize_blockwise(tensor)
|
|
|
|
+ except NameError:
|
|
|
|
+ raise ImportError(BNB_MISSING_MESSAGE)
|
|
|
|
+ return quantized.numpy(), (absmax.numpy(), codebook.numpy())
|
|
|
|
+
|
|
|
|
+ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
|
|
|
|
+ quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
|
|
|
|
+
|
|
|
|
+ serialized_data = (
|
|
|
|
+ np.int64(len(absmax)).tobytes(),
|
|
|
|
+ np.int64(len(codebook)).tobytes(),
|
|
|
|
+ absmax.tobytes(),
|
|
|
|
+ codebook.tobytes(),
|
|
|
|
+ quantized.tobytes(),
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ return runtime_pb2.Tensor(
|
|
|
|
+ buffer=b"".join(serialized_data),
|
|
|
|
+ size=tensor.shape,
|
|
|
|
+ requires_grad=tensor.requires_grad,
|
|
|
|
+ dtype=tensor.numpy().dtype.name,
|
|
|
|
+ compression=self.compression_type,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
|
|
|
|
+ absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
|
|
|
|
+ codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
|
|
|
|
+ absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
|
|
|
|
+ codebook = np.frombuffer(
|
|
|
|
+ serialized_tensor.buffer, offset=16 + absmax.nbytes, count=codebook_size, dtype=self.codebook_dtype
|
|
|
|
+ )
|
|
|
|
+ quantized = np.frombuffer(
|
|
|
|
+ serialized_tensor.buffer, offset=16 + absmax.nbytes + codebook.nbytes, dtype=self.indices_dtype
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ absmax = torch.as_tensor(absmax)
|
|
|
|
+ codebook = torch.as_tensor(codebook)
|
|
|
|
+ quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
|
|
|
|
+ try:
|
|
|
|
+ return dequantize_blockwise(quantized, (absmax, codebook))
|
|
|
|
+ except NameError:
|
|
|
|
+ raise ImportError(BNB_MISSING_MESSAGE)
|