|
@@ -402,10 +402,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.state_averager.local_epoch = self.tracker.global_epoch
|
|
|
|
|
|
self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
|
|
|
+ self.grad_averager.reset_accumulated_grads_()
|
|
|
if not self.client_mode:
|
|
|
self.grad_averager.state_sharing_priority = self.local_epoch
|
|
|
self.state_averager.state_sharing_priority = self.local_epoch
|
|
|
- self.grad_averager.reset_accumulated_grads_()
|
|
|
|
|
|
def state_dict(self) -> dict:
|
|
|
state_dict = self.state_averager.optimizer.state_dict()
|