123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- import math
- from typing import Dict, Any, Optional
- import torch
- from torch_optimizer.types import Betas2, Params
- from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
- from bitsandbytes.optim.optimizer import Optimizer2State
- __all__ = ('CPULAMB8Bit',)
- class CPULAMB8Bit(Optimizer2State):
- r"""
- Implements Lamb with quantized 8-bit statistics. The statistics are stored in host memory in the quantized form.
- The LAMB optimizer and block-wise quantization are described in the following papers:
- - LAMB: "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" https://arxiv.org/abs/1904.00962
- - Quantization: "8-bit Optimizers via Block-wise Quantization" https://arxiv.org/abs/2110.02861
- This specific implementation of LAMB is based on https://github.com/cybertronai/pytorch-lamb
- - bias correction defaults to False because paper v3 does not use debiasing
- - it has baked in clipping by global max_grad_norm
- Arguments:
- params: iterable of parameters to optimize or dicts defining
- parameter groups
- lr: learning rate (default: 1e-3)
- betas: coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps: term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay: weight decay (L2 penalty) (default: 0)
- clamp_value: clamp weight_norm in (0,clamp_value) (default: 10)
- set to a high value to avoid it (e.g 10e3)
- bias_correction: debias statistics by (1 - beta**step) (default: False)
- min_8bit_size: statistics for parameters with fewer than this many elements will not be quantized
- reuse_grad_buffers: if True, optimizer will modify gradients in-place to save memory.
- If enabled, one must ensure that .zero_grad() is called after each optimizer step.
- update_chunk_size: quantized statistics will be de-quantized in chunks of up to this many elements.
- """
- def __init__(
- self,
- params: Params,
- lr: float = 1e-3,
- betas: Betas2 = (0.9, 0.999),
- eps: float = 1e-6,
- weight_decay: float = 0,
- clamp_value: float = 10,
- bias_correction: bool = False,
- min_8bit_size: int = 65536,
- reuse_grad_buffers: bool = False,
- update_chunk_size: int = 2 ** 24,
- max_grad_norm: Optional[float] = None,
- ) -> None:
- if lr <= 0.0:
- raise ValueError('Invalid learning rate: {}'.format(lr))
- if eps < 0.0:
- raise ValueError('Invalid epsilon value: {}'.format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError(
- 'Invalid beta parameter at index 0: {}'.format(betas[0])
- )
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError(
- 'Invalid beta parameter at index 1: {}'.format(betas[1])
- )
- if weight_decay < 0:
- raise ValueError(
- 'Invalid weight_decay value: {}'.format(weight_decay)
- )
- if clamp_value < 0.0:
- raise ValueError('Invalid clamp value: {}'.format(clamp_value))
- self.clamp_value = clamp_value
- self.bias_correction = bias_correction
- self.reuse_grad_buffers = reuse_grad_buffers
- self.update_chunk_size = update_chunk_size
- self.max_grad_norm = max_grad_norm
- super(CPULAMB8Bit, self).__init__(
- 'cpu-lamb', params, lr, betas, eps, weight_decay, optim_bits=8, min_8bit_size=min_8bit_size, args=None,
- percentile_clipping=100, block_wise=4096, max_unorm=0)
- @torch.no_grad()
- def step(self, closure=None):
- if self.max_grad_norm is not None:
- iter_params = (param for group in self.param_groups for param in group['params'])
- torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
- return super().step(closure=closure)
- @torch.no_grad()
- def init_state(self, group, p, gindex, pindex):
- config = self.get_config(gindex, pindex, group)
- assert config['percentile_clipping'] == 100, "percentile clipping is not implemented on CPU"
- assert config['max_unorm'] == 0
- if config['optim_bits'] == 32:
- dtype = torch.float32
- elif config['optim_bits'] == 8:
- dtype = torch.uint8
- else:
- raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
- if p.numel() < config['min_8bit_size']: dtype = torch.float32
- state = self.state[p]
- state['step'] = 0
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32,
- device=p.device)
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32,
- device=p.device)
- elif dtype == torch.uint8:
- if state['step'] == 0:
- if 'dynamic' not in self.name2qmap: self.fill_qmap()
- self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
- self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
- n = p.numel()
- blocks = (n - 1) // config['block_wise'] + 1
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8,
- device=p.device)
- state['qmap1'] = self.name2qmap['dynamic']
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8,
- device=p.device)
- state['qmap2'] = self.name2qmap['udynamic']
- state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
- state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
- @torch.no_grad()
- def update_step(self, group: Dict[str, Any], p: torch.Tensor, gindex: int, pindex: int):
- state = self.state[p]
- config = self.get_config(gindex, pindex, group)
- p_cpu, grad_cpu = p.cpu(), p.grad.cpu()
- # this is a no-op if parameters are already on CPU
- step = state['step'] = state['step'] + 1
- beta1, beta2 = group['betas']
- param_delta = self._update_moments_and_compute_delta(
- state, config, p_cpu, grad_cpu, beta1, beta2, group['eps'], group['weight_decay']
- )
- del grad_cpu # grad_cpu is no longer needed and may be modified if self.reuse_grad_buffers
- step_norm = torch.norm(param_delta)
- weight_norm = p_cpu.norm().clamp(0, self.clamp_value)
- trust_ratio = weight_norm / step_norm if weight_norm != 0 and step_norm != 0 else 1.0
- state['weight_norm'], state['step_norm'], state['trust_ratio'] = weight_norm, step_norm, trust_ratio
- # Apply bias to lr to avoid broadcast.
- bias_correction = math.sqrt(1 - beta2 ** step) / (1 - beta1 ** step) if self.bias_correction else 1
- step_size = group['lr'] * bias_correction
- p.data.add_(param_delta.to(p.device), alpha=-step_size * trust_ratio)
- def _update_moments_and_compute_delta(
- self, state: Dict, config: Dict,
- p_cpu: torch.Tensor, grad_cpu: torch.Tensor,
- beta1: float, beta2: float, eps: float, weight_decay: float
- ) -> torch.Tensor:
- step, block_size, chunk_size = state['step'], config['block_wise'], self.update_chunk_size
- if state['state1'].dtype != torch.uint8:
- # not quantized: update normally
- exp_avg, exp_avg_sq = state['state1'], state['state2']
- exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
- exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
- sqrt_out = grad_cpu if self.reuse_grad_buffers else None
- _denominator = torch.sqrt(exp_avg_sq, out=sqrt_out).add_(eps)
- param_delta = torch.div(exp_avg, _denominator, out=_denominator)
- if weight_decay != 0:
- param_delta.add_(p_cpu, alpha=weight_decay)
- return param_delta
- elif p_cpu.numel() <= chunk_size:
- # quantized tensor within chunk size
- exp_avg = dequantize_blockwise(
- state['state1'], (state['absmax1'], state['qmap1']), blocksize=block_size
- )
- exp_avg_sq = dequantize_blockwise(
- state['state2'], (state['absmax2'], state['qmap2']), blocksize=block_size
- )
- exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
- exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
- quantize_blockwise(exp_avg, state['qmap1'], state['absmax1'], out=state['state1'])
- quantize_blockwise(exp_avg_sq, state['qmap2'], state['absmax2'], out=state['state2'])
- # note: quantize_blockwise also modifies qmap and absmax in-place
- param_delta = exp_avg.div_(exp_avg_sq.sqrt_().add_(eps))
- # note: this changes statistics in-place, but it's okay b/c we saved quantized version
- if weight_decay != 0:
- param_delta.add_(p_cpu, alpha=weight_decay)
- return param_delta
- else:
- # very large quantized tensor, compute updates in chunks to save RAM
- flat_p, flat_grad, flat_state1, flat_state2 = (
- tensor.view(-1) for tensor in (p_cpu, grad_cpu, state['state1'], state['state2'])
- )
- output_buffer = flat_grad if self.reuse_grad_buffers else torch.empty_like(flat_grad)
- for chunk_index, chunk_start in enumerate(range(0, len(flat_p), chunk_size)):
- chunk = slice(chunk_start, chunk_start + chunk_size)
- chunk_blocks = slice(chunk_start // block_size, (chunk_start + chunk_size) // block_size)
- chunk_p, chunk_grad = flat_p[chunk], flat_grad[chunk]
- chunk_state1, chunk_state2 = flat_state1[chunk], flat_state2[chunk]
- chunk_absmax1, chunk_absmax2 = state['absmax1'][chunk_blocks], state['absmax2'][chunk_blocks]
- if chunk_state1.storage_offset() != 0:
- chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2 = map(
- torch.clone, (chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2)
- ) # clone chunks to ensure that tensors do not have offsets
- exp_avg_chunk = dequantize_blockwise(
- chunk_state1, (chunk_absmax1, state['qmap1']), blocksize=block_size
- )
- exp_avg_sq_chunk = dequantize_blockwise(
- chunk_state2, (chunk_absmax2, state['qmap2']), blocksize=block_size
- )
- exp_avg_chunk.mul_(beta1).add_(chunk_grad, alpha=1 - beta1)
- exp_avg_sq_chunk.mul_(beta2).addcmul_(chunk_grad, chunk_grad, value=1 - beta2)
- # note: output_buffer cannot be modified until this line because it shares memory with grad_cpu
- del chunk_grad
- flat_state1[chunk], (state['absmax1'][chunk_blocks], state['qmap1']) = quantize_blockwise(
- exp_avg_chunk, state['qmap1'], chunk_absmax1, out=chunk_state1
- )
- flat_state2[chunk], (state['absmax2'][chunk_blocks], state['qmap2']) = quantize_blockwise(
- exp_avg_sq_chunk, state['qmap2'], chunk_absmax2, out=chunk_state2
- )
- # note: we need to explicitly assign new quantized tensors because of cloning earlier
- torch.div(exp_avg_chunk, exp_avg_sq_chunk.sqrt_().add_(eps), out=output_buffer[chunk])
- # note: this changes statistics in-place, but it's okay b/c we saved quantized version
- if weight_decay != 0:
- output_buffer[chunk].add_(flat_p[chunk], alpha=weight_decay)
- param_delta = output_buffer.view_as(grad_cpu)
- return param_delta
|