فهرست منبع

init auto_gptq

vahe1994 1 سال پیش
والد
کامیت
554779f654
3فایلهای تغییر یافته به همراه35 افزوده شده و 14 حذف شده
  1. 14 0
      src/petals/server/from_pretrained.py
  2. 18 14
      src/petals/server/server.py
  3. 3 0
      src/petals/utils/convert_block.py

+ 14 - 0
src/petals/server/from_pretrained.py

@@ -21,6 +21,7 @@ from huggingface_hub import get_hf_file_metadata, hf_hub_url
 from huggingface_hub.utils import EntryNotFoundError
 from transformers import PretrainedConfig, PreTrainedModel
 from transformers.utils import get_file_from_repo
+from transformers.quantizers import AutoHfQuantizer
 
 from petals.constants import DTYPE_MAP
 from petals.models.mixtral import WrappedMixtralBlock
@@ -28,6 +29,7 @@ from petals.server.block_utils import get_model_block, resolve_block_dtype
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
 from petals.utils.hf_auth import always_needs_auth
+from petals.utils.convert_block import is_gptq_quant
 
 logger = get_logger(__name__)
 
@@ -55,6 +57,18 @@ def load_pretrained_block(
         block = get_model_block(config, layer_idx=block_index)
 
     block_prefix = f"{config.block_prefix}.{block_index}."
+
+    if is_gptq_quant(config):
+        hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=True)
+        hf_quantizer.optimum_quantizer.block_name_to_quantize = str(block_index)
+        tmp_block_list = torch.nn.ModuleList([block])
+        tmp_block_list.__class__.main_input_name = "input_ids"
+        torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
+        device_map = hf_quantizer.update_device_map("cuda")
+        hf_quantizer.preprocess_model(
+            model=tmp_block_list, device_map=device_map, keep_in_fp32_modules=False,
+        )
+
     state_dict = _load_state_dict_from_repo(
         model_name,
         block_prefix,

+ 18 - 14
src/petals/server/server.py

@@ -33,7 +33,7 @@ from petals.server.memory_cache import MemoryCache
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
-from petals.utils.convert_block import QuantType, check_device_balance, convert_block
+from petals.utils.convert_block import QuantType, check_device_balance, convert_block, is_gptq_quant
 from petals.utils.dht import declare_active_modules, get_remote_module_infos
 from petals.utils.misc import get_size_in_bytes
 from petals.utils.ping import PingAggregator
@@ -428,6 +428,8 @@ class Server:
         self.dht.join()
 
 
+
+
 class ModuleContainer(threading.Thread):
     """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
 
@@ -495,19 +497,20 @@ class ModuleContainer(threading.Thread):
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                 )
-                block = convert_block(
-                    block,
-                    block_index,
-                    block_config,
-                    tensor_parallel_devices,
-                    device,
-                    quant_type,
-                    adapters=server_info.adapters,
-                    freeze=True,
-                    token=token,
-                    cache_dir=cache_dir,
-                    max_disk_space=max_disk_space,
-                )
+                if not is_gptq_quant(block_config):
+                    block = convert_block(
+                        block,
+                        block_index,
+                        block_config,
+                        tensor_parallel_devices,
+                        device,
+                        quant_type,
+                        adapters=server_info.adapters,
+                        freeze=True,
+                        token=token,
+                        cache_dir=cache_dir,
+                        max_disk_space=max_disk_space,
+                    )
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,
@@ -554,6 +557,7 @@ class ModuleContainer(threading.Thread):
             **kwargs,
         )
 
+
     def __init__(
         self,
         dht: DHT,

+ 3 - 0
src/petals/utils/convert_block.py

@@ -21,6 +21,9 @@ class QuantType(Enum):
     INT8 = 1  # 8-bit as in the LLM.int8() paper
     NF4 = 2  # 4-bit as in the QLoRA paper
 
+def is_gptq_quant(config):
+    return hasattr(config, 'quantization_config') and hasattr(config.quantization_config,
+                                                                  "quant_method") and config.quantization_config.quant_method == "gptq"
 
 def convert_block(
     block: nn.Module,