justheuristic 3 жил өмнө
parent
commit
a45edcb716

+ 17 - 8
hivemind/optim/experimental/optimizer.py

@@ -194,9 +194,23 @@ class Optimizer(torch.optim.Optimizer):
     def local_epoch(self) -> int:
     def local_epoch(self) -> int:
         return self.state_averager.local_epoch
         return self.state_averager.local_epoch
 
 
-    def should_load_state_from_peers(self, new_epoch: bool = False) -> bool:
-        """If true, peer will discard local progress and attempt to download state from peers."""
-        if new_epoch:
+    def should_load_state_from_peers(self) -> bool:
+        """
+        If true, peer will discard local progress and attempt to download state from peers.
+        This method allows peer to continue training in two cases:
+         - peer is on the same epoch as other collaborators - keep training normally
+         - peer was on the same epoch and accumulated some grads, but some collaborators
+             have just transitioned to the next epoch - this peer should also transition.
+
+        :note: The latter case occurs due to the lack of network synchrony: the first peer that
+        detects enough samples will transition to the next step and start counting samples anew.
+        Some other peers may take time before they check with DHT and observe that
+          - the global epoch is technically one epoch ahead of the current one and
+          - the remaining (non-transitioned) peers no longer have target_batch_size between them
+        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
+        """
+        just_transitioned = self.grad_averager.local_samples_accumulated == 0
+        if just_transitioned:
             return self.local_epoch != self.tracker.global_epoch
             return self.local_epoch != self.tracker.global_epoch
         else:
         else:
             return self.local_epoch < self.tracker.global_epoch - 1
             return self.local_epoch < self.tracker.global_epoch - 1
@@ -308,11 +322,6 @@ class Optimizer(torch.optim.Optimizer):
                 self.grad_averager.reset_accumulated_grads_()
                 self.grad_averager.reset_accumulated_grads_()
                 self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
                 self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
 
 
-                if self.should_load_state_from_peers(new_epoch=True):
-                    logger.log(self.status_loglevel, "Peer ended up out of sync after averaging.")
-                    self.load_state_from_peers()
-                    return loss
-
             logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
             logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
         return loss
         return loss