|
@@ -257,6 +257,14 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
|
|
|
if group_info:
|
|
|
logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
|
|
|
+
|
|
|
+ # update our current step if we averaged with another peer that was at a more recent step
|
|
|
+ for peer, peer_step in group_info.items():
|
|
|
+ if isinstance(peer_step, int):
|
|
|
+ current_step = max(current_step, peer_step)
|
|
|
+ else:
|
|
|
+ logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
|
|
|
+
|
|
|
except BaseException as e:
|
|
|
logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
|
|
|
|
|
@@ -293,12 +301,19 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
self.collaboration_state_updated.set()
|
|
|
|
|
|
with self.lock_collaboration_state:
|
|
|
- # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
|
|
|
current_step, group_info = self.averager.local_step, None
|
|
|
+
|
|
|
try:
|
|
|
- group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
|
|
|
+ group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
|
|
|
if group_info:
|
|
|
logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
|
|
|
+
|
|
|
+ # update our current step if we averaged with another peer that was at a more recent step
|
|
|
+ for peer, peer_step in group_info.items():
|
|
|
+ if isinstance(peer_step, int):
|
|
|
+ current_step = max(current_step, peer_step)
|
|
|
+ else:
|
|
|
+ logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
|
|
|
except BaseException as e:
|
|
|
logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
|
|
|
|
|
@@ -431,7 +446,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
for state in valid_peer_states:
|
|
|
total_samples_per_second += state.samples_per_second
|
|
|
- if state.step == global_optimizer_step:
|
|
|
+ if state.step >= global_optimizer_step - self.step_tolerance:
|
|
|
total_samples_accumulated += state.samples_accumulated
|
|
|
estimated_current_samples += (
|
|
|
state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
|