Browse Source

Clean up disk space (#152)

Alexander Borzunov 2 years ago
parent
commit
701ec7e53e

+ 59 - 14
src/petals/bloom/from_pretrained.py

@@ -8,6 +8,8 @@ If necessary, one can rewrite this to implement a different behavior, such as:
 """
 """
 from __future__ import annotations
 from __future__ import annotations
 
 
+import itertools
+import time
 from typing import Optional, OrderedDict, Union
 from typing import Optional, OrderedDict, Union
 
 
 import torch
 import torch
@@ -17,7 +19,8 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.utils import get_file_from_repo
 from transformers.utils import get_file_from_repo
 
 
 from petals.bloom.block import WrappedBloomBlock
 from petals.bloom.block import WrappedBloomBlock
-from petals.utils.disk_cache import DEFAULT_CACHE_DIR
+from petals.server.block_utils import get_block_size
+from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -33,6 +36,7 @@ def load_pretrained_block(
     torch_dtype: Union[torch.dtype, str] = "auto",
     torch_dtype: Union[torch.dtype, str] = "auto",
     use_auth_token: Optional[str] = None,
     use_auth_token: Optional[str] = None,
     cache_dir: Optional[str] = None,
     cache_dir: Optional[str] = None,
+    max_disk_space: Optional[int] = None,
 ) -> WrappedBloomBlock:
 ) -> WrappedBloomBlock:
     """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
     """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
 
 
