1234567891011121314 |
- import torch
- from torch_optimizer import Lamb
- class LambWithGradientClipping(Lamb):
- """ A version of LAMB that clips gradients based on their norm. """
- def __init__(self, *args, max_grad_norm: float, **kwargs):
- self.max_grad_norm = max_grad_norm
- super().__init__(*args, **kwargs)
- def step(self, *args, **kwargs):
- iter_params = (param for group in self.param_groups for param in group['params'])
- torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
- return super().step(*args, **kwargs)
|