Browse Source

Reset gradient buffers when synchronizing with peers (#222)

justheuristic 4 years ago
parent
commit
32b87bf3fe
1 changed files with 5 additions and 6 deletions
  1. 5 6
      hivemind/client/optim/collaborative.py

+ 5 - 6
hivemind/client/optim/collaborative.py

@@ -131,8 +131,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_collaboration_state:
         with self.lock_collaboration_state:
             self.averager.load_state_from_peers(**kwargs)
             self.averager.load_state_from_peers(**kwargs)
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.reset_accumulated_grads_()
             self.update_scheduler()
             self.update_scheduler()
-            self.opt.zero_grad()
 
 
     def step(self, batch_size: Optional[int] = None, **kwargs):
     def step(self, batch_size: Optional[int] = None, **kwargs):
         """
         """
@@ -189,8 +189,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 self.averager.local_step += 1
                 self.averager.local_step += 1
 
 
             self.opt.step()
             self.opt.step()
-            if self.reuse_grad_buffers:
-                self.opt.zero_grad()
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.collaboration_state.register_step()
             self.collaboration_state.register_step()
@@ -240,9 +238,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     @torch.no_grad()
     @torch.no_grad()
     def reset_accumulated_grads_(self):
     def reset_accumulated_grads_(self):
         if self.reuse_grad_buffers:
         if self.reuse_grad_buffers:
-            return
-        for grad_buf in self.accumulated_grads():
-            grad_buf.zero_()
+            self.opt.zero_grad()
+        else:
+            for grad_buf in self.accumulated_grads():
+                grad_buf.zero_()
 
 
     def report_training_progress(self):
     def report_training_progress(self):
         """ Periodically publish metadata and the current number of samples accumulated towards the next step """
         """ Periodically publish metadata and the current number of samples accumulated towards the next step """