Browse Source

[WIP] main hivemind.Optimizer

justheuristic 3 years ago
parent
commit
f34d02df43
2 changed files with 383 additions and 8 deletions
  1. 10 8
      hivemind/optim/experimental/grad_averager.py
  2. 373 0
      hivemind/optim/experimental/optimizer.py

+ 10 - 8
hivemind/optim/experimental/grad_averager.py

@@ -207,15 +207,17 @@ class GradientAverager(DecentralizedAverager):
 
     @contextlib.contextmanager
     @torch.no_grad()
-    def use_averaged_gradients(self):
+    def use_averaged_gradients(self, replace_model_gradients: bool = True):
         self._new_averaged_grads = False
         with self.get_tensors() as averaged_grads:
+            assert len(averaged_grads) == len(self.parameters)
             try:
-                assert len(averaged_grads) == len(self.parameters)
-                old_grads = [param.grad for param in self.parameters]
-                for param, new_grad in zip(self.parameters, averaged_grads):
-                    param.grad = new_grad
-                yield
+                if replace_model_gradients:
+                    old_grads = [param.grad for param in self.parameters]
+                    for param, new_grad in zip(self.parameters, averaged_grads):
+                        param.grad = new_grad
+                yield averaged_grads
             finally:
-                for param, old_grad in zip(self.parameters, old_grads):
-                    param.grad = old_grad
+                if replace_model_gradients:
+                    for param, old_grad in zip(self.parameters, old_grads):
+                        param.grad = old_grad

+ 373 - 0
hivemind/optim/experimental/optimizer.py

