|
@@ -123,6 +123,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
|
self.scheduled_round: Optional[StepControl] = None
|
|
|
+ self.previous_round: Optional[StepControl] = None
|
|
|
|
|
|
self.state_averager = self._make_state_averager(
|
|
|
optimizer=optimizer, params=params, scheduler=scheduler, **averager_opts or {}
|
|
@@ -263,6 +264,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
|
|
|
if self.scheduled_round is None or self.scheduled_round.triggered or self.scheduled_round.done():
|
|
|
+ if self.delay_grad_averaging:
|
|
|
+ # wait for previous averaging to finish before starting a new one
|
|
|
+ self.state_averager.step(wait_for_delayed_update=True)
|
|
|
+
|
|
|
eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
|
|
|
eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
|
|
|
logger.log(self.status_loglevel, f"Pre-scheduling next averaging round in {eta_seconds:.2f}s.")
|
|
@@ -274,7 +279,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if not self.tracker.ready_to_update_epoch:
|
|
|
return loss
|
|
|
|
|
|
+ assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
|
|
|
+
|
|
|
with self.tracker.pause_updates():
|
|
|
+ # note: we do not need to replace grads because we explicitly load grads into the optimizer
|
|
|
+
|
|
|
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
|
|
|
|
|
|
if grad_scaler is not None:
|
|
@@ -287,7 +296,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
swarm_not_empty = self.tracker.global_progress.num_peers > 1
|
|
|
began_averaging_gradients = False
|
|
|
-
|
|
|
if swarm_not_empty:
|
|
|
try:
|
|
|
self.scheduled_round = self.grad_averager.step(
|
|
@@ -306,27 +314,24 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if not self.delay_grad_averaging:
|
|
|
self._average_gradients_and_load_into_optimizer(self.scheduled_round)
|
|
|
|
|
|
- assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
|
|
|
- with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
|
|
|
- # note: we do not need to replace because the offloaded optimizer is already using averaged grads
|
|
|
- next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
|
|
|
-
|
|
|
- self.state_averager.step(
|
|
|
- increment_epoch=True,
|
|
|
- optimizer_step=not self.auxiliary,
|
|
|
- delay_optimizer_step=self.delay_optimizer_step,
|
|
|
- averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
|
|
|
- delay_averaging=not self.auxiliary,
|
|
|
- grad_scaler=grad_scaler,
|
|
|
- wait_for_trigger=partial(self._average_gradients_and_load_into_optimizer, self.scheduled_round)
|
|
|
- if self.delay_grad_averaging
|
|
|
- else None,
|
|
|
- averaging_opts=dict(
|
|
|
- scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
|
|
|
- )
|
|
|
- if swarm_not_empty
|
|
|
- else None,
|
|
|
+ next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
|
|
|
+
|
|
|
+ self.state_averager.step(
|
|
|
+ increment_epoch=True,
|
|
|
+ optimizer_step=not self.auxiliary,
|
|
|
+ delay_optimizer_step=self.delay_optimizer_step,
|
|
|
+ averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
|
|
|
+ delay_averaging=not self.auxiliary,
|
|
|
+ grad_scaler=grad_scaler,
|
|
|
+ wait_for_trigger=partial(self._average_gradients_and_load_into_optimizer, self.scheduled_round)
|
|
|
+ if self.delay_grad_averaging
|
|
|
+ else None,
|
|
|
+ averaging_opts=dict(
|
|
|
+ scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
|
|
|
)
|
|
|
+ if swarm_not_empty
|
|
|
+ else None,
|
|
|
+ )
|
|
|
|
|
|
if not self.auxiliary:
|
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
@@ -355,6 +360,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
logger.log(self.status_loglevel, f"Proceeding with local gradients")
|
|
|
self.grad_averager.load_accumulators_into_averager_()
|
|
|
|
|
|
+ self.grad_averager.notify_used_averaged_gradients()
|
|
|
+
|
|
|
def zero_grad(self, set_to_none: bool = False):
|
|
|
"""Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
|
|
|
if self.grad_averager.reuse_grad_buffers:
|