|
@@ -1,6 +1,7 @@
|
|
|
import contextlib
|
|
|
from copy import deepcopy
|
|
|
from typing import Dict, Optional
|
|
|
+import threading
|
|
|
|
|
|
import torch
|
|
|
from torch.cuda.amp import GradScaler as TorchGradScaler
|
|
@@ -25,6 +26,7 @@ class GradScaler(TorchGradScaler):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
self._is_running_global_step = False
|
|
|
self._optimizer_states_to_reset = set()
|
|
|
+ self._lock = threading.RLock()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
def running_global_step(self):
|
|
@@ -35,46 +37,49 @@ class GradScaler(TorchGradScaler):
|
|
|
self._is_running_global_step = was_running
|
|
|
|
|
|
def unscale_(self, optimizer: TorchOptimizer) -> bool:
|
|
|
- assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
- if self._is_running_global_step:
|
|
|
- super().unscale_(optimizer)
|
|
|
- self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
|
|
|
- return True
|
|
|
- else:
|
|
|
- self._check_inf_per_device(optimizer)
|
|
|
- self._optimizer_states_to_reset.add(id(optimizer))
|
|
|
- return False
|
|
|
+ with self._lock:
|
|
|
+ assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
+ if self._is_running_global_step:
|
|
|
+ super().unscale_(optimizer)
|
|
|
+ self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ self._check_inf_per_device(optimizer)
|
|
|
+ self._optimizer_states_to_reset.add(id(optimizer))
|
|
|
+ return False
|
|
|
|
|
|
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
|
|
|
- if self._is_running_global_step:
|
|
|
- assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
- assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
|
|
|
- "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
|
|
|
- if self.are_grads_finite(optimizer, use_cached=True):
|
|
|
- super().step(optimizer, *args, **kwargs)
|
|
|
+ with self._lock:
|
|
|
+ if self._is_running_global_step:
|
|
|
+ assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
+ assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
|
|
|
+ "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
|
|
|
+ if self.are_grads_finite(optimizer, use_cached=True):
|
|
|
+ super().step(optimizer, *args, **kwargs)
|
|
|
+ else:
|
|
|
+ logger.warning("Skipping global step due to gradient over/underflow")
|
|
|
+ return True
|
|
|
else:
|
|
|
- logger.warning("Skipping global step due to gradient over/underflow")
|
|
|
- return True
|
|
|
- else:
|
|
|
- assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
- super().step(optimizer)
|
|
|
- self._optimizer_states_to_reset.add(id(optimizer))
|
|
|
- return False
|
|
|
+ assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
+ super().step(optimizer)
|
|
|
+ self._optimizer_states_to_reset.add(id(optimizer))
|
|
|
+ return False
|
|
|
|
|
|
def update(self, new_scale: Optional[float] = None) -> bool:
|
|
|
- total_infs = 0
|
|
|
- for optimizer_state in self._per_optimizer_states.values():
|
|
|
- total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
|
|
|
+ with self._lock:
|
|
|
+ total_infs = 0
|
|
|
+ for optimizer_state in self._per_optimizer_states.values():
|
|
|
+ total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
|
|
|
|
|
|
- if self._is_running_global_step or total_infs != 0:
|
|
|
- # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
|
|
|
- super().update(new_scale)
|
|
|
- return True
|
|
|
- else:
|
|
|
- for opt_id in self._optimizer_states_to_reset:
|
|
|
- self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
|
|
|
- self._optimizer_states_to_reset.clear()
|
|
|
- return False
|
|
|
+ if self._is_running_global_step or total_infs != 0:
|
|
|
+ # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
|
|
|
+ super().update(new_scale)
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ for opt_id in self._optimizer_states_to_reset:
|
|
|
+ self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
|
|
|
+ self._optimizer_states_to_reset.clear()
|
|
|
+ return False
|
|
|
|
|
|
def _unscale_grads_(
|
|
|
self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
|