Browse Source

ignore_stale_updates

justheuristic 3 years ago
parent
commit
821f9ea41f
1 changed files with 18 additions and 9 deletions
  1. 18 9
      hivemind/optim/collaborative.py

+ 18 - 9
hivemind/optim/collaborative.py

@@ -80,6 +80,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
       refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
     :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
+    :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
     :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
     :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
@@ -115,6 +116,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         metadata_expiration: float = 60.0,
         averaging_timeout: Optional[float] = None,
         load_state_timeout: float = 600.0,
+        step_tolerance: int = 1,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = False,
@@ -141,7 +143,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.load_state_timeout = load_state_timeout
         self.metadata_expiration = metadata_expiration
         self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
-        self.client_mode = client_mode
+        self.client_mode, self.step_tolerance = client_mode, step_tolerance
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
 
@@ -177,10 +179,18 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     def local_step(self) -> int:
         return self.averager.local_step
 
+    @local_step.setter
+    def local_step(self, new_value: int):
+        self.averager.local_step = new_value
+
     @property
     def is_synchronized(self) -> bool:
         return self.local_step >= self.collaboration_state.optimizer_step
 
+    @property
+    def is_within_tolerance(self) -> bool:
+        return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
+
     def is_alive(self) -> bool:
         return self.averager.is_alive()
 
@@ -213,10 +223,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.batch_size_per_step = batch_size
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
-        if not self.is_synchronized:
+        if not self.is_synchronized and not self.is_within_tolerance:
             logger.log(self.status_loglevel, "Peer is out of sync.")
             self.load_state_from_peers()
             return
+        elif not self.is_synchronized and self.is_within_tolerance:
+            self.averager.local_step = self.collaboration_state.optimizer_step
+            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
             logger.warning(
@@ -236,14 +249,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
 
         logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self._fetch_state()
-        self.collaboration_state_updated.set()
-
-        if not self.is_synchronized:
-            self.load_state_from_peers()
-            return
-
         with self.performance_ema.pause(), self.lock_collaboration_state:
+            self.collaboration_state = self._fetch_state()
+            self.collaboration_state_updated.set()
+
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
             current_step, group_info = self.averager.local_step, None