|
@@ -155,6 +155,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
extra_tensors: Sequence[torch.Tensor] = (),
|
|
|
averager_opts: Optional[dict] = None,
|
|
|
tracker_opts: Optional[dict] = None,
|
|
|
+ preschedule_state_averaging: bool = False,
|
|
|
performance_ema_alpha: float = 0.1,
|
|
|
shutdown_timeout: float = 5,
|
|
|
verbose: bool = False,
|
|
@@ -189,6 +190,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
|
|
|
self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
|
|
|
self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
|
|
|
+ self.preschedule_state_averaging = preschedule_state_averaging
|
|
|
+
|
|
|
self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
|
|
|
self.shutdown_timeout = shutdown_timeout
|
|
|
|
|
@@ -347,7 +350,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
return loss # local gradients were reset due to overflow, must start over
|
|
|
|
|
|
self._maybe_schedule_gradient_averaging()
|
|
|
- self._maybe_schedule_state_averaging()
|
|
|
+ if self.preschedule_state_averaging:
|
|
|
+ self._maybe_schedule_state_averaging()
|
|
|
|
|
|
else:
|
|
|
# use_local_updates=True: update parameters on every step independently of other peers
|
|
@@ -358,7 +362,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
|
|
|
self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
|
|
|
- self._maybe_schedule_state_averaging()
|
|
|
+ if self.preschedule_state_averaging:
|
|
|
+ self._maybe_schedule_state_averaging()
|
|
|
|
|
|
self.state_averager.step(
|
|
|
increment_epoch=False,
|
|
@@ -399,8 +404,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
if should_average_state and self.scheduled_state is not None:
|
|
|
if self.scheduled_state.triggered or self.scheduled_state.done():
|
|
|
- logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
|
|
|
- f"was already used elsewhere: {self.scheduled_state}")
|
|
|
+ logger.log(
|
|
|
+ self.status_loglevel,
|
|
|
+ f"Not using pre-scheduled group for state averaging because it"
|
|
|
+ f"was already used elsewhere: {self.scheduled_state}",
|
|
|
+ )
|
|
|
self.scheduled_state = None
|
|
|
|
|
|
self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
|
|
@@ -417,6 +425,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
|
|
|
)
|
|
|
|
|
|
+ if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
|
|
|
+ self.scheduled_state.cancel()
|
|
|
+ self.scheduled_state = None
|
|
|
+
|
|
|
self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
self._should_check_synchronization_on_update = True
|
|
|
# the above line ensures that peers check for *strict* synchronization once per epoch
|
|
@@ -439,8 +451,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
began_averaging_gradients = False
|
|
|
if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
|
|
|
- logger.log(self.status_loglevel, f"Not using pre-scheduled group for state averaging because it"
|
|
|
- f"was already used elsewhere: {self.scheduled_state}")
|
|
|
+ logger.log(
|
|
|
+ self.status_loglevel,
|
|
|
+ f"Not using pre-scheduled group for state averaging because it"
|
|
|
+ f"was already used elsewhere: {self.scheduled_state}",
|
|
|
+ )
|
|
|
self.scheduled_grads = None
|
|
|
|
|
|
elif self.tracker.global_progress.num_peers > 1:
|
|
@@ -487,6 +502,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def _maybe_schedule_state_averaging(self) -> None:
|
|
|
"""If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
|
|
|
+ assert self.preschedule_state_averaging
|
|
|
next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
|
|
|
if next_epoch % self.average_state_every != 0:
|
|
|
return # averaging is not performed at this epoch
|
|
@@ -582,6 +598,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def load_state_from_peers(self, **kwargs):
|
|
|
"""Attempt to fetch the newest collaboration state from other peers"""
|
|
|
+ self._finish_background_averaging()
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
|
while True:
|
|
@@ -611,6 +628,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
def _finish_background_averaging(self):
|
|
|
for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
if scheduled_round is not None:
|
|
|
+ if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
|
+ scheduled_round.cancel()
|
|
|
if not scheduled_round.triggered:
|
|
|
scheduled_round.weight = 0
|
|
|
scheduled_round.allow_allreduce()
|