|
@@ -189,6 +189,7 @@ class ProgressTracker(threading.Thread):
|
|
|
async def _progress_reporter(self):
|
|
|
"""Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
|
|
|
last_report_time = -float("inf")
|
|
|
+ store_task = None
|
|
|
try:
|
|
|
while not self.shutdown_triggered.is_set():
|
|
|
wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
|
|
@@ -203,21 +204,34 @@ class ProgressTracker(threading.Thread):
|
|
|
local_progress = self.local_progress
|
|
|
last_report_time = get_dht_time()
|
|
|
|
|
|
- await self.dht.store(
|
|
|
+ store_task = asyncio.create_task(asyncio.wait_for(
|
|
|
+ self.dht.store(
|
|
|
key=self.training_progress_key,
|
|
|
subkey=self._local_public_key,
|
|
|
value=local_progress.dict(),
|
|
|
expiration_time=last_report_time + self.metadata_expiration,
|
|
|
return_future=True,
|
|
|
- )
|
|
|
+ ), timeout=self.metadata_expiration))
|
|
|
finally:
|
|
|
logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")
|
|
|
+ if store_task is not None and not store_task.done():
|
|
|
+ store_task.cancel()
|
|
|
|
|
|
async def _progress_fetcher(self):
|
|
|
"""
|
|
|
Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
|
|
|
"""
|
|
|
loop = asyncio.get_event_loop()
|
|
|
+ shutdown_checker = asyncio.create_task(loop.run_in_executor(None, self.shutdown_triggered.wait))
|
|
|
+
|
|
|
+ async def _fetch_progress_unless_shutdown_triggered():
|
|
|
+ """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
|
|
|
+ getter = asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
|
|
|
+ await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
|
|
|
+ if self.shutdown_triggered.is_set():
|
|
|
+ return
|
|
|
+ return await getter
|
|
|
+
|
|
|
try:
|
|
|
while not self.shutdown_triggered.is_set():
|
|
|
time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
|
|
@@ -229,9 +243,12 @@ class ProgressTracker(threading.Thread):
|
|
|
continue
|
|
|
|
|
|
async with enter_asynchronously(self.lock_global_progress):
|
|
|
- progress_entry = await self.dht.get(self.training_progress_key, latest=True, return_future=True)
|
|
|
- metadata = progress_entry.value if isinstance(progress_entry, ValueWithExpiration) else None
|
|
|
+ maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
|
|
|
+ if self.shutdown_triggered.is_set():
|
|
|
+ break
|
|
|
+ metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
|
|
|
self.global_progress = self._parse_swarm_progress_data(metadata)
|
|
|
+
|
|
|
finally:
|
|
|
logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
|
|
|
|
|
@@ -307,7 +324,7 @@ class ProgressTracker(threading.Thread):
|
|
|
next_fetch_time=current_time + time_to_next_fetch,
|
|
|
)
|
|
|
|
|
|
- def shutdown(self, timeout: Optional[float]=None):
|
|
|
+ def shutdown(self, timeout: Optional[float] = None):
|
|
|
"""Permanently disable all tracking activity"""
|
|
|
self.shutdown_triggered.set()
|
|
|
self.should_report_progress.set()
|