dbaranchuk 3 years ago
parent
commit
e297ae606f
1 changed files with 8 additions and 6 deletions
  1. 8 6
      src/bloom/from_pretrained.py

+ 8 - 6
src/bloom/from_pretrained.py

@@ -34,13 +34,15 @@ def load_pretrained_block(
     config: Optional[BloomConfig] = None,
     torch_dtype: Union[torch.dtype, str] = "auto",
     use_auth_token: Optional[str] = None,
-    cache_dir: Optional[str] = None
+    cache_dir: Optional[str] = None,
 ) -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
     if config is None:
         config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
     block = BloomBlock(config, layer_number=block_index)
-    state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir)
+    state_dict = _load_state_dict(
+        converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
+    )
     block.load_state_dict(state_dict)
 
     if torch_dtype == "auto":
@@ -58,10 +60,10 @@ def load_pretrained_block(
 
 
 def _load_state_dict(
-    pretrained_model_name_or_path: str, 
-    block_index: Optional[int] = None, 
-    use_auth_token: Optional[str] = None, 
-    cache_dir: Optional[str] = None
+    pretrained_model_name_or_path: str,
+    block_index: Optional[int] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None,
 ) -> OrderedDict[str, torch.Tensor]:
     revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
     archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)