|
@@ -403,6 +403,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
zero_grad: bool,
|
|
|
averaging_round: bool,
|
|
|
grad_scaler: Optional[GradScaler],
|
|
|
+ timeout: Optional[float] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
"""
|
|
@@ -410,7 +411,12 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
This method is meant to be called in the background executor.
|
|
|
"""
|
|
|
began_running = False
|
|
|
+ control = None
|
|
|
+
|
|
|
try:
|
|
|
+ if averaging_round:
|
|
|
+ control = super().step(gather=self.local_epoch, require_trigger=True, timeout=timeout, **kwargs)
|
|
|
+
|
|
|
if wait_for_trigger is not None:
|
|
|
wait_for_trigger()
|
|
|
began_running = True
|
|
@@ -440,7 +446,8 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
if not self.reuse_tensors:
|
|
|
self._load_local_tensors_into_averager_()
|
|
|
try:
|
|
|
- gathered = super().step(gather=self.local_epoch, **kwargs)
|
|
|
+ control.allow_allreduce()
|
|
|
+ gathered = control.result(timeout=timeout)
|
|
|
logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
|
|
|
except BaseException as e:
|
|
|
logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
|
|
@@ -459,8 +466,11 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
|
|
|
except Exception as e:
|
|
|
if not began_running:
|
|
|
- logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception.")
|
|
|
+ logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
|
|
|
logger.exception(e)
|
|
|
+ if control is not None and not control.done():
|
|
|
+ logger.error(f"Cancelled scheduled state averaging round")
|
|
|
+ control.cancel()
|
|
|
self.finished_optimizer_step.set()
|
|
|
self.finished_averaging_round.set()
|
|
|
|