vahe1994 1 年之前
父節點
當前提交
19e851567d
共有 2 個文件被更改,包括 17 次插入16 次删除
  1. 14 14
      src/petals/server/server.py
  2. 3 2
      src/petals/utils/convert_block.py

+ 14 - 14
src/petals/server/server.py

@@ -497,20 +497,20 @@ class ModuleContainer(threading.Thread):
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                 )
-                if not is_gptq_quant(block_config):
-                    block = convert_block(
-                        block,
-                        block_index,
-                        block_config,
-                        tensor_parallel_devices,
-                        device,
-                        quant_type,
-                        adapters=server_info.adapters,
-                        freeze=True,
-                        token=token,
-                        cache_dir=cache_dir,
-                        max_disk_space=max_disk_space,
-                    )
+                # if not is_gptq_quant(block_config):
+                block = convert_block(
+                    block,
+                    block_index,
+                    block_config,
+                    tensor_parallel_devices,
+                    device,
+                    quant_type,
+                    adapters=server_info.adapters,
+                    freeze=True,
+                    token=token,
+                    cache_dir=cache_dir,
+                    max_disk_space=max_disk_space,
+                )
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,

+ 3 - 2
src/petals/utils/convert_block.py

@@ -54,13 +54,14 @@ def convert_block(
 
     block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
 
-    if quant_type != QuantType.NONE:
+    if quant_type != QuantType.NONE and not is_gptq_quant(config):
+        print("I'm still quantizing ")
         block = quantize_module(block, quant_type=quant_type)
 
     for shard, device in zip(block.module_shards, block.devices):
         shard.to(device)
 
-    if adapters:
+    if adapters and not is_gptq_quant(config):
         from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
 
         create_lora_adapter(block)