@@ -0,0 +1,373 @@
+from __future__ import annotations
+
+import logging
+from typing import Optional, Sequence, Union
+
+import torch
+from torch import nn
+
+from hivemind import get_dht_time
+from hivemind.averaging.control import StepControl
+from hivemind.dht import DHT
+from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.progress_tracker import ProgressTracker
+from hivemind.optim.experimental.state_averager import (
+    LRSchedulerBase,
+    OptimizerFactory,
+    Parameters,
+    ParamGroups,
+    SchedulerFactory,
+    TorchOptimizer,
+    TrainingStateAverager,
+)
+from hivemind.optim.grad_scaler import HivemindGradScaler
+from hivemind.utils import get_logger
+
+logger = get_logger(__name__)
+
+
+class Optimizer(torch.optim.Optimizer):
+    """
+    Hivemind Optimizer wraps your regular PyTorch Optimizer for training in a swarm of peers. It can be configured with
+     synchronous, delayed or asynchronous updates to trade between optimization guarantees and compute utilization.
+
+    The Optimizer is meant as a drop-in replacement for your regular PyTorch code:
+
+    >>> model = transformers.AutoModel("albert-xxlarge-v2")
+    >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
+    >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, prefix="run_42",
+    >>>                          target_batch_size=4096, average_gradients=True, batch_size_per_step=4)
+    >>> while True:
+    >>>     loss = compute_loss_on_batch(model, batch_size=4)
+    >>>     opt.zero_grad()
+    >>>     loss.backward()
+    >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
+
+    However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
+    - accumulate a minibatch of data towards the (global) target batch size without changing parameters (yet),
+    - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step,
+    - if, for any reason, your peer lags behind the rest of the swarm, it will load state from up-to-date peers.
+
+    :note: Hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one limitation:
+      learning rate schedulers, curriculum and other time-dependent features should use opt.global_step (and not the
+      number of local forward-backward cycles). This is because any device can join midway through training, when
+      other peers have already made some progress and changed their learning rate accordingly.
+
+    TODO yozh, the doc below still needs update
+    #TODO forward timeout to state averager
+    #TODO option to offload optimizer and DPU
+
+    :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
+    :param dht: a running hivemind.DHT daemon connected to other peers
+    :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
+    :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
+    :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
+    :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
+    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
+    :param scheduler: if specified, use this scheduler to update optimizer learning rate
+    :param epoch_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
+    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
+      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
+    :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
+    :param kwargs: additional parameters forwarded to DecentralizedAverager
+    :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
+      explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
+
+
+    Internally, hivemind.Optimizer consists of 4 components:
+    - DHT, a decentralized key-value storage used for coordination across the swarm
+    - GradientAverager that is responsible for aggregating gradients with peers for global steps (can be disabled)
+    - TrainingStateAverager holds parameters and optimizer/scheduler statistics, keeping them weakly synchronized
+     by averaging with peers. It can also download these variable from other peers if your peer is out of sync.
+    - ProgressTracker that uses DHT to track the global training progress: the number of steps or samples accumulated
+
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: DHT,
+        prefix: str,
+        target_batch_size: int,
+        batch_size_per_step: Optional[int] = None,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        param_groups: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        matchmaking_time: Optional[float] = 15.0,
+        averaging_timeout: Optional[float] = 300.0,
+        load_state_timeout: float = 600.0,
+        reuse_grad_buffers: bool = False,
+        epoch_tolerance: int = 1,
+        delay_optimizer_step: bool = False,
+        client_mode: bool = None,
+        averager_opts: Optional[dict] = None,
+        tracker_opts: Optional[dict] = None,
+        verbose: bool = False,
+    ):
+        self.dht, self.prefix, self.epoch_tolerance = dht, prefix, epoch_tolerance
+        self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
+        self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
+        self.matchmaking_time, self.delay_optimizer_step = matchmaking_time, delay_optimizer_step
+
+        self.client_mode = client_mode if client_mode is not None else self.dht.client_mode
+        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
+        self.scheduled_round: Optional[StepControl] = None
+
+        self.state_averager = self._make_state_averager(
+            optimizer=optimizer, param_groups=param_groups, scheduler=scheduler, **averager_opts or {}
+        )
+        self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
+        self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
+        self._schema_hash = self._compute_schema_hash()
+
+        self._step_supports_amp_scaling = self.grad_averager.reuse_grad_buffers
+        # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
+        # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
+
+    def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
+        return TrainingStateAverager(
+            dht=self.dht,
+            prefix=f"{self.prefix}_state_averager",
+            allreduce_timeout=self.averaging_timeout,
+            status_loglevel=self.status_loglevel,
+            client_mode=self.client_mode,
+            offload_optimizer=True,
+            custom_gradients=True,
+            start=True,
+            **kwargs,
+        )
+
+    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+        assert hasattr(self, "state_averager"), "must initialize state averager first"
+        grad_averager = GradientAverager(
+            dht=self.dht,
+            prefix=f"{self.prefix}_grad_averager",
+            parameters=self.state_averager.main_parameters,
+            allreduce_timeout=self.averaging_timeout,
+            client_mode=self.client_mode,
+            start=True,
+            **kwargs,
+        )
+        optimized_param_groups = self.state_averager.optimizer.param_groups
+        optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+        with grad_averager.get_tensors() as averaged_gradients:
+            assert len(averaged_gradients) == len(optimized_parameters)
+            for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                opt_param.grad = averaged_grad
+        return grad_averager
+
+    def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
+        return ProgressTracker(
+            dht=self.dht,
+            prefix=self.prefix,
+            target_batch_size=target_batch_size,
+            client_mode=self.client_mode,
+            status_loglevel=self.status_loglevel,
+            start=True,
+            **kwargs,
+        )
+
+    def _compute_schema_hash(self) -> int:
+        optimized_param_groups = self.state_averager.optimizer.param_groups
+        optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+        param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
+        grad_ids = tuple(id(param.grad) for param in optimized_parameters)
+        return hash((grad_ids, param_shapes))
+
+    def is_alive(self) -> bool:
+        return self.state_averager.is_alive()
+
+    @property
+    def local_epoch(self) -> int:
+        return self.state_averager.local_epoch
+
+    @property
+    def should_load_state_from_peers(self) -> bool:
+        """If true, peer will discard local progress and attempt to download state from peers."""
+        return self.local_epoch < self.tracker.global_epoch - self.epoch_tolerance
+
+    def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
+        """
+        Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
+
+        :param batch_size: optional override for batch_size_per_step from init
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
+            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
+        if self.batch_size_per_step is None and batch_size is None:
+            raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
+        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
+
+        if self.should_load_state_from_peers:
+            self.load_state_from_peers()
+            return
+
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            raise NotImplementedError()
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
+            return
+
+        self.grad_averager.accumulate_grads_(batch_size)
+        self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
+        self.state_averager.step(apply_delayed_updates=True)
+
+        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
+            if self.scheduled_round is None or self.scheduled_round.triggered or self.scheduled_round.done():
+                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
+                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                logger.log(self.status_loglevel, f"Pre-scheduling next averaging round in {eta_seconds:.2f}s.")
+                scheduled_time = self.tracker.estimated_next_update_time
+                if self.client_mode:
+                    scheduled_time = get_dht_time() + self.averaging_timeout
+                self.scheduled_round = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
+
+        if not self.tracker.ready_to_update_epoch:
+            return
+
+        with self.tracker.pause_updates():
+            logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
+            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.unscale_(self.opt)
+
+            if self.scheduled_round is not None and self.scheduled_round.triggered or self.scheduled_round.done():
+                logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")
+                self.scheduled_round = None
+
+            logger.info(  # TODO
+                f"BEFORE: {self.grad_averager.local_samples_accumulated}, {repr([grad.norm() / self.grad_averager.local_times_accumulated for grad in self.grad_averager._grad_accumulators()])}"
+            )
+
+            need_averaging = self.tracker.global_progress.num_peers > 1
+            if need_averaging:
+                try:
+                    group_info = self.grad_averager.step(
+                        control=self.scheduled_round, reset_accumulators=True, timeout=self.averaging_timeout
+                    )
+                    logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                except BaseException as e:
+                    logger.log(self.status_loglevel, f"Averaging failed with {repr(e)}")
+
+            else:
+                if self.scheduled_round is not None:
+                    self.scheduled_round.cancel()
+                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
+
+            assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
+            with self.grad_averager.use_averaged_gradients(replace_model_gradients=False):
+                # note: we do not need to replace because the offloaded optimizer is already using averaged grads
+
+                self.state_averager.step(
+                    increment_epoch=True,
+                    optimizer_step=True,
+                    delay_optimizer_step=self.delay_optimizer_step,
+                    grad_scaler=grad_scaler,
+                    averaging_round=need_averaging,
+                    delay_averaging=True,
+                    averaging_opts=dict(
+                        scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
+                    )
+                    if need_averaging
+                    else None,
+                )
+
+            self.grad_averager.reset_accumulated_grads_()
+            self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+            logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
+
+    def step_aux(self, **kwargs):
+        """
+        Find and assist other peers in averaging without sending local gradients.
+
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+        raise NotImplementedError("Auxiliary step for hivemind.Optimizer is not implemented yet.")
+
+    def zero_grad(self, set_to_none: bool = False):
+        """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
+        if self.grad_averager.reuse_grad_buffers:
+            raise ValueError(
+                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
+                f"call zero_grad manually. Gradients will be refreshed internally."
+            )
+        for param in self.grad_averager.parameters:
+            if param.grad is None:
+                pass
+            elif set_to_none:
+                param.grad = None
+            else:
+                param.grad.zero_()
+
+    def load_state_from_peers(self, **kwargs):
+        """Attempt to fetch the newest collaboration state from other peers"""
+        if self.scheduled_round is not None and not self.scheduled_round.done():
+            self.scheduled_round.cancel()
+
+        with self.tracker.pause_updates():
+            while True:
+                try:
+                    self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    break
+                except KeyboardInterrupt:
+                    raise
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
+                    continue
+
+            if self.tracker.global_epoch - self.epoch_tolerance <= self.local_epoch < self.tracker.global_epoch:
+                logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}.")
+                self.state_averager.local_epoch = self.tracker.global_epoch
+
+            self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
+
+    def state_dict(self) -> dict:
+        state_dict = self.state_averager.optimizer.state_dict()
+        state_dict["state"]["local_epoch"] = self.local_epoch
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "local_epoch" in state_dict["state"]:
+            self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
+        return self.state_averager.optimizer.load_state_dict(state_dict)
+
+    @property
+    def state(self):
+        return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
+
+    @property
+    def opt(self) -> TorchOptimizer:
+        # for compatibility with HivemindGradScaler
+        return self.state_averager.optimizer
+
+    @property
+    def param_groups(self) -> ParamGroups:
+        return self.state_averager.optimizer.param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise ValueError(
+            f"{self.__class__.__name__} does not support calling add_param_group after creation."
+            f"Please provide all parameter groups at init."
+        )
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(prefix={self.prefix}, epoch={self.local_epoch})"
+
+    def shutdown(self):
+        logger.debug("Sending goodbye to peers...")
+        self.tracker.shutdown()
+        logger.debug("Shutting down averager...")
+        self.state_averager.step(wait_for_delayed_update=True)
+        self.state_averager.shutdown()
+        self.grad_averager.shutdown()
+        logger.debug(f"{self.__class__.__name__} is shut down.")
+
+    def __del__(self):
+        if self.is_alive():  # TODO check os.getpid!!!
+            self.shutdown()