Ver Fonte

Update transformers to 4.31.0 and peft to 0.4.0 (#371)

Alexander Borzunov há 2 anos atrás
pai
commit
c735dd7ba3

+ 1 - 1
.github/workflows/run-tests.yaml

@@ -10,7 +10,7 @@ jobs:
     runs-on: ubuntu-latest
     strategy:
       matrix:
-        python-version: [ '3.7', '3.8', '3.9', '3.10' ]
+        python-version: [ '3.8', '3.9', '3.10' ]
       fail-fast: false
     timeout-minutes: 15
     steps:

+ 1 - 1
README.md

@@ -31,7 +31,7 @@ print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
 
 ### Connect your GPU and increase Petals capacity
 
-Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+):
+Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+):
 
 ```bash
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia

+ 4 - 4
setup.cfg

@@ -15,9 +15,9 @@ classifiers =
     Intended Audience :: Science/Research
     License :: OSI Approved :: MIT License
     Programming Language :: Python :: 3
-    Programming Language :: Python :: 3.7
     Programming Language :: Python :: 3.8
     Programming Language :: Python :: 3.9
+    Programming Language :: Python :: 3.10
     Topic :: Scientific/Engineering
     Topic :: Scientific/Engineering :: Mathematics
     Topic :: Scientific/Engineering :: Artificial Intelligence
@@ -29,14 +29,14 @@ classifiers =
 package_dir =
     = src
 packages = find:
-python_requires = >=3.7
+python_requires = >=3.8
 install_requires =
     torch>=1.12
     bitsandbytes==0.40.1.post1
     accelerate>=0.16.0,<0.21.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
-    transformers>=4.30.1,<4.31.0
+    transformers>=4.31.0,<5.0.0
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind==1.1.8
     hivemind==1.1.8
@@ -46,7 +46,7 @@ install_requires =
     cpufeature>=0.2.0
     packaging>=20.9
     sentencepiece>=0.1.99
-    peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
+    peft>=0.4.0
     safetensors>=0.3.1
     Dijkstar>=2.6.0
 

+ 2 - 2
src/petals/__init__.py

@@ -16,8 +16,8 @@ __version__ = "1.2.0.dev3"
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
-        version.parse("4.30.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.30.1,<5.0.0"
+        version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0"
 
 
 def _override_bfloat16_mode_default():

+ 1 - 1
src/petals/cli/run_server.py

@@ -132,7 +132,7 @@ def main():
     parser.add_argument("--mean_balance_check_period", type=float, default=60,
                         help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
 
-    parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
+    parser.add_argument("--token", action='store_true', help="Hugging Face hub auth token for .from_pretrained()")
     parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
                         help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
                              "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "

+ 4 - 12
src/petals/models/bloom/model.py

@@ -20,9 +20,7 @@ logger = get_logger(__name__)
 class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
     """BloomModel, but all transformer layers are hosted by the swarm"""
 
-    _keys_to_ignore_on_load_missing = (
-        BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing
-    )
+    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
     _keys_to_ignore_on_load_unexpected = [r"^h\."]
 
     config_class = DistributedBloomConfig
@@ -93,11 +91,8 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
 
 
 class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
-    _keys_to_ignore_on_load_missing = (
-        BloomForCausalLM._keys_to_ignore_on_load_missing
-        + DistributedBloomModel._keys_to_ignore_on_load_missing
-        + [r"^lm_head\."]  # Missing since they are shared with input embeddings
-    )
+    _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_missing += [r"^lm_head\."]  # Missing since they are shared with input embeddings
     _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
 
     config_class = DistributedBloomConfig
@@ -115,10 +110,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
 
 
 class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
-    _keys_to_ignore_on_load_missing = (
-        BloomForSequenceClassification._keys_to_ignore_on_load_missing
-        + DistributedBloomModel._keys_to_ignore_on_load_missing
-    )
+    _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
     _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
 
     config_class = DistributedBloomConfig

+ 4 - 5
src/petals/models/llama/model.py

@@ -21,7 +21,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
     """LlamaModel, but all transformer layers are hosted by the swarm"""
 
     _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
-    _keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."]
+    _keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
 
     config_class = DistributedLlamaConfig
 
@@ -115,6 +115,8 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll
     def __init__(self, config: DistributedLlamaConfig):
         LlamaPreTrainedModel.__init__(self, config)
         self.model = DistributedLlamaModel(config)
+        self.pretraining_tp = config.pretraining_tp
+        self.vocab_size = config.vocab_size
         self.lm_head = LMHead(config)
 
         # Initialize weights and apply final processing
@@ -129,10 +131,7 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll
 
 
 class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
-    _keys_to_ignore_on_load_missing = (
-        LlamaForSequenceClassification._keys_to_ignore_on_load_missing
-        + DistributedLlamaModel._keys_to_ignore_on_load_missing
-    )
+    _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
     _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
 
     config_class = DistributedLlamaConfig

+ 10 - 10
src/petals/server/from_pretrained.py

@@ -34,12 +34,12 @@ def load_pretrained_block(
     config: Optional[PretrainedConfig] = None,
     torch_dtype: Union[torch.dtype, str] = "auto",
     revision: Optional[str] = None,
-    use_auth_token: Optional[str] = None,
+    token: Optional[str] = None,
     cache_dir: Optional[str] = None,
     max_disk_space: Optional[int] = None,
 ) -> nn.Module:
     if config is None:
-        config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token)
+        config = AutoDistributedConfig.from_pretrained(model_name, token=token)
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
 
@@ -54,7 +54,7 @@ def load_pretrained_block(
         model_name,
         block_prefix,
         revision=revision,
-        use_auth_token=use_auth_token,
+        token=token,
         cache_dir=cache_dir,
         max_disk_space=max_disk_space,
     )
@@ -82,12 +82,12 @@ def _load_state_dict_from_repo(
     block_prefix: str,
     *,
     revision: Optional[str] = None,
-    use_auth_token: Optional[str] = None,
+    token: Optional[str] = None,
     cache_dir: str,
     max_disk_space: Optional[int] = None,
 ) -> StateDict:
     index_file = get_file_from_repo(
-        model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir
+        model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
     )
     if index_file is not None:  # Sharded model
         with open(index_file) as f:
@@ -107,7 +107,7 @@ def _load_state_dict_from_repo(
             model_name,
             filename,
             revision=revision,
-            use_auth_token=use_auth_token,
+            token=token,
             cache_dir=cache_dir,
             max_disk_space=max_disk_space,
         )
@@ -125,7 +125,7 @@ def _load_state_dict_from_file(
     filename: str,
     *,
     revision: Optional[str] = None,
-    use_auth_token: Optional[str] = None,
+    token: Optional[str] = None,
     cache_dir: str,
     max_disk_space: Optional[int] = None,
     delay: float = 30,
@@ -137,7 +137,7 @@ def _load_state_dict_from_file(
                 model_name,
                 filename,
                 revision=revision,
-                use_auth_token=use_auth_token,
+                use_auth_token=token,
                 cache_dir=cache_dir,
                 local_files_only=True,
             )
@@ -151,7 +151,7 @@ def _load_state_dict_from_file(
         try:
             with allow_cache_writes(cache_dir):
                 url = hf_hub_url(model_name, filename, revision=revision)
-                file_size = get_hf_file_metadata(url, token=use_auth_token).size
+                file_size = get_hf_file_metadata(url, token=token).size
                 if file_size is not None:
                     free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
                 else:
@@ -161,7 +161,7 @@ def _load_state_dict_from_file(
                     model_name,
                     filename,
                     revision=revision,
-                    use_auth_token=use_auth_token,
+                    use_auth_token=token,
                     cache_dir=cache_dir,
                     local_files_only=False,
                 )

+ 8 - 8
src/petals/server/server.py

@@ -77,7 +77,7 @@ class Server:
         balance_quality: float = 0.75,
         mean_balance_check_period: float = 120,
         mean_block_selection_delay: float = 2.5,
-        use_auth_token: Optional[str] = None,
+        token: Optional[str] = None,
         quant_type: Optional[QuantType] = None,
         tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
         skip_reachability_check: bool = False,
@@ -98,14 +98,14 @@ class Server:
         self.compression = compression
         self.stats_report_interval, self.update_period = stats_report_interval, update_period
         self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
-        self.revision, self.use_auth_token = revision, use_auth_token
+        self.revision, self.token = revision, token
 
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
 
         self.block_config = AutoDistributedConfig.from_pretrained(
             converted_model_name_or_path,
-            use_auth_token=use_auth_token,
+            token=token,
             revision=revision,
         )
 
@@ -271,7 +271,7 @@ class Server:
                 self.block_config,
                 self.torch_dtype,
                 self.adapters,
-                use_auth_token=self.use_auth_token,
+                token=self.token,
                 cache_dir=self.cache_dir,
                 max_disk_space=self.max_disk_space,
             )
@@ -316,7 +316,7 @@ class Server:
                 prefetch_batches=self.prefetch_batches,
                 sender_threads=self.sender_threads,
                 revision=self.revision,
-                use_auth_token=self.use_auth_token,
+                token=self.token,
                 quant_type=self.quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 should_validate_reachability=self.should_validate_reachability,
@@ -409,7 +409,7 @@ class ModuleContainer(threading.Thread):
         update_period: float,
         expiration: Optional[float],
         revision: Optional[str],
-        use_auth_token: Optional[str],
+        token: Optional[str],
         quant_type: QuantType,
         tensor_parallel_devices: Sequence[torch.device],
         should_validate_reachability: bool,
@@ -443,7 +443,7 @@ class ModuleContainer(threading.Thread):
                     config=block_config,
                     torch_dtype=torch_dtype,
                     revision=revision,
-                    use_auth_token=use_auth_token,
+                    token=token,
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                 )
@@ -456,7 +456,7 @@ class ModuleContainer(threading.Thread):
                     quant_type,
                     adapters=server_info.adapters,
                     freeze=True,
-                    use_auth_token=use_auth_token,
+                    token=token,
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                 )

+ 15 - 8
src/petals/utils/peft.py

@@ -45,13 +45,20 @@ def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", d
         return tensors
 
 
-def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
-    config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
+def get_adapter_from_repo(
+    repo_id: str,
+    block_idx: Optional[int] = None,
+    device: Optional[int] = None,
+    *,
+    token: Optional[str] = None,
+    **kwargs,
+):
+    config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
     if config_path is None:
         raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
     config = PeftConfig.from_json_file(config_path)
 
-    weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
+    weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs)
     if weight_path is None:
         raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
     if block_idx is None:
@@ -65,7 +72,7 @@ def load_peft(
     device: Optional[int] = None,
     *,
     revision: Optional[str] = None,
-    use_auth_token: Optional[str] = None,
+    token: Optional[str] = None,
     cache_dir: str,
     max_disk_space: Optional[int] = None,
     delay: float = 30,
@@ -82,7 +89,7 @@ def load_peft(
                 block_idx,
                 device,
                 revision=revision,
-                use_auth_token=use_auth_token,
+                token=token,
                 cache_dir=cache_dir,
                 local_files_only=False,
             )
@@ -93,9 +100,9 @@ def load_peft(
         try:
             with allow_cache_writes(cache_dir):
                 config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
-                config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
+                config_file_size = get_hf_file_metadata(config_url, token=token).size
                 weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
-                weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
+                weight_file_size = get_hf_file_metadata(weight_url, token=token).size
 
                 file_size = config_file_size + weight_file_size
                 if file_size is not None:
@@ -108,7 +115,7 @@ def load_peft(
                     block_idx,
                     device,
                     revision=revision,
-                    use_auth_token=use_auth_token,
+                    token=token,
                     cache_dir=cache_dir,
                     local_files_only=False,
                 )