|
@@ -338,6 +338,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
self._should_check_synchronization_on_update = True
|
|
|
|
|
|
+ if not self.client_mode:
|
|
|
+ self.grad_averager.state_sharing_priority = self.local_epoch
|
|
|
+ self.state_averager.state_sharing_priority = self.local_epoch
|
|
|
+
|
|
|
logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}.")
|
|
|
return loss
|
|
|
|
|
@@ -398,6 +402,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.state_averager.local_epoch = self.tracker.global_epoch
|
|
|
|
|
|
self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
|
|
|
+ if not self.client_mode:
|
|
|
+ self.grad_averager.state_sharing_priority = self.local_epoch
|
|
|
+ self.state_averager.state_sharing_priority = self.local_epoch
|
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
|
|
|
|
def state_dict(self) -> dict:
|