clipped_lamb.py 551 B

1234567891011121314
  1. import torch
  2. from torch_optimizer import Lamb
  3. class LambWithGradientClipping(Lamb):
  4. """ A version of LAMB that clips gradients based on their norm. """
  5. def __init__(self, *args, max_grad_norm: float, **kwargs):
  6. self.max_grad_norm = max_grad_norm
  7. super().__init__(*args, **kwargs)
  8. def step(self, *args, **kwargs):
  9. iter_params = (param for group in self.param_groups for param in group['params'])
  10. torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
  11. return super().step(*args, **kwargs)