Преглед изворни кода

Fix checking for nonexistent keys

Max Ryabinin пре 1 година
родитељ
комит
9cb4c721e7
1 измењених фајлова са 6 додато и 5 уклоњено
  1. 6 5
      src/petals/server/from_pretrained.py

+ 6 - 5
src/petals/server/from_pretrained.py

@@ -70,11 +70,12 @@ def load_pretrained_block(
     assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
 
     for param_name, _ in block.named_parameters():
-        assert param_name in state_dict, f"{param_name} not in state dict"
-        param = state_dict[param_name]
-        if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
-            param = param.to(torch_dtype)
-        set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
+        if param_name != "self_attn.qkv_proj.weight":
+            assert param_name in state_dict, f"{param_name} not in state dict"
+            param = state_dict[param_name]
+            if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
+                param = param.to(torch_dtype)
+            set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
 
     logger.info(f"Loaded {model_name} block {block_index}")
     logger.debug(f"Details: {report}")