فهرست منبع

Fix --token arg (#378)

Alexander Borzunov 2 سال پیش
والد
کامیت
3218534745
4فایلهای تغییر یافته به همراه13 افزوده شده و 9 حذف شده
  1. 5 1
      src/petals/cli/run_server.py
  2. 3 3
      src/petals/server/from_pretrained.py
  3. 2 2
      src/petals/server/server.py
  4. 3 3
      src/petals/utils/peft.py

+ 5 - 1
src/petals/cli/run_server.py

@@ -25,6 +25,11 @@ def main():
                        help="path or name of a pretrained model, converted with cli/convert_model.py")
     group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
 
+    group = parser.add_mutually_exclusive_group(required=False)
+    group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
+    group.add_argument("--use_auth_token", action="store_true", dest="token",
+                       help="Read token saved by `huggingface-cli login")
+
     parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
     parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix")
@@ -132,7 +137,6 @@ def main():
     parser.add_argument("--mean_balance_check_period", type=float, default=60,
                         help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
 
-    parser.add_argument("--token", action='store_true', help="Hugging Face hub auth token for .from_pretrained()")
     parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
                         help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
                              "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "

+ 3 - 3
src/petals/server/from_pretrained.py

@@ -34,7 +34,7 @@ def load_pretrained_block(
     config: Optional[PretrainedConfig] = None,
     torch_dtype: Union[torch.dtype, str] = "auto",
     revision: Optional[str] = None,
-    token: Optional[str] = None,
+    token: Optional[Union[str, bool]] = None,
     cache_dir: Optional[str] = None,
     max_disk_space: Optional[int] = None,
 ) -> nn.Module:
@@ -82,7 +82,7 @@ def _load_state_dict_from_repo(
     block_prefix: str,
     *,
     revision: Optional[str] = None,
-    token: Optional[str] = None,
+    token: Optional[Union[str, bool]] = None,
     cache_dir: str,
     max_disk_space: Optional[int] = None,
 ) -> StateDict:
@@ -125,7 +125,7 @@ def _load_state_dict_from_file(
     filename: str,
     *,
     revision: Optional[str] = None,
-    token: Optional[str] = None,
+    token: Optional[Union[str, bool]] = None,
     cache_dir: str,
     max_disk_space: Optional[int] = None,
     delay: float = 30,

+ 2 - 2
src/petals/server/server.py

@@ -77,7 +77,7 @@ class Server:
         balance_quality: float = 0.75,
         mean_balance_check_period: float = 120,
         mean_block_selection_delay: float = 2.5,
-        token: Optional[str] = None,
+        token: Optional[Union[str, bool]] = None,
         quant_type: Optional[QuantType] = None,
         tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
         skip_reachability_check: bool = False,
@@ -409,7 +409,7 @@ class ModuleContainer(threading.Thread):
         update_period: float,
         expiration: Optional[float],
         revision: Optional[str],
-        token: Optional[str],
+        token: Optional[Union[str, bool]],
         quant_type: QuantType,
         tensor_parallel_devices: Sequence[torch.device],
         should_validate_reachability: bool,

+ 3 - 3
src/petals/utils/peft.py

@@ -1,7 +1,7 @@
 import contextlib
 import re
 import time
-from typing import Optional, Sequence
+from typing import Optional, Sequence, Union
 
 import bitsandbytes as bnb
 import torch
@@ -50,7 +50,7 @@ def get_adapter_from_repo(
     block_idx: Optional[int] = None,
     device: Optional[int] = None,
     *,
-    token: Optional[str] = None,
+    token: Optional[Union[str, bool]] = None,
     **kwargs,
 ):
     config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
@@ -72,7 +72,7 @@ def load_peft(
     device: Optional[int] = None,
     *,
     revision: Optional[str] = None,
-    token: Optional[str] = None,
+    token: Optional[Union[str, bool]] = None,
     cache_dir: str,
     max_disk_space: Optional[int] = None,
     delay: float = 30,