|
@@ -33,14 +33,14 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
Hivemind Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
|
|
|
By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size;
|
|
|
There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
|
|
|
- or even fully asynchronous (local_updates=True). However, these options require careful tuning.
|
|
|
+ or even fully asynchronous (local_updates=True). However, these options require careful tuning.
|
|
|
|
|
|
- The Optimizer is meant as a drop-in replacement for your regular PyTorch Optimizer:
|
|
|
+ :example: The Optimizer is meant as a drop-in replacement for your regular PyTorch Optimizer:
|
|
|
|
|
|
>>> model = transformers.AutoModel("albert-xxlarge-v2")
|
|
|
>>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
|
|
|
- >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam, params=model.parameters(),
|
|
|
- >>> target_batch_size=4096, batch_size_per_step=4) # recommended way to create Optimizer
|
|
|
+ >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=lambda params: torch.optim.Adam(params, ...),
|
|
|
+ params=model.parameters(), target_batch_size=4096, batch_size_per_step=4)
|
|
|
>>> # alternative: opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam(model.parameters())
|
|
|
>>> while True:
|
|
|
>>> loss = compute_loss_on_batch(model, batch_size=4)
|
|
@@ -49,33 +49,58 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
>>> opt.step() # <-- train collaboratively with any peers that use the same prefix (run_42)
|
|
|
|
|
|
However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
|
|
|
- - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
|
|
|
- - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
|
|
|
- - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
|
|
|
|
|
|
- :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one limitation:
|
|
|
+ - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
|
|
|
+ - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
|
|
|
+ - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
|
|
|
+
|
|
|
+ :example: the optimizer has many keyword arguments that may be difficult to understand in one go. Here's quickstart
|
|
|
+ that will help you setup your first synchronous optimizer.
|
|
|
+
|
|
|
+ >>> hivemind.Optimizer(
|
|
|
+ >>> dht=hivemind.DHT(initial_peers=ADDRESS_HERE, client_mode=TRUE_IF_BEHIND_FIREWALL_OR_UNRELIABLE, start=True),
|
|
|
+ >>> run_id="a_unique_name_that_every_participant_will_see_when_training",
|
|
|
+ >>> batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER,
|
|
|
+ >>> target_batch_size=LARGE_GLOBAL_BATCH, # global batch will be this or *slightly* larger due to stragglers;
|
|
|
+ >>> # peers should finish averaging in roughly half the time they need to accumulate this batch between them
|
|
|
+ >>> optimizer=lambda params: AnyPyTorchOptimizer(params, **config_that_makes_sense_for_target_batch_size),
|
|
|
+ >>> # ^-- scale learning rate for your target_batch_size; good reference: https://arxiv.org/abs/1904.00962
|
|
|
+ >>> offload_optimizer=True, # this saves GPU memory; large-batch training does not need optimizer that often
|
|
|
+ >>> scheduler=lambda opt: AnyPytTorchScheduler(opt, **config_that_makes_sense_for_target_batch_size),
|
|
|
+ >>> # scheduler.step will be called once every time peers collectively accumulate target_batch_size
|
|
|
+ >>> matchmaking_time=15.0, averaging_timeout=60.0, # <-- if the network is fast reduce to 3-5s and 10-15s
|
|
|
+ >>> # increase matchmaking_time if at least 25% of the time you see "averaged gradients with <...> peers",
|
|
|
+ >>> # ... but N is less than 0.9x the actual number of peers. Increase averaging_timeout if half of the epochs
|
|
|
+ >>> # ... print "Proceeding with local gradients" instead of "Averaged gradients with N peers"
|
|
|
+ >>> grad_compression=hivemind.Float16Compression(), state_averaging_compression=hivemind.Float16Compression(),
|
|
|
+ >>> # it is generally fine to use pure 16-bit or even lower precision during communication with no precaution;
|
|
|
+ >>> # See hivemind/examples/albert for an example of mixed 8-bit compression.
|
|
|
+ >>> delay_grad_averaging=SHOULD_I_USE_DPU, delay_optimizer_step=SHOULD_I_USE_DPU, # DPU stands for Delayed Para-
|
|
|
+ >>> # -meter Updates, running allreduce and optimizer step in background. See https://arxiv.org/abs/2101.06840
|
|
|
+ >>> verbose=True # periodically report the training progress to the console
|
|
|
+ >>> ) # and you're done!
|
|
|
+
|
|
|
+ :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one caveat:
|
|
|
learning rate schedulers, curriculum and other **time-dependent features should depend on Optimizer.local_epoch**
|
|
|
(and not the number ot calls to opt.step). This is because peers are allowed to join midway through training,
|
|
|
when others have already made some progress and changed their learning rates accordingly.
|
|
|
|
|
|
:param dht: a running hivemind.DHT instance connected to other peers
|
|
|
- :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys
|
|
|
-
|
|
|
- :note: peers with the same run_id should *generally* train the same model and use the same optimizer configuration.
|
|
|
- Some options can be safely changed by individual peers: `batch_size_per_step`, `client_mode`, `auxiliary`,
|
|
|
- `reuse_grad_buffers`, `offload_optimizer`, and `verbose`. In some cases, other options may also be tuned
|
|
|
+ :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
|
|
|
+ **Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
|
|
|
+ Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
|
|
|
+ ``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
|
|
|
individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
|
|
|
|
|
|
:param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch
|
|
|
:param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
|
|
|
|
|
|
:param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer
|
|
|
+ **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging are not
|
|
|
+ supported if hivemind.optimizer is created with a pre-initialized optimizer and require optimizer factory
|
|
|
:param params: parameters or param groups for the optimizer; required if optimizer is a callable(params)
|
|
|
- :note: creating hivemind.Optimizer with params=model.parameters() and optimizer=lambda params: make_optim(params)
|
|
|
- is required for advanced options: offload_optimizer, delay_optimizer_step and delay_grad_averaging.
|
|
|
-
|
|
|
- :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler
|
|
|
- :note: the learning rate scheduler will adjust learning rate based on collaboration-wide epoch, not the number of
|
|
|
+ :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
|
|
|
+ The learning rate scheduler will adjust learning rate based on global epoch, not the number of
|
|
|
local calls to optimizer.step; this is required to keep different peers synchronized.
|
|
|
|
|
|
:param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds
|
|
@@ -87,24 +112,23 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
:param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
|
|
|
:param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
|
|
|
:param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
|
|
|
- :note: offload_optimizer, delay_optimizer_step and delay_grad_averaging require that the optimizer is
|
|
|
- created as follows: `hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())`
|
|
|
-
|
|
|
:param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
|
|
|
if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
|
|
|
- :param average_state_every: average state (parameters, chosen opt statistics) with peers every this many **epochs**
|
|
|
- This reduces the communication overhead increasing, but can cause parameters to diverge if too large
|
|
|
- :note: The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
|
|
|
- hardly ever skip averaging rounds, they can average state less frequently. Network failures, lossy gradient
|
|
|
- compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
|
|
|
+ The above 3 options (offload_optimizer, delay_optimizer_step and delay_grad_averaging) require that the optimizer
|
|
|
+ is created with: ``hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())``
|
|
|
+
|
|
|
+ :param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
|
|
|
+ This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
|
|
|
+ The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
|
|
|
+ hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
|
|
|
+ gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
|
|
|
|
|
|
:param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
|
|
|
- if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients
|
|
|
- :note: even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
|
|
|
+ if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
|
|
|
+ Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
|
|
|
|
|
|
:param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
|
|
|
:param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
|
|
|
- :note: client_mode=True and auxiliary=True are mutually exclusive; auxiliary also requires batch_size_per_step=None
|
|
|
|
|
|
:param grad_compression: compression strategy used for averaging gradients, default = no compression
|
|
|
:param state_averaging_compression: compression for averaging params and state tensors, default = no compression
|
|
@@ -117,12 +141,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
:param performance_ema_alpha: moving average alpha in ProgressTracer, TrainingStateAverager and Optimizer
|
|
|
:param verbose: if True, report internal events such as accumilating gradients and running background tasks
|
|
|
|
|
|
- Internally, hivemind.Optimizer consists of 4 components:
|
|
|
- - DHT, a decentralized key-value storage used for coordination across the swarm
|
|
|
- - GradientAverager that is responsible for aggregating gradients with peers for global steps (can be disabled)
|
|
|
- - TrainingStateAverager holds parameters and optimizer/scheduler statistics, keeping them weakly synchronized
|
|
|
- by averaging with peers. It can also download these variable from other peers if your peer is out of sync.
|
|
|
- - ProgressTracker that uses DHT to track the global training progress: the number of steps or samples accumulated
|
|
|
+ :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
|
|
|
+ is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.
|
|
|
|
|
|
"""
|
|
|
|
|
@@ -312,7 +332,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
grad_scaler: Optional[GradScaler] = None,
|
|
|
):
|
|
|
"""
|
|
|
- Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
|
|
|
+ Update training progress after accumulating another local batch size. Depending on the configuration, this will
|
|
|
+ report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
|
|
|
|
|
|
:param closure: A closure that reevaluates the model and returns the loss
|
|
|
:param batch_size: optional override for batch_size_per_step from init
|
|
@@ -561,7 +582,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.grad_averager.notify_used_averaged_gradients()
|
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False):
|
|
|
- """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
|
|
|
+ """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
|
|
|
if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
|
|
|
raise ValueError(
|
|
|
f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
|
|
@@ -597,7 +618,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
return self.local_epoch < self.tracker.global_epoch - 1 # catch up if a peer just switched to next epoch
|
|
|
|
|
|
def load_state_from_peers(self, **kwargs):
|
|
|
- """Attempt to fetch the newest collaboration state from other peers"""
|
|
|
+ """
|
|
|
+ Attempt to load the newest collaboration state from other peers within the same run_id.
|
|
|
+
|
|
|
+ If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
|
|
|
+ """
|
|
|
self._finish_background_averaging()
|
|
|
self.state_averager.step(wait_for_delayed_updates=True)
|
|
|
|