|
@@ -54,6 +54,11 @@ class GradScaler(TorchGradScaler):
|
|
if self._is_running_global_step:
|
|
if self._is_running_global_step:
|
|
super().unscale_(optimizer)
|
|
super().unscale_(optimizer)
|
|
self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
|
|
self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
|
|
|
|
+ # note: we store unscaled optimizer state in a separate dict and not in _per_optimizer_states in order
|
|
|
|
+ # to avoid an edge case where full DPU peer encounters overflow in local gradients while averaging
|
|
|
|
+ # gradients (i.e. after global unscale but before global step). In that case, the next call to .update
|
|
|
|
+ # on user side would reset *all* optimizer states and cause .step to unscale gradients the second time.
|
|
|
|
+ # Offloaded optimizer is not affected by overflow in on-device gradients and should not be reset.
|
|
return True
|
|
return True
|
|
else:
|
|
else:
|
|
self._check_inf_per_device(optimizer)
|
|
self._check_inf_per_device(optimizer)
|