vahe1994 1 年之前
父節點
當前提交
2df7a328fd
共有 2 個文件被更改,包括 2 次插入4 次删除
  1. 1 2
      src/petals/server/from_pretrained.py
  2. 1 2
      src/petals/utils/convert_block.py

+ 1 - 2
src/petals/server/from_pretrained.py

@@ -63,7 +63,7 @@ def load_pretrained_block(
         print("Now loading GPTQ")
         hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=True)
         hf_quantizer.optimum_quantizer.block_name_to_quantize = str(block_index)
-        tmp_block_list = block
+        tmp_block_list = torch.nn.ModuleList([block])
         tmp_block_list.__class__.main_input_name = "input_ids"
         torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
         device_map = hf_quantizer.update_device_map("cuda")
@@ -80,7 +80,6 @@ def load_pretrained_block(
         max_disk_space=max_disk_space,
     )
     print("now printing", block)
-    print("state dict", state_dict)
     for param_name, _ in block.named_parameters():
         assert param_name in state_dict, f"{param_name} not in state dict"
         param = state_dict[param_name]

+ 1 - 2
src/petals/utils/convert_block.py

@@ -22,8 +22,7 @@ class QuantType(Enum):
     NF4 = 2  # 4-bit as in the QLoRA paper
 
 def is_gptq_quant(config):
-    return hasattr(config, 'quantization_config') and hasattr(config.quantization_config,
-                                                                  "quant_method") and config.quantization_config.quant_method == "gptq"
+    return hasattr(config, 'quantization_config') and ("quant_method" in config.quantization_config)
 
 def convert_block(
     block: nn.Module,