|
@@ -63,7 +63,7 @@ def load_pretrained_block(
|
|
|
print("Now loading GPTQ")
|
|
|
hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=True)
|
|
|
hf_quantizer.optimum_quantizer.block_name_to_quantize = str(block_index)
|
|
|
- tmp_block_list = block
|
|
|
+ tmp_block_list = torch.nn.ModuleList([block])
|
|
|
tmp_block_list.__class__.main_input_name = "input_ids"
|
|
|
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
|
|
device_map = hf_quantizer.update_device_map("cuda")
|
|
@@ -80,7 +80,6 @@ def load_pretrained_block(
|
|
|
max_disk_space=max_disk_space,
|
|
|
)
|
|
|
print("now printing", block)
|
|
|
- print("state dict", state_dict)
|
|
|
for param_name, _ in block.named_parameters():
|
|
|
assert param_name in state_dict, f"{param_name} not in state dict"
|
|
|
param = state_dict[param_name]
|