Sfoglia il codice sorgente

Create dummy data when materializing qkv_proj

Max Ryabinin 1 anno fa
parent
commit
4159e557bf
1 ha cambiato i file con 4 aggiunte e 0 eliminazioni
  1. 4 0
      src/petals/server/from_pretrained.py

+ 4 - 0
src/petals/server/from_pretrained.py

@@ -76,6 +76,10 @@ def load_pretrained_block(
             if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
                 param = param.to(torch_dtype)
             set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
+        else:
+            cur_block = getattr(block, param_name)
+            dummy_value = torch.empty_like(cur_block, device="cpu")
+            set_module_tensor_to_device(block, param_name, "cpu", dummy_value)
 
     logger.info(f"Loaded {model_name} block {block_index}")
     logger.debug(f"Details: {report}")