Explorar el Código

Minor tweaks learned from the demo experiment (#422)

- if matchmaking encountered ControlFailure (dial failed), it would catch an error and restart completely. A more preferable behavior is to treat it as any other network errror
- awaiting for delayed updates to finish will no longer freeze forever if underlying update hanged
- rollback default max_refresh_timeout to 10 seconds

Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
justheuristic hace 3 años
padre
commit
896885a9f0

+ 3 - 2
hivemind/averaging/matchmaking.py

@@ -16,6 +16,7 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils.asyncio import anext, cancel_and_wait
 from hivemind.utils.asyncio import anext, cancel_and_wait
@@ -240,8 +241,8 @@ class Matchmaking:
         except asyncio.TimeoutError:
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
             return None
-        except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.exception(f"{self} - failed to request potential leader {leader}:")
+        except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
+            logger.debug(f"{self} - failed to request potential leader {leader}:")
             return None
             return None
 
 
         finally:
         finally:

+ 5 - 1
hivemind/optim/experimental/progress_tracker.py

@@ -83,7 +83,7 @@ class ProgressTracker(threading.Thread):
         *,
         *,
         client_mode: Optional[bool] = None,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 30,
+        max_refresh_period: float = 10,
         default_refresh_period: float = 3,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,
         expected_drift_rate: float = 0.2,
@@ -352,3 +352,7 @@ class ProgressTracker(threading.Thread):
             expiration_time=get_dht_time() + self.metadata_expiration,
             expiration_time=get_dht_time() + self.metadata_expiration,
             return_future=True,
             return_future=True,
         )
         )
+
+    def __del__(self):
+        if self.is_alive():
+            self.shutdown()

+ 7 - 4
hivemind/optim/experimental/state_averager.py

@@ -376,17 +376,20 @@ class TrainingStateAverager(DecentralizedAverager):
         if wait_for_delayed_updates:
         if wait_for_delayed_updates:
             for pending_update in self.pending_updates:
             for pending_update in self.pending_updates:
                 try:
                 try:
+                    timeout = (averaging_opts or {}).get("averaging_timeout", self._allreduce_timeout)
                     logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
                     logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
-                    output = pending_update.result()
+                    output = pending_update.result(timeout)
                 except BaseException:
                 except BaseException:
-                    pass  # exception will be reported below
+                    # exception will be reported below
+                    if not pending_update.done():
+                        pending_update.cancel()
 
 
         # remove finished updates, log any exceptions
         # remove finished updates, log any exceptions
         finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
         finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
         self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
         self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
         for finished_update in finished_updates:
         for finished_update in finished_updates:
-            if finished_update.exception():
-                logger.log(self.status_loglevel, f"Background update failed with {finished_update.exception()}")
+            if finished_update.cancelled() or finished_update.exception():
+                logger.log(self.status_loglevel, f"Background update failed: {finished_update}")
 
 
         if apply_delayed_updates:
         if apply_delayed_updates:
             if self.finished_averaging_round.is_set():
             if self.finished_averaging_round.is_set():