Browse Source

Import bitsandbytes only if it is used (#546)

When bitsandbytes is installed, hivemind always tries to import it even if it doesn't work. This leads to a number of undesirable behaviors:

- `import hivemind` fails if bitsandbytes is installed but fails to resolve CUDA dynamic libraries or load its own dynamic libraries. This happens all the time:
    - on CPU-only hosts with CUDA installed (e.g., Colabs)
    - on hosts with misconfigured CUDA (e.g., most nora envs)
    - on macOS (even when the daemon is compiled correctly and it should work), etc.
- Users see the irrelevant bitsandbytes welcome message even if they don't use it. This message is currently impossible to suppress. Multiple users has reported that this is not expected when they import hivemind/petals.

This PR fixes it, so now hivemind's behavior matches with, e.g., HF transformers (import bitsandbytes only if it's necessary).
Alexander Borzunov 2 years ago
parent
commit
f5ca10ab23
1 changed files with 13 additions and 11 deletions
  1. 13 11
      hivemind/compression/quantization.py

+ 13 - 11
hivemind/compression/quantization.py

@@ -1,4 +1,3 @@
-import importlib.util
 import math
 import math
 import os
 import os
 import warnings
 import warnings
@@ -9,13 +8,11 @@ from typing import Tuple
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-if importlib.util.find_spec("bitsandbytes") is not None:
-    warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
-    from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
-
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 
 
+warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
+
 EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
 EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
 
 
 
 
@@ -133,9 +130,12 @@ class BlockwiseQuantization(Quantization):
         self, tensor: torch.Tensor, allow_inplace: bool = False
         self, tensor: torch.Tensor, allow_inplace: bool = False
     ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
     ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
         try:
         try:
-            quantized, (absmax, codebook) = quantize_blockwise(tensor)
-        except NameError:
+            # This runs actual import only on the 1st call, copies references after that
+            from bitsandbytes.functional import quantize_blockwise
+        except ImportError:
             raise ImportError(BNB_MISSING_MESSAGE)
             raise ImportError(BNB_MISSING_MESSAGE)
+
+        quantized, (absmax, codebook) = quantize_blockwise(tensor)
         return quantized.numpy(), (absmax.numpy(), codebook.numpy())
         return quantized.numpy(), (absmax.numpy(), codebook.numpy())
 
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
@@ -163,6 +163,11 @@ class BlockwiseQuantization(Quantization):
         )
         )
 
 
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        try:
+            from bitsandbytes.functional import dequantize_blockwise
+        except ImportError:
+            raise ImportError(BNB_MISSING_MESSAGE)
+
         absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
         absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
         codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
         codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
         absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
         absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
@@ -176,9 +181,6 @@ class BlockwiseQuantization(Quantization):
         absmax = torch.as_tensor(absmax)
         absmax = torch.as_tensor(absmax)
         codebook = torch.as_tensor(codebook)
         codebook = torch.as_tensor(codebook)
         quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
         quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
-        try:
-            result = dequantize_blockwise(quantized, (absmax, codebook))  # Always returns a float32 tensor
-        except NameError:
-            raise ImportError(BNB_MISSING_MESSAGE)
+        result = dequantize_blockwise(quantized, (absmax, codebook))  # Always returns a float32 tensor
         result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
         result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
         return result
         return result