|
@@ -541,9 +541,9 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
if not began_running:
|
|
if not began_running:
|
|
logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
|
|
logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
|
|
logger.exception(e)
|
|
logger.exception(e)
|
|
- if averaging_control is not None and not averaging_control.done():
|
|
|
|
- logger.error(f"Cancelled scheduled state averaging round")
|
|
|
|
- averaging_control.cancel()
|
|
|
|
|
|
+ if averaging_control is not None and not averaging_control.triggered:
|
|
|
|
+ averaging_control.weight = 0.0
|
|
|
|
+ averaging_control.allow_allreduce()
|
|
self.finished_optimizer_step.set()
|
|
self.finished_optimizer_step.set()
|
|
self.finished_averaging_round.set()
|
|
self.finished_averaging_round.set()
|
|
|
|
|