from_pretrained.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. """
  2. Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
  3. If necessary, one can rewrite this to implement a different behavior, such as:
  4. - loading files from a local data source (e.g. S3)
  5. - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
  6. - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
  7. """
  8. from __future__ import annotations
  9. from typing import Optional, OrderedDict, Union
  10. import torch
  11. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  12. from transformers.modeling_utils import WEIGHTS_NAME
  13. from transformers.utils.hub import cached_path, hf_bucket_url
  14. from src.bloom import BloomBlock, BloomConfig
  15. use_hivemind_log_handler("in_root_logger")
  16. logger = get_logger(__file__)
  17. CLIENT_BRANCH = "main"
  18. BLOCK_BRANCH_PREFIX = "block_"
  19. USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
  20. FORCE_DOWNLOAD = False
  21. RESUME_DOWNLOAD = False
  22. LOCAL_FILES_ONLY = False
  23. def load_pretrained_block(
  24. converted_model_name_or_path: str,
  25. block_index: int,
  26. config: Optional[BloomConfig] = None,
  27. torch_dtype: Union[torch.dtype, str] = "auto",
  28. use_auth_token: Optional[str] = None,
  29. ) -> BloomBlock:
  30. """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
  31. if config is None:
  32. config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
  33. block = BloomBlock(config, layer_number=block_index)
  34. state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
  35. block.load_state_dict(state_dict)
  36. if torch_dtype == "auto":
  37. with torch.no_grad():
  38. for name, param in block.named_parameters():
  39. assert name in state_dict, f"{name} not in state dict"
  40. param.data = param.data.to(state_dict[name].dtype)
  41. else:
  42. assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
  43. block = block.to(dtype=torch_dtype)
  44. report = block.load_state_dict(state_dict, strict=True)
  45. logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
  46. return block
  47. def _load_state_dict(
  48. pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
  49. ) -> OrderedDict[str, torch.Tensor]:
  50. revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
  51. archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
  52. # Load from URL or cache if already cached
  53. resolved_archive_file = cached_path(
  54. archive_file,
  55. cache_dir=None,
  56. force_download=FORCE_DOWNLOAD,
  57. proxies=None,
  58. resume_download=RESUME_DOWNLOAD,
  59. local_files_only=LOCAL_FILES_ONLY,
  60. use_auth_token=use_auth_token,
  61. user_agent=USER_AGENT,
  62. )
  63. state_dict = torch.load(resolved_archive_file, map_location="cpu")
  64. return state_dict
  65. DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")