Ver código fonte

Safer way to download pefts

artek0chumak 2 anos atrás
pai
commit
9ea77b0911
1 arquivos alterados com 29 adições e 43 exclusões
  1. 29 43
      src/petals/utils/peft.py

+ 29 - 43
src/petals/utils/peft.py

@@ -36,7 +36,7 @@ def load_peft(
     
     try:
         with allow_cache_reads(cache_dir):
-            path = get_file_from_repo(
+            config_path = get_file_from_repo(
                 repo_id,
                 CONFIG_NAME,
                 revision=revision,
@@ -44,21 +44,37 @@ def load_peft(
                 cache_dir=cache_dir,
                 local_files_only=True,
             )
-            config = PeftConfig.from_json_file(path)
+            assert config_path is not None
+            config = PeftConfig.from_json_file(config_path)
+
+            weight_path = get_file_from_repo(
+                repo_id,
+                SAFETENSORS_WEIGHTS_NAME,
+                revision=revision,
+                use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
+                local_files_only=True,
+            )
+            assert weight_path is not None
+            return config, load_file(weight_path)
     except Exception:
         logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
 
     while True:
         try:
             with allow_cache_writes(cache_dir):
-                url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
-                file_size = get_hf_file_metadata(url, token=use_auth_token).size
+                config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
+                config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
+                wieght_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
+                wieght_file_size = get_hf_file_metadata(wieght_url, token=use_auth_token).size
+
+                file_size = config_file_size + wieght_file_size
                 if file_size is not None:
                     free_disk_space_for(repo_id, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
                 else:
-                    logger.warning(f"Failed to fetch size of weight from peft repo {repo_id}")
+                    logger.warning(f"Failed to fetch size from peft repo {repo_id}")
 
-                path = get_file_from_repo(
+                config_path = get_file_from_repo(
                     repo_id,
                     CONFIG_NAME,
                     revision=revision,
@@ -66,41 +82,11 @@ def load_peft(
                     cache_dir=cache_dir,
                     local_files_only=False,
                 )
-                if path is None:
+                if config_path is None:
                     raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
-                config = PeftConfig.from_json_file(path)
-                break
-        except Exception as e:
-            logger.warning(f"Failed to load file {CONFIG_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
-            time.sleep(delay)
-    
-    try:
-        with allow_cache_reads(cache_dir):
-            path = get_file_from_repo(
-                repo_id,
-                SAFETENSORS_WEIGHTS_NAME,
-                revision=revision,
-                use_auth_token=use_auth_token,
-                cache_dir=cache_dir,
-                local_files_only=True,
-            )
-            if path is not None:
-                return config, load_file(path)
-    except Exception:
-        logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
-        
-    # If not found, ensure that we have enough disk space to download them (maybe remove something)
-    while True:
-        try:
-            with allow_cache_writes(cache_dir):
-                url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
-                file_size = get_hf_file_metadata(url, token=use_auth_token).size
-                if file_size is not None:
-                    free_disk_space_for(repo_id, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
-                else:
-                    logger.warning(f"Failed to fetch size of weight from peft repo {repo_id}")
+                config = PeftConfig.from_json_file(config_path)
 
-                path = get_file_from_repo(
+                weight_path = get_file_from_repo(
                     repo_id,
                     SAFETENSORS_WEIGHTS_NAME,
                     revision=revision,
@@ -108,9 +94,9 @@ def load_peft(
                     cache_dir=cache_dir,
                     local_files_only=False,
                 )
-                if path is None:
+                if weight_path is None:
                     raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
-                return config, load_file(path)
+                return config, load_file(weight_path)
         except Exception as e:
-            logger.warning(f"Failed to load file {SAFETENSORS_WEIGHTS_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
-            time.sleep(delay)
+            logger.warning(f"Failed to load file {CONFIG_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
+            time.sleep(delay)