|
@@ -376,17 +376,20 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
if wait_for_delayed_updates:
|
|
|
for pending_update in self.pending_updates:
|
|
|
try:
|
|
|
+ timeout = (averaging_opts or {}).get("averaging_timeout", self._allreduce_timeout)
|
|
|
logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
|
|
|
- output = pending_update.result()
|
|
|
+ output = pending_update.result(timeout)
|
|
|
except BaseException:
|
|
|
- pass # exception will be reported below
|
|
|
+ # exception will be reported below
|
|
|
+ if not pending_update.done():
|
|
|
+ pending_update.cancel()
|
|
|
|
|
|
# remove finished updates, log any exceptions
|
|
|
finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
|
|
|
self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
|
|
|
for finished_update in finished_updates:
|
|
|
- if finished_update.exception():
|
|
|
- logger.log(self.status_loglevel, f"Background update failed with {finished_update.exception()}")
|
|
|
+ if finished_update.cancelled() or finished_update.exception():
|
|
|
+ logger.log(self.status_loglevel, f"Background update failed: {finished_update}")
|
|
|
|
|
|
if apply_delayed_updates:
|
|
|
if self.finished_averaging_round.is_set():
|