Ver código fonte

update separately

justheuristic 3 anos atrás
pai
commit
0f2b356a42

+ 1 - 0
hivemind/optim/experimental/optimizer.py

@@ -246,6 +246,7 @@ class Optimizer(torch.optim.Optimizer):
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
                     assert grad_scaler.unscale_(self)
+                    assert grad_scaler.update()
 
             if self.scheduled_round is not None and self.scheduled_round.triggered or self.scheduled_round.done():
                 logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")

+ 0 - 5
hivemind/optim/experimental/state_averager.py

@@ -399,11 +399,6 @@ class TrainingStateAverager(DecentralizedAverager):
                 else:
                     with grad_scaler.running_global_step():
                         assert grad_scaler.step(self.optimizer)
-
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.update()
-
             self._update_scheduler()
 
             if zero_grad: