|
@@ -254,7 +254,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
|
|
|
weight = self.local_samples_accumulated / mean_samples_per_worker
|
|
|
try:
|
|
|
- group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
|
|
|
+ group_info = self.averager.step(
|
|
|
+ weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs)
|
|
|
if group_info:
|
|
|
logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
|
|
|
|