|
@@ -57,24 +57,20 @@ class Lamb(Optimizer):
|
|
loss = None
|
|
loss = None
|
|
if closure is not None:
|
|
if closure is not None:
|
|
loss = closure()
|
|
loss = closure()
|
|
- logger.info("UUU")
|
|
|
|
|
|
|
|
for group in self.param_groups:
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
if p.grad is None:
|
|
continue
|
|
continue
|
|
grad = p.grad.data
|
|
grad = p.grad.data
|
|
- logger.info("00")
|
|
|
|
if grad.is_sparse:
|
|
if grad.is_sparse:
|
|
msg = (
|
|
msg = (
|
|
'Lamb does not support sparse gradients, '
|
|
'Lamb does not support sparse gradients, '
|
|
'please consider SparseAdam instead'
|
|
'please consider SparseAdam instead'
|
|
)
|
|
)
|
|
raise RuntimeError(msg)
|
|
raise RuntimeError(msg)
|
|
- logger.info("1")
|
|
|
|
|
|
|
|
state = self.state[p]
|
|
state = self.state[p]
|
|
- logger.info("2")
|
|
|
|
|
|
|
|
# State initialization
|
|
# State initialization
|
|
if len(state) == 0:
|
|
if len(state) == 0:
|
|
@@ -84,11 +80,9 @@ class Lamb(Optimizer):
|
|
p, )
|
|
p, )
|
|
# Exponential moving average of squared gradient values
|
|
# Exponential moving average of squared gradient values
|
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
|
- logger.info("3")
|
|
|
|
|
|
|
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
beta1, beta2 = group['betas']
|
|
beta1, beta2 = group['betas']
|
|
- logger.info("4")
|
|
|
|
|
|
|
|
state['step'] += 1
|
|
state['step'] += 1
|
|
|
|
|
|
@@ -96,28 +90,21 @@ class Lamb(Optimizer):
|
|
# m_t
|
|
# m_t
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
# v_t
|
|
# v_t
|
|
- logger.info("5")
|
|
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
- logger.info("6")
|
|
|
|
|
|
|
|
|
|
|
|
# Apply bias to lr to avoid broadcast.
|
|
# Apply bias to lr to avoid broadcast.
|
|
step_size = group['lr']
|
|
step_size = group['lr']
|
|
|
|
|
|
weight_norm = torch.norm(p.data).clamp(0, self.clamp_value)
|
|
weight_norm = torch.norm(p.data).clamp(0, self.clamp_value)
|
|
- logger.info("9")
|
|
|
|
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
|
|
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
|
|
- logger.info("10")
|
|
|
|
adam_step.add_(p.data, alpha=group['weight_decay'])
|
|
adam_step.add_(p.data, alpha=group['weight_decay'])
|
|
- logger.info("11")
|
|
|
|
|
|
|
|
adam_norm = torch.norm(adam_step).clamp_min(0.001)
|
|
adam_norm = torch.norm(adam_step).clamp_min(0.001)
|
|
- logger.info("12")
|
|
|
|
trust_ratio = weight_norm / adam_norm
|
|
trust_ratio = weight_norm / adam_norm
|
|
state['weight_norm'] = weight_norm
|
|
state['weight_norm'] = weight_norm
|
|
state['adam_norm'] = adam_norm
|
|
state['adam_norm'] = adam_norm
|
|
state['trust_ratio'] = trust_ratio
|
|
state['trust_ratio'] = trust_ratio
|
|
- logger.info("13")
|
|
|
|
|
|
|
|
if self.adam:
|
|
if self.adam:
|
|
trust_ratio = 1
|
|
trust_ratio = 1
|