justheuristic пре 3 година
родитељ
комит
d5c410bb1f
3 измењених фајлова са 17 додато и 5 уклоњено
  1. 4 0
      cli/run_server.py
  2. 5 4
      src/bloom/from_pretrained.py
  3. 8 1
      src/server/server.py

+ 4 - 0
cli/run_server.py

@@ -55,6 +55,7 @@ def main():
     parser.add_argument('--custom_module_path', type=str, required=False,
                         help='Path of a file with custom nn.modules, wrapped into special decorator')
     parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
+    parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
 
     # fmt:on
     args = vars(parser.parse_args())
@@ -66,6 +67,9 @@ def main():
     compression_type = args.pop("compression")
     compression = getattr(CompressionType, compression_type)
 
+    use_auth_token = args.pop("use_auth_token")
+    args['use_auth_token'] = True if use_auth_token in ('True', 'true', '') else use_auth_token
+
     server = Server.create(**args, start=True, compression=compression)
 
     try:

+ 5 - 4
src/bloom/from_pretrained.py

@@ -34,12 +34,13 @@ def load_pretrained_block(
     block_index: int,
     config: Optional[DistributedBloomConfig] = None,
     torch_dtype: Union[torch.dtype, str] = "auto",
+    use_auth_token: Optional[str]=None
 ) -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
     if config is None:
-        config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path)
+        config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
     block = BloomBlock(config, layer_number=block_index)
-    state_dict = _load_state_dict(converted_model_name_or_path, block_index)
+    state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
     block.load_state_dict(state_dict)
 
     if torch_dtype == "auto":
@@ -57,7 +58,7 @@ def load_pretrained_block(
 
 
 def _load_state_dict(
-    pretrained_model_name_or_path: str, block_index: Optional[int] = None
+    pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
 ) -> OrderedDict[str, torch.Tensor]:
     revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
     archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
@@ -70,7 +71,7 @@ def _load_state_dict(
         proxies=None,
         resume_download=RESUME_DOWNLOAD,
         local_files_only=LOCAL_FILES_ONLY,
-        use_auth_token=True,
+        use_auth_token=use_auth_token,
         user_agent=USER_AGENT,
     )
     state_dict = torch.load(resolved_archive_file, map_location="cpu")

+ 8 - 1
src/server/server.py

@@ -100,6 +100,7 @@ class Server(threading.Thread):
         custom_module_path=None,
         update_period: float = 30,
         expiration: Optional[float] = None,
+        use_auth_token: Optional[str] = None,
         *,
         start: bool,
         **kwargs,
@@ -137,7 +138,13 @@ class Server(threading.Thread):
         blocks = {}
         for block_index in block_indices:
             module_uid = f"{prefix}.{block_index}"
-            block = load_pretrained_block(converted_model_name_or_path, block_index, block_config, torch_dtype)
+            block = load_pretrained_block(
+                converted_model_name_or_path,
+                block_index,
+                block_config,
+                torch_dtype=torch_dtype,
+                use_auth_token=use_auth_token
+            )
             for param in block.parameters():
                 param.requires_grad = False