Kaynağa Gözat

use main as client branch

justheuristic 3 yıl önce
ebeveyn
işleme
6c437c9249
1 değiştirilmiş dosya ile 4 ekleme ve 4 silme
  1. 4 4
      src/bloom/from_pretrained.py

+ 4 - 4
src/bloom/from_pretrained.py

@@ -15,12 +15,12 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from transformers.modeling_utils import WEIGHTS_NAME
 from transformers.utils.hub import cached_path, hf_bucket_url
 
-from src.bloom import BloomBlock, DistributedBloomConfig
+from src.bloom import BloomBlock, BloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
-CLIENT_BRANCH = "client"
+CLIENT_BRANCH = "main"
 BLOCK_BRANCH_PREFIX = "block_"
 USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
 FORCE_DOWNLOAD = False
@@ -31,13 +31,13 @@ LOCAL_FILES_ONLY = False
 def load_pretrained_block(
     converted_model_name_or_path: str,
     block_index: int,
-    config: Optional[DistributedBloomConfig] = None,
+    config: Optional[BloomConfig] = None,
     torch_dtype: Union[torch.dtype, str] = "auto",
     use_auth_token: Optional[str] = None,
 ) -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
     if config is None:
-        config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
+        config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
     block = BloomBlock(config, layer_number=block_index)
     state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
     block.load_state_dict(state_dict)