Bläddra i källkod

minor bugfix with all-reduce scheduling

justheuristic 3 år sedan
förälder
incheckning
076b1f6a3b
1 ändrade filer med 3 tillägg och 5 borttagningar
  1. 3 5
      hivemind/optim/experimental/optimizer.py

+ 3 - 5
hivemind/optim/experimental/optimizer.py

@@ -385,7 +385,7 @@ class Optimizer(torch.optim.Optimizer):
                 began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
                 if not began_averaging_gradients:
                     pass  # failed to start gradient averaging due to an internal error
-                if self.delay_grad_averaging:
+                elif self.delay_grad_averaging:
                     # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
                     wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
                 else:
@@ -437,13 +437,13 @@ class Optimizer(torch.optim.Optimizer):
             with grad_scaler.running_global_step():
                 assert grad_scaler.unscale_(self)
 
+        began_averaging_gradients = False
         if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
             logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
                                              f"was already used elsewhere: {self.scheduled_state}")
             self.scheduled_grads = None
 
-        began_averaging_gradients = False
-        if self.tracker.global_progress.num_peers > 1:
+        elif self.tracker.global_progress.num_peers > 1:
             try:
                 self.scheduled_grads = self.grad_averager.step(
                     control=self.scheduled_grads, reset_accumulators=True, wait=False
@@ -582,8 +582,6 @@ class Optimizer(torch.optim.Optimizer):
 
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
-        self.state_averager.step(wait_for_delayed_updates=True)
-        self._finish_background_averaging()
 
         with self.tracker.pause_updates():
             while True: