Ver Fonte

Merge branch 'main' into forward_kwargs

justheuristic há 2 anos atrás
pai
commit
22bcbb34fd
3 ficheiros alterados com 71 adições e 19 exclusões
  1. 3 3
      setup.cfg
  2. 63 12
      src/petals/server/from_pretrained.py
  3. 5 4
      src/petals/utils/peft.py

+ 3 - 3
setup.cfg

@@ -33,10 +33,10 @@ python_requires = >=3.8
 install_requires =
     torch>=1.12
     bitsandbytes==0.41.1
-    accelerate>=0.20.3,<0.21.0
+    accelerate>=0.22.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
-    transformers>=4.31.0,<5.0.0
+    transformers>=4.31.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
     hivemind==1.1.9
@@ -46,7 +46,7 @@ install_requires =
     cpufeature>=0.2.0
     packaging>=20.9
     sentencepiece>=0.1.99
-    peft>=0.4.0
+    peft>=0.5.0
     safetensors>=0.3.1
     Dijkstar>=2.6.0
 

+ 63 - 12
src/petals/server/from_pretrained.py

@@ -8,14 +8,17 @@ If necessary, one can rewrite this to implement a different behavior, such as:
 """
 import json
 import time
+from contextlib import suppress
 from typing import Dict, Optional, Union
 
+import safetensors
 import torch
 import torch.nn as nn
 from accelerate import init_empty_weights
 from accelerate.utils import set_module_tensor_to_device
 from hivemind.utils.logging import get_logger
 from huggingface_hub import get_hf_file_metadata, hf_hub_url
+from huggingface_hub.utils import EntryNotFoundError
 from transformers import PretrainedConfig
 from transformers.utils import get_file_from_repo
 
@@ -61,7 +64,7 @@ def load_pretrained_block(
     )
 
     # dummy load, check that keys match
-    report = block.load_state_dict(state_dict, strict=True)
+    report = block.load_state_dict(state_dict, strict=False)
     assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
 
     for param_name, _ in block.named_parameters():
@@ -71,7 +74,8 @@ def load_pretrained_block(
             param = param.to(torch_dtype)
         set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
 
-    logger.info(f"Loaded {model_name} block {block_index}, {report}")
+    logger.info(f"Loaded {model_name} block {block_index}")
+    logger.debug(f"Details: {report}")
     return block
 
 
@@ -90,11 +94,14 @@ def _load_state_dict_from_repo(
     if always_needs_auth(model_name) and token is None:
         token = True
 
-    index_file = get_file_from_repo(
-        model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
-    )
-    if index_file is not None:  # Sharded model
-        with open(index_file) as f:
+    index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir)
+    if index_file.endswith(".index.json"):  # Sharded model
+        path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir)
+        if path is None:
+            # _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared)
+            raise ValueError(f"Failed to get file {index_file}")
+
+        with open(path) as f:
             index = json.load(f)
         filenames = {
             filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
@@ -102,14 +109,15 @@ def _load_state_dict_from_repo(
         if not filenames:
             raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
     else:  # Non-sharded model
-        filenames = {"pytorch_model.bin"}
+        filenames = {index_file}
     logger.debug(f"Loading {block_prefix}* from {filenames}")
 
     state_dict = {}
     for filename in filenames:
-        shard_state_dict = _load_state_dict_from_file(
+        shard_state_dict = _load_state_dict_from_repo_file(
             model_name,
             filename,
+            block_prefix=block_prefix,
             revision=revision,
             token=token,
             cache_dir=cache_dir,
@@ -124,10 +132,42 @@ def _load_state_dict_from_repo(
     return state_dict
 
 
-def _load_state_dict_from_file(
+INDEX_FILES = ["model.safetensors.index.json", "model.safetensors", "pytorch_model.bin.index.json", "pytorch_model.bin"]
+
+
+def _find_index_file(
+    model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str
+) -> str:
+    # If we have cached weights (e.g., Pickle from older Petals versions), reuse them
+    for filename in INDEX_FILES:
+        path = get_file_from_repo(
+            model_name,
+            filename,
+            revision=revision,
+            use_auth_token=token,
+            cache_dir=cache_dir,
+            local_files_only=True,
+        )
+        if path is not None:
+            return filename
+
+    # If we don't, prefer Safetensors when possible
+    # (we don't download files here since we can't account for max_disk_space in case of large files)
+    for filename in INDEX_FILES:
+        with suppress(EntryNotFoundError):
+            get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token)
+            return filename
+
+    raise ValueError(
+        f"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist"
+    )
+
+
+def _load_state_dict_from_repo_file(
     model_name: str,
     filename: str,
     *,
+    block_prefix: Optional[str] = None,
     revision: Optional[str] = None,
     token: Optional[Union[str, bool]] = None,
     cache_dir: str,
@@ -146,7 +186,7 @@ def _load_state_dict_from_file(
                 local_files_only=True,
             )
             if path is not None:
-                return torch.load(path, map_location="cpu")
+                return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
     except Exception:
         logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
 
@@ -171,7 +211,18 @@ def _load_state_dict_from_file(
                 )
                 if path is None:
                     raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
-                return torch.load(path, map_location="cpu")
+                return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
         except Exception as e:
             logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
             time.sleep(delay)
+
+
+def _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict:
+    if path.endswith(".bin"):
+        return torch.load(path, map_location="cpu")
+
+    if path.endswith(".safetensors"):
+        with safetensors.safe_open(path, framework="pt", device="cpu") as f:
+            return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)}
+
+    raise ValueError(f"Unknown weight format: {path}")

+ 5 - 4
src/petals/utils/peft.py

@@ -10,8 +10,9 @@ import transformers
 from accelerate import init_empty_weights
 from hivemind.utils.logging import get_logger
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
+from peft.config import PeftConfig
 from peft.tuners import lora
-from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
+from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
 from safetensors import safe_open
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
@@ -155,15 +156,15 @@ class AdapterContextMixin:
 using_adapter = AdapterContextMixin.using_adapter
 
 
-class LoraLinear(lora.Linear, AdapterContextMixin):
+class LoraLinear(AdapterContextMixin, lora.Linear):
     """LoRA linear layer that uses adapter selected via using_adapter"""
 
 
-class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
+class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
     """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
 
 
-class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
+class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
     """LoRA linear 4-bit that uses adapter selected via using_adapter"""