Parcourir la source

Import bitsandbytes only if it is used (#546)

When bitsandbytes is installed, hivemind always tries to import it even if it doesn't work. This leads to a number of undesirable behaviors:

- `import hivemind` fails if bitsandbytes is installed but fails to resolve CUDA dynamic libraries or load its own dynamic libraries. This happens all the time:
    - on CPU-only hosts with CUDA installed (e.g., Colabs)
    - on hosts with misconfigured CUDA (e.g., most nora envs)
    - on macOS (even when the daemon is compiled correctly and it should work), etc.
- Users see the irrelevant bitsandbytes welcome message even if they don't use it. This message is currently impossible to suppress. Multiple users has reported that this is not expected when they import hivemind/petals.

This PR fixes it, so now hivemind's behavior matches with, e.g., HF transformers (import bitsandbytes only if it's necessary).
Alexander Borzunov il y a 2 ans
Parent
commit
f5ca10ab23
1 fichiers modifiés avec 13 ajouts et 11 suppressions
  1. 13 11
      hivemind/compression/quantization.py

+ 13 - 11
hivemind/compression/quantization.py

@@ -1,4 +1,3 @@
-import importlib.util
 import math
 import os
 import warnings
@@ -9,13 +8,11 @@ from typing import Tuple
 import numpy as np
 import torch
 
-if importlib.util.find_spec("bitsandbytes") is not None:
-    warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
-    from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
-
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.proto import runtime_pb2
 
+warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
+
 EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
 
 
@@ -133,9 +130,12 @@ class BlockwiseQuantization(Quantization):
         self, tensor: torch.Tensor, allow_inplace: bool = False
     ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
         try:
-            quantized, (absmax, codebook) = quantize_blockwise(tensor)
-        except NameError:
+            # This runs actual import only on the 1st call, copies references after that
+            from bitsandbytes.functional import quantize_blockwise
+        except ImportError:
             raise ImportError(BNB_MISSING_MESSAGE)
+
+        quantized, (absmax, codebook) = quantize_blockwise(tensor)
         return quantized.numpy(), (absmax.numpy(), codebook.numpy())
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
@@ -163,6 +163,11 @@ class BlockwiseQuantization(Quantization):
         )
 
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        try:
+            from bitsandbytes.functional import dequantize_blockwise
+        except ImportError:
+            raise ImportError(BNB_MISSING_MESSAGE)
+
         absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
         codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
         absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
@@ -176,9 +181,6 @@ class BlockwiseQuantization(Quantization):
         absmax = torch.as_tensor(absmax)
         codebook = torch.as_tensor(codebook)
         quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
-        try:
-            result = dequantize_blockwise(quantized, (absmax, codebook))  # Always returns a float32 tensor
-        except NameError:
-            raise ImportError(BNB_MISSING_MESSAGE)
+        result = dequantize_blockwise(quantized, (absmax, codebook))  # Always returns a float32 tensor
         result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
         return result