@@ -43,7 +47,12 @@ def load_pretrained_block(
 
 
     block = WrappedBloomBlock(config)
     block = WrappedBloomBlock(config)
     state_dict = _load_state_dict(
     state_dict = _load_state_dict(
-        converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
+        converted_model_name_or_path,
+        block_index,
+        config,
+        use_auth_token=use_auth_token,
+        cache_dir=cache_dir,
+        max_disk_space=max_disk_space,
     )
     )
 
 
     if torch_dtype == "auto":
     if torch_dtype == "auto":
@@ -62,20 +71,56 @@ def load_pretrained_block(
 
 
 def _load_state_dict(
 def _load_state_dict(
     pretrained_model_name_or_path: str,
     pretrained_model_name_or_path: str,
-    block_index: Optional[int] = None,
+    block_index: int,
+    config: BloomConfig,
+    *,
     use_auth_token: Optional[str] = None,
     use_auth_token: Optional[str] = None,
-    cache_dir: Optional[str] = None,
+    cache_dir: str,
+    max_disk_space: Optional[int] = None,
+    min_backoff: float = 5,
 ) -> OrderedDict[str, torch.Tensor]:
 ) -> OrderedDict[str, torch.Tensor]:
-    revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
-    archive_file = get_file_from_repo(
-        pretrained_model_name_or_path,
-        filename=WEIGHTS_NAME,
-        revision=revision,
-        use_auth_token=use_auth_token,
-        cache_dir=cache_dir,
-    )
-    state_dict = torch.load(archive_file, map_location="cpu")
-    return state_dict
+    revision = BLOCK_BRANCH_PREFIX + str(block_index)
+
+    # First, try to find the weights locally
+    try:
+        with allow_cache_reads(cache_dir):
+            archive_file = get_file_from_repo(
+                pretrained_model_name_or_path,
+                filename=WEIGHTS_NAME,
+                revision=revision,
+                use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
+                local_files_only=True,
+            )
+            if archive_file is not None:
+                return torch.load(archive_file, map_location="cpu")
+    except Exception:
+        logger.debug(
+            f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
+        )
+
+    # If not found, ensure that we have enough disk space to download them (maybe remove something)
+    for attempt_no in itertools.count():
+        try:
+            with allow_cache_writes(cache_dir):
+                block_size = get_block_size(config, "disk")
+                free_disk_space_for(
+                    pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
+                )
+
+                archive_file = get_file_from_repo(
+                    pretrained_model_name_or_path,
+                    filename=WEIGHTS_NAME,
+                    revision=revision,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    local_files_only=False,
+                )
+                return torch.load(archive_file, map_location="cpu")
+        except Exception as e:
+            delay = min_backoff * (2**attempt_no)
+            logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
+            time.sleep(delay)
 
 
 
 
 DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
 DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 19 - 2
src/petals/cli/run_server.py

@@ -47,8 +47,18 @@ def main():
                         help='Use this many threads to pass results/exceptions from Runtime to Pools')
                         help='Use this many threads to pass results/exceptions from Runtime to Pools')
     parser.add_argument('--inference_max_length', type=int, default=2048,
     parser.add_argument('--inference_max_length', type=int, default=2048,
                         help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
                         help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
+
     parser.add_argument('--cache_dir', type=str, default=None,
     parser.add_argument('--cache_dir', type=str, default=None,
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
+    parser.add_argument("--max_disk_space", type=str, default=None,
+                        help="Maximal disk space used for caches. Example: 50GB, 100GiB (GB != GiB here). "
+                             "Default: unlimited. "
+                             "For bigscience/bloom-petals, this default means that the server may use up to "
+                             "min(free_disk_space, 350GB) in the worst case, which happens when the server runs "
+                             "for a long time and caches all model blocks after a number of rebalancings. "
+                             "However, this worst case is unlikely, expect the server to consume "
+                             "the disk space equal to 2-4x of your GPU memory on average.")
+
     parser.add_argument('--device', type=str, default=None, required=False,
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all blocks will use this device in torch notation; default: cuda if available else cpu')
                         help='all blocks will use this device in torch notation; default: cuda if available else cpu')
     parser.add_argument("--torch_dtype", type=str, default="auto",
     parser.add_argument("--torch_dtype", type=str, default="auto",
@@ -129,7 +139,14 @@ def main():
         attn_cache_size = parse_size(attn_cache_size)
         attn_cache_size = parse_size(attn_cache_size)
     assert isinstance(
     assert isinstance(
         attn_cache_size, (int, type(None))
         attn_cache_size, (int, type(None))
-    ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
+    ), "Unrecognized value for --attn_cache_size. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)"
+
+    max_disk_space = args.pop("max_disk_space")
+    if max_disk_space is not None:
+        max_disk_space = parse_size(max_disk_space)
+    assert isinstance(
+        max_disk_space, (int, type(None))
+    ), "Unrecognized value for --max_disk_space. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)"
 
 
     if args.pop("new_swarm"):
     if args.pop("new_swarm"):
         args["initial_peers"] = []
         args["initial_peers"] = []
@@ -138,7 +155,7 @@ def main():
     if load_in_8bit is not None:
     if load_in_8bit is not None:
         args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
         args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
 
 
-    server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
+    server = Server(**args, compression=compression, max_disk_space=max_disk_space, attn_cache_size=attn_cache_size)
     try:
     try:
         server.run()
         server.run()
     except KeyboardInterrupt:
     except KeyboardInterrupt:

+ 8 - 0
src/petals/server/handler.py

@@ -54,6 +54,14 @@ class TransformerConnectionHandler(ConnectionHandler):
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self._prioritizer = task_prioritizer
         self._prioritizer = task_prioritizer
 
 
+    def shutdown(self):
+        if self.is_alive():
+            self._outer_pipe.send("_shutdown")
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
+                self.terminate()
+
     async def _gather_inputs(
     async def _gather_inputs(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> Tuple[str, List[torch.Tensor], Dict]:
     ) -> Tuple[str, List[torch.Tensor], Dict]:

+ 13 - 3
src/petals/server/server.py

@@ -29,6 +29,7 @@ from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.memory_cache import MemoryCache
 from petals.server.throughput import get_host_throughput
 from petals.server.throughput import get_host_throughput
 from petals.utils.convert_8bit import replace_8bit_linear
 from petals.utils.convert_8bit import replace_8bit_linear
+from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -56,6 +57,7 @@ class Server:
         torch_dtype: str = "auto",
         torch_dtype: str = "auto",
         revision: str = "main",
         revision: str = "main",
         cache_dir: Optional[str] = None,
         cache_dir: Optional[str] = None,
+        max_disk_space: Optional[int] = None,
         attn_cache_size: Optional[int] = None,
         attn_cache_size: Optional[int] = None,
         alloc_timeout: float = 60,
         alloc_timeout: float = 60,
         device: Optional[Union[str, torch.device]] = None,
         device: Optional[Union[str, torch.device]] = None,
@@ -82,7 +84,6 @@ class Server:
         self.num_handlers = num_handlers
         self.num_handlers = num_handlers
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.inference_max_length = inference_max_length
         self.inference_max_length = inference_max_length
-        self.cache_dir = cache_dir
         self.compression = compression
         self.compression = compression
         self.stats_report_interval, self.update_period = stats_report_interval, update_period
         self.stats_report_interval, self.update_period = stats_report_interval, update_period
         self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
         self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
@@ -117,7 +118,8 @@ class Server:
         self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
         self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
         visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
         visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
         if initial_peers == PUBLIC_INITIAL_PEERS:
         if initial_peers == PUBLIC_INITIAL_PEERS:
-            logger.info("Connecting to the public Petals swarm")
+            logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
+            logger.info("Please check that your server is reachable at http://health.petals.ml")
         else:
         else:
             logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
             logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
 
@@ -158,6 +160,11 @@ class Server:
         logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
         logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
         self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
         self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
 
 
+        if cache_dir is None:
+            cache_dir = DEFAULT_CACHE_DIR
+        self.cache_dir = cache_dir
+        self.max_disk_space = max_disk_space
+
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
         if throughput in ["auto", "eval"]:
             throughput = get_host_throughput(
             throughput = get_host_throughput(
@@ -213,6 +220,7 @@ class Server:
                 inference_max_length=self.inference_max_length,
                 inference_max_length=self.inference_max_length,
                 torch_dtype=self.torch_dtype,
                 torch_dtype=self.torch_dtype,
                 cache_dir=self.cache_dir,
                 cache_dir=self.cache_dir,
+                max_disk_space=self.max_disk_space,
                 device=self.device,
                 device=self.device,
                 compression=self.compression,
                 compression=self.compression,
                 stats_report_interval=self.stats_report_interval,
                 stats_report_interval=self.stats_report_interval,
@@ -308,7 +316,8 @@ class ModuleContainer(threading.Thread):
         min_batch_size: int,
         min_batch_size: int,
         max_batch_size: int,
         max_batch_size: int,
         torch_dtype: torch.dtype,
         torch_dtype: torch.dtype,
-        cache_dir: Optional[str],
+        cache_dir: str,
+        max_disk_space: int,
         device: Union[str, torch.device],
         device: Union[str, torch.device],
         compression: CompressionType,
         compression: CompressionType,
         update_period: float,
         update_period: float,
@@ -340,6 +349,7 @@ class ModuleContainer(threading.Thread):
                     torch_dtype=torch_dtype,
                     torch_dtype=torch_dtype,
                     use_auth_token=use_auth_token,
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
                     cache_dir=cache_dir,
+                    max_disk_space=max_disk_space,
                 )
                 )
 
 
                 if load_in_8bit:
                 if load_in_8bit:

+ 82 - 0
src/petals/utils/disk_cache.py

@@ -1,4 +1,86 @@
+import fcntl
 import os
 import os
+import shutil
+from contextlib import contextmanager
 from pathlib import Path
 from pathlib import Path
+from typing import Optional
+
+import huggingface_hub
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__file__)
 
 
 DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))
 DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))
+
+BLOCKS_LOCK_FILE = "blocks.lock"
+
+
+@contextmanager
+def _blocks_lock(cache_dir: Optional[str], mode: int):
+    if cache_dir is None:
+        cache_dir = DEFAULT_CACHE_DIR
+    lock_path = Path(cache_dir, BLOCKS_LOCK_FILE)
+
+    os.makedirs(lock_path.parent, exist_ok=True)
+    with open(lock_path, "wb") as lock_fd:
+        fcntl.flock(lock_fd.fileno(), mode)
+        # The OS will release the lock when lock_fd is closed or the process is killed
+        yield
+
+
+def allow_cache_reads(cache_dir: Optional[str]):
+    """Allows simultaneous reads, guarantees that blocks won't be removed along the way (shared lock)"""
+    return _blocks_lock(cache_dir, fcntl.LOCK_SH)
+
+
+def allow_cache_writes(
+    cache_dir: Optional[str], *, reserve: Optional[int] = None, max_disk_space: Optional[int] = None
+):
+    """Allows saving new blocks and removing the old ones (exclusive lock)"""
+    return _blocks_lock(cache_dir, fcntl.LOCK_EX)
+
+
+def free_disk_space_for(
+    model_name: str,
+    size: int,
+    *,
+    cache_dir: Optional[str],
+    max_disk_space: Optional[int],
+    os_quota: int = 1024**3,  # Minimal space we should leave to keep OS function normally
+):
+    if cache_dir is None:
+        cache_dir = DEFAULT_CACHE_DIR
+    cache_info = huggingface_hub.scan_cache_dir(cache_dir)
+    model_repos = [repo for repo in cache_info.repos if repo.repo_type == "model" and repo.repo_id == model_name]
+
+    occupied_space = sum(repo.size_on_disk for repo in model_repos)
+    available_space = shutil.disk_usage(cache_dir).free - os_quota
+    if max_disk_space is not None:
+        available_space = min(available_space, max_disk_space - occupied_space)
+    if size <= available_space:
+        return
+
+    revisions = [revision for repo in model_repos for revision in repo.revisions]
+    revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified))
+
+    # Remove as few least recently used blocks as possible
+    pending_removal = []
+    freed_space = 0
+    extra_space_needed = size - available_space
+    for rev in revisions:
+        pending_removal.append(rev.commit_hash)
+        freed_space += rev.size_on_disk
+        if freed_space >= extra_space_needed:
+            break
+
+    if pending_removal:
+        gib = 1024**3
+        logger.info(f"Removing {len(pending_removal)} blocks to free {freed_space / gib:.1f} GiB of disk space")
+        delete_strategy = cache_info.delete_revisions(*pending_removal)
+        delete_strategy.execute()
+
+    if freed_space < extra_space_needed:
+        raise RuntimeError(
+            f"Insufficient disk space to load a block. Please free {extra_space_needed - freed_space:.1f} GiB "
+            f"on the volume for {cache_dir} or increase --max_disk_space if you set it manually"
+        )