|
@@ -131,8 +131,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
with self.lock_collaboration_state:
|
|
with self.lock_collaboration_state:
|
|
self.averager.load_state_from_peers(**kwargs)
|
|
self.averager.load_state_from_peers(**kwargs)
|
|
self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
|
|
+ self.reset_accumulated_grads_()
|
|
self.update_scheduler()
|
|
self.update_scheduler()
|
|
- self.opt.zero_grad()
|
|
|
|
|
|
|
|
def step(self, batch_size: Optional[int] = None, **kwargs):
|
|
def step(self, batch_size: Optional[int] = None, **kwargs):
|
|
"""
|
|
"""
|
|
@@ -189,8 +189,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
self.averager.local_step += 1
|
|
self.averager.local_step += 1
|
|
|
|
|
|
self.opt.step()
|
|
self.opt.step()
|
|
- if self.reuse_grad_buffers:
|
|
|
|
- self.opt.zero_grad()
|
|
|
|
self.reset_accumulated_grads_()
|
|
self.reset_accumulated_grads_()
|
|
self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
self.collaboration_state.register_step()
|
|
self.collaboration_state.register_step()
|
|
@@ -240,9 +238,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
def reset_accumulated_grads_(self):
|
|
def reset_accumulated_grads_(self):
|
|
if self.reuse_grad_buffers:
|
|
if self.reuse_grad_buffers:
|
|
- return
|
|
|
|
- for grad_buf in self.accumulated_grads():
|
|
|
|
- grad_buf.zero_()
|
|
|
|
|
|
+ self.opt.zero_grad()
|
|
|
|
+ else:
|
|
|
|
+ for grad_buf in self.accumulated_grads():
|
|
|
|
+ grad_buf.zero_()
|
|
|
|
|
|
def report_training_progress(self):
|
|
def report_training_progress(self):
|
|
""" Periodically publish metadata and the current number of samples accumulated towards the next step """
|
|
""" Periodically publish metadata and the current number of samples accumulated towards the next step """
|