浏览代码

PR attribution (from CollaborativeOptimizer)

Alexey Bukhtiyarov 3 年之前
父节点
当前提交
bf7a2084d2
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      hivemind/optim/experimental/optimizer.py

+ 1 - 1
hivemind/optim/experimental/optimizer.py

@@ -402,10 +402,10 @@ class Optimizer(torch.optim.Optimizer):
                 self.state_averager.local_epoch = self.tracker.global_epoch
 
             self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
             if not self.client_mode:
                 self.grad_averager.state_sharing_priority = self.local_epoch
                 self.state_averager.state_sharing_priority = self.local_epoch
-            self.grad_averager.reset_accumulated_grads_()
 
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()