|
@@ -121,6 +121,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
:param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
|
|
|
Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
|
|
|
Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
|
|
|
+ :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
|
|
|
:param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
|
|
|
:param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
|
|
|
This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
|
|
@@ -173,6 +174,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
|
|
|
matchmaking_time: Optional[float] = 15.0,
|
|
|
averaging_timeout: Optional[float] = 60.0,
|
|
|
+ allreduce_timeout: Optional[float] = None,
|
|
|
load_state_timeout: float = 600.0,
|
|
|
reuse_grad_buffers: bool = False,
|
|
|
offload_optimizer: Optional[bool] = None,
|
|
@@ -197,6 +199,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
client_mode = client_mode if client_mode is None else dht.client_mode
|
|
|
delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
|
|
|
offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
|
|
|
+ allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
|
|
|
assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
|
|
|
assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
|
|
|
assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
|
|
@@ -225,8 +228,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
|
|
|
self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
|
|
|
|
|
|
- self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
|
|
|
- self.shutdown_timeout = shutdown_timeout
|
|
|
+ self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
|
|
|
+ self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
|
|
|
|
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
|
self.scheduled_grads: Optional[StepControl] = None
|
|
@@ -271,7 +274,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
dht=self.dht,
|
|
|
prefix=f"{self.run_id}_state_averager",
|
|
|
min_matchmaking_time=self.matchmaking_time,
|
|
|
- allreduce_timeout=self.averaging_timeout,
|
|
|
+ allreduce_timeout=self.allreduce_timeout,
|
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
|
offload_optimizer=self.offload_optimizer,
|
|
|
custom_gradients=self.offload_optimizer,
|
|
@@ -289,7 +292,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
prefix=f"{self.run_id}_grad_averager",
|
|
|
parameters=self.state_averager.main_parameters,
|
|
|
min_matchmaking_time=self.matchmaking_time,
|
|
|
- allreduce_timeout=self.averaging_timeout,
|
|
|
+ allreduce_timeout=self.allreduce_timeout,
|
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
|
client_mode=self.client_mode,
|
|
|
auxiliary=self.auxiliary,
|
|
@@ -508,8 +511,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
logger.exception(e)
|
|
|
|
|
|
if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
|
- logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
|
|
|
- self.scheduled_grads.cancel()
|
|
|
+ logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
|
|
|
+ self._tag_along_with_zero_weight(self.scheduled_grads)
|
|
|
self.scheduled_grads = None
|
|
|
return began_averaging_gradients
|
|
|
|
|
@@ -542,6 +545,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def _maybe_schedule_state_averaging(self) -> None:
|
|
|
"""If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
|
|
|
+ return
|
|
|
next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
|
|
|
if next_epoch % self.average_state_every != 0:
|
|
|
return # averaging is not performed at this epoch
|
|
@@ -643,7 +647,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
|
|
|
"""
|
|
|
- self._finish_background_averaging()
|
|
|
+ # note: we tag along for the next all-reduce because the run may have already started and cancelling it
|
|
|
+ # will cause peers to restart matchmaking and may stall the entire collaboration for a few seconds.
|
|
|
+ if self.scheduled_grads is not None and not self.scheduled_grads.done():
|
|
|
+ self._tag_along_with_zero_weight(self.scheduled_grads)
|
|
|
+ self.scheduled_grads = None
|
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
|
|
|
with self.tracker.pause_updates():
|
|
@@ -671,25 +679,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if not self.client_mode:
|
|
|
self.grad_averager.state_sharing_priority = self.local_epoch
|
|
|
|
|
|
- def _finish_background_averaging(self):
|
|
|
- for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
- if scheduled_round is not None:
|
|
|
- if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
|
- scheduled_round.cancel()
|
|
|
- if not scheduled_round.triggered:
|
|
|
- scheduled_round.weight = 0
|
|
|
- scheduled_round.allow_allreduce()
|
|
|
- for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
- if scheduled_round is not None and not scheduled_round.done():
|
|
|
- try:
|
|
|
- time_to_deadline = scheduled_round.deadline - get_dht_time()
|
|
|
- scheduled_round.result(timeout=max(0.0, time_to_deadline))
|
|
|
- except BaseException as e:
|
|
|
- logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
|
|
|
- if not scheduled_round.done():
|
|
|
- scheduled_round.cancel()
|
|
|
- self.scheduled_grads = self.scheduled_state = None
|
|
|
-
|
|
|
def state_dict(self) -> dict:
|
|
|
state_dict = self.state_averager.optimizer.state_dict()
|
|
|
state_dict["state"]["local_epoch"] = self.local_epoch
|
|
@@ -729,11 +718,30 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
def __repr__(self):
|
|
|
return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
|
|
|
|
|
|
+ def _tag_along_with_zero_weight(self, control: StepControl):
|
|
|
+ """Wait for a running averaging round to finish with zero weight."""
|
|
|
+ if not control.triggered:
|
|
|
+ control.weight = 0
|
|
|
+ control.allow_allreduce()
|
|
|
+ if not control.done():
|
|
|
+ try:
|
|
|
+ control.result(self.averaging_timeout)
|
|
|
+ except BaseException as e:
|
|
|
+ logger.exception(e)
|
|
|
+ if not control.done():
|
|
|
+ control.cancel()
|
|
|
+
|
|
|
def shutdown(self):
|
|
|
logger.log(self.status_loglevel, "Sending goodbye to peers...")
|
|
|
self.tracker.shutdown(self.shutdown_timeout)
|
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
- self._finish_background_averaging()
|
|
|
+ for scheduled_round in self.scheduled_grads, self.scheduled_state:
|
|
|
+ if scheduled_round is not None:
|
|
|
+ if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
|
|
|
+ scheduled_round.cancel()
|
|
|
+ else:
|
|
|
+ self._tag_along_with_zero_weight(scheduled_round)
|
|
|
+
|
|
|
logger.log(self.status_loglevel, "Shutting down averagers...")
|
|
|
self.state_averager.shutdown()
|
|
|
if self.use_gradient_averaging:
|