فهرست منبع

Reset gradient buffers when synchronizing with peers (#222)

justheuristic 4 سال پیش
والد
کامیت
32b87bf3fe
1فایلهای تغییر یافته به همراه5 افزوده شده و 6 حذف شده
  1. 5 6
      hivemind/client/optim/collaborative.py

+ 5 - 6
hivemind/client/optim/collaborative.py

@@ -131,8 +131,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_collaboration_state:
             self.averager.load_state_from_peers(**kwargs)
             self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.reset_accumulated_grads_()
             self.update_scheduler()
-            self.opt.zero_grad()
 
     def step(self, batch_size: Optional[int] = None, **kwargs):
         """
@@ -189,8 +189,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 self.averager.local_step += 1
 
             self.opt.step()
-            if self.reuse_grad_buffers:
-                self.opt.zero_grad()
             self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.collaboration_state.register_step()
@@ -240,9 +238,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     @torch.no_grad()
     def reset_accumulated_grads_(self):
         if self.reuse_grad_buffers:
-            return
-        for grad_buf in self.accumulated_grads():
-            grad_buf.zero_()
+            self.opt.zero_grad()
+        else:
+            for grad_buf in self.accumulated_grads():
+                grad_buf.zero_()
 
     def report_training_progress(self):
         """ Periodically publish metadata and the current number of samples accumulated towards the next step """