|
@@ -125,7 +125,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
)
|
|
|
self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
|
|
|
self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
|
|
|
- self._last_synchronized_time = get_dht_time()
|
|
|
+ self._should_check_synchronization_on_update = True # used in self.should_load_state_from_peers
|
|
|
self._schema_hash = self._compute_schema_hash()
|
|
|
self._parent_pid = os.getpid()
|
|
|
|
|
@@ -209,11 +209,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
- the remaining (non-transitioned) peers no longer have target_batch_size between them
|
|
|
If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
|
|
|
"""
|
|
|
- just_transitioned = self.grad_averager.local_samples_accumulated == 0
|
|
|
- if just_transitioned:
|
|
|
- return self.local_epoch != self.tracker.global_epoch
|
|
|
- else:
|
|
|
- return self.local_epoch < self.tracker.global_epoch - 1
|
|
|
+ if self._should_check_synchronization_on_update and self.tracker.updated_progress_this_epoch.is_set():
|
|
|
+ self._should_check_synchronization_on_update = False
|
|
|
+ return self.local_epoch != self.tracker.global_epoch # require exact synchronization once per step
|
|
|
+ return self.local_epoch < self.tracker.global_epoch - 1 # catch up if a peer just switched to next epoch
|
|
|
|
|
|
def step(
|
|
|
self,
|
|
@@ -294,7 +293,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.grad_averager.load_accumulators_into_averager_()
|
|
|
|
|
|
else:
|
|
|
- if self.scheduled_round is not None:
|
|
|
+ if self.scheduled_round is not None and not self.scheduled_round.done():
|
|
|
self.scheduled_round.cancel()
|
|
|
logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
|
|
|
self.grad_averager.load_accumulators_into_averager_()
|
|
@@ -321,8 +320,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if not self.auxiliary:
|
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
+ self._should_check_synchronization_on_update = True
|
|
|
|
|
|
- logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
|
|
|
+ logger.log(self.status_loglevel, f"Optimizer step done! Transitioning to epoch {self.local_epoch}.")
|
|
|
return loss
|
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False):
|