|
@@ -227,11 +227,15 @@ class ProgressTracker(threading.Thread):
|
|
|
Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
|
|
|
"""
|
|
|
loop = asyncio.get_event_loop()
|
|
|
- shutdown_checker = asyncio.create_task(asyncio.wait_for(loop.run_in_executor(None, self.shutdown_triggered.wait), None))
|
|
|
+ shutdown_checker = asyncio.create_task(
|
|
|
+ asyncio.wait_for(loop.run_in_executor(None, self.shutdown_triggered.wait), None)
|
|
|
+ )
|
|
|
|
|
|
async def _fetch_progress_unless_shutdown_triggered():
|
|
|
"""Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
|
|
|
- getter = asyncio.create_task(asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None))
|
|
|
+ getter = asyncio.create_task(
|
|
|
+ asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
|
|
|
+ )
|
|
|
await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
|
|
|
if self.shutdown_triggered.is_set():
|
|
|
return
|