grad_scaler.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import contextlib
  2. from typing import Dict, Optional
  3. import torch
  4. from torch.cuda.amp import GradScaler as TorchGradScaler
  5. from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
  6. from torch.optim import Optimizer as TorchOptimizer
  7. from hivemind.optim import DecentralizedOptimizerBase, Optimizer
  8. from hivemind.utils.logging import get_logger
  9. logger = get_logger(__name__)
  10. class GradScaler(TorchGradScaler):
  11. """
  12. A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
  13. - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
  14. - limit increasing gradient scale to only immediately after global optimizer steps
  15. - allow training with some or all master parameters in fp16
  16. """
  17. def __init__(self, *args, **kwargs):
  18. super().__init__(*args, **kwargs)
  19. self._is_running_global_step = False
  20. self._optimizer_states_to_reset = set()
  21. @contextlib.contextmanager
  22. def running_global_step(self):
  23. was_running, self._is_running_global_step = self._is_running_global_step, True
  24. try:
  25. yield
  26. finally:
  27. self._is_running_global_step = was_running
  28. def unscale_(self, optimizer: TorchOptimizer) -> bool:
  29. assert isinstance(optimizer, (Optimizer, DecentralizedOptimizerBase))
  30. if self._is_running_global_step:
  31. super().unscale_(optimizer.opt)
  32. return True
  33. else:
  34. self._check_inf_per_device(optimizer.opt)
  35. self._optimizer_states_to_reset.add(id(optimizer))
  36. return False
  37. def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
  38. if self._is_running_global_step:
  39. if self.are_grads_finite(optimizer):
  40. super().step(optimizer, *args, **kwargs)
  41. else:
  42. logger.warning("Skipping global step due to gradient over/underflow")
  43. return True
  44. else:
  45. super().step(optimizer)
  46. self._optimizer_states_to_reset.add(id(optimizer))
  47. return False
  48. def update(self, new_scale: Optional[float] = None) -> bool:
  49. total_infs = 0
  50. for optimizer_state in self._per_optimizer_states.values():
  51. total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
  52. if self._is_running_global_step or total_infs != 0:
  53. # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
  54. super().update(new_scale)
  55. return True
  56. else:
  57. for opt_id in self._optimizer_states_to_reset:
  58. self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
  59. self._optimizer_states_to_reset.clear()
  60. return False
  61. def _unscale_grads_(
  62. self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
  63. ) -> Dict[torch.device, torch.Tensor]:
  64. # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
  65. # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
  66. return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
  67. def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
  68. return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
  69. class HivemindGradScaler(GradScaler):
  70. def __init__(self, *args, **kwargs):
  71. logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
  72. super().__init__(*args, **kwargs)