|
@@ -34,12 +34,13 @@ def load_pretrained_block(
|
|
|
block_index: int,
|
|
|
config: Optional[DistributedBloomConfig] = None,
|
|
|
torch_dtype: Union[torch.dtype, str] = "auto",
|
|
|
+ use_auth_token: Optional[str]=None
|
|
|
) -> BloomBlock:
|
|
|
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
|
|
if config is None:
|
|
|
- config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path)
|
|
|
+ config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
|
block = BloomBlock(config, layer_number=block_index)
|
|
|
- state_dict = _load_state_dict(converted_model_name_or_path, block_index)
|
|
|
+ state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
|
|
|
block.load_state_dict(state_dict)
|
|
|
|
|
|
if torch_dtype == "auto":
|
|
@@ -57,7 +58,7 @@ def load_pretrained_block(
|
|
|
|
|
|
|
|
|
def _load_state_dict(
|
|
|
- pretrained_model_name_or_path: str, block_index: Optional[int] = None
|
|
|
+ pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
|
|
|
) -> OrderedDict[str, torch.Tensor]:
|
|
|
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
|
|
|
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
|
|
@@ -70,7 +71,7 @@ def _load_state_dict(
|
|
|
proxies=None,
|
|
|
resume_download=RESUME_DOWNLOAD,
|
|
|
local_files_only=LOCAL_FILES_ONLY,
|
|
|
- use_auth_token=True,
|
|
|
+ use_auth_token=use_auth_token,
|
|
|
user_agent=USER_AGENT,
|
|
|
)
|
|
|
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|