|
@@ -511,9 +511,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
logger.exception(e)
|
|
logger.exception(e)
|
|
|
|
|
|
if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
- logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
|
|
|
|
- self.scheduled_grads.cancel()
|
|
|
|
- self.scheduled_grads = None
|
|
|
|
|
|
+ logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
|
|
|
|
+ self._tag_along_with_zero_weight(self.scheduled_grads)
|
|
|
|
+ self.scheduled_grads = None
|
|
return began_averaging_gradients
|
|
return began_averaging_gradients
|
|
|
|
|
|
def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
|
|
def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
|
|
@@ -649,13 +649,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
# note: we tag along for the next all-reduce because the run may have already started and cancelling it
|
|
# note: we tag along for the next all-reduce because the run may have already started and cancelling it
|
|
# will cause peers to restart matchmaking and may stall the entire collaboration for a few seconds.
|
|
# will cause peers to restart matchmaking and may stall the entire collaboration for a few seconds.
|
|
if self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
if self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
- self.scheduled_grads.weight = 0
|
|
|
|
- self.scheduled_grads.allow_allreduce()
|
|
|
|
- try:
|
|
|
|
- self.scheduled_grads.result(self.averaging_timeout)
|
|
|
|
- except BaseException as e:
|
|
|
|
- logger.exception(e)
|
|
|
|
-
|
|
|
|
|
|
+ self._tag_along_with_zero_weight(self.scheduled_grads)
|
|
|
|
+ self.scheduled_grads = None
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
with self.tracker.pause_updates():
|
|
@@ -683,9 +678,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
if not self.client_mode:
|
|
if not self.client_mode:
|
|
self.grad_averager.state_sharing_priority = self.local_epoch
|
|
self.grad_averager.state_sharing_priority = self.local_epoch
|
|
|
|
|
|
- def _finish_background_averaging(self):
|
|
|
|
- self.scheduled_grads = self.scheduled_state = None
|
|
|
|
-
|
|
|
|
def state_dict(self) -> dict:
|
|
def state_dict(self) -> dict:
|
|
state_dict = self.state_averager.optimizer.state_dict()
|
|
state_dict = self.state_averager.optimizer.state_dict()
|
|
state_dict["state"]["local_epoch"] = self.local_epoch
|
|
state_dict["state"]["local_epoch"] = self.local_epoch
|
|
@@ -725,6 +717,19 @@ class Optimizer(torch.optim.Optimizer):
|
|
def __repr__(self):
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
|
|
return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
|
|
|
|
|
|
|
|
+ def _tag_along_with_zero_weight(self, control: StepControl):
|
|
|
|
+ """Wait for a running averaging round to finish with zero weight."""
|
|
|
|
+ if not control.triggered:
|
|
|
|
+ control.weight = 0
|
|
|
|
+ control.allow_allreduce()
|
|
|
|
+ if not control.done():
|
|
|
|
+ try:
|
|
|
|
+ control.result(self.averaging_timeout)
|
|
|
|
+ except BaseException as e:
|
|
|
|
+ logger.exception(e)
|
|
|
|
+ if not control.done():
|
|
|
|
+ control.cancel()
|
|
|
|
+
|
|
def shutdown(self):
|
|
def shutdown(self):
|
|
logger.log(self.status_loglevel, "Sending goodbye to peers...")
|
|
logger.log(self.status_loglevel, "Sending goodbye to peers...")
|
|
self.tracker.shutdown(self.shutdown_timeout)
|
|
self.tracker.shutdown(self.shutdown_timeout)
|
|
@@ -733,9 +738,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
if scheduled_round is not None:
|
|
if scheduled_round is not None:
|
|
if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
scheduled_round.cancel()
|
|
scheduled_round.cancel()
|
|
- if not scheduled_round.triggered:
|
|
|
|
- scheduled_round.weight = 0
|
|
|
|
- scheduled_round.allow_allreduce()
|
|
|
|
|
|
+ else:
|
|
|
|
+ self._tag_along_with_zero_weight(scheduled_round)
|
|
|
|
|
|
logger.log(self.status_loglevel, "Shutting down averagers...")
|
|
logger.log(self.status_loglevel, "Shutting down averagers...")
|
|
self.state_averager.shutdown()
|
|
self.state_averager.shutdown()
|