Browse Source

Update RTFD

justheuristic 3 năm trước cách đây
mục cha
commit
6beee923e7

+ 29 - 3
docs/modules/optim.rst

@@ -1,14 +1,40 @@
 **hivemind.optim**
 ==================
 
-.. automodule:: hivemind.optim
-.. currentmodule:: hivemind.optim
-
 .. raw:: html
 
   This module contains decentralized optimizers that wrap regular pytorch optimizers to collaboratively train a shared model. Depending on the exact type, optimizer may average model parameters with peers, exchange gradients, or follow a more complicated distributed training strategy.
   <br><br>
 
+.. automodule:: hivemind.optim.experimental.optimizer
+.. currentmodule:: hivemind.optim.experimental.optimizer
+
+**hivemind.Optimizer**
+----------------------
+
+.. autoclass:: Optimizer
+   :members: step, zero_grad, load_state_from_peers, param_groups, shutdown
+   :member-order: bysource
+
+.. currentmodule:: hivemind.optim.grad_scaler
+.. autoclass:: GradScaler
+   :member-order: bysource
+
+
+**CollaborativeOptimizer**
+--------------------------
+
+.. raw:: html
+
+  CollaborativeOptimizer is a legacy version of hivemind.Optimizer. **For new projects, please use hivemind.Optimizer.**
+  Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and then some.
+  CollaborativeOptimizer will still be supported for awhile, but will eventually be deprecated.
+  <br><br>
+
+
+.. automodule:: hivemind.optim.collaborative
+.. currentmodule:: hivemind.optim
+
 .. autoclass:: CollaborativeOptimizer
    :members: step
    :member-order: bysource

+ 63 - 38
hivemind/optim/experimental/optimizer.py

@@ -33,14 +33,14 @@ class Optimizer(torch.optim.Optimizer):
     Hivemind Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
     By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size;
     There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
-     or even fully asynchronous (local_updates=True). However, these options require careful tuning.
+    or even fully asynchronous (local_updates=True). However, these options require careful tuning.
 
-    The Optimizer is meant as a drop-in replacement for your regular PyTorch Optimizer:
+    :example: The Optimizer is meant as a drop-in replacement for your regular PyTorch Optimizer:
 
     >>> model = transformers.AutoModel("albert-xxlarge-v2")
     >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
-    >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam, params=model.parameters(),
-    >>>                          target_batch_size=4096, batch_size_per_step=4)  # recommended way to create Optimizer
+    >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=lambda params: torch.optim.Adam(params, ...),
+                                 params=model.parameters(), target_batch_size=4096, batch_size_per_step=4)
     >>> # alternative: opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam(model.parameters())
     >>> while True:
     >>>     loss = compute_loss_on_batch(model, batch_size=4)
@@ -49,33 +49,58 @@ class Optimizer(torch.optim.Optimizer):
     >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
 
     However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
-    - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
-    - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
-    - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
 
-    :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one limitation:
+     - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
+     - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
+     - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
+
+    :example: the optimizer has many keyword arguments that may be difficult to understand in one go. Here's quickstart
+      that will help you setup your first synchronous optimizer.
+
+    >>> hivemind.Optimizer(
+    >>>    dht=hivemind.DHT(initial_peers=ADDRESS_HERE, client_mode=TRUE_IF_BEHIND_FIREWALL_OR_UNRELIABLE, start=True),
+    >>>    run_id="a_unique_name_that_every_participant_will_see_when_training",
+    >>>    batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER,
+    >>>    target_batch_size=LARGE_GLOBAL_BATCH,  # global batch will be this or *slightly* larger due to stragglers;
+    >>>      #  peers should finish averaging in roughly half the time they need to accumulate this batch between them
+    >>>    optimizer=lambda params: AnyPyTorchOptimizer(params, **config_that_makes_sense_for_target_batch_size),
+    >>>      # ^-- scale learning rate for your target_batch_size; good reference: https://arxiv.org/abs/1904.00962
+    >>>    offload_optimizer=True,  # this saves GPU memory; large-batch training does not need optimizer that often
+    >>>    scheduler=lambda opt: AnyPytTorchScheduler(opt, **config_that_makes_sense_for_target_batch_size),
+    >>>      # scheduler.step will be called once every time peers collectively accumulate target_batch_size
+    >>>    matchmaking_time=15.0, averaging_timeout=60.0,  # <-- if the network is fast reduce to 3-5s and 10-15s
+    >>>      # increase matchmaking_time if at least 25% of the time you see "averaged gradients with <...> peers",
+    >>>      # ... but N is less than 0.9x the actual number of peers. Increase averaging_timeout if half of the epochs
+    >>>      # ... print "Proceeding with local gradients" instead of "Averaged gradients with N peers"
+    >>>    grad_compression=hivemind.Float16Compression(),  state_averaging_compression=hivemind.Float16Compression(),
+    >>>      # it is generally fine to use pure 16-bit or even lower precision during communication with no precaution;
+    >>>      # See hivemind/examples/albert for an example of mixed 8-bit compression.
+    >>>    delay_grad_averaging=SHOULD_I_USE_DPU, delay_optimizer_step=SHOULD_I_USE_DPU, # DPU stands for Delayed Para-
+    >>>      # -meter Updates, running allreduce and optimizer step in background. See https://arxiv.org/abs/2101.06840
+    >>>    verbose=True  # periodically report the training progress to the console
+    >>> )  # and you're done!
+
+    :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one caveat:
       learning rate schedulers, curriculum and other **time-dependent features should depend on Optimizer.local_epoch**
       (and not the number ot calls to opt.step). This is because peers are allowed to join midway through training,
       when others have already made some progress and changed their learning rates accordingly.
 
     :param dht: a running hivemind.DHT instance connected to other peers
