Michael Diskin 4 年 前
コミット
9df7bf3c4b
1 ファイル変更18 行追加1 行削除
  1. 18 1
      examples/albert/optim.py

+ 18 - 1
examples/albert/optim.py

@@ -2,10 +2,12 @@ import math
 
 import torch
 from torch.optim.optimizer import Optimizer
+import logging
 
 
 __all__ = ('Lamb',)
 
+logger = logging.getLogger(__name__)
 
 class Lamb(Optimizer):
 
@@ -55,20 +57,24 @@ class Lamb(Optimizer):
         loss = None
         if closure is not None:
             loss = closure()
+        logger.info("UUU")
 
         for group in self.param_groups:
             for p in group['params']:
                 if p.grad is None:
                     continue
                 grad = p.grad.data
+                logger.info("00")
                 if grad.is_sparse:
                     msg = (
                         'Lamb does not support sparse gradients, '
                         'please consider SparseAdam instead'
                     )
                     raise RuntimeError(msg)
+                logger.info("1")
 
                 state = self.state[p]
+                logger.info("2")
 
                 # State initialization
                 if len(state) == 0:
@@ -78,9 +84,11 @@ class Lamb(Optimizer):
                         p, )
                     # Exponential moving average of squared gradient values
                     state['exp_avg_sq'] = torch.zeros_like(p)
+                logger.info("3")
 
                 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                 beta1, beta2 = group['betas']
+                logger.info("4")
 
                 state['step'] += 1
 
@@ -88,7 +96,9 @@ class Lamb(Optimizer):
                 # m_t
                 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                 # v_t
+                logger.info("5")
                 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+                logger.info("6")
 
                 # Paper v3 does not use debiasing.
                 if self.debias:
@@ -96,20 +106,27 @@ class Lamb(Optimizer):
                     bias_correction /= 1 - beta1 ** state['step']
                 else:
                     bias_correction = 1
+                logger.info("7")
 
                 # Apply bias to lr to avoid broadcast.
                 step_size = group['lr'] * bias_correction
+                logger.info("8")
 
                 weight_norm = torch.norm(p.data).clamp(0, self.clamp_value)
-
+                logger.info("9")
                 adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
+                logger.info("10")
                 adam_step.add_(p.data, alpha=group['weight_decay'])
+                logger.info("11")
 
                 adam_norm = torch.norm(adam_step).clamp_min(0.001)
+                logger.info("12")
                 trust_ratio = weight_norm / adam_norm
                 state['weight_norm'] = weight_norm
                 state['adam_norm'] = adam_norm
                 state['trust_ratio'] = trust_ratio
+                logger.info("13")
+
                 if self.adam:
                     trust_ratio = 1