|
@@ -155,7 +155,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
extra_tensors: Sequence[torch.Tensor] = (),
|
|
|
averager_opts: Optional[dict] = None,
|
|
|
tracker_opts: Optional[dict] = None,
|
|
|
- preschedule_state_averaging: bool = False,
|
|
|
performance_ema_alpha: float = 0.1,
|
|
|
shutdown_timeout: float = 5,
|
|
|
verbose: bool = False,
|
|
@@ -190,7 +189,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
|
|
|
self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
|
|
|
self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
|
|
|
- self.preschedule_state_averaging = preschedule_state_averaging
|
|
|
|
|
|
self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
|
|
|
self.shutdown_timeout = shutdown_timeout
|
|
@@ -350,8 +348,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
return loss # local gradients were reset due to overflow, must start over
|
|
|
|
|
|
self._maybe_schedule_gradient_averaging()
|
|
|
- if self.preschedule_state_averaging:
|
|
|
- self._maybe_schedule_state_averaging()
|
|
|
+ self._maybe_schedule_state_averaging()
|
|
|
|
|
|
else:
|
|
|
# use_local_updates=True: update parameters on every step independently of other peers
|
|
@@ -362,8 +359,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
|
|
|
self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
|
|
|
- if self.preschedule_state_averaging:
|
|
|
- self._maybe_schedule_state_averaging()
|
|
|
+ self._maybe_schedule_state_averaging()
|
|
|
|
|
|
self.state_averager.step(
|
|
|
increment_epoch=False,
|
|
@@ -400,7 +396,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
|
|
|
swarm_not_empty = self.tracker.global_progress.num_peers > 1
|
|
|
should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
|
|
|
- should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
|
|
|
+ should_average_state = (swarm_not_empty and
|
|
|
+ next_epoch % self.average_state_every == 0 and
|
|
|
+ not self.state_averager.averaging_in_progress)
|
|
|
|
|
|
if should_average_state and self.scheduled_state is not None:
|
|
|
if self.scheduled_state.triggered or self.scheduled_state.done():
|
|
@@ -410,7 +408,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
f"was already used elsewhere: {self.scheduled_state}",
|
|
|
)
|
|
|
self.scheduled_state = None
|
|
|
-
|
|
|
self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
|
|
|
|
|
|
self.state_averager.step(
|
|
@@ -502,10 +499,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def _maybe_schedule_state_averaging(self) -> None:
|
|
|
"""If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
|
|
|
- assert self.preschedule_state_averaging
|
|
|
next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
|
|
|
if next_epoch % self.average_state_every != 0:
|
|
|
return # averaging is not performed at this epoch
|
|
|
+ if self.state_averager.averaging_in_progress:
|
|
|
+ return # previous run is still in progress
|
|
|
|
|
|
estimated_time = self.tracker.estimated_next_update_time
|
|
|
estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
|
|
@@ -599,6 +597,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
def load_state_from_peers(self, **kwargs):
|
|
|
"""Attempt to fetch the newest collaboration state from other peers"""
|
|
|
self._finish_background_averaging()
|
|
|
+ self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
|
while True:
|