Artem Chumachenko 3 жил өмнө
parent
commit
2e763be1de

+ 1 - 2
docs/user/quickstart.md

@@ -58,7 +58,6 @@ opt = hivemind.Optimizer(
     batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
     target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch 
     optimizer=opt,            # wrap the SGD optimizer defined above
-    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
     matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
     averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
@@ -111,6 +110,7 @@ from tqdm.auto import tqdm
 
 import hivemind
 
+
 # Create dataset and model, same as in the basic tutorial
 # For this basic tutorial, we download only the training set
 transform = transforms.Compose(
@@ -134,7 +134,6 @@ opt = hivemind.Optimizer(
     batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
     target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch
     optimizer=opt,            # wrap the SGD optimizer defined above
-    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
     matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
     averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently

+ 13 - 5
hivemind/optim/grad_averager.py

@@ -1,5 +1,5 @@
 import contextlib
-from typing import Iterable, Iterator, Optional, Sequence
+from typing import Any, Callable, Union, Iterable, Iterator, Optional, Sequence, Type, TypeVar
 
 import torch
 
@@ -11,6 +11,10 @@ from hivemind.utils import DHTExpiration, get_dht_time, get_logger
 logger = get_logger(__name__)
 
 
+TGradientAverager = TypeVar('TGradientAverager', bound='GradientAverager')
+GradientAveragerFactory = Callable[[Type[TGradientAverager], Any], TGradientAverager]
+
+
 class GradientAverager(DecentralizedAverager):
     """
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
@@ -75,7 +79,7 @@ class GradientAverager(DecentralizedAverager):
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         warn: bool = True,
-        grad_extra_tensors: Sequence[torch.Tensor] = (),
+        averaged_grads: Sequence[torch.Tensor] = (),
         **kwargs,
     ):
         if reuse_grad_buffers and accumulate_grads_on is not None:
@@ -96,9 +100,7 @@ class GradientAverager(DecentralizedAverager):
         self._new_averaged_grads = False
 
         with torch.no_grad():
-            if grad_extra_tensors:
-                averaged_grads = grad_extra_tensors
-            else:
+            if not averaged_grads:
                 averaged_grads = tuple(
                     grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
                 )
@@ -228,3 +230,9 @@ class GradientAverager(DecentralizedAverager):
     def notify_used_averaged_gradients(self):
         """Notify averager that the results of a previous averaging round are accounted for"""
         self._new_averaged_grads = False
+
+    @classmethod
+    def get_factory(cls, **kwargs1) -> GradientAveragerFactory:
+        def _factory(**kwargs2):
+            return cls(**kwargs1, **kwargs2)
+        return _factory

+ 23 - 33
hivemind/optim/optimizer.py

@@ -11,7 +11,7 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.power_ef_averager import PowerEFGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
@@ -35,7 +35,7 @@ class Optimizer(torch.optim.Optimizer):
 
     By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
     There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
-    or even fully asynchronous (use_local_updates=True).
+    or even fully asynchronous (grad_averager=None).
 
     :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
 
@@ -140,10 +140,6 @@ class Optimizer(torch.optim.Optimizer):
       hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
       gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
 
-    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
-      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
-      Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
-
     :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
 
@@ -184,17 +180,16 @@ class Optimizer(torch.optim.Optimizer):
         delay_grad_averaging: bool = False,
         delay_state_averaging: bool = True,
         average_state_every: int = 1,
-        use_local_updates: bool = False,
         client_mode: bool = None,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
-        grad_rank_averager: Optional[str] = None,
+        grad_averager: Optional[GradientAveragerFactory] = GradientAverager.get_factory(),
+        use_ext_grad_buffer: bool = False,
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
-        grad_averager_opts: Optional[dict] = dict(),
         tracker_opts: Optional[dict] = None,
         performance_ema_alpha: float = 0.1,
         shutdown_timeout: float = 5,
@@ -223,10 +218,14 @@ class Optimizer(torch.optim.Optimizer):
                 "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
                 "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
             )
-        if use_local_updates:
+        if grad_averager is None:
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
 
+        params = list(params) if params is not None else optimizer.param_groups
+        if all(isinstance(p, torch.Tensor) for p in params):
+            params = (dict(params=params),)
+
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
@@ -244,23 +243,19 @@ class Optimizer(torch.optim.Optimizer):
         self.tracker = self._make_progress_tracker(
             target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
         )
-        if grad_rank_averager == "power_ef" and not use_local_updates:
-            assert len(extra_tensors) == 0
-            grad_extra_tensors = [
-                torch.zeros_like(param, device="cpu")
-                for param_group in optimizer.param_groups
-                for param in param_group["params"]
+        averaged_grads = None
+        if use_ext_grad_buffer:
+            assert grad_averager is not None, "Use external gradient buffers only with working gradient averager."
+            averaged_grads = [
+                torch.zeros_like(param, device="cpu").share_memory_()
+                for param_group in params for param in param_group["params"]
             ]
-            for tensor in grad_extra_tensors:
-                if tensor is not None:
-                    tensor.share_memory_()
-            grad_averager_opts["grad_extra_tensors"] = grad_extra_tensors
-            extra_tensors = [e for e in extra_tensors] + [eg for eg in grad_extra_tensors]
+            extra_tensors = [e for e in extra_tensors] + [ag for ag in averaged_grads]
         self.state_averager = self._make_state_averager(
             optimizer=optimizer,
             params=params,
             scheduler=scheduler,
-            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
+            delta_rule_averaging=grad_averager is None and self.delay_state_averaging,
             compression=state_averaging_compression,
             state_compression=load_state_compression,
             average_opt_statistics=average_opt_statistics,
@@ -268,12 +263,11 @@ class Optimizer(torch.optim.Optimizer):
             extra_tensors=extra_tensors,
             **averager_opts or {},
         )
-        if not use_local_updates:
+        if grad_averager:
             self.grad_averager = self._make_gradient_averager(
                 reuse_grad_buffers=reuse_grad_buffers,
-                grad_rank_averager=grad_rank_averager,
-                compression=grad_compression,
-                **grad_averager_opts or {},
+                grad_averager=grad_averager,
+                averaged_grads=averaged_grads
             )
         else:
             self.grad_averager = None
@@ -307,13 +301,9 @@ class Optimizer(torch.optim.Optimizer):
             **kwargs,
         )
 
-    def _make_gradient_averager(self, grad_rank_averager, **kwargs) -> GradientAverager:
+    def _make_gradient_averager(self, grad_averager, **kwargs) -> GradientAverager:
         assert hasattr(self, "state_averager"), "must initialize state averager first"
-        if grad_rank_averager == "power_ef":
-            grad_averager_type = PowerEFGradientAverager
-        else:
-            grad_averager_type = GradientAverager
-        grad_averager = grad_averager_type(
+        grad_averager = grad_averager(
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
@@ -426,7 +416,7 @@ class Optimizer(torch.optim.Optimizer):
             self._maybe_schedule_state_averaging()
 
         else:
-            # use_local_updates=True: update parameters on every step independently of other peers
+            # grad_averager=None: update parameters on every step independently of other peers
             if not self.auxiliary:
                 if grad_scaler is not None:
                     with grad_scaler.running_global_step():

+ 21 - 21
hivemind/optim/power_ef_averager.py

@@ -58,7 +58,7 @@ class PowerEFGradientAverager(GradientAverager):
         client_mode: bool = None,
         warn: bool = True,
         min_comprasion_ratio: float = 0.5,
-        grad_extra_tensors: Sequence[torch.Tensor] = (),
+        averaged_grads: Optional[Sequence[torch.Tensor]] = None,
         **kwargs,
     ):
         self.rank = averager_rank
@@ -71,17 +71,19 @@ class PowerEFGradientAverager(GradientAverager):
                 self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio
             )
         )
-        self._gradient_rests = list(torch.zeros_like(grad, device="cpu") for grad in self._grads_from_parameters())
+        self._gradient_residual = list(torch.zeros_like(grad, device="cpu") for grad in self._grads_from_parameters())
         self._qs = list(
             torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu")
             for idx, grad in enumerate(self._grads_from_parameters())
             if idx not in self._uncompressed_gradients
         )
-        for tensor in self._qs + self._gradient_rests:
+        for tensor in self._qs + self._gradient_residual:
             if tensor is not None:
                 assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
                 tensor.share_memory_()
 
+        self.all_reduce_phases = (b".phase1", b".phase2")
+
         super().__init__(
             self.parameters,
             dht=dht,
@@ -90,7 +92,7 @@ class PowerEFGradientAverager(GradientAverager):
             accumulate_grads_on=accumulate_grads_on,
             client_mode=client_mode,
             warn=warn,
-            grad_extra_tensors=grad_extra_tensors,
+            averaged_grads=averaged_grads,
             **kwargs,
         )
 
@@ -98,17 +100,15 @@ class PowerEFGradientAverager(GradientAverager):
     def _register_allreduce_group(self, group_info: GroupInfo):
         """registers a given all-reduce runner to listen for incoming connections"""
         try:
-            self._running_groups[group_info.group_id + b".phase1"] = asyncio.Future()
-            self._running_groups[group_info.group_id + b".phase2"] = asyncio.Future()
+            for phase in self.all_reduce_phases:
+                self._running_groups[group_info.group_id + phase] = asyncio.Future()
             self._pending_groups_registered.set()
             yield
         finally:
-            maybe_future = self._running_groups.pop(group_info.group_id + b".phase1", None)
-            if maybe_future and not maybe_future.done():
-                logger.warning(f"All-reduce group {group_info.group_id + b'.phase1'} did not finish.")
-            maybe_future = self._running_groups.pop(group_info.group_id + b".phase2", None)
-            if maybe_future and not maybe_future.done():
-                logger.warning(f"All-reduce group {group_info.group_id + b'.phase2'} did not finish.")
+            for phase in self.all_reduce_phases:
+                maybe_future = self._running_groups.pop(group_info.group_id + phase, None)
+                if maybe_future and not maybe_future.done():
+                    logger.warning(f"All-reduce group {group_info.group_id + phase} did not finish.")
             self._pending_groups_registered.set()
 
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
@@ -126,7 +126,7 @@ class PowerEFGradientAverager(GradientAverager):
             )
 
             async with enter_asynchronously(self.get_tensors()) as averaged_grads:
-                cs = [rest for idx, rest in enumerate(self._gradient_rests) if idx not in self._uncompressed_gradients]
+                cs = [rest for idx, rest in enumerate(self._gradient_residual) if idx not in self._uncompressed_gradients]
                 ps = [
                     torch.zeros((grad.size(0), self.rank), device="cpu")
                     for idx, grad in enumerate(averaged_grads)
@@ -135,13 +135,13 @@ class PowerEFGradientAverager(GradientAverager):
                 for p, q, rest in zip(ps, self._qs, cs):
                     torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
                 first_all_reduced = ps + [
-                    rest for idx, rest in enumerate(self._gradient_rests) if idx in self._uncompressed_gradients
+                    rest for idx, rest in enumerate(self._gradient_residual) if idx in self._uncompressed_gradients
                 ]
                 allreduce1 = AllReduceRunner(
                     p2p=self._p2p,
                     servicer_type=type(self),
                     prefix=self.prefix,
-                    group_id=group_info.group_id + b".phase1",
+                    group_id=group_info.group_id + self.all_reduce_phases[0],
                     tensors=first_all_reduced,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
@@ -149,7 +149,7 @@ class PowerEFGradientAverager(GradientAverager):
                     modes=modes,
                     **kwargs,
                 )
-                self._running_groups[group_info.group_id + b".phase1"].set_result(allreduce1)
+                self._running_groups[group_info.group_id + self.all_reduce_phases[0]].set_result(allreduce1)
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                     async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce1):
@@ -171,7 +171,7 @@ class PowerEFGradientAverager(GradientAverager):
                     p2p=self._p2p,
                     servicer_type=type(self),
                     prefix=self.prefix,
-                    group_id=group_info.group_id + b".phase2",
+                    group_id=group_info.group_id + self.all_reduce_phases[1],
                     tensors=self._qs,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
@@ -179,7 +179,7 @@ class PowerEFGradientAverager(GradientAverager):
                     modes=modes,
                     **kwargs,
                 )
-                self._running_groups[group_info.group_id + b".phase2"].set_result(allreduce2)
+                self._running_groups[group_info.group_id + self.all_reduce_phases[1]].set_result(allreduce2)
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                     async for tensor, update in azip(as_aiter(*self._qs), allreduce2):
@@ -196,7 +196,7 @@ class PowerEFGradientAverager(GradientAverager):
                     new_c = torch.matmul(p, q.t())
                     c.copy_(new_c.reshape(c.size()))
 
-                for rest, grad in zip(self._gradient_rests, averaged_grads):
+                for rest, grad in zip(self._gradient_residual, averaged_grads):
                     torch.add(grad, rest, out=grad)
 
                 return allreduce1.gathered
@@ -212,8 +212,8 @@ class PowerEFGradientAverager(GradientAverager):
         # divide locally accumulated gradients by the number of times they were accumulated
         grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
         with self.get_tensors() as averaged_grads:
-            for grad_acc, averaged_grad, rest in zip(self._grad_accumulators(), averaged_grads, self._gradient_rests):
-                torch.sub(grad_acc * grad_scale, averaged_grad, out=rest)
+            for grad_acc, averaged_grad, rest in zip(self._grad_accumulators(), averaged_grads, self._gradient_residual):
+                rest.copy_(grad_acc, non_blocking=False).mul_(grad_scale).sub_(averaged_grad)
 
 
 @torch.jit.script

+ 17 - 2
tests/test_optimizer.py

@@ -2,6 +2,7 @@ import ctypes
 import multiprocessing as mp
 import time
 from functools import partial
+from typing import Optional
 
 import numpy as np
 import pytest
@@ -11,7 +12,8 @@ import torch.nn.functional as F
 
 import hivemind
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
+from hivemind.optim.power_ef_averager import PowerEFGradientAverager
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.state_averager import TrainingStateAverager
@@ -286,7 +288,15 @@ def test_progress_tracker():
 
 
 @pytest.mark.forked
+@pytest.mark.parametrize(
+    "grad_averager",
+    [
+        (GradientAverager.get_factory(),),
+        (PowerEFGradientAverager.get_factory(averager_rank=1),)
+    ],
+)
 def test_optimizer(
+    grad_averager: GradientAveragerFactory,
     num_peers: int = 1,
     num_clients: int = 0,
     target_batch_size: int = 32,
@@ -305,7 +315,11 @@ def test_optimizer(
 
     def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
         nonlocal optimizer
-        model = nn.Linear(5, 1)
+        model = nn.Sequential(
+            nn.Linear(5, 5),
+            nn.ReLU(),
+            nn.Linear(5, 1),
+        )
 
         assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
 
@@ -326,6 +340,7 @@ def test_optimizer(
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
             client_mode=client_mode,
+            grad_averager=GradientAverager,
             verbose=False,
         )
         optimizer.load_state_from_peers()