|
@@ -56,8 +56,8 @@ class GradScaler(TorchGradScaler):
|
|
|
self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
|
|
|
# note: we store unscaled optimizer state in a separate dict and not in _per_optimizer_states in order
|
|
|
# to avoid an edge case where full DPU peer encounters overflow in local gradients while averaging
|
|
|
- # gradients (i.e. after global unscale but before global step). In that case, the next call to .update
|
|
|
- # on user side would reset *all* optimizer states and cause .step to unscale gradients the second time.
|
|
|
+ # offloaded gradients (i.e. after global unscale but before global step). Due to overflow, next call to
|
|
|
+ # .update on user side would reset *all* optimizer states and cause .step to unscale gradients twice.
|
|
|
# Offloaded optimizer is not affected by overflow in on-device gradients and should not be reset.
|
|
|
return True
|
|
|
else:
|