Pārlūkot izejas kodu

add revision support

dbaranchuk 3 gadi atpakaļ
vecāks
revīzija
5200dc7029
2 mainītis faili ar 5 papildinājumiem un 1 dzēšanām
  1. 3 0
      cli/run_server.py
  2. 2 1
      src/server/server.py

+ 3 - 0
cli/run_server.py

@@ -42,6 +42,9 @@ def main():
     parser.add_argument("--torch_dtype", type=str, default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
+    parser.add_argument('--revision', type=str, default='main',
+                        help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
+                             "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
 
     parser.add_argument('--throughput',
                         type=lambda value: value if value in ['auto', 'eval'] else float(value),

+ 2 - 1
src/server/server.py

@@ -105,6 +105,7 @@ class Server(threading.Thread):
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
         torch_dtype: str = "auto",
+        revision: str = "main",
         cache_dir: Optional[str] = None,
         cache_size_bytes: Optional[int] = None,
         device: Optional[Union[str, torch.device]] = None,
@@ -151,7 +152,7 @@ class Server(threading.Thread):
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
         block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path, use_auth_token=use_auth_token, revision="client"
+            converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
         )
 
         if block_indices is not None: