Procházet zdrojové kódy

Import bitsandbytes only if it is used (#546)

When bitsandbytes is installed, hivemind always tries to import it even if it doesn't work. This leads to a number of undesirable behaviors:

- `import hivemind` fails if bitsandbytes is installed but fails to resolve CUDA dynamic libraries or load its own dynamic libraries. This happens all the time:
    - on CPU-only hosts with CUDA installed (e.g., Colabs)
    - on hosts with misconfigured CUDA (e.g., most nora envs)
    - on macOS (even when the daemon is compiled correctly and it should work), etc.
- Users see the irrelevant bitsandbytes welcome message even if they don't use it. This message is currently impossible to suppress. Multiple users has reported that this is not expected when they import hivemind/petals.

This PR fixes it, so now hivemind's behavior matches with, e.g., HF transformers (import bitsandbytes only if it's necessary).
Alexander Borzunov před 2 roky
rodič
revize
f5ca10ab23
1 změnil soubory, kde provedl 13 přidání a 11 odebrání
  1. 13 11
      hivemind/compression/quantization.py

+ 13 - 11
hivemind/compression/quantization.py

@@ -1,4 +1,3 @@
-import importlib.util
 import math
 import os
 import warnings
@@ -9,13 +8,11 @@ from typing import Tuple
 import numpy as np
 import torch
 
-if importlib.util.find_spec("bitsandbytes") is not None:
-    warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
-    from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
-
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.proto import runtime_pb2
 
+warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
+
 EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
 
 
@@ -133,9 +130,12 @@ class BlockwiseQuantization(Quantization):
         self, tensor: torch.Tensor, allow_inplace: bool = False
     ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
         try:
-            quantized, (absmax, codebook) = quantize_blockwise(tensor)
-        except NameError:
+            # This runs actual import only on the 1st call, copies references after that
+            from bitsandbytes.functional import quantize_blockwise
+        except ImportError:
             raise ImportError(BNB_MISSING_MESSAGE)
+
+        quantized, (absmax, codebook) = quantize_blockwise(tensor)
         return quantized.numpy(), (absmax.numpy(), codebook.numpy())
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
@@ -163,6 +163,11 @@ class BlockwiseQuantization(Quantization):
         )
 
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        try:
+            from bitsandbytes.functional import dequantize_blockwise
+        except ImportError:
+            raise ImportError(BNB_MISSING_MESSAGE)
+
         absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
         codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
         absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
@@ -176,9 +181,6 @@ class BlockwiseQuantization(Quantization):
         absmax = torch.as_tensor(absmax)
         codebook = torch.as_tensor(codebook)
         quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
-        try:
-            result = dequantize_blockwise(quantized, (absmax, codebook))  # Always returns a float32 tensor
-        except NameError:
-            raise ImportError(BNB_MISSING_MESSAGE)
+        result = dequantize_blockwise(quantized, (absmax, codebook))  # Always returns a float32 tensor
         result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
         return result