Ver código fonte

found on calm

justheuristic 3 anos atrás
pai
commit
10118ab370
1 arquivos alterados com 17 adições e 12 exclusões
  1. 17 12
      hivemind/optim/experimental/progress_tracker.py

+ 17 - 12
hivemind/optim/experimental/progress_tracker.py

@@ -195,6 +195,7 @@ class ProgressTracker(threading.Thread):
     async def _progress_reporter(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
         last_report_time = -float("inf")
+        last_report_epoch = -float("inf")
         store_task = None
         try:
             while not self.shutdown_triggered.is_set():
@@ -209,19 +210,23 @@ class ProgressTracker(threading.Thread):
 
                 local_progress = self.local_progress
                 last_report_time = get_dht_time()
-
-                store_task = asyncio.create_task(
-                    asyncio.wait_for(
-                        self.dht.store(
-                            key=self.training_progress_key,
-                            subkey=self._local_public_key,
-                            value=local_progress.dict(),
-                            expiration_time=last_report_time + self.metadata_expiration,
-                            return_future=True,
-                        ),
-                        timeout=self.metadata_expiration,
+                if local_progress.samples_accumulated > 0:
+                    last_report_epoch = self.global_epoch
+
+                if last_report_epoch >= self.global_epoch - 1:
+                    # report progress if peer is synchronized and actively reporting samples. Do not report aux peers.
+                    store_task = asyncio.create_task(
+                        asyncio.wait_for(
+                            self.dht.store(
+                                key=self.training_progress_key,
+                                subkey=self._local_public_key,
+                                value=local_progress.dict(),
+                                expiration_time=last_report_time + self.metadata_expiration,
+                                return_future=True,
+                            ),
+                            timeout=self.metadata_expiration,
+                        )
                     )
-                )
         finally:
             logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}")
             if store_task is not None: