from_pretrained.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """
  2. Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
  3. If necessary, one can rewrite this to implement a different behavior, such as:
  4. - loading files from a local data source (e.g. S3)
  5. - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
  6. - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
  7. """
  8. import json
  9. import time
  10. from typing import Dict, Optional, Union
  11. import torch
  12. import torch.nn as nn
  13. from accelerate import init_empty_weights
  14. from accelerate.utils import set_module_tensor_to_device
  15. from hivemind.utils.logging import get_logger
  16. from huggingface_hub import get_hf_file_metadata, hf_hub_url
  17. from transformers import PretrainedConfig
  18. from transformers.utils import get_file_from_repo
  19. from petals.constants import DTYPE_MAP
  20. from petals.server.block_utils import resolve_block_dtype
  21. from petals.utils.auto_config import AutoDistributedConfig
  22. from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
  23. from petals.utils.hf_auth import always_needs_auth
  24. logger = get_logger(__name__)
  25. def load_pretrained_block(
  26. model_name: str,
  27. block_index: int,
  28. *,
  29. config: Optional[PretrainedConfig] = None,
  30. torch_dtype: Union[torch.dtype, str] = "auto",
  31. revision: Optional[str] = None,
  32. token: Optional[Union[str, bool]] = None,
  33. cache_dir: Optional[str] = None,
  34. max_disk_space: Optional[int] = None,
  35. ) -> nn.Module:
  36. if config is None:
  37. config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token)
  38. if cache_dir is None:
  39. cache_dir = DEFAULT_CACHE_DIR
  40. assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
  41. torch_dtype = resolve_block_dtype(config, torch_dtype)
  42. with init_empty_weights():
  43. block = config.block_class(config)
  44. block_prefix = f"{config.block_prefix}.{block_index}."
  45. state_dict = _load_state_dict_from_repo(
  46. model_name,
  47. block_prefix,
  48. revision=revision,
  49. token=token,
  50. cache_dir=cache_dir,
  51. max_disk_space=max_disk_space,
  52. )
  53. # dummy load, check that keys match
  54. report = block.load_state_dict(state_dict, strict=False)
  55. assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
  56. for param_name, _ in block.named_parameters():
  57. assert param_name in state_dict, f"{param_name} not in state dict"
  58. param = state_dict[param_name]
  59. if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
  60. param = param.to(torch_dtype)
  61. set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
  62. logger.info(f"Loaded {model_name} block {block_index}, {report}")
  63. return block
  64. StateDict = Dict[str, torch.Tensor]
  65. def _load_state_dict_from_repo(
  66. model_name: str,
  67. block_prefix: str,
  68. *,
  69. revision: Optional[str] = None,
  70. token: Optional[Union[str, bool]] = None,
  71. cache_dir: str,
  72. max_disk_space: Optional[int] = None,
  73. ) -> StateDict:
  74. if always_needs_auth(model_name) and token is None:
  75. token = True
  76. index_file = get_file_from_repo(
  77. model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
  78. )
  79. if index_file is not None: # Sharded model
  80. with open(index_file) as f:
  81. index = json.load(f)
  82. filenames = {
  83. filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
  84. }
  85. if not filenames:
  86. raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
  87. else: # Non-sharded model
  88. filenames = {"pytorch_model.bin"}
  89. logger.debug(f"Loading {block_prefix}* from {filenames}")
  90. state_dict = {}
  91. for filename in filenames:
  92. shard_state_dict = _load_state_dict_from_file(
  93. model_name,
  94. filename,
  95. revision=revision,
  96. token=token,
  97. cache_dir=cache_dir,
  98. max_disk_space=max_disk_space,
  99. )
  100. shard_state_dict = {
  101. param_name[len(block_prefix) :]: param
  102. for param_name, param in shard_state_dict.items()
  103. if param_name.startswith(block_prefix)
  104. } # Remove unused parameters from memory
  105. state_dict.update(shard_state_dict)
  106. return state_dict
  107. def _load_state_dict_from_file(
  108. model_name: str,
  109. filename: str,
  110. *,
  111. revision: Optional[str] = None,
  112. token: Optional[Union[str, bool]] = None,
  113. cache_dir: str,
  114. max_disk_space: Optional[int] = None,
  115. delay: float = 30,
  116. ) -> StateDict:
  117. # First, try to find the weights locally
  118. try:
  119. with allow_cache_reads(cache_dir):
  120. path = get_file_from_repo(
  121. model_name,
  122. filename,
  123. revision=revision,
  124. use_auth_token=token,
  125. cache_dir=cache_dir,
  126. local_files_only=True,
  127. )
  128. if path is not None:
  129. return torch.load(path, map_location="cpu")
  130. except Exception:
  131. logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
  132. # If not found, ensure that we have enough disk space to download them (maybe remove something)
  133. while True:
  134. try:
  135. with allow_cache_writes(cache_dir):
  136. url = hf_hub_url(model_name, filename, revision=revision)
  137. file_size = get_hf_file_metadata(url, token=token).size
  138. if file_size is not None:
  139. free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
  140. else:
  141. logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
  142. path = get_file_from_repo(
  143. model_name,
  144. filename,
  145. revision=revision,
  146. use_auth_token=token,
  147. cache_dir=cache_dir,
  148. local_files_only=False,
  149. )
  150. if path is None:
  151. raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
  152. return torch.load(path, map_location="cpu")
  153. except Exception as e:
  154. logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
  155. time.sleep(delay)