|
@@ -468,7 +468,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
|
|
|
)
|
|
|
|
|
|
- if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
|
|
|
+ if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.triggered:
|
|
|
self.scheduled_state.cancel()
|
|
|
self.scheduled_state = None
|
|
|
|
|
@@ -513,7 +513,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
|
logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
|
|
|
self.scheduled_grads.cancel()
|
|
|
- self.scheduled_grads = None
|
|
|
+ self.scheduled_grads = None
|
|
|
return began_averaging_gradients
|
|
|
|
|
|
def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
|
|
@@ -646,7 +646,16 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
|
|
|
"""
|
|
|
- self._finish_background_averaging()
|
|
|
+ # note: we tag along for the next all-reduce because the run may have already started and cancelling it
|
|
|
+ # will cause peers to restart matchmaking and may stall the entire collaboration for a few seconds.
|
|
|
+ if self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
|
+ self.scheduled_grads.weight = 0
|
|
|
+ self.scheduled_grads.allow_allreduce()
|
|
|
+ try:
|
|
|
+ self.scheduled_grads.result(self.averaging_timeout)
|
|
|
+ except BaseException as e:
|
|
|
+ logger.exception(e)
|
|
|
+
|
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
@@ -675,13 +684,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.grad_averager.state_sharing_priority = self.local_epoch
|
|
|
|
|
|
def _finish_background_averaging(self):
|
|
|
- for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
- if scheduled_round is not None:
|
|
|
- if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
|
- scheduled_round.cancel()
|
|
|
- if not scheduled_round.triggered:
|
|
|
- scheduled_round.weight = 0
|
|
|
- scheduled_round.allow_allreduce()
|
|
|
self.scheduled_grads = self.scheduled_state = None
|
|
|
|
|
|
def state_dict(self) -> dict:
|
|
@@ -727,7 +729,14 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
logger.log(self.status_loglevel, "Sending goodbye to peers...")
|
|
|
self.tracker.shutdown(self.shutdown_timeout)
|
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
- self._finish_background_averaging()
|
|
|
+ for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
+ if scheduled_round is not None:
|
|
|
+ if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
|
+ scheduled_round.cancel()
|
|
|
+ if not scheduled_round.triggered:
|
|
|
+ scheduled_round.weight = 0
|
|
|
+ scheduled_round.allow_allreduce()
|
|
|
+
|
|
|
logger.log(self.status_loglevel, "Shutting down averagers...")
|
|
|
self.state_averager.shutdown()
|
|
|
if self.use_gradient_averaging:
|