vahe1994 1 anno fa
parent
commit
ad0d51e5f7
1 ha cambiato i file con 1 aggiunte e 1 eliminazioni
  1. 1 1
      src/petals/server/from_pretrained.py

+ 1 - 1
src/petals/server/from_pretrained.py

@@ -62,7 +62,7 @@ def load_pretrained_block(
     if is_gptq_quant(config):
         print("Now loading GPTQ")
         hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=True)
-        hf_quantizer.optimum_quantizer.block_name_to_quantize = str(block_index)
+        hf_quantizer.optimum_quantizer.block_name_to_quantize = str(0)
         tmp_block_list = torch.nn.ModuleList([block])
         tmp_block_list.__class__.main_input_name = "input_ids"
         torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)