|
@@ -15,12 +15,12 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
from transformers.modeling_utils import WEIGHTS_NAME
|
|
|
from transformers.utils.hub import cached_path, hf_bucket_url
|
|
|
|
|
|
-from src.bloom import BloomBlock, DistributedBloomConfig
|
|
|
+from src.bloom import BloomBlock, BloomConfig
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
-CLIENT_BRANCH = "client"
|
|
|
+CLIENT_BRANCH = "main"
|
|
|
BLOCK_BRANCH_PREFIX = "block_"
|
|
|
USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
|
|
|
FORCE_DOWNLOAD = False
|
|
@@ -31,13 +31,13 @@ LOCAL_FILES_ONLY = False
|
|
|
def load_pretrained_block(
|
|
|
converted_model_name_or_path: str,
|
|
|
block_index: int,
|
|
|
- config: Optional[DistributedBloomConfig] = None,
|
|
|
+ config: Optional[BloomConfig] = None,
|
|
|
torch_dtype: Union[torch.dtype, str] = "auto",
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
) -> BloomBlock:
|
|
|
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
|
|
if config is None:
|
|
|
- config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
|
+ config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
|
block = BloomBlock(config, layer_number=block_index)
|
|
|
state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
|
|
|
block.load_state_dict(state_dict)
|