|
@@ -1,4 +1,3 @@
|
|
|
-import importlib.util
|
|
|
import math
|
|
|
import os
|
|
|
import warnings
|
|
@@ -9,13 +8,11 @@ from typing import Tuple
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
-if importlib.util.find_spec("bitsandbytes") is not None:
|
|
|
- warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
|
|
|
- from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
|
|
|
-
|
|
|
from hivemind.compression.base import CompressionBase, CompressionInfo
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
|
|
+warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
|
|
|
+
|
|
|
EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
|
|
|
|
|
|
|
|
@@ -133,9 +130,12 @@ class BlockwiseQuantization(Quantization):
|
|
|
self, tensor: torch.Tensor, allow_inplace: bool = False
|
|
|
) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
|
|
try:
|
|
|
- quantized, (absmax, codebook) = quantize_blockwise(tensor)
|
|
|
- except NameError:
|
|
|
+ # This runs actual import only on the 1st call, copies references after that
|
|
|
+ from bitsandbytes.functional import quantize_blockwise
|
|
|
+ except ImportError:
|
|
|
raise ImportError(BNB_MISSING_MESSAGE)
|
|
|
+
|
|
|
+ quantized, (absmax, codebook) = quantize_blockwise(tensor)
|
|
|
return quantized.numpy(), (absmax.numpy(), codebook.numpy())
|
|
|
|
|
|
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
|
|
@@ -163,6 +163,11 @@ class BlockwiseQuantization(Quantization):
|
|
|
)
|
|
|
|
|
|
def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
|
|
|
+ try:
|
|
|
+ from bitsandbytes.functional import dequantize_blockwise
|
|
|
+ except ImportError:
|
|
|
+ raise ImportError(BNB_MISSING_MESSAGE)
|
|
|
+
|
|
|
absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
|
|
|
codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
|
|
|
absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
|
|
@@ -176,9 +181,6 @@ class BlockwiseQuantization(Quantization):
|
|
|
absmax = torch.as_tensor(absmax)
|
|
|
codebook = torch.as_tensor(codebook)
|
|
|
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
|
|
|
- try:
|
|
|
- result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor
|
|
|
- except NameError:
|
|
|
- raise ImportError(BNB_MISSING_MESSAGE)
|
|
|
+ result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor
|
|
|
result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
|
|
|
return result
|