Bladeren bron

check for exact synchronization once per step

justheuristic 3 jaren geleden
bovenliggende
commit
e924917889

+ 6 - 6
hivemind/optim/experimental/grad_averager.py

@@ -170,6 +170,12 @@ class GradientAverager(DecentralizedAverager):
         elif len(kwargs) > 0:
             raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
         assert not control.triggered, f"This {type(control)} instance was already used."
+        if self._new_averaged_grads and self.warn:
+            logger.warning(
+                "[warn=True] Starting new averaging round, but previous round results were not used."
+                "This may be a sign of incorrect optimizer behavior."
+            )
+
         self.load_accumulators_into_averager_()
         self._accumulators_used_in_step = True
         self._new_averaged_grads = True
@@ -184,12 +190,6 @@ class GradientAverager(DecentralizedAverager):
     @torch.no_grad()
     def load_accumulators_into_averager_(self):
         """load locally accumulated gradients into the averager for aggregation"""
-        if self._new_averaged_grads and self.warn:
-            logger.warning(
-                "[warn=True] Starting new averaging round, but previous round results were not used."
-                "This may be a sign of incorrect optimizer behavior."
-            )
-            self._new_averaged_grads = False  # warn once per round
         # divide locally accumulated gradients by the number of times they were accumulated
         grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
         with self.get_tensors() as averaged_grads:

+ 8 - 8
hivemind/optim/experimental/optimizer.py

@@ -125,7 +125,7 @@ class Optimizer(torch.optim.Optimizer):
         )
         self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
         self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
-        self._last_synchronized_time = get_dht_time()
+        self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
         self._schema_hash = self._compute_schema_hash()
         self._parent_pid = os.getpid()
 
@@ -209,11 +209,10 @@ class Optimizer(torch.optim.Optimizer):
           - the remaining (non-transitioned) peers no longer have target_batch_size between them
         If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
         """
-        just_transitioned = self.grad_averager.local_samples_accumulated == 0
-        if just_transitioned:
-            return self.local_epoch != self.tracker.global_epoch
-        else:
-            return self.local_epoch < self.tracker.global_epoch - 1
+        if self._should_check_synchronization_on_update and self.tracker.updated_progress_this_epoch.is_set():
+            self._should_check_synchronization_on_update = False
+            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
+        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
 
     def step(
         self,
@@ -294,7 +293,7 @@ class Optimizer(torch.optim.Optimizer):
                     self.grad_averager.load_accumulators_into_averager_()
 
             else:
-                if self.scheduled_round is not None:
+                if self.scheduled_round is not None and not self.scheduled_round.done():
                     self.scheduled_round.cancel()
                 logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
                 self.grad_averager.load_accumulators_into_averager_()
@@ -321,8 +320,9 @@ class Optimizer(torch.optim.Optimizer):
             if not self.auxiliary:
                 self.grad_averager.reset_accumulated_grads_()
                 self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+                self._should_check_synchronization_on_update = True
 
-            logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
+            logger.log(self.status_loglevel, f"Optimizer step done! Transitioning to epoch {self.local_epoch}.")
         return loss
 
     def zero_grad(self, set_to_none: bool = False):

+ 3 - 1
hivemind/optim/experimental/progress_tracker.py

@@ -114,7 +114,7 @@ class ProgressTracker(threading.Thread):
         metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         self.global_progress = self._parse_swarm_progress_data(metadata)
         self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
-        self.should_report_progress = threading.Event()
+        self.should_report_progress, self.updated_progress_this_epoch = threading.Event(), threading.Event()
         self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
         super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
         if start:
@@ -178,6 +178,7 @@ class ProgressTracker(threading.Thread):
             self.global_progress.samples_accumulated = 0
             self.global_progress.eta_next_epoch = float("inf")
         self.report_local_progress(new_epoch, samples_accumulated=0)
+        self.updated_progress_this_epoch.clear()
         return new_epoch
 
     def run(self):
@@ -252,6 +253,7 @@ class ProgressTracker(threading.Thread):
                         break
                     metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
                     self.global_progress = self._parse_swarm_progress_data(metadata)
+                    self.updated_progress_this_epoch.set()
 
         finally:
             logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")