Bläddra i källkod

Add loading into device directly

artek0chumak 2 år sedan
förälder
incheckning
da204f1285
1 ändrade filer med 9 tillägg och 4 borttagningar
  1. 9 4
      src/petals/utils/peft.py

+ 9 - 4
src/petals/utils/peft.py

@@ -19,10 +19,10 @@ def check_peft_repository(repo_id: str) -> bool:
     return len(list_of_files) > 0
     return len(list_of_files) > 0
 
 
 
 
-def load_specific_module(layers_name: List[str], filepath: str, framework: str = "pt"):
+def load_specific_module(layers_name: List[str], filepath: str, framework: str = "pt", device: Optional[int] = None):
     tensors = dict()
     tensors = dict()
     is_tensors_found = dict()
     is_tensors_found = dict()
-    with safe_open(filepath, framework=framework) as f:
+    with safe_open(filepath, framework=framework, device=device) as f:
         for k in f.keys():
         for k in f.keys():
             for layer_name in layers_name:
             for layer_name in layers_name:
                 if k.startswith(layer_name):
                 if k.startswith(layer_name):
@@ -34,7 +34,9 @@ def load_specific_module(layers_name: List[str], filepath: str, framework: str =
         return tensors
         return tensors
 
 
 
 
-def get_adapter_from_repo(repo_id: str, layers_name: Optional[List[str]] = None, **kwargs):
+def get_adapter_from_repo(
+    repo_id: str, layers_name: Optional[List[str]] = None, device: Optional[int] = None, **kwargs
+):
     config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
     config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
     if config_path is None:
     if config_path is None:
         raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
         raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
@@ -45,12 +47,13 @@ def get_adapter_from_repo(repo_id: str, layers_name: Optional[List[str]] = None,
         raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
         raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
     if layers_name is None:
     if layers_name is None:
         return config, load_file(weight_path)
         return config, load_file(weight_path)
-    return config, load_specific_module(layers_name, weight_path)
+    return config, load_specific_module(layers_name, weight_path, device=device)
 
 
 
 
 def load_peft(
 def load_peft(
     repo_id: str,
     repo_id: str,
     layers_name: Optional[List[str]] = None,
     layers_name: Optional[List[str]] = None,
+    device: Optional[int] = None,
     *,
     *,
     revision: Optional[str] = None,
     revision: Optional[str] = None,
     use_auth_token: Optional[str] = None,
     use_auth_token: Optional[str] = None,
@@ -68,6 +71,7 @@ def load_peft(
             return get_adapter_from_repo(
             return get_adapter_from_repo(
                 repo_id,
                 repo_id,
                 layers_name,
                 layers_name,
+                device,
                 revision=revision,
                 revision=revision,
                 use_auth_token=use_auth_token,
                 use_auth_token=use_auth_token,
                 cache_dir=cache_dir,
                 cache_dir=cache_dir,
@@ -93,6 +97,7 @@ def load_peft(
                 return get_adapter_from_repo(
                 return get_adapter_from_repo(
                     repo_id,
                     repo_id,
                     layers_name,
                     layers_name,
+                    device,
                     revision=revision,
                     revision=revision,
                     use_auth_token=use_auth_token,
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
                     cache_dir=cache_dir,