|
@@ -100,17 +100,9 @@ class Lamb(Optimizer):
|
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
|
logger.info("6")
|
|
|
|
|
|
- # Paper v3 does not use debiasing.
|
|
|
- if self.debias:
|
|
|
- bias_correction = math.sqrt(1 - beta2 ** state['step'])
|
|
|
- bias_correction /= 1 - beta1 ** state['step']
|
|
|
- else:
|
|
|
- bias_correction = 1
|
|
|
- logger.info("7")
|
|
|
|
|
|
# Apply bias to lr to avoid broadcast.
|
|
|
- step_size = group['lr'] * bias_correction
|
|
|
- logger.info("8")
|
|
|
+ step_size = group['lr']
|
|
|
|
|
|
weight_norm = torch.norm(p.data).clamp(0, self.clamp_value)
|
|
|
logger.info("9")
|