convert_block.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """
  2. Tools for converting transformer blocks, applying quantization and/or tensor parallelism
  3. """
  4. import os
  5. import re
  6. from enum import Enum
  7. from typing import List, Optional, Sequence
  8. import tensor_parallel as tp
  9. import torch
  10. import torch.nn as nn
  11. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  12. from tensor_parallel.slicing_configs import get_bloom_config
  13. from transformers import PretrainedConfig
  14. from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
  15. use_hivemind_log_handler("in_root_logger")
  16. logger = get_logger(__name__)
  17. class QuantType(Enum):
  18. NONE = 0
  19. INT8 = 1 # 8-bit as in the LLM.int8() paper
  20. NF4 = 2 # 4-bit as in the QLoRA paper
  21. def convert_block(
  22. block: nn.Module,
  23. block_index: int,
  24. config: PretrainedConfig,
  25. tensor_parallel_devices: Sequence[torch.device],
  26. output_device: torch.device,
  27. quant_type: QuantType,
  28. freeze: bool = True,
  29. adapters: Optional[List[str]] = None,
  30. **kwargs,
  31. ) -> tp.TensorParallel:
  32. """
  33. Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
  34. :note: some optimizations will modify the input block in-place!
  35. :param block: a single transformer block, either pre-trained or newly initialized
  36. :param config: HF transformers config for the full model
  37. :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
  38. :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
  39. :param output_device: if tensor_parallel_devices is True, output
  40. :param quant_type: quantization type
  41. :param freeze: if True (default), make all module parameters non-trainable
  42. :return: a module that acts like the original block, but runs with all specified optimizations
  43. """
  44. if freeze:
  45. block.requires_grad_(False)
  46. block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
  47. if quant_type != QuantType.NONE:
  48. block = quantize_module(block, quant_type=quant_type)
  49. for shard, device in zip(block.module_shards, block.devices):
  50. shard.to(device)
  51. if adapters:
  52. create_lora_adapter(block)
  53. for adapter_name in adapters:
  54. adapter_config, adapter_state_dict = load_peft(
  55. adapter_name,
  56. block_idx=block_index,
  57. **kwargs,
  58. )
  59. add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
  60. return block
  61. def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module:
  62. # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
  63. os.environ["BITSANDBYTES_NOWELCOME"] = "1"
  64. import bitsandbytes as bnb
  65. for n, module in model.named_children():
  66. if len(list(module.children())) > 0:
  67. quantize_module(module, quant_type=quant_type)
  68. if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
  69. assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
  70. if quant_type == QuantType.INT8:
  71. model._modules[n] = bnb.nn.Linear8bitLt(
  72. module.in_features,
  73. module.out_features,
  74. module.bias is not None,
  75. has_fp16_weights=False,
  76. threshold=6.0, # Default from the LLM.int8() paper
  77. )
  78. model._modules[n].weight = bnb.nn.Int8Params(
  79. module.weight.data, requires_grad=False, has_fp16_weights=False
  80. ).to(module.weight.dtype)
  81. elif quant_type == QuantType.NF4:
  82. compress_statistics = True
  83. model._modules[n] = bnb.nn.LinearNF4(
  84. module.in_features,
  85. module.out_features,
  86. module.bias is not None,
  87. compress_statistics=compress_statistics,
  88. )
  89. model._modules[n].weight = bnb.nn.Params4bit(
  90. module.weight.data,
  91. requires_grad=False,
  92. quant_type="nf4",
  93. blocksize=64,
  94. compress_statistics=compress_statistics,
  95. ).to(module.weight.dtype)
  96. else:
  97. raise ValueError(f"Unsupported quant_type='{quant_type}'")
  98. model._modules[n].bias = module.bias
  99. return model
  100. def make_tensor_parallel(
  101. block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device
  102. ) -> nn.Module:
  103. if model_config.model_type == "bloom":
  104. tp_config = get_bloom_config(model_config, devices)
  105. del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
  106. else:
  107. if len(devices) > 1:
  108. logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution")
  109. tp_config = None
  110. tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
  111. total_heads = 0
  112. for tp_shard in tp_block.module_shards:
  113. for submodule in tp_shard.modules():
  114. if isinstance(submodule, model_config.attn_class):
  115. total_heads += submodule.num_heads
  116. assert total_heads == model_config.num_attention_heads
  117. return tp_block
  118. def check_device_balance(devices: Sequence[torch.device]):
  119. if not all(device.type == "cuda" for device in devices):
  120. logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk")
  121. return
  122. unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))
  123. if len(unique_device_capabilities) > 1:
  124. logger.warning(
  125. f"Found GPUs with uneven capabilities: {unique_device_capabilities}. "
  126. f"Using GPUs with different performance will cause the server to wait for the slowest GPU."
  127. )
  128. memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)
  129. used_memory = min(memory_per_device) * len(memory_per_device)
  130. wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)
  131. if wasted_memory_rate > 0.05:
  132. logger.warning(
  133. f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
  134. f"Consider running high-memory GPUs in a separate server."
  135. )