|
@@ -396,9 +396,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
|
|
|
swarm_not_empty = self.tracker.global_progress.num_peers > 1
|
|
|
should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
|
|
|
- should_average_state = (swarm_not_empty and
|
|
|
- next_epoch % self.average_state_every == 0 and
|
|
|
- not self.state_averager.averaging_in_progress)
|
|
|
+ should_average_state = (
|
|
|
+ swarm_not_empty
|
|
|
+ and next_epoch % self.average_state_every == 0
|
|
|
+ and not self.state_averager.averaging_in_progress
|
|
|
+ )
|
|
|
|
|
|
if should_average_state and self.scheduled_state is not None:
|
|
|
if self.scheduled_state.triggered or self.scheduled_state.done():
|