Ver Fonte

load state timeout

justheuristic há 4 anos atrás
pai
commit
6975af0b7f
1 ficheiros alterados com 11 adições e 2 exclusões
  1. 11 2
      hivemind/optim/collaborative.py

+ 11 - 2
hivemind/optim/collaborative.py

@@ -85,6 +85,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
     :param scheduler: if specified, use this scheduler to update optimizer learning rate
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
@@ -114,6 +115,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         performance_ema_alpha: float = 0.1,
         metadata_expiration: float = 60.0,
         averaging_timeout: Optional[float] = None,
+        load_state_timeout: float = 600.0,
         step_tolerance: int = 1,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
@@ -137,7 +139,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             default_refresh_period,
         )
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
-        self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
+        self.averaging_timeout, self.load_state_timeout, self.metadata_expiration = \
+            averaging_timeout, load_state_timeout, metadata_expiration
         self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
         self.client_mode, self.step_tolerance = client_mode, step_tolerance
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
@@ -185,7 +188,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
         with self.lock_collaboration_state:
-            self.averager.load_state_from_peers(**kwargs)
+            while True:
+                try:
+                    self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, will retry now")
+                    continue
+
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.reset_accumulated_grads_()
             self.update_scheduler()