|
@@ -95,22 +95,28 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
|
|
|
matchmaking_time: Optional[float] = 15.0,
|
|
|
averaging_timeout: Optional[float] = 300.0,
|
|
|
- average_state_every: int = 1,
|
|
|
load_state_timeout: float = 600.0,
|
|
|
+ average_state_every: int = 1,
|
|
|
reuse_grad_buffers: bool = False,
|
|
|
delay_optimizer_step: bool = False,
|
|
|
client_mode: bool = None,
|
|
|
+ auxiliary: bool = False,
|
|
|
averager_opts: Optional[dict] = None,
|
|
|
tracker_opts: Optional[dict] = None,
|
|
|
+ shutdown_timeout: float = 5,
|
|
|
verbose: bool = False,
|
|
|
):
|
|
|
- self.dht, self.prefix = dht, prefix
|
|
|
- self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
|
|
|
+ client_mode = client_mode if client_mode is None else dht.client_mode
|
|
|
+ assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
|
|
|
+ assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
|
|
|
+
|
|
|
+ self.dht, self.prefix, self.client_mode, self.auxiliary = dht, prefix, client_mode, auxiliary
|
|
|
self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
|
|
|
self.matchmaking_time, self.delay_optimizer_step = matchmaking_time, delay_optimizer_step
|
|
|
self.average_state_every = average_state_every
|
|
|
+ self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
|
|
|
+ self.shutdown_timeout = shutdown_timeout
|
|
|
|
|
|
- self.client_mode = client_mode if client_mode is not None else self.dht.client_mode
|
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
|
self.scheduled_round: Optional[StepControl] = None
|
|
|
|
|
@@ -132,8 +138,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
dht=self.dht,
|
|
|
prefix=f"{self.prefix}_state_averager",
|
|
|
allreduce_timeout=self.averaging_timeout,
|
|
|
+ shutdown_timeout=self.shutdown_timeout,
|
|
|
status_loglevel=self.status_loglevel,
|
|
|
client_mode=self.client_mode,
|
|
|
+ auxiliary=self.auxiliary,
|
|
|
offload_optimizer=True,
|
|
|
custom_gradients=True,
|
|
|
start=True,
|
|
@@ -147,7 +155,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
prefix=f"{self.prefix}_grad_averager",
|
|
|
parameters=self.state_averager.main_parameters,
|
|
|
allreduce_timeout=self.averaging_timeout,
|
|
|
+ shutdown_timeout=self.shutdown_timeout,
|
|
|
client_mode=self.client_mode,
|
|
|
+ auxiliary=self.auxiliary,
|
|
|
start=True,
|
|
|
**kwargs,
|
|
|
)
|
|
@@ -207,9 +217,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
|
"""
|
|
|
if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
|
|
|
- raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler).")
|
|
|
- if self.batch_size_per_step is None and batch_size is None:
|
|
|
+ raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
|
|
|
+ if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
|
|
|
raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
|
|
|
+ if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
|
|
|
+ raise ValueError("Auxiliary peers should not have batch size run closures, or use grad_scaler")
|
|
|
batch_size = batch_size if batch_size is not None else self.batch_size_per_step
|
|
|
|
|
|
loss = None
|
|
@@ -217,7 +229,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
with torch.enable_grad():
|
|
|
loss = closure()
|
|
|
|
|
|
- if self.should_load_state_from_peers():
|
|
|
+ if not self.auxiliary and self.should_load_state_from_peers():
|
|
|
logger.log(self.status_loglevel, "Peer is out of sync.")
|
|
|
self.load_state_from_peers()
|
|
|
return loss
|
|
@@ -228,9 +240,10 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.grad_averager.reset_accumulated_grads_()
|
|
|
return loss
|
|
|
|
|
|
- self.grad_averager.accumulate_grads_(batch_size)
|
|
|
- self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
|
|
|
- self.state_averager.step(apply_delayed_updates=True)
|
|
|
+ if not self.auxiliary:
|
|
|
+ self.grad_averager.accumulate_grads_(batch_size)
|
|
|
+ self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
|
|
|
+ self.state_averager.step(apply_delayed_updates=True)
|
|
|
|
|
|
if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
|
|
|
if self.scheduled_round is None or self.scheduled_round.triggered or self.scheduled_round.done():
|
|
@@ -280,11 +293,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
self.state_averager.step(
|
|
|
increment_epoch=True,
|
|
|
- optimizer_step=True,
|
|
|
+ optimizer_step=not self.auxiliary,
|
|
|
delay_optimizer_step=self.delay_optimizer_step,
|
|
|
- grad_scaler=grad_scaler,
|
|
|
averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
|
|
|
- delay_averaging=True,
|
|
|
+ delay_averaging=not self.auxiliary,
|
|
|
+ grad_scaler=grad_scaler,
|
|
|
averaging_opts=dict(
|
|
|
scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
|
|
|
)
|
|
@@ -292,25 +305,18 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
else None,
|
|
|
)
|
|
|
|
|
|
- self.grad_averager.reset_accumulated_grads_()
|
|
|
- self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
+ if not self.auxiliary:
|
|
|
+ self.grad_averager.reset_accumulated_grads_()
|
|
|
+ self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
|
|
|
|
|
|
- if self.should_load_state_from_peers(new_epoch=True):
|
|
|
- logger.log(self.status_loglevel, "Peer ended up out of sync after averaging.")
|
|
|
- self.load_state_from_peers()
|
|
|
- return loss
|
|
|
+ if self.should_load_state_from_peers(new_epoch=True):
|
|
|
+ logger.log(self.status_loglevel, "Peer ended up out of sync after averaging.")
|
|
|
+ self.load_state_from_peers()
|
|
|
+ return loss
|
|
|
|
|
|
logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
|
|
|
return loss
|
|
|
|
|
|
- def step_aux(self, **kwargs):
|
|
|
- """
|
|
|
- Find and assist other peers in averaging without sending local gradients.
|
|
|
-
|
|
|
- :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
|
- """
|
|
|
- raise NotImplementedError("Auxiliary step for hivemind.Optimizer is not implemented yet.")
|
|
|
-
|
|
|
def zero_grad(self, set_to_none: bool = False):
|
|
|
"""Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
|
|
|
if self.grad_averager.reuse_grad_buffers:
|
|
@@ -390,7 +396,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
def shutdown(self):
|
|
|
logger.debug("Sending goodbye to peers...")
|
|
|
- self.tracker.shutdown()
|
|
|
+ self.tracker.shutdown(self.shutdown_timeout)
|
|
|
logger.debug("Shutting down averager...")
|
|
|
self.state_averager.step(wait_for_delayed_update=True)
|
|
|
self.state_averager.shutdown()
|