|
@@ -195,6 +195,7 @@ class ProgressTracker(threading.Thread):
|
|
|
async def _progress_reporter(self):
|
|
|
"""Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
|
|
|
last_report_time = -float("inf")
|
|
|
+ last_report_epoch = -float("inf")
|
|
|
store_task = None
|
|
|
try:
|
|
|
while not self.shutdown_triggered.is_set():
|
|
@@ -209,19 +210,23 @@ class ProgressTracker(threading.Thread):
|
|
|
|
|
|
local_progress = self.local_progress
|
|
|
last_report_time = get_dht_time()
|
|
|
-
|
|
|
- store_task = asyncio.create_task(
|
|
|
- asyncio.wait_for(
|
|
|
- self.dht.store(
|
|
|
- key=self.training_progress_key,
|
|
|
- subkey=self._local_public_key,
|
|
|
- value=local_progress.dict(),
|
|
|
- expiration_time=last_report_time + self.metadata_expiration,
|
|
|
- return_future=True,
|
|
|
- ),
|
|
|
- timeout=self.metadata_expiration,
|
|
|
+ if local_progress.samples_accumulated > 0:
|
|
|
+ last_report_epoch = self.global_epoch
|
|
|
+
|
|
|
+ if last_report_epoch >= self.global_epoch - 1:
|
|
|
+ # report progress if peer is synchronized and actively reporting samples. Do not report aux peers.
|
|
|
+ store_task = asyncio.create_task(
|
|
|
+ asyncio.wait_for(
|
|
|
+ self.dht.store(
|
|
|
+ key=self.training_progress_key,
|
|
|
+ subkey=self._local_public_key,
|
|
|
+ value=local_progress.dict(),
|
|
|
+ expiration_time=last_report_time + self.metadata_expiration,
|
|
|
+ return_future=True,
|
|
|
+ ),
|
|
|
+ timeout=self.metadata_expiration,
|
|
|
+ )
|
|
|
)
|
|
|
- )
|
|
|
finally:
|
|
|
logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}")
|
|
|
if store_task is not None:
|