|
@@ -62,7 +62,7 @@ def load_pretrained_block(
|
|
|
if is_gptq_quant(config):
|
|
|
print("Now loading GPTQ")
|
|
|
hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=True)
|
|
|
- # hf_quantizer.optimum_quantizer.block_name_to_quantize = str(0)
|
|
|
+ hf_quantizer.optimum_quantizer.block_name_to_quantize = str(0)
|
|
|
tmp_block_list = torch.nn.ModuleList([block])
|
|
|
tmp_block_list.__class__.main_input_name = "input_ids"
|
|
|
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|