vahe1994 1 år sedan
förälder
incheckning
d220120b78
1 ändrade filer med 17 tillägg och 11 borttagningar
  1. 17 11
      src/petals/server/from_pretrained.py

+ 17 - 11
src/petals/server/from_pretrained.py

@@ -57,7 +57,7 @@ def load_pretrained_block(
         block = get_model_block(config, layer_idx=block_index)
 
     block_prefix = f"{config.block_prefix}.{block_index}."
-    print(config)
+    # print(config)
     print(is_gptq_quant(config))
     if is_gptq_quant(config):
         print("Now loading GPTQ")
@@ -66,10 +66,12 @@ def load_pretrained_block(
         tmp_block_list = torch.nn.ModuleList([block])
         tmp_block_list.__class__.main_input_name = "input_ids"
         torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
-        device_map = hf_quantizer.update_device_map("cuda")
+        device_map = hf_quantizer.update_device_map("cpu")
         hf_quantizer.preprocess_model(
             model=tmp_block_list, device_map=device_map, keep_in_fp32_modules=False,
         )
+        if model_name[-1] != '.':
+            model_name += '.'
 
     state_dict = _load_state_dict_from_repo(
         model_name,
@@ -79,17 +81,21 @@ def load_pretrained_block(
         cache_dir=cache_dir,
         max_disk_space=max_disk_space,
     )
-    print(model_name,block_prefix,revision,token,cache_dir,max_disk_space)
+    print(model_name, block_prefix, revision,token,cache_dir,max_disk_space)
 
-    print("now printing", state_dict)
+    # print("now printing", state_dict)
     # print("block.named_parameters()",block.named_parameters())
-    for param_name, _ in block.named_parameters():
-        print(param_name)
-        assert param_name in state_dict, f"{param_name} not in state dict"
-        param = state_dict[param_name]
-        if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
-            param = param.to(torch_dtype)
-        set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
+    if is_gptq_quant(config):
+        print("loading state_dict")
+        block.load_state_dict(state_dict, assign=True, strict=False, device_map='cpu')
+    else:
+        for param_name, _ in block.named_parameters():
+            print(param_name)
+            assert param_name in state_dict, f"{param_name} not in state dict"
+            param = state_dict[param_name]
+            if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
+                param = param.to(torch_dtype)
+            set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
 
     logger.info(f"Loaded {model_name} block {block_index}")
     return block