浏览代码

Reset gradient buffers when synchronizing with peers (#222)

justheuristic 4 年之前
父节点
当前提交
32b87bf3fe
共有 1 个文件被更改,包括 5 次插入6 次删除
  1. 5 6
      hivemind/client/optim/collaborative.py

+ 5 - 6
hivemind/client/optim/collaborative.py

@@ -131,8 +131,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_collaboration_state:
         with self.lock_collaboration_state:
             self.averager.load_state_from_peers(**kwargs)
             self.averager.load_state_from_peers(**kwargs)
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.reset_accumulated_grads_()
             self.update_scheduler()
             self.update_scheduler()
-            self.opt.zero_grad()
 
 
     def step(self, batch_size: Optional[int] = None, **kwargs):
     def step(self, batch_size: Optional[int] = None, **kwargs):
         """
         """
@@ -189,8 +189,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 self.averager.local_step += 1
                 self.averager.local_step += 1
 
 
             self.opt.step()
             self.opt.step()
-            if self.reuse_grad_buffers:
-                self.opt.zero_grad()
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.collaboration_state.register_step()
             self.collaboration_state.register_step()
@@ -240,9 +238,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     @torch.no_grad()
     @torch.no_grad()
     def reset_accumulated_grads_(self):
     def reset_accumulated_grads_(self):
         if self.reuse_grad_buffers:
         if self.reuse_grad_buffers:
-            return
-        for grad_buf in self.accumulated_grads():
-            grad_buf.zero_()
+            self.opt.zero_grad()
+        else:
+            for grad_buf in self.accumulated_grads():
+                grad_buf.zero_()
 
 
     def report_training_progress(self):
     def report_training_progress(self):
         """ Periodically publish metadata and the current number of samples accumulated towards the next step """
         """ Periodically publish metadata and the current number of samples accumulated towards the next step """