|
@@ -30,92 +30,107 @@ logger = get_logger(__name__)
|
|
|
|
|
|
class Optimizer(torch.optim.Optimizer):
|
|
|
"""
|
|
|
- Hivemind Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
|
|
|
- By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size;
|
|
|
+ hivemind.Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
|
|
|
+
|
|
|
+ By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
|
|
|
There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
|
|
|
- or even fully asynchronous (local_updates=True). However, these options require careful tuning.
|
|
|
+ or even fully asynchronous (use_local_updates=True).
|
|
|
|
|
|
- :example: The Optimizer can be used as a drop-in replacement for your regular PyTorch Optimizer:
|
|
|
+ :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
|
|
|
|
|
|
>>> model = transformers.AutoModel("albert-xxlarge-v2")
|
|
|
>>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
|
|
|
- >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=lambda params: torch.optim.Adam(params, ...),
|
|
|
- params=model.parameters(), target_batch_size=4096, batch_size_per_step=4)
|
|
|
- >>> # alternative: opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam(model.parameters())
|
|
|
+ >>> opt = hivemind.Optimizer(dht=dht, run_id="run_42", batch_size_per_step=4, target_batch_size=4096,
|
|
|
+ >>> params=model.parameters(), optimizer=lambda params: torch.optim.Adam(params))
|
|
|
>>> while True:
|
|
|
>>> loss = compute_loss_on_batch(model, batch_size=4)
|
|
|
>>> opt.zero_grad()
|
|
|
>>> loss.backward()
|
|
|
>>> opt.step() # <-- train collaboratively with any peers that use the same prefix (run_42)
|
|
|
|
|
|
- However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
|
|
|
-
|
|
|
- - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
|
|
|
- - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
|
|
|
- - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
|
|
|
-
|
|
|
- :example: the optimizer has many keyword arguments that may be difficult to understand in one go. Here's quickstart
|
|
|
- that will help you setup your first synchronous optimizer.
|
|
|
-
|
|
|
- >>> hivemind.Optimizer(
|
|
|
- >>> dht=hivemind.DHT(initial_peers=ADDRESS_HERE, client_mode=TRUE_IF_BEHIND_FIREWALL_OR_UNRELIABLE, start=True),
|
|
|
- >>> run_id="a_unique_name_that_every_participant_will_see_when_training",
|
|
|
- >>> batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER,
|
|
|
- >>> target_batch_size=LARGE_GLOBAL_BATCH, # global batch will be this or *slightly* larger due to stragglers;
|
|
|
- >>> # peers should finish averaging in roughly half the time they need to accumulate this batch between them
|
|
|
- >>> optimizer=lambda params: AnyPyTorchOptimizer(params, **config_that_makes_sense_for_target_batch_size),
|
|
|
- >>> # ^-- scale learning rate for your target_batch_size; good reference: https://arxiv.org/abs/1904.00962
|
|
|
- >>> offload_optimizer=True, # this saves GPU memory; large-batch training does not need optimizer that often
|
|
|
- >>> scheduler=lambda opt: AnyPytTorchScheduler(opt, **config_that_makes_sense_for_target_batch_size),
|
|
|
- >>> # scheduler.step will be called once every time peers collectively accumulate target_batch_size
|
|
|
- >>> matchmaking_time=15.0, averaging_timeout=60.0, # <-- if the network is fast reduce to 3-5s and 10-15s
|
|
|
- >>> # increase matchmaking_time if at least 25% of the time you see "averaged gradients with <...> peers",
|
|
|
- >>> # ... but N is less than 0.9x the actual number of peers. Increase averaging_timeout if half of the epochs
|
|
|
- >>> # ... print "Proceeding with local gradients" instead of "Averaged gradients with N peers"
|
|
|
+ By default, peers will perform the following steps:
|
|
|
+
|
|
|
+ * accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
|
|
|
+ * after peers collectively accumulate target_batch_size, average gradients with peers and perform optimizer step;
|
|
|
+ * if your peer lags behind the rest of the swarm, it will download parameters and optimizer state from others;
|
|
|
+
|
|
|
+ Unlike regular training, your device may join midway through training, when other peers already made some progress.
|
|
|
+ For this reason, any learning rate schedulers, curriculum and other **time-dependent features should be based on**
|
|
|
+ ``optimizer.local_epoch`` (and not the number ot calls to opt.step). Otherwise, peers that joined training late
|
|
|
+ may end up having different learning rates. To do so automatically, specify ``scheduler=...`` parameter below.
|
|
|
+
|
|
|
+ :What is an epoch?: Optimizer uses the term ``epoch`` to describe intervals between synchronizations. One epoch
|
|
|
+ coresponds to processing certain number of training samples (``target_batch_size``) in total across all peers.
|
|
|
+ Like in PyTorch LR Scheduler, **epoch does not necessarily correspond to a full pass over the training data.**
|
|
|
+ At the end of epoch, peers perform synchronous actions such as averaging gradients for a global optimizer update,
|
|
|
+ updating the learning rate scheduler or simply averaging parameters (if using local updates).
|
|
|
+ The purpose of this is to ensure that changing the number of peers does not reqire changing hyperparameters.
|
|
|
+ For instance, if the number of peers doubles, they will run all-reduce more frequently to adjust for faster training.
|
|
|
+
|
|
|
+ :Configuration guide: This guide will help you set up your first collaborative training run. It covers the most
|
|
|
+ important basic options, but ignores features that require significant changes to the training code.
|
|
|
+
|
|
|
+ >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=IF_BEHIND_FIREWALL_OR_VERY_UNRELIABLE, start=True)
|
|
|
+ >>> opt = hivemind.Optimizer(
|
|
|
+ >>> dht=dht, run_id="a_unique_name_that_every_participant_will_see_when_training",
|
|
|
+ >>> batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER, target_batch_size=LARGE_GLOBAL_BATCH,
|
|
|
+ >>> # ^--- Each global optimzier step will use gradients from 1x-1.1x of target_batch_size (due to latency);
|
|
|
+ >>> # It is recommended to train with very large batch sizes to reduce the % of time spent on communication.
|
|
|
+ >>>
|
|
|
+ >>> params=params, optimizer=lambda params: AnyPyTorchOptimizer(params, **hyperparams_for_target_batch_size),
|
|
|
+ >>> # tune learning rate for your target_batch_size. Here's a good reference: https://arxiv.org/abs/1904.00962
|
|
|
+ >>> scheduler=lambda opt: AnyPyTorchScheduler(opt, **hyperparams_for_target_batch_size),
|
|
|
+ >>> # scheduler.step will be called automatically each time when peers collectively accumulate target_batch_size
|
|
|
+ >>>
|
|
|
+ >>> offload_optimizer=True, # saves GPU memory, but increases RAM usage; Generally a good practice to use this.
|
|
|
+ >>> delay_grad_averaging=OPTIONAL, delay_optimizer_step=OPTIONAL, # train faster, but with 1 round of staleness;
|
|
|
+ >>> # setting both to True is equivalent to Delayed Parameter Updates (see https://arxiv.org/abs/2101.06840)
|
|
|
+ >>>
|
|
|
>>> grad_compression=hivemind.Float16Compression(), state_averaging_compression=hivemind.Float16Compression(),
|
|
|
- >>> # it is generally fine to use pure 16-bit or even lower precision during communication with no precaution;
|
|
|
- >>> # See hivemind/examples/albert for an example of mixed 8-bit compression.
|
|
|
- >>> delay_grad_averaging=SHOULD_I_USE_DPU, delay_optimizer_step=SHOULD_I_USE_DPU, # DPU stands for Delayed Para-
|
|
|
- >>> # -meter Updates, running allreduce and optimizer step in background. See https://arxiv.org/abs/2101.06840
|
|
|
- >>> verbose=True # periodically report the training progress to the console
|
|
|
+ >>> # ^-- it is usually fine to use pure 16-bit or even lower precision during communication with no precaution;
|
|
|
+ >>> # See hivemind/examples/albert for an working example of mixed 8/16-bit compression.
|
|
|
+ >>>
|
|
|
+ >>> matchmaking_time=15.0, # 3-5s for small local runs, 10-15s for training over the internet or with many peers
|
|
|
+ >>> averaging_timeout=60.0, # around of 2x the actual time it takes to run all-reduce
|
|
|
+ >>> verbose=True # periodically report the training progress to the console (e.g. "Averaged with N peers")
|
|
|
>>> ) # and you're done!
|
|
|
|
|
|
- :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one caveat:
|
|
|
- learning rate schedulers, curriculum and other **time-dependent features should depend on Optimizer.local_epoch**
|
|
|
- (and not the number ot calls to opt.step). This is because peers are allowed to join midway through training,
|
|
|
- when others have already made some progress and changed their learning rates accordingly.
|
|
|
|
|
|
- :param dht: a running hivemind.DHT instance connected to other peers
|
|
|
+ :param dht: a running hivemind.DHT instance connected to other peers.
|
|
|
:param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
|
|
|
**Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
|
|
|
Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
|
|
|
``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
|
|
|
individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
|
|
|
|
|
|
- :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch
|
|
|
- :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
|
|
|
+ :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch.
|
|
|
+ The actual batch may be *slightly* larger due asynchrony (e.g. peers submit more gradients in the last second).
|
|
|
+ :param batch_size_per_step: you should accumulate gradients over this many samples between calls to optimizer.step.
|
|
|
|
|
|
- :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer
|
|
|
- **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging are not
|
|
|
- supported if hivemind.optimizer is created with a pre-initialized optimizer and require optimizer factory
|
|
|
- :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params)
|
|
|
+ :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params).
|
|
|
+ :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer.
|
|
|
+ **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging require
|
|
|
+ and require the callable and will not work if hivemind.optimizer is created with a pre-existing PyTorch Optimizer.
|
|
|
:param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
|
|
|
The learning rate scheduler will adjust learning rate based on global epoch, not the number of
|
|
|
local calls to optimizer.step; this is required to keep different peers synchronized.
|
|
|
|
|
|
- :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds
|
|
|
- :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
|
|
|
- :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
|
|
|
+ :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds.
|
|
|
+ Increase if you see "averaged gradients with N peers" where N is below 0.9x the real siee on >=25% of epochs.
|
|
|
+ When training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes.
|
|
|
+ :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
|
|
|
+ Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
|
|
|
+ Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
|
|
|
+ :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
|
|
|
:param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
|
|
|
This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
|
|
|
|
|
|
:param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
|
|
|
:param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
|
|
|
:param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
|
|
|
+
|
|
|
:param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
|
|
|
if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
|
|
|
- The above 3 options (offload_optimizer, delay_optimizer_step and delay_grad_averaging) require that the optimizer
|
|
|
- is created with: ``hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())``
|
|
|
|
|
|
:param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
|
|
|
This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
|
|
@@ -315,6 +330,11 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
@property
|
|
|
def local_epoch(self) -> int:
|
|
|
+ """
|
|
|
+ This worker's current epoch, kept synchronized with peers. If peer's local_epoch lags behind others, it will
|
|
|
+ automatically re-synchronize by downloading state from another peer.
|
|
|
+ An epoch corresponds to accumulating target_batch_size across all active devices.
|
|
|
+ """
|
|
|
return self.state_averager.local_epoch
|
|
|
|
|
|
@property
|
|
@@ -335,9 +355,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
Update training progress after accumulating another local batch size. Depending on the configuration, this will
|
|
|
report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
|
|
|
|
|
|
- :param closure: A closure that reevaluates the model and returns the loss
|
|
|
- :param batch_size: optional override for batch_size_per_step from init
|
|
|
- :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
|
|
|
+ :param closure: A closure that reevaluates the model and returns the loss.
|
|
|
+ :param batch_size: optional override for batch_size_per_step from init.
|
|
|
+ :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler.
|
|
|
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
|
"""
|
|
|
if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
|