Michael Diskin před 4 roky
rodič
revize
ec7dedd111
1 změnil soubory, kde provedl 0 přidání a 13 odebrání
  1. 0 13
      examples/albert/optim.py

+ 0 - 13
examples/albert/optim.py

@@ -57,24 +57,20 @@ class Lamb(Optimizer):
         loss = None
         if closure is not None:
             loss = closure()
-        logger.info("UUU")
 
         for group in self.param_groups:
             for p in group['params']:
                 if p.grad is None:
                     continue
                 grad = p.grad.data
-                logger.info("00")
                 if grad.is_sparse:
                     msg = (
                         'Lamb does not support sparse gradients, '
                         'please consider SparseAdam instead'
                     )
                     raise RuntimeError(msg)
-                logger.info("1")
 
                 state = self.state[p]
-                logger.info("2")
 
                 # State initialization
                 if len(state) == 0:
@@ -84,11 +80,9 @@ class Lamb(Optimizer):
                         p, )
                     # Exponential moving average of squared gradient values
                     state['exp_avg_sq'] = torch.zeros_like(p)
-                logger.info("3")
 
                 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                 beta1, beta2 = group['betas']
-                logger.info("4")
 
                 state['step'] += 1
 
@@ -96,28 +90,21 @@ class Lamb(Optimizer):
                 # m_t
                 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                 # v_t
-                logger.info("5")
                 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
-                logger.info("6")
 
 
                 # Apply bias to lr to avoid broadcast.
                 step_size = group['lr']
 
                 weight_norm = torch.norm(p.data).clamp(0, self.clamp_value)
-                logger.info("9")
                 adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
-                logger.info("10")
                 adam_step.add_(p.data, alpha=group['weight_decay'])
-                logger.info("11")
 
                 adam_norm = torch.norm(adam_step).clamp_min(0.001)
-                logger.info("12")
                 trust_ratio = weight_norm / adam_norm
                 state['weight_norm'] = weight_norm
                 state['adam_norm'] = adam_norm
                 state['trust_ratio'] = trust_ratio
-                logger.info("13")
 
                 if self.adam:
                     trust_ratio = 1