grad_scaler.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import contextlib
  2. from typing import Dict
  3. import torch
  4. from hivemind import DecentralizedOptimizerBase, get_logger
  5. from torch.cuda.amp import GradScaler as TorchGradScaler
  6. from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
  7. from torch.optim import Optimizer
  8. logger = get_logger(__name__)
  9. class HivemindGradScaler(TorchGradScaler):
  10. """A thin wrapper over GradScaler that supports hivemind-style training with CollaborativeOptimizer and others"""
  11. def __init__(self, *args, **kwargs):
  12. super().__init__(*args, **kwargs)
  13. self._is_running_global_step = False
  14. self._optimizer_states_to_reset = set()
  15. @contextlib.contextmanager
  16. def running_global_step(self):
  17. was_running, self._is_running_global_step = self._is_running_global_step, True
  18. try:
  19. yield
  20. finally:
  21. self._is_running_global_step = was_running
  22. def unscale_(self, optimizer, actually_unscale: bool = False):
  23. assert isinstance(optimizer, DecentralizedOptimizerBase)
  24. if self._is_running_global_step:
  25. super().unscale_(optimizer.opt)
  26. return True
  27. else:
  28. self._check_inf_per_device(optimizer.opt)
  29. self._optimizer_states_to_reset.add(id(optimizer))
  30. return False
  31. def step(self, optimizer, *args, **kwargs):
  32. assert isinstance(optimizer, DecentralizedOptimizerBase)
  33. if self._is_running_global_step:
  34. if self.are_grads_finite(optimizer):
  35. super().step(optimizer.opt, *args, **kwargs)
  36. else:
  37. logger.warning("Skipping global step due to gradient over/underflow")
  38. return True
  39. else:
  40. super().step(optimizer)
  41. self._optimizer_states_to_reset.add(optimizer)
  42. return False
  43. def update(self, new_scale=None):
  44. total_infs = 0
  45. for optimizer_state in self._per_optimizer_states.values():
  46. total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
  47. if self._is_running_global_step or total_infs != 0:
  48. # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
  49. super().update(new_scale)
  50. return True
  51. else:
  52. for opt_id in self._optimizer_states_to_reset:
  53. self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
  54. self._optimizer_states_to_reset.clear()
  55. return False
  56. def _unscale_grads_(
  57. self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
  58. ) -> Dict[torch.device, torch.Tensor]:
  59. return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
  60. def are_grads_finite(self, optimizer: DecentralizedOptimizerBase):
  61. assert isinstance(optimizer, DecentralizedOptimizerBase)
  62. return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())