grad_scaler.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import contextlib
  2. from typing import Dict, Optional
  3. import torch
  4. from torch.cuda.amp import GradScaler as TorchGradScaler
  5. from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state, OptState
  6. from torch.optim import Optimizer as TorchOptimizer
  7. import hivemind
  8. from hivemind.utils.logging import get_logger
  9. logger = get_logger(__name__)
  10. class GradScaler(TorchGradScaler):
  11. """
  12. A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
  13. - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
  14. - limit increasing gradient scale to only immediately after global optimizer steps
  15. - allow training with some or all master parameters in fp16
  16. """
  17. def __init__(self, *args, **kwargs):
  18. super().__init__(*args, **kwargs)
  19. self._is_running_global_step = False
  20. self._optimizer_states_to_reset = set()
  21. @contextlib.contextmanager
  22. def running_global_step(self):
  23. was_running, self._is_running_global_step = self._is_running_global_step, True
  24. try:
  25. yield
  26. finally:
  27. self._is_running_global_step = was_running
  28. def unscale_(self, optimizer: TorchOptimizer) -> bool:
  29. assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
  30. if self._is_running_global_step:
  31. super().unscale_(optimizer)
  32. self._per_optimizer_states[id(optimizer.opt)] = self._per_optimizer_states[id(optimizer)]
  33. return True
  34. else:
  35. self._check_inf_per_device(optimizer)
  36. self._optimizer_states_to_reset.add(id(optimizer))
  37. return False
  38. def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
  39. if self._is_running_global_step:
  40. assert self._per_optimizer_states[id(optimizer.opt)]["stage"] == OptState.UNSCALED, \
  41. "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
  42. if self.are_grads_finite(optimizer, use_cached=True):
  43. super().step(optimizer.opt, *args, **kwargs)
  44. else:
  45. logger.warning("Skipping global step due to gradient over/underflow")
  46. return True
  47. else:
  48. super().step(optimizer)
  49. self._optimizer_states_to_reset.add(id(optimizer))
  50. return False
  51. def update(self, new_scale: Optional[float] = None) -> bool:
  52. total_infs = 0
  53. for optimizer_state in self._per_optimizer_states.values():
  54. total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
  55. if self._is_running_global_step or total_infs != 0:
  56. # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
  57. super().update(new_scale)
  58. return True
  59. else:
  60. for opt_id in self._optimizer_states_to_reset:
  61. self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
  62. self._optimizer_states_to_reset.clear()
  63. return False
  64. def _unscale_grads_(
  65. self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
  66. ) -> Dict[torch.device, torch.Tensor]:
  67. # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
  68. # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
  69. return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
  70. def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
  71. opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
  72. return not sum(v.item() for v in opt_dict.values())
  73. class HivemindGradScaler(GradScaler):
  74. def __init__(self, *args, **kwargs):
  75. logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
  76. super().__init__(*args, **kwargs)