Forráskód Böngészése

await for running all-reduce instead of cancelling;
hide network errors that happened while reporting other network errors.

justheuristic 3 éve
szülő
commit
f495efb846
1 módosított fájl, 20 hozzáadás és 16 törlés
  1. 20 16
      hivemind/optim/experimental/optimizer.py

+ 20 - 16
hivemind/optim/experimental/optimizer.py

@@ -511,9 +511,9 @@ class Optimizer(torch.optim.Optimizer):
                 logger.exception(e)
 
         if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
-            logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
-            self.scheduled_grads.cancel()
-        self.scheduled_grads = None
+            logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
+            self._tag_along_with_zero_weight(self.scheduled_grads)
+            self.scheduled_grads = None
         return began_averaging_gradients
 
     def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
@@ -649,13 +649,8 @@ class Optimizer(torch.optim.Optimizer):
         # note: we tag along for the next all-reduce because the run may have already started and cancelling it
         # will cause peers to restart matchmaking and may  stall the entire collaboration for a few seconds.
         if self.scheduled_grads is not None and not self.scheduled_grads.done():
-            self.scheduled_grads.weight = 0
-            self.scheduled_grads.allow_allreduce()
-            try:
-                self.scheduled_grads.result(self.averaging_timeout)
-            except BaseException as e:
-                logger.exception(e)
-
+            self._tag_along_with_zero_weight(self.scheduled_grads)
+            self.scheduled_grads = None
         self.state_averager.step(wait_for_delayed_updates=True)
 
         with self.tracker.pause_updates():
@@ -683,9 +678,6 @@ class Optimizer(torch.optim.Optimizer):
                 if not self.client_mode:
                     self.grad_averager.state_sharing_priority = self.local_epoch
 
-    def _finish_background_averaging(self):
-        self.scheduled_grads = self.scheduled_state = None
-
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()
         state_dict["state"]["local_epoch"] = self.local_epoch
@@ -725,6 +717,19 @@ class Optimizer(torch.optim.Optimizer):
     def __repr__(self):
         return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
 
+    def _tag_along_with_zero_weight(self, control: StepControl):
+        """Wait for a running averaging round to finish with zero weight."""
+        if not control.triggered:
+            control.weight = 0
+            control.allow_allreduce()
+        if not control.done():
+            try:
+                control.result(self.averaging_timeout)
+            except BaseException as e:
+                logger.exception(e)
+                if not control.done():
+                    control.cancel()
+
     def shutdown(self):
         logger.log(self.status_loglevel, "Sending goodbye to peers...")
         self.tracker.shutdown(self.shutdown_timeout)
@@ -733,9 +738,8 @@ class Optimizer(torch.optim.Optimizer):
             if scheduled_round is not None:
                 if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
                     scheduled_round.cancel()
-                if not scheduled_round.triggered:
-                    scheduled_round.weight = 0
-                    scheduled_round.allow_allreduce()
+                else:
+                    self._tag_along_with_zero_weight(scheduled_round)
 
         logger.log(self.status_loglevel, "Shutting down averagers...")
         self.state_averager.shutdown()