Jelajahi Sumber

fix a bug that incorrectly accounted for step tolerance in CollaborativeOptimizer

justheuristic 3 tahun lalu
induk
melakukan
9f5fc866c8
1 mengubah file dengan 18 tambahan dan 3 penghapusan
  1. 18 3
      hivemind/optim/collaborative.py

+ 18 - 3
hivemind/optim/collaborative.py

@@ -257,6 +257,14 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
                     if group_info:
                         logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                        # update our current step if we averaged with another peer that was at a more recent step
+                        for peer, peer_step in group_info.items():
+                            if isinstance(peer_step, int):
+                                current_step = max(current_step, peer_step)
+                            else:
+                                logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
+
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
@@ -293,12 +301,19 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.collaboration_state_updated.set()
 
         with self.lock_collaboration_state:
-            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             current_step, group_info = self.averager.local_step, None
+
             try:
-                group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
+                group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
                 if group_info:
                     logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                    # update our current step if we averaged with another peer that was at a more recent step
+                    for peer, peer_step in group_info.items():
+                        if isinstance(peer_step, int):
+                            current_step = max(current_step, peer_step)
+                        else:
+                            logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
             except BaseException as e:
                 logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
@@ -431,7 +446,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         for state in valid_peer_states:
             total_samples_per_second += state.samples_per_second
-            if state.step == global_optimizer_step:
+            if state.step >= global_optimizer_step - self.step_tolerance:
                 total_samples_accumulated += state.samples_accumulated
                 estimated_current_samples += (
                     state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second