|
@@ -194,9 +194,23 @@ class Optimizer(torch.optim.Optimizer):
|
|
def local_epoch(self) -> int:
|
|
def local_epoch(self) -> int:
|
|
return self.state_averager.local_epoch
|
|
return self.state_averager.local_epoch
|
|
|
|
|
|
- def should_load_state_from_peers(self, new_epoch: bool = False) -> bool:
|
|
|
|
- """If true, peer will discard local progress and attempt to download state from peers."""
|
|
|
|
- if new_epoch:
|
|
|
|
|
|
+ def should_load_state_from_peers(self) -> bool:
|
|
|
|
+ """
|
|
|
|
+ If true, peer will discard local progress and attempt to download state from peers.
|
|
|
|
+ This method allows peer to continue training in two cases:
|
|
|
|
+ - peer is on the same epoch as other collaborators - keep training normally
|
|
|
|
+ - peer was on the same epoch and accumulated some grads, but some collaborators
|
|
|
|
+ have just transitioned to the next epoch - this peer should also transition.
|
|
|
|
+
|
|
|
|
+ :note: The latter case occurs due to the lack of network synchrony: the first peer that
|
|
|
|
+ detects enough samples will transition to the next step and start counting samples anew.
|
|
|
|
+ Some other peers may take time before they check with DHT and observe that
|
|
|
|
+ - the global epoch is technically one epoch ahead of the current one and
|
|
|
|
+ - the remaining (non-transitioned) peers no longer have target_batch_size between them
|
|
|
|
+ If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
|
|
|
|
+ """
|
|
|
|
+ just_transitioned = self.grad_averager.local_samples_accumulated == 0
|
|
|
|
+ if just_transitioned:
|
|
return self.local_epoch != self.tracker.global_epoch
|
|
return self.local_epoch != self.tracker.global_epoch
|
|
else:
|
|
else:
|
|
return self.local_epoch < self.tracker.global_epoch - 1
|
|
return self.local_epoch < self.tracker.global_epoch - 1
|
|
@@ -308,11 +322,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
|
|
|
- if self.should_load_state_from_peers(new_epoch=True):
|
|
|
|
- logger.log(self.status_loglevel, "Peer ended up out of sync after averaging.")
|
|
|
|
- self.load_state_from_peers()
|
|
|
|
- return loss
|
|
|
|
-
|
|
|
|
logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
|
|
logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
|
|
return loss
|
|
return loss
|
|
|
|
|