justheuristic 3 年之前
父节点
当前提交
42d8d63979
共有 1 个文件被更改,包括 8 次插入6 次删除
  1. 8 6
      hivemind/optim/experimental/state_averager.py

+ 8 - 6
hivemind/optim/experimental/state_averager.py

@@ -411,11 +411,13 @@ class TrainingStateAverager(DecentralizedAverager):
         This method is meant to be called in the background executor.
         """
         began_running = False
-        control = None
+        averaging_control = None
 
         try:
             if averaging_round:
-                control = super().step(gather=self.local_epoch, require_trigger=True, timeout=timeout, **kwargs)
+                averaging_control = super().step(
+                    gather=self.local_epoch, require_trigger=True, timeout=timeout, wait=False, **kwargs
+                )
 
             if wait_for_trigger is not None:
                 wait_for_trigger()
@@ -446,8 +448,8 @@ class TrainingStateAverager(DecentralizedAverager):
                 if not self.reuse_tensors:
                     self._load_local_tensors_into_averager_()
                 try:
-                    control.allow_allreduce()
-                    gathered = control.result(timeout=timeout)
+                    averaging_control.allow_allreduce()
+                    gathered = averaging_control.result(timeout=timeout)
                     logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
@@ -468,9 +470,9 @@ class TrainingStateAverager(DecentralizedAverager):
             if not began_running:
                 logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
             logger.exception(e)
-            if control is not None and not control.done():
+            if averaging_control is not None and not averaging_control.done():
                 logger.error(f"Cancelled scheduled state averaging round")
-                control.cancel()
+                averaging_control.cancel()
             self.finished_optimizer_step.set()
             self.finished_averaging_round.set()