Răsfoiți Sursa

Allow free_disk_space_for() remove arbitrary files from Petals cache (#339)

Before this PR, `free_disk_space_for()` was able to remove **(a)** only entire cached revisions (= git commits/branches) and **(b)** only from the repository we're loading right now.

This PR allows this functions to remove arbitrary files separately from any repositories.

This is useful for transition to Petals 1.2.0+, since it now uses original repos instead of the ones with converted models (see #323). In particular, the cache for `bigscience/bloom-petals` is now deprecated and should be removed in favor of `bigscience/bloom`. This is also useful as a way to free space before loading LoRA adapters (#335).
Alexander Borzunov 2 ani în urmă
părinte
comite
4d9c26fe5c
2 a modificat fișierele cu 15 adăugiri și 20 ștergeri
  1. 1 1
      src/petals/server/from_pretrained.py
  2. 14 19
      src/petals/utils/disk_cache.py

+ 1 - 1
src/petals/server/from_pretrained.py

@@ -153,7 +153,7 @@ def _load_state_dict_from_file(
                 url = hf_hub_url(model_name, filename, revision=revision)
                 file_size = get_hf_file_metadata(url, token=use_auth_token).size
                 if file_size is not None:
-                    free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
+                    free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
                 else:
                     logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
 

+ 14 - 19
src/petals/utils/disk_cache.py

@@ -33,15 +33,12 @@ def allow_cache_reads(cache_dir: Optional[str]):
     return _blocks_lock(cache_dir, fcntl.LOCK_SH)
 
 
-def allow_cache_writes(
-    cache_dir: Optional[str], *, reserve: Optional[int] = None, max_disk_space: Optional[int] = None
-):
+def allow_cache_writes(cache_dir: Optional[str]):
     """Allows saving new blocks and removing the old ones (exclusive lock)"""
     return _blocks_lock(cache_dir, fcntl.LOCK_EX)
 
 
 def free_disk_space_for(
-    model_name: str,
     size: int,
     *,
     cache_dir: Optional[str],
@@ -51,35 +48,33 @@ def free_disk_space_for(
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
     cache_info = huggingface_hub.scan_cache_dir(cache_dir)
-    model_repos = [repo for repo in cache_info.repos if repo.repo_type == "model" and repo.repo_id == model_name]
 
-    occupied_space = sum(repo.size_on_disk for repo in model_repos)
     available_space = shutil.disk_usage(cache_dir).free - os_quota
     if max_disk_space is not None:
-        available_space = min(available_space, max_disk_space - occupied_space)
+        available_space = min(available_space, max_disk_space - cache_info.size_on_disk)
 
     gib = 1024**3
     logger.debug(f"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB")
     if size <= available_space:
         return
 
-    revisions = [revision for repo in model_repos for revision in repo.revisions]
-    revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified))
+    cached_files = [file for repo in cache_info.repos for revision in repo.revisions for file in revision.files]
 
-    # Remove as few least recently used shards as possible
-    pending_removal = []
+    # Remove as few least recently used files as possible
+    removed_files = []
     freed_space = 0
     extra_space_needed = size - available_space
-    for rev in revisions:
-        pending_removal.append(rev.commit_hash)
-        freed_space += rev.size_on_disk
+    for file in sorted(cached_files, key=lambda file: file.blob_last_accessed):
+        os.remove(file.file_path)  # Remove symlink
+        os.remove(file.blob_path)  # Remove contents
+
+        removed_files.append(file)
+        freed_space += file.size_on_disk
         if freed_space >= extra_space_needed:
             break
-
-    if pending_removal:
-        logger.info(f"Removing {len(pending_removal)} shards to free {freed_space / gib:.1f} GiB of disk space")
-        delete_strategy = cache_info.delete_revisions(*pending_removal)
-        delete_strategy.execute()
+    if removed_files:
+        logger.info(f"Removed {len(removed_files)} files to free {freed_space / gib:.1f} GiB of disk space")
+        logger.debug(f"Removed paths: {[str(file.file_path) for file in removed_files]}")
 
     if freed_space < extra_space_needed:
         raise RuntimeError(