|
@@ -236,11 +236,15 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
with self.performance_ema.pause(), self.lock_collaboration_state:
|
|
|
# divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
|
|
|
+ logger.log(self.status_loglevel, f"-4")
|
|
|
self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
|
|
|
+ logger.log(self.status_loglevel, f"-3")
|
|
|
current_step, group_info = self.averager.local_step, None
|
|
|
+ logger.log(self.status_loglevel, f"-2")
|
|
|
|
|
|
if self.collaboration_state.num_peers > 1:
|
|
|
weight = self.local_samples_accumulated / self.target_batch_size
|
|
|
+ logger.log(self.status_loglevel, f"-1")
|
|
|
try:
|
|
|
group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
|
|
|
if group_info:
|