grad_scaler.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import contextlib
  2. import threading
  3. from copy import deepcopy
  4. from typing import Dict, Optional
  5. import torch
  6. from torch.cuda.amp import GradScaler as TorchGradScaler
  7. from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
  8. from torch.optim import Optimizer as TorchOptimizer
  9. import hivemind
  10. from hivemind.utils.logging import get_logger
  11. logger = get_logger(__name__)
  12. class GradScaler(TorchGradScaler):
  13. """
  14. A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.
  15. :note: if not using reuse_grad_buffers=True, one can and *should* train normally without this class, e.g. using
  16. standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.
  17. hivemind.GradScaler makes 3 modifications to the regular PyTorch AMP:
  18. - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
  19. - limit increasing gradient scale to only immediately after global optimizer steps
  20. - allow training with some or master parameters in float16
  21. :note: The above modiffications will be enabled automatically. One can (and should) use hivemind.GradScaler exactly
  22. as regular ``torch.amp.GradScaler``.
  23. """
  24. def __init__(self, *args, **kwargs):
  25. super().__init__(*args, **kwargs)
  26. self._is_running_global_step = False
  27. self._is_ready_to_update = False
  28. self._inner_optimizer_states = {}
  29. self._optimizer_states_to_reset = set()
  30. self._lock = threading.RLock()
  31. @contextlib.contextmanager
  32. def running_global_step(self):
  33. with self._lock:
  34. was_running, self._is_running_global_step = self._is_running_global_step, True
  35. try:
  36. yield
  37. finally:
  38. self._is_running_global_step = was_running
  39. def unscale_(self, optimizer: TorchOptimizer) -> bool:
  40. with self._lock:
  41. assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
  42. if self._is_running_global_step:
  43. super().unscale_(optimizer)
  44. self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
  45. # note: we store unscaled optimizer state in a separate dict and not in _per_optimizer_states in order
  46. # to avoid an edge case where full DPU peer encounters overflow in local gradients while averaging
  47. # offloaded gradients (i.e. after global unscale but before global step). Due to overflow, next call to
  48. # .update on user side would reset *all* optimizer states and cause .step to unscale gradients twice.
  49. # Offloaded optimizer is not affected by overflow in on-device gradients and should not be reset.
  50. return True
  51. else:
  52. self._check_inf_per_device(optimizer)
  53. self._optimizer_states_to_reset.add(id(optimizer))
  54. return False
  55. def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
  56. if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
  57. # ^-- invoked privately within hivemind optimizer
  58. inner_optimizer = optimizer
  59. with self._lock:
  60. if self._is_ready_to_update:
  61. logger.warning("Please call grad_scaler.update() after each step")
  62. inner_optimizer_state = self._inner_optimizer_states.pop(id(inner_optimizer), None)
  63. if inner_optimizer_state is not None:
  64. self._per_optimizer_states[id(inner_optimizer)] = inner_optimizer_state
  65. assert (
  66. self._per_optimizer_states[id(inner_optimizer)]["stage"] == OptState.UNSCALED
  67. ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step"
  68. if self.are_grads_finite(inner_optimizer, use_cached=True):
  69. super().step(inner_optimizer, *args, **kwargs)
  70. else:
  71. logger.warning("Skipping global step due to gradient over/underflow")
  72. self._is_ready_to_update = True
  73. return True
  74. else:
  75. super().step(optimizer)
  76. self._optimizer_states_to_reset.add(id(optimizer))
  77. return False
  78. def update(self, new_scale: Optional[float] = None) -> bool:
  79. with self._lock:
  80. total_infs = 0
  81. for optimizer_state in self._per_optimizer_states.values():
  82. total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
  83. if self._is_ready_to_update or total_infs != 0:
  84. # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
  85. super().update(new_scale)
  86. self._is_ready_to_update = False
  87. return True
  88. else:
  89. for opt_id in self._optimizer_states_to_reset:
  90. self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
  91. self._optimizer_states_to_reset.clear()
  92. return False
  93. def _unscale_grads_(
  94. self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
  95. ) -> Dict[torch.device, torch.Tensor]:
  96. # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
  97. # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
  98. return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
  99. def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
  100. opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
  101. return not sum(v.item() for v in opt_dict.values())
  102. class HivemindGradScaler(GradScaler):
  103. def __init__(self, *args, **kwargs):
  104. logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
  105. super().__init__(*args, **kwargs)