Prechádzať zdrojové kódy

Import bitsandbytes only if it's going to be used (#180)

Alexander Borzunov 2 rokov pred
rodič
commit
6dd9a938bd
1 zmenil súbory, kde vykonal 6 pridanie a 3 odobranie
  1. 6 3
      src/petals/utils/convert_block.py

+ 6 - 3
src/petals/utils/convert_block.py

@@ -4,7 +4,6 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
 import re
 from typing import Sequence
 
-import bitsandbytes as bnb
 import tensor_parallel as tp
 import torch
 import torch.nn as nn
@@ -14,7 +13,6 @@ from transformers import BloomConfig
 from transformers.models.bloom.modeling_bloom import BloomAttention
 
 from petals.bloom.block import WrappedBloomBlock
-from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -75,6 +73,12 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
             `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
             `6.0` as described by the paper.
     """
+
+    # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
+    import bitsandbytes as bnb
+
+    from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
+
     for n, module in model.named_children():
         if len(list(module.children())) > 0:
             replace_8bit_linear(module, threshold)
@@ -98,7 +102,6 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
 def make_tensor_parallel(
     block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
 ):
-    assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt))
     tp_config = get_bloom_config(model_config, devices)
     del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
     tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)