|
@@ -57,7 +57,7 @@ def load_pretrained_block(
|
|
|
block = get_model_block(config, layer_idx=block_index)
|
|
|
|
|
|
block_prefix = f"{config.block_prefix}.{block_index}."
|
|
|
- print(config)
|
|
|
+ # print(config)
|
|
|
print(is_gptq_quant(config))
|
|
|
if is_gptq_quant(config):
|
|
|
print("Now loading GPTQ")
|
|
@@ -66,10 +66,12 @@ def load_pretrained_block(
|
|
|
tmp_block_list = torch.nn.ModuleList([block])
|
|
|
tmp_block_list.__class__.main_input_name = "input_ids"
|
|
|
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
|
|
- device_map = hf_quantizer.update_device_map("cuda")
|
|
|
+ device_map = hf_quantizer.update_device_map("cpu")
|
|
|
hf_quantizer.preprocess_model(
|
|
|
model=tmp_block_list, device_map=device_map, keep_in_fp32_modules=False,
|
|
|
)
|
|
|
+ if model_name[-1] != '.':
|
|
|
+ model_name += '.'
|
|
|
|
|
|
state_dict = _load_state_dict_from_repo(
|
|
|
model_name,
|
|
@@ -79,17 +81,21 @@ def load_pretrained_block(
|
|
|
cache_dir=cache_dir,
|
|
|
max_disk_space=max_disk_space,
|
|
|
)
|
|
|
- print(model_name,block_prefix,revision,token,cache_dir,max_disk_space)
|
|
|
+ print(model_name, block_prefix, revision,token,cache_dir,max_disk_space)
|
|
|
|
|
|
- print("now printing", state_dict)
|
|
|
+ # print("now printing", state_dict)
|
|
|
# print("block.named_parameters()",block.named_parameters())
|
|
|
- for param_name, _ in block.named_parameters():
|
|
|
- print(param_name)
|
|
|
- assert param_name in state_dict, f"{param_name} not in state dict"
|
|
|
- param = state_dict[param_name]
|
|
|
- if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
|
|
- param = param.to(torch_dtype)
|
|
|
- set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
|
|
+ if is_gptq_quant(config):
|
|
|
+ print("loading state_dict")
|
|
|
+ block.load_state_dict(state_dict, assign=True, strict=False, device_map='cpu')
|
|
|
+ else:
|
|
|
+ for param_name, _ in block.named_parameters():
|
|
|
+ print(param_name)
|
|
|
+ assert param_name in state_dict, f"{param_name} not in state dict"
|
|
|
+ param = state_dict[param_name]
|
|
|
+ if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
|
|
+ param = param.to(torch_dtype)
|
|
|
+ set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
|
|
|
|
|
logger.info(f"Loaded {model_name} block {block_index}")
|
|
|
return block
|