Browse Source

hopefully handle shutdown correctly

justheuristic 3 years ago
parent
commit
e46d53b15e
1 changed files with 22 additions and 5 deletions
  1. 22 5
      hivemind/optim/experimental/progress_tracker.py

+ 22 - 5
hivemind/optim/experimental/progress_tracker.py

@@ -189,6 +189,7 @@ class ProgressTracker(threading.Thread):
     async def _progress_reporter(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
         last_report_time = -float("inf")
+        store_task = None
         try:
             while not self.shutdown_triggered.is_set():
                 wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
@@ -203,21 +204,34 @@ class ProgressTracker(threading.Thread):
                 local_progress = self.local_progress
                 last_report_time = get_dht_time()
 
-                await self.dht.store(
+                store_task = asyncio.create_task(asyncio.wait_for(
+                    self.dht.store(
                     key=self.training_progress_key,
                     subkey=self._local_public_key,
                     value=local_progress.dict(),
                     expiration_time=last_report_time + self.metadata_expiration,
                     return_future=True,
-                )
+                ), timeout=self.metadata_expiration))
         finally:
             logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")
+            if store_task is not None and not store_task.done():
+                store_task.cancel()
 
     async def _progress_fetcher(self):
         """
         Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
         """
         loop = asyncio.get_event_loop()
+        shutdown_checker = asyncio.create_task(loop.run_in_executor(None, self.shutdown_triggered.wait))
+
+        async def _fetch_progress_unless_shutdown_triggered():
+            """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
+            getter = asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
+            await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
+            if self.shutdown_triggered.is_set():
+                return
+            return await getter
+
         try:
             while not self.shutdown_triggered.is_set():
                 time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
@@ -229,9 +243,12 @@ class ProgressTracker(threading.Thread):
                     continue
 
                 async with enter_asynchronously(self.lock_global_progress):
-                    progress_entry = await self.dht.get(self.training_progress_key, latest=True, return_future=True)
-                    metadata = progress_entry.value if isinstance(progress_entry, ValueWithExpiration) else None
+                    maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
+                    if self.shutdown_triggered.is_set():
+                        break
+                    metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
                     self.global_progress = self._parse_swarm_progress_data(metadata)
+
         finally:
             logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
 
@@ -307,7 +324,7 @@ class ProgressTracker(threading.Thread):
             next_fetch_time=current_time + time_to_next_fetch,
         )
 
-    def shutdown(self, timeout: Optional[float]=None):
+    def shutdown(self, timeout: Optional[float] = None):
         """Permanently disable all tracking activity"""
         self.shutdown_triggered.set()
         self.should_report_progress.set()