import math from typing import Dict, Any, Optional import torch from torch_optimizer.types import Betas2, Params from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise from bitsandbytes.optim.optimizer import Optimizer2State __all__ = ('CPULAMB8Bit',) class CPULAMB8Bit(Optimizer2State): r""" Implements Lamb with quantized 8-bit statistics. The statistics are stored in host memory in the quantized form. The LAMB optimizer and block-wise quantization are described in the following papers: - LAMB: "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" https://arxiv.org/abs/1904.00962 - Quantization: "8-bit Optimizers via Block-wise Quantization" https://arxiv.org/abs/2110.02861 This specific implementation of LAMB is based on https://github.com/cybertronai/pytorch-lamb - bias correction defaults to False because paper v3 does not use debiasing - it has baked in clipping by global max_grad_norm Arguments: params: iterable of parameters to optimize or dicts defining parameter groups lr: learning rate (default: 1e-3) betas: coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps: term added to the denominator to improve numerical stability (default: 1e-8) weight_decay: weight decay (L2 penalty) (default: 0) clamp_value: clamp weight_norm in (0,clamp_value) (default: 10) set to a high value to avoid it (e.g 10e3) bias_correction: debias statistics by (1 - beta**step) (default: False) min_8bit_size: statistics for parameters with fewer than this many elements will not be quantized reuse_grad_buffers: if True, optimizer will modify gradients in-place to save memory. If enabled, one must ensure that .zero_grad() is called after each optimizer step. update_chunk_size: quantized statistics will be de-quantized in chunks of up to this many elements. """ def __init__( self, params: Params, lr: float = 1e-3, betas: Betas2 = (0.9, 0.999), eps: float = 1e-6, weight_decay: float = 0, clamp_value: float = 10, bias_correction: bool = False, min_8bit_size: int = 65536, reuse_grad_buffers: bool = False, update_chunk_size: int = 2 ** 24, max_grad_norm: Optional[float] = None, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) if eps < 0.0: raise ValueError('Invalid epsilon value: {}'.format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError( 'Invalid beta parameter at index 0: {}'.format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( 'Invalid beta parameter at index 1: {}'.format(betas[1]) ) if weight_decay < 0: raise ValueError( 'Invalid weight_decay value: {}'.format(weight_decay) ) if clamp_value < 0.0: raise ValueError('Invalid clamp value: {}'.format(clamp_value)) self.clamp_value = clamp_value self.bias_correction = bias_correction self.reuse_grad_buffers = reuse_grad_buffers self.update_chunk_size = update_chunk_size self.max_grad_norm = max_grad_norm super(CPULAMB8Bit, self).__init__( 'cpu-lamb', params, lr, betas, eps, weight_decay, optim_bits=8, min_8bit_size=min_8bit_size, args=None, percentile_clipping=100, block_wise=4096, max_unorm=0) @torch.no_grad() def step(self, closure=None): if self.max_grad_norm is not None: iter_params = (param for group in self.param_groups for param in group['params']) torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm) return super().step(closure=closure) @torch.no_grad() def init_state(self, group, p, gindex, pindex): config = self.get_config(gindex, pindex, group) assert config['percentile_clipping'] == 100, "percentile clipping is not implemented on CPU" assert config['max_unorm'] == 0 if config['optim_bits'] == 32: dtype = torch.float32 elif config['optim_bits'] == 8: dtype = torch.uint8 else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') if p.numel() < config['min_8bit_size']: dtype = torch.float32 state = self.state[p] state['step'] = 0 if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) elif dtype == torch.uint8: if state['step'] == 0: if 'dynamic' not in self.name2qmap: self.fill_qmap() self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device) n = p.numel() blocks = (n - 1) // config['block_wise'] + 1 state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) state['qmap1'] = self.name2qmap['dynamic'] state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) state['qmap2'] = self.name2qmap['udynamic'] state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) @torch.no_grad() def update_step(self, group: Dict[str, Any], p: torch.Tensor, gindex: int, pindex: int): state = self.state[p] config = self.get_config(gindex, pindex, group) p_cpu, grad_cpu = p.cpu(), p.grad.cpu() # this is a no-op if parameters are already on CPU step = state['step'] = state['step'] + 1 beta1, beta2 = group['betas'] param_delta = self._update_moments_and_compute_delta( state, config, p_cpu, grad_cpu, beta1, beta2, group['eps'], group['weight_decay'] ) del grad_cpu # grad_cpu is no longer needed and may be modified if self.reuse_grad_buffers step_norm = torch.norm(param_delta) weight_norm = p_cpu.norm().clamp(0, self.clamp_value) trust_ratio = weight_norm / step_norm if weight_norm != 0 and step_norm != 0 else 1.0 state['weight_norm'], state['step_norm'], state['trust_ratio'] = weight_norm, step_norm, trust_ratio # Apply bias to lr to avoid broadcast. bias_correction = math.sqrt(1 - beta2 ** step) / (1 - beta1 ** step) if self.bias_correction else 1 step_size = group['lr'] * bias_correction p.data.add_(param_delta.to(p.device), alpha=-step_size * trust_ratio) def _update_moments_and_compute_delta( self, state: Dict, config: Dict, p_cpu: torch.Tensor, grad_cpu: torch.Tensor, beta1: float, beta2: float, eps: float, weight_decay: float ) -> torch.Tensor: step, block_size, chunk_size = state['step'], config['block_wise'], self.update_chunk_size if state['state1'].dtype != torch.uint8: # not quantized: update normally exp_avg, exp_avg_sq = state['state1'], state['state2'] exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2) sqrt_out = grad_cpu if self.reuse_grad_buffers else None _denominator = torch.sqrt(exp_avg_sq, out=sqrt_out).add_(eps) param_delta = torch.div(exp_avg, _denominator, out=_denominator) if weight_decay != 0: param_delta.add_(p_cpu, alpha=weight_decay) return param_delta elif p_cpu.numel() <= chunk_size: # quantized tensor within chunk size exp_avg = dequantize_blockwise( state['state1'], (state['absmax1'], state['qmap1']), blocksize=block_size ) exp_avg_sq = dequantize_blockwise( state['state2'], (state['absmax2'], state['qmap2']), blocksize=block_size ) exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2) quantize_blockwise(exp_avg, state['qmap1'], state['absmax1'], out=state['state1']) quantize_blockwise(exp_avg_sq, state['qmap2'], state['absmax2'], out=state['state2']) # note: quantize_blockwise also modifies qmap and absmax in-place param_delta = exp_avg.div_(exp_avg_sq.sqrt_().add_(eps)) # note: this changes statistics in-place, but it's okay b/c we saved quantized version if weight_decay != 0: param_delta.add_(p_cpu, alpha=weight_decay) return param_delta else: # very large quantized tensor, compute updates in chunks to save RAM flat_p, flat_grad, flat_state1, flat_state2 = ( tensor.view(-1) for tensor in (p_cpu, grad_cpu, state['state1'], state['state2']) ) output_buffer = flat_grad if self.reuse_grad_buffers else torch.empty_like(flat_grad) for chunk_index, chunk_start in enumerate(range(0, len(flat_p), chunk_size)): chunk = slice(chunk_start, chunk_start + chunk_size) chunk_blocks = slice(chunk_start // block_size, (chunk_start + chunk_size) // block_size) chunk_p, chunk_grad = flat_p[chunk], flat_grad[chunk] chunk_state1, chunk_state2 = flat_state1[chunk], flat_state2[chunk] chunk_absmax1, chunk_absmax2 = state['absmax1'][chunk_blocks], state['absmax2'][chunk_blocks] if chunk_state1.storage_offset() != 0: chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2 = map( torch.clone, (chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2) ) # clone chunks to ensure that tensors do not have offsets exp_avg_chunk = dequantize_blockwise( chunk_state1, (chunk_absmax1, state['qmap1']), blocksize=block_size ) exp_avg_sq_chunk = dequantize_blockwise( chunk_state2, (chunk_absmax2, state['qmap2']), blocksize=block_size ) exp_avg_chunk.mul_(beta1).add_(chunk_grad, alpha=1 - beta1) exp_avg_sq_chunk.mul_(beta2).addcmul_(chunk_grad, chunk_grad, value=1 - beta2) # note: output_buffer cannot be modified until this line because it shares memory with grad_cpu del chunk_grad flat_state1[chunk], (state['absmax1'][chunk_blocks], state['qmap1']) = quantize_blockwise( exp_avg_chunk, state['qmap1'], chunk_absmax1, out=chunk_state1 ) flat_state2[chunk], (state['absmax2'][chunk_blocks], state['qmap2']) = quantize_blockwise( exp_avg_sq_chunk, state['qmap2'], chunk_absmax2, out=chunk_state2 ) # note: we need to explicitly assign new quantized tensors because of cloning earlier torch.div(exp_avg_chunk, exp_avg_sq_chunk.sqrt_().add_(eps), out=output_buffer[chunk]) # note: this changes statistics in-place, but it's okay b/c we saved quantized version if weight_decay != 0: output_buffer[chunk].add_(flat_p[chunk], alpha=weight_decay) param_delta = output_buffer.view_as(grad_cpu) return param_delta