Эх сурвалжийг харах

PR attribution (from CollaborativeOptimizer)

Alexey Bukhtiyarov 3 жил өмнө
parent
commit
bf7a2084d2

+ 1 - 1
hivemind/optim/experimental/optimizer.py

@@ -402,10 +402,10 @@ class Optimizer(torch.optim.Optimizer):
                 self.state_averager.local_epoch = self.tracker.global_epoch
 
             self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
             if not self.client_mode:
                 self.grad_averager.state_sharing_priority = self.local_epoch
                 self.state_averager.state_sharing_priority = self.local_epoch
-            self.grad_averager.reset_accumulated_grads_()
 
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()