grad_scaler.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import contextlib
  2. from copy import deepcopy
  3. from typing import Dict, Optional
  4. import threading
  5. import torch
  6. from torch.cuda.amp import GradScaler as TorchGradScaler
  7. from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state, OptState
  8. from torch.optim import Optimizer as TorchOptimizer
  9. import hivemind
  10. from hivemind.utils.logging import get_logger
  11. logger = get_logger(__name__)
  12. class GradScaler(TorchGradScaler):
  13. """
  14. A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
  15. - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
  16. - limit increasing gradient scale to only immediately after global optimizer steps
  17. - allow training with some or all master parameters in fp16
  18. """
  19. def __init__(self, *args, **kwargs):
  20. super().__init__(*args, **kwargs)
  21. self._is_running_global_step = False
  22. self._is_ready_to_update = False
  23. self._optimizer_states_to_reset = set()
  24. self._lock = threading.RLock()
  25. @contextlib.contextmanager
  26. def running_global_step(self):
  27. with self._lock:
  28. was_running, self._is_running_global_step = self._is_running_global_step, True
  29. try:
  30. yield
  31. finally:
  32. self._is_running_global_step = was_running
  33. def unscale_(self, optimizer: TorchOptimizer) -> bool:
  34. with self._lock:
  35. assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
  36. if self._is_running_global_step:
  37. super().unscale_(optimizer)
  38. self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
  39. return True
  40. else:
  41. self._check_inf_per_device(optimizer)
  42. self._optimizer_states_to_reset.add(id(optimizer))
  43. return False
  44. def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
  45. if self._is_running_global_step:
  46. with self._lock:
  47. if self._is_ready_to_update:
  48. logger.warning("Please call grad_scaler.update() after each step.")
  49. assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
  50. assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
  51. "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
  52. if self.are_grads_finite(optimizer, use_cached=True):
  53. super().step(optimizer, *args, **kwargs)
  54. else:
  55. logger.warning("Skipping global step due to gradient over/underflow")
  56. self._is_ready_to_update = True
  57. return True
  58. else:
  59. assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
  60. super().step(optimizer)
  61. self._optimizer_states_to_reset.add(id(optimizer))
  62. return False
  63. def update(self, new_scale: Optional[float] = None) -> bool:
  64. with self._lock:
  65. total_infs = 0
  66. for optimizer_state in self._per_optimizer_states.values():
  67. total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
  68. if self._is_ready_to_update or total_infs != 0:
  69. # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
  70. super().update(new_scale)
  71. self._is_ready_to_update = False
  72. return True
  73. else:
  74. for opt_id in self._optimizer_states_to_reset:
  75. self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
  76. self._optimizer_states_to_reset.clear()
  77. return False
  78. def _unscale_grads_(
  79. self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
  80. ) -> Dict[torch.device, torch.Tensor]:
  81. # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
  82. # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
  83. return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
  84. def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
  85. opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
  86. return not sum(v.item() for v in opt_dict.values())
  87. class HivemindGradScaler(GradScaler):
  88. def __init__(self, *args, **kwargs):
  89. logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
  90. super().__init__(*args, **kwargs)