|
@@ -2,10 +2,12 @@ import math
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from torch.optim.optimizer import Optimizer
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
+import logging
|
|
|
|
|
|
|
|
|
|
__all__ = ('Lamb',)
|
|
__all__ = ('Lamb',)
|
|
|
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Lamb(Optimizer):
|
|
class Lamb(Optimizer):
|
|
|
|
|
|
@@ -55,20 +57,24 @@ class Lamb(Optimizer):
|
|
loss = None
|
|
loss = None
|
|
if closure is not None:
|
|
if closure is not None:
|
|
loss = closure()
|
|
loss = closure()
|
|
|
|
+ logger.info("UUU")
|
|
|
|
|
|
for group in self.param_groups:
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
if p.grad is None:
|
|
continue
|
|
continue
|
|
grad = p.grad.data
|
|
grad = p.grad.data
|
|
|
|
+ logger.info("00")
|
|
if grad.is_sparse:
|
|
if grad.is_sparse:
|
|
msg = (
|
|
msg = (
|
|
'Lamb does not support sparse gradients, '
|
|
'Lamb does not support sparse gradients, '
|
|
'please consider SparseAdam instead'
|
|
'please consider SparseAdam instead'
|
|
)
|
|
)
|
|
raise RuntimeError(msg)
|
|
raise RuntimeError(msg)
|
|
|
|
+ logger.info("1")
|
|
|
|
|
|
state = self.state[p]
|
|
state = self.state[p]
|
|
|
|
+ logger.info("2")
|
|
|
|
|
|
# State initialization
|
|
# State initialization
|
|
if len(state) == 0:
|
|
if len(state) == 0:
|
|
@@ -78,9 +84,11 @@ class Lamb(Optimizer):
|
|
p, )
|
|
p, )
|
|
# Exponential moving average of squared gradient values
|
|
# Exponential moving average of squared gradient values
|
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
|
|
|
+ logger.info("3")
|
|
|
|
|
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
beta1, beta2 = group['betas']
|
|
beta1, beta2 = group['betas']
|
|
|
|
+ logger.info("4")
|
|
|
|
|
|
state['step'] += 1
|
|
state['step'] += 1
|
|
|
|
|
|
@@ -88,7 +96,9 @@ class Lamb(Optimizer):
|
|
# m_t
|
|
# m_t
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
# v_t
|
|
# v_t
|
|
|
|
+ logger.info("5")
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
|
|
+ logger.info("6")
|
|
|
|
|
|
# Paper v3 does not use debiasing.
|
|
# Paper v3 does not use debiasing.
|
|
if self.debias:
|
|
if self.debias:
|
|
@@ -96,20 +106,27 @@ class Lamb(Optimizer):
|
|
bias_correction /= 1 - beta1 ** state['step']
|
|
bias_correction /= 1 - beta1 ** state['step']
|
|
else:
|
|
else:
|
|
bias_correction = 1
|
|
bias_correction = 1
|
|
|
|
+ logger.info("7")
|
|
|
|
|
|
# Apply bias to lr to avoid broadcast.
|
|
# Apply bias to lr to avoid broadcast.
|
|
step_size = group['lr'] * bias_correction
|
|
step_size = group['lr'] * bias_correction
|
|
|
|
+ logger.info("8")
|
|
|
|
|
|
weight_norm = torch.norm(p.data).clamp(0, self.clamp_value)
|
|
weight_norm = torch.norm(p.data).clamp(0, self.clamp_value)
|
|
-
|
|
|
|
|
|
+ logger.info("9")
|
|
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
|
|
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
|
|
|
|
+ logger.info("10")
|
|
adam_step.add_(p.data, alpha=group['weight_decay'])
|
|
adam_step.add_(p.data, alpha=group['weight_decay'])
|
|
|
|
+ logger.info("11")
|
|
|
|
|
|
adam_norm = torch.norm(adam_step).clamp_min(0.001)
|
|
adam_norm = torch.norm(adam_step).clamp_min(0.001)
|
|
|
|
+ logger.info("12")
|
|
trust_ratio = weight_norm / adam_norm
|
|
trust_ratio = weight_norm / adam_norm
|
|
state['weight_norm'] = weight_norm
|
|
state['weight_norm'] = weight_norm
|
|
state['adam_norm'] = adam_norm
|
|
state['adam_norm'] = adam_norm
|
|
state['trust_ratio'] = trust_ratio
|
|
state['trust_ratio'] = trust_ratio
|
|
|
|
+ logger.info("13")
|
|
|
|
+
|
|
if self.adam:
|
|
if self.adam:
|
|
trust_ratio = 1
|
|
trust_ratio = 1
|
|
|
|
|