lamb_8bit.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import math
  2. from typing import Dict, Any, Optional
  3. import torch
  4. from torch_optimizer.types import Betas2, Params
  5. from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
  6. from bitsandbytes.optim.optimizer import Optimizer2State
  7. __all__ = ('CPULAMB8Bit',)
  8. class CPULAMB8Bit(Optimizer2State):
  9. r"""
  10. Implements Lamb with quantized 8-bit statistics. The statistics are stored in host memory in the quantized form.
  11. The LAMB optimizer and block-wise quantization are described in the following papers:
  12. - LAMB: "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" https://arxiv.org/abs/1904.00962
  13. - Quantization: "8-bit Optimizers via Block-wise Quantization" https://arxiv.org/abs/2110.02861
  14. This specific implementation of LAMB is based on https://github.com/cybertronai/pytorch-lamb
  15. - bias correction defaults to False because paper v3 does not use debiasing
  16. - it has baked in clipping by global max_grad_norm
  17. Arguments:
  18. params: iterable of parameters to optimize or dicts defining
  19. parameter groups
  20. lr: learning rate (default: 1e-3)
  21. betas: coefficients used for computing
  22. running averages of gradient and its square (default: (0.9, 0.999))
  23. eps: term added to the denominator to improve
  24. numerical stability (default: 1e-8)
  25. weight_decay: weight decay (L2 penalty) (default: 0)
  26. clamp_value: clamp weight_norm in (0,clamp_value) (default: 10)
  27. set to a high value to avoid it (e.g 10e3)
  28. bias_correction: debias statistics by (1 - beta**step) (default: False)
  29. min_8bit_size: statistics for parameters with fewer than this many elements will not be quantized
  30. reuse_grad_buffers: if True, optimizer will modify gradients in-place to save memory.
  31. If enabled, one must ensure that .zero_grad() is called after each optimizer step.
  32. update_chunk_size: quantized statistics will be de-quantized in chunks of up to this many elements.
  33. """
  34. def __init__(
  35. self,
  36. params: Params,
  37. lr: float = 1e-3,
  38. betas: Betas2 = (0.9, 0.999),
  39. eps: float = 1e-6,
  40. weight_decay: float = 0,
  41. clamp_value: float = 10,
  42. bias_correction: bool = False,
  43. min_8bit_size: int = 65536,
  44. reuse_grad_buffers: bool = False,
  45. update_chunk_size: int = 2 ** 24,
  46. max_grad_norm: Optional[float] = None,
  47. ) -> None:
  48. if lr <= 0.0:
  49. raise ValueError('Invalid learning rate: {}'.format(lr))
  50. if eps < 0.0:
  51. raise ValueError('Invalid epsilon value: {}'.format(eps))
  52. if not 0.0 <= betas[0] < 1.0:
  53. raise ValueError(
  54. 'Invalid beta parameter at index 0: {}'.format(betas[0])
  55. )
  56. if not 0.0 <= betas[1] < 1.0:
  57. raise ValueError(
  58. 'Invalid beta parameter at index 1: {}'.format(betas[1])
  59. )
  60. if weight_decay < 0:
  61. raise ValueError(
  62. 'Invalid weight_decay value: {}'.format(weight_decay)
  63. )
  64. if clamp_value < 0.0:
  65. raise ValueError('Invalid clamp value: {}'.format(clamp_value))
  66. self.clamp_value = clamp_value
  67. self.bias_correction = bias_correction
  68. self.reuse_grad_buffers = reuse_grad_buffers
  69. self.update_chunk_size = update_chunk_size
  70. self.max_grad_norm = max_grad_norm
  71. super(CPULAMB8Bit, self).__init__(
  72. 'cpu-lamb', params, lr, betas, eps, weight_decay, optim_bits=8, min_8bit_size=min_8bit_size, args=None,
  73. percentile_clipping=100, block_wise=4096, max_unorm=0)
  74. @torch.no_grad()
  75. def step(self, closure=None):
  76. if self.max_grad_norm is not None:
  77. iter_params = (param for group in self.param_groups for param in group['params'])
  78. torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
  79. return super().step(closure=closure)
  80. @torch.no_grad()
  81. def init_state(self, group, p, gindex, pindex):
  82. config = self.get_config(gindex, pindex, group)
  83. assert config['percentile_clipping'] == 100, "percentile clipping is not implemented on CPU"
  84. assert config['max_unorm'] == 0
  85. if config['optim_bits'] == 32:
  86. dtype = torch.float32
  87. elif config['optim_bits'] == 8:
  88. dtype = torch.uint8
  89. else:
  90. raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
  91. if p.numel() < config['min_8bit_size']: dtype = torch.float32
  92. state = self.state[p]
  93. state['step'] = 0
  94. if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
  95. state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32,
  96. device=p.device)
  97. state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32,
  98. device=p.device)
  99. elif dtype == torch.uint8:
  100. if state['step'] == 0:
  101. if 'dynamic' not in self.name2qmap: self.fill_qmap()
  102. self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
  103. self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
  104. n = p.numel()
  105. blocks = (n - 1) // config['block_wise'] + 1
  106. state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8,
  107. device=p.device)
  108. state['qmap1'] = self.name2qmap['dynamic']
  109. state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8,
  110. device=p.device)
  111. state['qmap2'] = self.name2qmap['udynamic']
  112. state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
  113. state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
  114. @torch.no_grad()
  115. def update_step(self, group: Dict[str, Any], p: torch.Tensor, gindex: int, pindex: int):
  116. state = self.state[p]
  117. config = self.get_config(gindex, pindex, group)
  118. p_cpu, grad_cpu = p.cpu(), p.grad.cpu()
  119. # this is a no-op if parameters are already on CPU
  120. step = state['step'] = state['step'] + 1
  121. beta1, beta2 = group['betas']
  122. param_delta = self._update_moments_and_compute_delta(
  123. state, config, p_cpu, grad_cpu, beta1, beta2, group['eps'], group['weight_decay']
  124. )
  125. del grad_cpu # grad_cpu is no longer needed and may be modified if self.reuse_grad_buffers
  126. step_norm = torch.norm(param_delta)
  127. weight_norm = p_cpu.norm().clamp(0, self.clamp_value)
  128. trust_ratio = weight_norm / step_norm if weight_norm != 0 and step_norm != 0 else 1.0
  129. state['weight_norm'], state['step_norm'], state['trust_ratio'] = weight_norm, step_norm, trust_ratio
  130. # Apply bias to lr to avoid broadcast.
  131. bias_correction = math.sqrt(1 - beta2 ** step) / (1 - beta1 ** step) if self.bias_correction else 1
  132. step_size = group['lr'] * bias_correction
  133. p.data.add_(param_delta.to(p.device), alpha=-step_size * trust_ratio)
  134. def _update_moments_and_compute_delta(
  135. self, state: Dict, config: Dict,
  136. p_cpu: torch.Tensor, grad_cpu: torch.Tensor,
  137. beta1: float, beta2: float, eps: float, weight_decay: float
  138. ) -> torch.Tensor:
  139. step, block_size, chunk_size = state['step'], config['block_wise'], self.update_chunk_size
  140. if state['state1'].dtype != torch.uint8:
  141. # not quantized: update normally
  142. exp_avg, exp_avg_sq = state['state1'], state['state2']
  143. exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
  144. exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
  145. sqrt_out = grad_cpu if self.reuse_grad_buffers else None
  146. _denominator = torch.sqrt(exp_avg_sq, out=sqrt_out).add_(eps)
  147. param_delta = torch.div(exp_avg, _denominator, out=_denominator)
  148. if weight_decay != 0:
  149. param_delta.add_(p_cpu, alpha=weight_decay)
  150. return param_delta
  151. elif p_cpu.numel() <= chunk_size:
  152. # quantized tensor within chunk size
  153. exp_avg = dequantize_blockwise(
  154. state['state1'], (state['absmax1'], state['qmap1']), blocksize=block_size
  155. )
  156. exp_avg_sq = dequantize_blockwise(
  157. state['state2'], (state['absmax2'], state['qmap2']), blocksize=block_size
  158. )
  159. exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
  160. exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
  161. quantize_blockwise(exp_avg, state['qmap1'], state['absmax1'], out=state['state1'])
  162. quantize_blockwise(exp_avg_sq, state['qmap2'], state['absmax2'], out=state['state2'])
  163. # note: quantize_blockwise also modifies qmap and absmax in-place
  164. param_delta = exp_avg.div_(exp_avg_sq.sqrt_().add_(eps))
  165. # note: this changes statistics in-place, but it's okay b/c we saved quantized version
  166. if weight_decay != 0:
  167. param_delta.add_(p_cpu, alpha=weight_decay)
  168. return param_delta
  169. else:
  170. # very large quantized tensor, compute updates in chunks to save RAM
  171. flat_p, flat_grad, flat_state1, flat_state2 = (
  172. tensor.view(-1) for tensor in (p_cpu, grad_cpu, state['state1'], state['state2'])
  173. )
  174. output_buffer = flat_grad if self.reuse_grad_buffers else torch.empty_like(flat_grad)
  175. for chunk_index, chunk_start in enumerate(range(0, len(flat_p), chunk_size)):
  176. chunk = slice(chunk_start, chunk_start + chunk_size)
  177. chunk_blocks = slice(chunk_start // block_size, (chunk_start + chunk_size) // block_size)
  178. chunk_p, chunk_grad = flat_p[chunk], flat_grad[chunk]
  179. chunk_state1, chunk_state2 = flat_state1[chunk], flat_state2[chunk]
  180. chunk_absmax1, chunk_absmax2 = state['absmax1'][chunk_blocks], state['absmax2'][chunk_blocks]
  181. if chunk_state1.storage_offset() != 0:
  182. chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2 = map(
  183. torch.clone, (chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2)
  184. ) # clone chunks to ensure that tensors do not have offsets
  185. exp_avg_chunk = dequantize_blockwise(
  186. chunk_state1, (chunk_absmax1, state['qmap1']), blocksize=block_size
  187. )
  188. exp_avg_sq_chunk = dequantize_blockwise(
  189. chunk_state2, (chunk_absmax2, state['qmap2']), blocksize=block_size
  190. )
  191. exp_avg_chunk.mul_(beta1).add_(chunk_grad, alpha=1 - beta1)
  192. exp_avg_sq_chunk.mul_(beta2).addcmul_(chunk_grad, chunk_grad, value=1 - beta2)
  193. # note: output_buffer cannot be modified until this line because it shares memory with grad_cpu
  194. del chunk_grad
  195. flat_state1[chunk], (state['absmax1'][chunk_blocks], state['qmap1']) = quantize_blockwise(
  196. exp_avg_chunk, state['qmap1'], chunk_absmax1, out=chunk_state1
  197. )
  198. flat_state2[chunk], (state['absmax2'][chunk_blocks], state['qmap2']) = quantize_blockwise(
  199. exp_avg_sq_chunk, state['qmap2'], chunk_absmax2, out=chunk_state2
  200. )
  201. # note: we need to explicitly assign new quantized tensors because of cloning earlier
  202. torch.div(exp_avg_chunk, exp_avg_sq_chunk.sqrt_().add_(eps), out=output_buffer[chunk])
  203. # note: this changes statistics in-place, but it's okay b/c we saved quantized version
  204. if weight_decay != 0:
  205. output_buffer[chunk].add_(flat_p[chunk], alpha=weight_decay)
  206. param_delta = output_buffer.view_as(grad_cpu)
  207. return param_delta