Răsfoiți Sursa

fetch a specific bloom block without downloading the entire model

justheuristic 3 ani în urmă
părinte
comite
a6fca51212
3 a modificat fișierele cu 69 adăugiri și 2 ștergeri
  1. 1 1
      src/bloom/__init__.py
  2. 67 0
      src/bloom/from_pretrained.py
  3. 1 1
      src/bloom/model.py

+ 1 - 1
src/bloom/__init__.py

@@ -1 +1 @@
-from src.bloom.model import BloomForCausalLM, BloomModel, DistributedBloomConfig
+from src.bloom.model import BloomForCausalLM, BloomModel, DistributedBloomConfig, BloomBlock

+ 67 - 0
src/bloom/from_pretrained.py

@@ -0,0 +1,67 @@
+"""
+Utils for fetching pre-trained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
+If necessary, one can rewrite this to implement a different behavior, such as:
+ - loading files from a local data source (e.g. S3)
+ - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
+ - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
+
+"""
+from typing import Optional, OrderedDict
+
+import torch
+from hivemind.utils.logging import use_hivemind_log_handler, get_logger
+from transformers.utils.hub import hf_bucket_url, cached_path
+
+from src.bloom import BloomForCausalLM, DistributedBloomConfig, BloomBlock
+from transformers.modeling_utils import WEIGHTS_NAME
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+CLIENT_BRANCH = "client"
+BLOCK_BRANCH_PREFIX = "block_"
+USER_AGENT = {'file_type': 'model', 'framework': 'pytorch', 'from_auto_class': False}
+cls = BloomForCausalLM
+FORCE_DOWNLOAD = False
+RESUME_DOWNLOAD = False
+LOCAL_FILES_ONLY = False
+
+
+def load_pretrained_block(
+        converted_model_name_or_path: str, block_index: int, config: Optional[DistributedBloomConfig] = None):
+    """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
+    if config is None:
+        config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path)
+    block = BloomBlock(config, layer_number=block_index)
+    state_dict = _load_state_dict(converted_model_name_or_path, block_index)
+    with torch.no_grad():
+        for name, param in block.named_parameters():
+            assert name in state_dict, f"{name} not in state dict"
+            param.data = param.data.to(state_dict[name].dtype)
+    report = block.load_state_dict(state_dict, strict=True)
+    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
+    return block
+
+
+def _load_state_dict(
+        pretrained_model_name_or_path: str, block_index: Optional[int] = None) -> OrderedDict[str, torch.Tensor]:
+    revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
+    archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
+
+    # Load from URL or cache if already cached
+    resolved_archive_file = cached_path(
+        archive_file,
+        cache_dir=None,
+        force_download=FORCE_DOWNLOAD,
+        proxies=None,
+        resume_download=RESUME_DOWNLOAD,
+        local_files_only=LOCAL_FILES_ONLY,
+        use_auth_token=True,
+        user_agent=USER_AGENT,
+    )
+    state_dict = torch.load(resolved_archive_file, map_location='cpu')
+    return state_dict
+
+
+
+

+ 1 - 1
src/bloom/model.py

@@ -25,7 +25,7 @@ from src.bloom.block import BloomBlock
 from src.bloom.ops import build_alibi_tensor
 
 use_hivemind_log_handler("in_root_logger")
-logger = logging.get_logger(__name__)
+logger = logging.get_logger(__file__)
 
 _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
 _CONFIG_FOR_DOC = "MemoryEfficientBloomConfig"