Переглянути джерело

pre-schedule state averaging

justheuristic 3 роки тому
батько
коміт
242ef7fcaf

+ 3 - 3
hivemind/optim/experimental/optimizer.py

@@ -286,7 +286,7 @@ class Optimizer(torch.optim.Optimizer):
                 self.scheduled_round = None
 
             swarm_not_empty = self.tracker.global_progress.num_peers > 1
-            began_averaging = False
+            began_averaging_gradients = False
 
             if swarm_not_empty:
                 try:
@@ -294,11 +294,11 @@ class Optimizer(torch.optim.Optimizer):
                         control=self.scheduled_round, reset_accumulators=True, wait=False
                     )
                     assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
-                    began_averaging = True
+                    began_averaging_gradients = True
                 except BaseException as e:
                     logger.exception(e)
 
-            if not began_averaging and self.scheduled_round is not None and not self.scheduled_round.done():
+            if not began_averaging_gradients and self.scheduled_round is not None and not self.scheduled_round.done():
                 logger.log(self.status_loglevel, f"Cancelled pre-scheduled averaging round")
                 self.scheduled_round.cancel()
                 self.scheduled_round = None

+ 12 - 2
hivemind/optim/experimental/state_averager.py

@@ -403,6 +403,7 @@ class TrainingStateAverager(DecentralizedAverager):
         zero_grad: bool,
         averaging_round: bool,
         grad_scaler: Optional[GradScaler],
+        timeout: Optional[float] = None,
         **kwargs,
     ):
         """
@@ -410,7 +411,12 @@ class TrainingStateAverager(DecentralizedAverager):
         This method is meant to be called in the background executor.
         """
         began_running = False
+        control = None
+
         try:
+            if averaging_round:
+                control = super().step(gather=self.local_epoch, require_trigger=True, timeout=timeout, **kwargs)
+
             if wait_for_trigger is not None:
                 wait_for_trigger()
             began_running = True
@@ -440,7 +446,8 @@ class TrainingStateAverager(DecentralizedAverager):
                 if not self.reuse_tensors:
                     self._load_local_tensors_into_averager_()
                 try:
-                    gathered = super().step(gather=self.local_epoch, **kwargs)
+                    control.allow_allreduce()
+                    gathered = control.result(timeout=timeout)
                     logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
@@ -459,8 +466,11 @@ class TrainingStateAverager(DecentralizedAverager):
 
         except Exception as e:
             if not began_running:
-                logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception.")
+                logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
             logger.exception(e)
+            if control is not None and not control.done():
+                logger.error(f"Cancelled scheduled state averaging round")
+                control.cancel()
             self.finished_optimizer_step.set()
             self.finished_averaging_round.set()