-    :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys
-
-    :note: peers with the same run_id should *generally* train the same model and use the same optimizer configuration.
-      Some options can be safely changed by individual peers: `batch_size_per_step`, `client_mode`, `auxiliary`,
-      `reuse_grad_buffers`, `offload_optimizer`, and `verbose`. In some cases, other options may also be tuned
+    :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
+      **Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
+      Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
+      ``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
       individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
 
     :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch
     :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
 
     :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer
+      **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging are not
+      supported if hivemind.optimizer is created with a pre-initialized optimizer and require optimizer factory
     :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params)
-    :note: creating hivemind.Optimizer with params=model.parameters() and optimizer=lambda params: make_optim(params)
-      is required for advanced options: offload_optimizer, delay_optimizer_step and delay_grad_averaging.
-
-    :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler
-    :note: the learning rate scheduler will adjust learning rate based on collaboration-wide epoch, not the number of
+    :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
+      The learning rate scheduler will adjust learning rate based on global epoch, not the number of
       local calls to optimizer.step; this is required to keep different peers synchronized.
 
     :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds
@@ -87,24 +112,23 @@ class Optimizer(torch.optim.Optimizer):
     :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
     :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
     :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
-    :note: offload_optimizer, delay_optimizer_step and delay_grad_averaging require that the optimizer is
-      created as follows: `hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())`
-
     :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
       if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
-    :param average_state_every: average state (parameters, chosen opt statistics) with peers every this many **epochs**
-      This reduces the communication overhead increasing, but can cause parameters to diverge if too large
-    :note: The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
-      hardly ever skip averaging rounds, they can average state less frequently. Network failures, lossy gradient
-      compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
+      The above 3 options (offload_optimizer, delay_optimizer_step and delay_grad_averaging) require that the optimizer
+      is created with: ``hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())``
+
+    :param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
+      This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
+      The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
+      hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
+      gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
 
     :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
-      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients
-    :note: even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
+      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
+      Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
 
     :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
-    :note: client_mode=True and auxiliary=True are mutually exclusive; auxiliary also requires batch_size_per_step=None
 
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
@@ -117,12 +141,8 @@ class Optimizer(torch.optim.Optimizer):
     :param performance_ema_alpha: moving average alpha  in ProgressTracer, TrainingStateAverager and Optimizer
     :param verbose: if True, report internal events such as accumilating gradients and running background tasks
 
-    Internally, hivemind.Optimizer consists of 4 components:
-    - DHT, a decentralized key-value storage used for coordination across the swarm
-    - GradientAverager that is responsible for aggregating gradients with peers for global steps (can be disabled)
-    - TrainingStateAverager holds parameters and optimizer/scheduler statistics, keeping them weakly synchronized
-     by averaging with peers. It can also download these variable from other peers if your peer is out of sync.
-    - ProgressTracker that uses DHT to track the global training progress: the number of steps or samples accumulated
+    :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
+      is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.
 
     """
 
@@ -312,7 +332,8 @@ class Optimizer(torch.optim.Optimizer):
         grad_scaler: Optional[GradScaler] = None,
     ):
         """
-        Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
+        Update training progress after accumulating another local batch size. Depending on the configuration, this will
+        report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
 
         :param closure: A closure that reevaluates the model and returns the loss
         :param batch_size: optional override for batch_size_per_step from init
@@ -561,7 +582,7 @@ class Optimizer(torch.optim.Optimizer):
         self.grad_averager.notify_used_averaged_gradients()
 
     def zero_grad(self, set_to_none: bool = False):
-        """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
+        """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
         if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
             raise ValueError(
                 f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
@@ -597,7 +618,11 @@ class Optimizer(torch.optim.Optimizer):
         return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
 
     def load_state_from_peers(self, **kwargs):
-        """Attempt to fetch the newest collaboration state from other peers"""
+        """
+        Attempt to load the newest collaboration state from other peers within the same run_id.
+
+        If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
+        """
         self._finish_background_averaging()
         self.state_averager.step(wait_for_delayed_updates=True)
 

+ 6 - 1
hivemind/optim/grad_scaler.py

@@ -16,7 +16,12 @@ logger = get_logger(__name__)
 
 class GradScaler(TorchGradScaler):
     """
-    A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
+    A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.
+
+    :note: if not using reuse_grad_buffers=True, one can and *should* train normally without this class, e.g. using
+      standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.
+
+    GradScaler removes several:
     - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
     - limit increasing gradient scale to only immediately after global optimizer steps
     - allow training with some or all master parameters in fp16