grad_scaler.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import contextlib
  2. import threading
  3. from copy import deepcopy
  4. from typing import Dict, Optional
  5. import torch
  6. from torch.cuda.amp import GradScaler as TorchGradScaler
  7. from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
  8. from torch.optim import Optimizer as TorchOptimizer
  9. import hivemind
  10. from hivemind.utils.logging import get_logger
  11. logger = get_logger(__name__)
  12. class GradScaler(TorchGradScaler):
  13. """
  14. A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.
  15. :note: if not using reuse_grad_buffers=True, one can and *should* train normally without this class, e.g. using
  16. standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.
  17. hivemind.GradScaler makes 3 modifications to the regular PyTorch AMP:
  18. - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
  19. - limit increasing gradient scale to only immediately after global optimizer steps
  20. - allow training with some or master parameters in float16
  21. :note: The above modiffications will be enabled automatically. One can (and should) use hivemind.GradScaler exactly
  22. as regular ``torch.amp.GradScaler``.
  23. """
  24. def __init__(self, *args, **kwargs):
  25. super().__init__(*args, **kwargs)
  26. self._is_running_global_step = False
  27. self._is_ready_to_update = False
  28. self._optimizer_states_to_reset = set()
  29. self._lock = threading.RLock()
  30. @contextlib.contextmanager
  31. def running_global_step(self):
  32. with self._lock:
  33. was_running, self._is_running_global_step = self._is_running_global_step, True
  34. try:
  35. yield
  36. finally:
  37. self._is_running_global_step = was_running
  38. def unscale_(self, optimizer: TorchOptimizer) -> bool:
  39. with self._lock:
  40. assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
  41. if self._is_running_global_step:
  42. super().unscale_(optimizer)
  43. self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
  44. return True
  45. else:
  46. self._check_inf_per_device(optimizer)
  47. self._optimizer_states_to_reset.add(id(optimizer))
  48. return False
  49. def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
  50. if self._is_running_global_step:
  51. with self._lock:
  52. if self._is_ready_to_update:
  53. logger.warning("Please call grad_scaler.update() after each step.")
  54. assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
  55. assert (
  56. self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
  57. ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
  58. if self.are_grads_finite(optimizer, use_cached=True):
  59. super().step(optimizer, *args, **kwargs)
  60. else:
  61. logger.warning("Skipping global step due to gradient over/underflow")
  62. self._is_ready_to_update = True
  63. return True
  64. else:
  65. assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
  66. super().step(optimizer)
  67. self._optimizer_states_to_reset.add(id(optimizer))
  68. return False
  69. def update(self, new_scale: Optional[float] = None) -> bool:
  70. with self._lock:
  71. total_infs = 0
  72. for optimizer_state in self._per_optimizer_states.values():
  73. total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
  74. if self._is_ready_to_update or total_infs != 0:
  75. # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
  76. super().update(new_scale)
  77. self._is_ready_to_update = False
  78. return True
  79. else:
  80. for opt_id in self._optimizer_states_to_reset:
  81. self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
  82. self._optimizer_states_to_reset.clear()
  83. return False
  84. def _unscale_grads_(
  85. self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
  86. ) -> Dict[torch.device, torch.Tensor]:
  87. # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
  88. # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
  89. return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
  90. def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
  91. opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
  92. return not sum(v.item() for v in opt_dict.values())
  93. class HivemindGradScaler(GradScaler):
  94. def __init__(self, *args, **kwargs):
  95. logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
  96. super().__init__(*args, **kwargs)