|
@@ -227,7 +227,7 @@ class ProgressTracker(threading.Thread):
|
|
|
Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
|
|
|
"""
|
|
|
loop = asyncio.get_event_loop()
|
|
|
- shutdown_checker = loop.run_in_executor(None, self.shutdown_triggered.wait)
|
|
|
+ shutdown_checker = asyncio.create_task(asyncio.wait_for(loop.run_in_executor(None, self.shutdown_triggered.wait), None))
|
|
|
|
|
|
async def _fetch_progress_unless_shutdown_triggered():
|
|
|
"""Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
|