vahe1994 1 year ago
parent
commit
c3014d4d8f
1 changed files with 2 additions and 1 deletions
  1. 2 1
      src/petals/server/from_pretrained.py

+ 2 - 1
src/petals/server/from_pretrained.py

@@ -63,7 +63,7 @@ def load_pretrained_block(
         print("Now loading GPTQ")
         hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=True)
         hf_quantizer.optimum_quantizer.block_name_to_quantize = str(block_index)
-        tmp_block_list = torch.nn.ModuleList([block])
+        tmp_block_list = block
         tmp_block_list.__class__.main_input_name = "input_ids"
         torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
         device_map = hf_quantizer.update_device_map("cuda")
@@ -80,6 +80,7 @@ def load_pretrained_block(
         max_disk_space=max_disk_space,
     )
     print("now printing", block)
+    print("state dict", state_dict)
     for param_name, _ in block.named_parameters():
         assert param_name in state_dict, f"{param_name} not in state dict"
         param = state_dict[param_name]