|
@@ -256,8 +256,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")
|
|
|
self.scheduled_round = None
|
|
|
|
|
|
- need_averaging = self.tracker.global_progress.num_peers > 1
|
|
|
- if need_averaging:
|
|
|
+ swarm_not_empty = self.tracker.global_progress.num_peers > 1
|
|
|
+ if swarm_not_empty:
|
|
|
try:
|
|
|
group_info = self.grad_averager.step(
|
|
|
control=self.scheduled_round, reset_accumulators=True, timeout=self.averaging_timeout
|
|
@@ -276,18 +276,19 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
|
|
|
with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
|
|
|
# note: we do not need to replace because the offloaded optimizer is already using averaged grads
|
|
|
+ next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
|
|
|
|
|
|
self.state_averager.step(
|
|
|
increment_epoch=True,
|
|
|
optimizer_step=True,
|
|
|
delay_optimizer_step=self.delay_optimizer_step,
|
|
|
grad_scaler=grad_scaler,
|
|
|
- averaging_round=need_averaging and self.tracker.global_epoch % self.average_state_every == 0,
|
|
|
+ averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
|
|
|
delay_averaging=True,
|
|
|
averaging_opts=dict(
|
|
|
scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
|
|
|
)
|
|
|
- if need_averaging
|
|
|
+ if swarm_not_empty
|
|
|
else None,
|
|
|
)
|
|
|
|