|
@@ -276,7 +276,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
|
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
|
|
|
- # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
|
|
|
+
|
|
|
if grad_scaler is not None:
|
|
|
with grad_scaler.running_global_step():
|
|
|
assert grad_scaler.unscale_(self)
|