Forráskód Böngészése

Minor tweaks learned from the demo experiment (#422)

- if matchmaking encountered ControlFailure (dial failed), it would catch an error and restart completely. A more preferable behavior is to treat it as any other network errror
- awaiting for delayed updates to finish will no longer freeze forever if underlying update hanged
- rollback default max_refresh_timeout to 10 seconds

Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
justheuristic 3 éve
szülő
commit
896885a9f0

+ 3 - 2
hivemind/averaging/matchmaking.py

@@ -16,6 +16,7 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils.asyncio import anext, cancel_and_wait
@@ -240,8 +241,8 @@ class Matchmaking:
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
-        except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.exception(f"{self} - failed to request potential leader {leader}:")
+        except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
+            logger.debug(f"{self} - failed to request potential leader {leader}:")
             return None
 
         finally:

+ 5 - 1
hivemind/optim/experimental/progress_tracker.py

@@ -83,7 +83,7 @@ class ProgressTracker(threading.Thread):
         *,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 30,
+        max_refresh_period: float = 10,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,
@@ -352,3 +352,7 @@ class ProgressTracker(threading.Thread):
             expiration_time=get_dht_time() + self.metadata_expiration,
             return_future=True,
         )
+
+    def __del__(self):
+        if self.is_alive():
+            self.shutdown()

+ 7 - 4
hivemind/optim/experimental/state_averager.py

@@ -376,17 +376,20 @@ class TrainingStateAverager(DecentralizedAverager):
         if wait_for_delayed_updates:
             for pending_update in self.pending_updates:
                 try:
+                    timeout = (averaging_opts or {}).get("averaging_timeout", self._allreduce_timeout)
                     logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
-                    output = pending_update.result()
+                    output = pending_update.result(timeout)
                 except BaseException:
-                    pass  # exception will be reported below
+                    # exception will be reported below
+                    if not pending_update.done():
+                        pending_update.cancel()
 
         # remove finished updates, log any exceptions
         finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
         self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
         for finished_update in finished_updates:
-            if finished_update.exception():
-                logger.log(self.status_loglevel, f"Background update failed with {finished_update.exception()}")
+            if finished_update.cancelled() or finished_update.exception():
+                logger.log(self.status_loglevel, f"Background update failed: {finished_update}")
 
         if apply_delayed_updates:
             if self.finished_averaging_round.is_set():