|
@@ -451,7 +451,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
for state in valid_peer_states:
|
|
|
total_samples_per_second += state.samples_per_second
|
|
|
- if state.step >= self.local_step and current_time - state.time < self.staleness_timeout:
|
|
|
+ if state.step == self.local_step and current_time - state.time < self.staleness_timeout:
|
|
|
total_samples_accumulated += state.samples_accumulated
|
|
|
estimated_current_samples += (
|
|
|
state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
|