|
@@ -3,50 +3,17 @@ import math
|
|
|
import torch
|
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
|
|
-from .types import Betas2, OptFloat, OptLossClosure, Params
|
|
|
|
|
|
__all__ = ('Lamb',)
|
|
|
|
|
|
|
|
|
class Lamb(Optimizer):
|
|
|
- r"""Implements Lamb algorithm.
|
|
|
-
|
|
|
- It has been proposed in `Large Batch Optimization for Deep Learning:
|
|
|
- Training BERT in 76 minutes`__.
|
|
|
-
|
|
|
- Arguments:
|
|
|
- params: iterable of parameters to optimize or dicts defining
|
|
|
- parameter groups
|
|
|
- lr: learning rate (default: 1e-3)
|
|
|
- betas: coefficients used for computing
|
|
|
- running averages of gradient and its square (default: (0.9, 0.999))
|
|
|
- eps: term added to the denominator to improve
|
|
|
- numerical stability (default: 1e-8)
|
|
|
- weight_decay: weight decay (L2 penalty) (default: 0)
|
|
|
- clamp_value: clamp weight_norm in (0,clamp_value) (default: 10)
|
|
|
- set to a high value to avoid it (e.g 10e3)
|
|
|
- adam: always use trust ratio = 1, which turns this
|
|
|
- into Adam. Useful for comparison purposes. (default: False)
|
|
|
- debias: debias adam by (1 - beta**step) (default: False)
|
|
|
-
|
|
|
- Example:
|
|
|
- >>> import torch_optimizer as optim
|
|
|
- >>> optimizer = optim.Lamb(model.parameters(), lr=0.1)
|
|
|
- >>> optimizer.zero_grad()
|
|
|
- >>> loss_fn(model(input), target).backward()
|
|
|
- >>> optimizer.step()
|
|
|
-
|
|
|
- __ https://arxiv.org/abs/1904.00962
|
|
|
-
|
|
|
- Note:
|
|
|
- Reference code: https://github.com/cybertronai/pytorch-lamb
|
|
|
- """
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- params: Params,
|
|
|
+ params,
|
|
|
lr: float = 1e-3,
|
|
|
- betas: Betas2 = (0.9, 0.999),
|
|
|
+ betas = (0.9, 0.999),
|
|
|
eps: float = 1e-6,
|
|
|
weight_decay: float = 0,
|
|
|
clamp_value: float = 10,
|
|
@@ -79,7 +46,7 @@ class Lamb(Optimizer):
|
|
|
|
|
|
super(Lamb, self).__init__(params, defaults)
|
|
|
|
|
|
- def step(self, closure: OptLossClosure = None) -> OptFloat:
|
|
|
+ def step(self, closure = None):
|
|
|
r"""Performs a single optimization step.
|
|
|
|
|
|
Arguments:
|