block_utils.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from typing import Optional, Union
  2. import torch
  3. from accelerate import init_empty_weights
  4. from transformers import PretrainedConfig
  5. from petals.utils.convert_block import QuantType
  6. def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
  7. """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
  8. if dtype not in ("auto", None):
  9. return dtype
  10. if config.torch_dtype not in ("auto", None, torch.float32):
  11. # If config specifies float32, we override it to the default dtype below
  12. return config.torch_dtype
  13. return torch.bfloat16
  14. def get_block_size(
  15. config: PretrainedConfig,
  16. location: str,
  17. *,
  18. dtype: Optional[Union[str, torch.dtype]] = None,
  19. quant_type: QuantType = QuantType.NONE,
  20. eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
  21. ) -> int:
  22. if location == "memory":
  23. assert (
  24. dtype is not None and quant_type is not None
  25. ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
  26. with init_empty_weights(include_buffers=True):
  27. block = config.block_class(config)
  28. n_params = sum(param.numel() for param in block.parameters())
  29. if location == "memory":
  30. if quant_type == QuantType.NONE:
  31. dtype = resolve_block_dtype(config, dtype)
  32. bytes_per_value = torch.finfo(dtype).bits // 8
  33. elif quant_type == QuantType.INT8:
  34. bytes_per_value = 1
  35. elif quant_type == QuantType.NF4:
  36. bytes_per_value = 4.25 / 8 # Bitness of NF4 with this config (measured empirically)
  37. else:
  38. raise ValueError(f"Unsupported quant_type={quant_type}")
  39. elif location == "disk":
  40. dtype = resolve_block_dtype(config, "auto")
  41. bytes_per_value = torch.finfo(dtype).bits // 8
  42. return round(n_params * bytes_per_value * (1 + eps))