|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
|
|
|
import logging
|
|
|
import os
|
|
|
+from functools import partial
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
|
|
import torch
|
|
@@ -98,7 +99,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
load_state_timeout: float = 600.0,
|
|
|
average_state_every: int = 1,
|
|
|
reuse_grad_buffers: bool = False,
|
|
|
- delay_optimizer_step: bool = False,
|
|
|
+ delay_grad_averaging: bool = False,
|
|
|
+ delay_optimizer_step: Optional[bool] = None,
|
|
|
client_mode: bool = None,
|
|
|
auxiliary: bool = False,
|
|
|
averager_opts: Optional[dict] = None,
|
|
@@ -107,13 +109,15 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
verbose: bool = False,
|
|
|
):
|
|
|
client_mode = client_mode if client_mode is None else dht.client_mode
|
|
|
+ delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
|
|
|
+ assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
|
|
|
assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
|
|
|
assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
|
|
|
|
|
|
self.dht, self.prefix, self.client_mode, self.auxiliary = dht, prefix, client_mode, auxiliary
|
|
|
self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
|
|
|
- self.matchmaking_time, self.delay_optimizer_step = matchmaking_time, delay_optimizer_step
|
|
|
- self.average_state_every = average_state_every
|
|
|
+ self.matchmaking_time, self.average_state_every = matchmaking_time, average_state_every
|
|
|
+ self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
|
|
|
self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
|
|
|
self.shutdown_timeout = shutdown_timeout
|
|
|
|
|
@@ -282,21 +286,25 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.scheduled_round = None
|
|
|
|
|
|
swarm_not_empty = self.tracker.global_progress.num_peers > 1
|
|
|
+ began_averaging = False
|
|
|
+
|
|
|
if swarm_not_empty:
|
|
|
try:
|
|
|
- group_info = self.grad_averager.step(
|
|
|
- control=self.scheduled_round, reset_accumulators=True, timeout=self.averaging_timeout
|
|
|
+ self.scheduled_round = self.grad_averager.step(
|
|
|
+ control=self.scheduled_round, reset_accumulators=True, wait=False
|
|
|
)
|
|
|
- logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
|
|
|
+ assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
|
|
|
+ began_averaging = True
|
|
|
except BaseException as e:
|
|
|
- logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}, using local grads")
|
|
|
- self.grad_averager.load_accumulators_into_averager_()
|
|
|
+ logger.exception(e)
|
|
|
|
|
|
- else:
|
|
|
- if self.scheduled_round is not None and not self.scheduled_round.done():
|
|
|
- self.scheduled_round.cancel()
|
|
|
- logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
|
|
|
- self.grad_averager.load_accumulators_into_averager_()
|
|
|
+ if not began_averaging and self.scheduled_round is not None and not self.scheduled_round.done():
|
|
|
+ logger.log(self.status_loglevel, f"Cancelled pre-scheduled averaging round")
|
|
|
+ self.scheduled_round.cancel()
|
|
|
+ self.scheduled_round = None
|
|
|
+
|
|
|
+ if not self.delay_grad_averaging:
|
|
|
+ self._average_gradients_and_load_into_optimizer(self.scheduled_round)
|
|
|
|
|
|
assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
|
|
|
with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
|
|
@@ -310,6 +318,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
|
|
|
delay_averaging=not self.auxiliary,
|
|
|
grad_scaler=grad_scaler,
|
|
|
+ wait_for_trigger=partial(
|
|
|
+ self._average_gradients_and_load_into_optimizer, self.scheduled_round) if self.delay_grad_averaging else None,
|
|
|
averaging_opts=dict(
|
|
|
scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
|
|
|
)
|
|
@@ -325,6 +335,25 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
logger.log(self.status_loglevel, f"Optimizer step done! Transitioning to epoch {self.local_epoch}.")
|
|
|
return loss
|
|
|
|
|
|
+ def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
|
|
|
+ """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
|
|
|
+ assert maybe_step_control is None or maybe_step_control.triggered
|
|
|
+ averaged_gradients = False
|
|
|
+
|
|
|
+ try:
|
|
|
+ if maybe_step_control is not None:
|
|
|
+ group_info = maybe_step_control.result(self.averaging_timeout)
|
|
|
+ logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
|
|
|
+ averaged_gradients = True
|
|
|
+ else:
|
|
|
+ logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
|
|
|
+ except BaseException as e:
|
|
|
+ logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
|
|
|
+
|
|
|
+ if not averaged_gradients:
|
|
|
+ logger.log(self.status_loglevel, f"Proceeding with local gradients")
|
|
|
+ self.grad_averager.load_accumulators_into_averager_()
|
|
|
+
|
|
|
def zero_grad(self, set_to_none: bool = False):
|
|
|
"""Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
|
|
|
if self.grad_averager.reuse_grad_buffers:
|