|
@@ -82,8 +82,9 @@ def load_pretrained_block(
|
|
print(model_name,block_prefix,revision,token,cache_dir,max_disk_space)
|
|
print(model_name,block_prefix,revision,token,cache_dir,max_disk_space)
|
|
|
|
|
|
print("now printing", state_dict)
|
|
print("now printing", state_dict)
|
|
- print("block.named_parameters()",block.named_parameters())
|
|
|
|
|
|
+ # print("block.named_parameters()",block.named_parameters())
|
|
for param_name, _ in block.named_parameters():
|
|
for param_name, _ in block.named_parameters():
|
|
|
|
+ print(param_name)
|
|
assert param_name in state_dict, f"{param_name} not in state dict"
|
|
assert param_name in state_dict, f"{param_name} not in state dict"
|
|
param = state_dict[param_name]
|
|
param = state_dict[param_name]
|
|
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
|
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|