Переглянути джерело

Use common folder for all caches, make it a volume in Dockerfile (#141)

Alexander Borzunov 2 роки тому
батько
коміт
e99bf36647

+ 4 - 1
Dockerfile

@@ -19,9 +19,12 @@ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -
 ENV PATH="/opt/conda/bin:${PATH}"
 
 RUN conda install python~=3.10 pip && \
-    pip install --no-cache-dir "torch>=1.12" torchvision torchaudio && \
+    pip install --no-cache-dir "torch>=1.12" && \
     conda clean --all && rm -rf ~/.cache/pip
 
+VOLUME /cache
+ENV PETALS_CACHE=/cache
+
 COPY . petals/
 RUN pip install -e petals[dev]
 

+ 1 - 1
README.md

@@ -40,7 +40,7 @@ Connect your own GPU and increase Petals capacity:
 (conda) $ python -m petals.cli.run_server bigscience/bloom-petals
 
 # Or using a GPU-enabled Docker image
-sudo docker run --net host --ipc host --gpus all --rm learningathome/petals:main \
+sudo docker run --net host --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
     python -m petals.cli.run_server bigscience/bloom-petals
 ```
 

+ 5 - 0
src/petals/bloom/from_pretrained.py

@@ -16,6 +16,7 @@ from transformers.modeling_utils import WEIGHTS_NAME
 from transformers.utils.hub import cached_path, hf_bucket_url
 
 from petals.bloom import BloomBlock, BloomConfig
+from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -37,8 +38,12 @@ def load_pretrained_block(
     cache_dir: Optional[str] = None,
 ) -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
+
     if config is None:
         config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
+    if cache_dir is None:
+        cache_dir = DEFAULT_CACHE_DIR
+
     block = BloomBlock(config, layer_number=block_index)
     state_dict = _load_state_dict(
         converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir

+ 1 - 1
src/petals/server/backend.py

@@ -7,7 +7,7 @@ from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 
 from petals.bloom.from_pretrained import BloomBlock
-from petals.server.cache import MemoryCache
+from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.utils.misc import is_dummy
 

+ 0 - 0
src/petals/server/cache.py → src/petals/server/memory_cache.py


+ 7 - 2
src/petals/server/server.py

@@ -24,8 +24,8 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend
-from petals.server.cache import MemoryCache
 from petals.server.handler import TransformerConnectionHandler
+from petals.server.memory_cache import MemoryCache
 from petals.server.throughput import get_host_throughput
 from petals.utils.convert_8bit import replace_8bit_linear
 
@@ -160,7 +160,12 @@ class Server:
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
             throughput = get_host_throughput(
-                self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
+                self.block_config,
+                device,
+                torch_dtype,
+                load_in_8bit=load_in_8bit,
+                force_eval=(throughput == "eval"),
+                cache_dir=cache_dir,
             )
         self.throughput = throughput
 

+ 8 - 8
src/petals/server/throughput.py

@@ -2,11 +2,10 @@ import fcntl
 import json
 import os
 import subprocess
-import tempfile
 import time
 from hashlib import sha256
 from pathlib import Path
-from typing import Union
+from typing import Optional, Union
 
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -15,15 +14,12 @@ from petals.bloom.block import BloomBlock
 from petals.bloom.model import BloomConfig
 from petals.bloom.ops import build_alibi_tensor
 from petals.utils.convert_8bit import replace_8bit_linear
+from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput_v2.json")
-DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
-
-
 def get_host_throughput(
     config: BloomConfig,
     device: torch.device,
@@ -31,8 +27,7 @@ def get_host_throughput(
     *,
     load_in_8bit: bool,
     force_eval: bool = False,
-    cache_path: str = DEFAULT_CACHE_PATH,
-    lock_path: str = DEFAULT_LOCK_PATH,
+    cache_dir: Optional[str] = None,
 ) -> float:
     # Resolve default dtypes
     if dtype == "auto" or dtype is None:
@@ -40,6 +35,11 @@ def get_host_throughput(
         if dtype == "auto" or dtype is None:
             dtype = torch.float32
 
+    if cache_dir is None:
+        cache_dir = DEFAULT_CACHE_DIR
+    lock_path = Path(cache_dir, "throughput.lock")
+    cache_path = Path(cache_dir, "throughput_v2.json")
+
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
     with open(lock_path, "wb") as lock_fd:

+ 4 - 0
src/petals/utils/disk_cache.py

@@ -0,0 +1,4 @@
+import os
+from pathlib import Path
+
+DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))