vahe1994 hace 1 año
padre
commit
d0ba4bce11
Se han modificado 1 ficheros con 2 adiciones y 1 borrados
  1. 2 1
      src/petals/server/from_pretrained.py

+ 2 - 1
src/petals/server/from_pretrained.py

@@ -82,8 +82,9 @@ def load_pretrained_block(
     print(model_name,block_prefix,revision,token,cache_dir,max_disk_space)
     print(model_name,block_prefix,revision,token,cache_dir,max_disk_space)
 
 
     print("now printing", state_dict)
     print("now printing", state_dict)
-    print("block.named_parameters()",block.named_parameters())
+    # print("block.named_parameters()",block.named_parameters())
     for param_name, _ in block.named_parameters():
     for param_name, _ in block.named_parameters():
+        print(param_name)
         assert param_name in state_dict, f"{param_name} not in state dict"
         assert param_name in state_dict, f"{param_name} not in state dict"
         param = state_dict[param_name]
         param = state_dict[param_name]
         if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
         if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):