Artem Chumachenko 3 lat temu
rodzic
commit
2e763be1de

+ 1 - 2
docs/user/quickstart.md

@@ -58,7 +58,6 @@ opt = hivemind.Optimizer(
     batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
     batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
     target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch 
     target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch 
     optimizer=opt,            # wrap the SGD optimizer defined above
     optimizer=opt,            # wrap the SGD optimizer defined above
-    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
     matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
     matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
     averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
     verbose=True              # print logs incessently
@@ -111,6 +110,7 @@ from tqdm.auto import tqdm
 
 
 import hivemind
 import hivemind
 
 
+
 # Create dataset and model, same as in the basic tutorial
 # Create dataset and model, same as in the basic tutorial
 # For this basic tutorial, we download only the training set
 # For this basic tutorial, we download only the training set
 transform = transforms.Compose(
 transform = transforms.Compose(
@@ -134,7 +134,6 @@ opt = hivemind.Optimizer(
     batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
     batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
     target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch
     target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch
     optimizer=opt,            # wrap the SGD optimizer defined above
     optimizer=opt,            # wrap the SGD optimizer defined above
-    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
     matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
     matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
     averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
     verbose=True              # print logs incessently

+ 13 - 5
hivemind/optim/grad_averager.py

@@ -1,5 +1,5 @@
 import contextlib
 import contextlib
-from typing import Iterable, Iterator, Optional, Sequence
+from typing import Any, Callable, Union, Iterable, Iterator, Optional, Sequence, Type, TypeVar
 
 
 import torch
 import torch
 
 
@@ -11,6 +11,10 @@ from hivemind.utils import DHTExpiration, get_dht_time, get_logger
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+TGradientAverager = TypeVar('TGradientAverager', bound='GradientAverager')
+GradientAveragerFactory = Callable[[Type[TGradientAverager], Any], TGradientAverager]
+
+
 class GradientAverager(DecentralizedAverager):
 class GradientAverager(DecentralizedAverager):
     """
     """
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
@@ -75,7 +79,7 @@ class GradientAverager(DecentralizedAverager):
         accumulate_grads_on: Optional[torch.device] = None,
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         client_mode: bool = None,
         warn: bool = True,
         warn: bool = True,
-        grad_extra_tensors: Sequence[torch.Tensor] = (),
+        averaged_grads: Sequence[torch.Tensor] = (),
         **kwargs,
         **kwargs,
     ):
     ):
         if reuse_grad_buffers and accumulate_grads_on is not None:
         if reuse_grad_buffers and accumulate_grads_on is not None:
@@ -96,9 +100,7 @@ class GradientAverager(DecentralizedAverager):
         self._new_averaged_grads = False
         self._new_averaged_grads = False
 
 
         with torch.no_grad():
         with torch.no_grad():
-            if grad_extra_tensors:
-                averaged_grads = grad_extra_tensors
-            else:
+            if not averaged_grads:
                 averaged_grads = tuple(
                 averaged_grads = tuple(
                     grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
                     grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
                 )
                 )
@@ -228,3 +230,9 @@ class GradientAverager(DecentralizedAverager):
     def notify_used_averaged_gradients(self):
     def notify_used_averaged_gradients(self):
         """Notify averager that the results of a previous averaging round are accounted for"""
         """Notify averager that the results of a previous averaging round are accounted for"""
         self._new_averaged_grads = False
         self._new_averaged_grads = False
+
+    @classmethod
+    def get_factory(cls, **kwargs1) -> GradientAveragerFactory:
+        def _factory(**kwargs2):
+            return cls(**kwargs1, **kwargs2)
+        return _factory

+ 23 - 33
hivemind/optim/optimizer.py

@@ -11,7 +11,7 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.power_ef_averager import PowerEFGradientAverager
 from hivemind.optim.power_ef_averager import PowerEFGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
@@ -35,7 +35,7 @@ class Optimizer(torch.optim.Optimizer):
 
 
     By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
     By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
     There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
     There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
-    or even fully asynchronous (use_local_updates=True).
+    or even fully asynchronous (grad_averager=None).
 
 
     :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
     :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
 
 
@@ -140,10 +140,6 @@ class Optimizer(torch.optim.Optimizer):
       hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
       hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
       gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
       gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
 
 
-    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
-      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
-      Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
-
     :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
     :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
 
 
@@ -184,17 +180,16 @@ class Optimizer(torch.optim.Optimizer):
         delay_grad_averaging: bool = False,
         delay_grad_averaging: bool = False,
         delay_state_averaging: bool = True,
         delay_state_averaging: bool = True,
         average_state_every: int = 1,
         average_state_every: int = 1,
-        use_local_updates: bool = False,
         client_mode: bool = None,
         client_mode: bool = None,
         auxiliary: bool = False,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
         grad_compression: CompressionBase = NoCompression(),
-        grad_rank_averager: Optional[str] = None,
+        grad_averager: Optional[GradientAveragerFactory] = GradientAverager.get_factory(),
+        use_ext_grad_buffer: bool = False,
         state_averaging_compression: CompressionBase = NoCompression(),
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
         average_opt_statistics: Sequence[str] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
         averager_opts: Optional[dict] = None,
-        grad_averager_opts: Optional[dict] = dict(),
         tracker_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
         performance_ema_alpha: float = 0.1,
         performance_ema_alpha: float = 0.1,
         shutdown_timeout: float = 5,
         shutdown_timeout: float = 5,
@@ -223,10 +218,14 @@ class Optimizer(torch.optim.Optimizer):
                 "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
                 "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
                 "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
                 "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
             )
             )
-        if use_local_updates:
+        if grad_averager is None:
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
 
 
+        params = list(params) if params is not None else optimizer.param_groups
+        if all(isinstance(p, torch.Tensor) for p in params):
+            params = (dict(params=params),)
+
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
         self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
@@ -244,23 +243,19 @@ class Optimizer(torch.optim.Optimizer):
         self.tracker = self._make_progress_tracker(
         self.tracker = self._make_progress_tracker(
             target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
             target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
         )
         )
-        if grad_rank_averager == "power_ef" and not use_local_updates:
-            assert len(extra_tensors) == 0
-            grad_extra_tensors = [
-                torch.zeros_like(param, device="cpu")
-                for param_group in optimizer.param_groups
-                for param in param_group["params"]
+        averaged_grads = None
+        if use_ext_grad_buffer:
+            assert grad_averager is not None, "Use external gradient buffers only with working gradient averager."
+            averaged_grads = [
+                torch.zeros_like(param, device="cpu").share_memory_()
+                for param_group in params for param in param_group["params"]
             ]
             ]
-            for tensor in grad_extra_tensors:
-                if tensor is not None:
-                    tensor.share_memory_()
-            grad_averager_opts["grad_extra_tensors"] = grad_extra_tensors
-            extra_tensors = [e for e in extra_tensors] + [eg for eg in grad_extra_tensors]
+            extra_tensors = [e for e in extra_tensors] + [ag for ag in averaged_grads]
         self.state_averager = self._make_state_averager(
         self.state_averager = self._make_state_averager(
             optimizer=optimizer,
             optimizer=optimizer,
             params=params,
             params=params,
             scheduler=scheduler,
             scheduler=scheduler,
-            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
+            delta_rule_averaging=grad_averager is None and self.delay_state_averaging,
             compression=state_averaging_compression,
             compression=state_averaging_compression,
             state_compression=load_state_compression,
             state_compression=load_state_compression,
             average_opt_statistics=average_opt_statistics,
             average_opt_statistics=average_opt_statistics,
@@ -268,12 +263,11 @@ class Optimizer(torch.optim.Optimizer):
             extra_tensors=extra_tensors,
             extra_tensors=extra_tensors,
             **averager_opts or {},
             **averager_opts or {},
         )
         )
-        if not use_local_updates:
+        if grad_averager:
             self.grad_averager = self._make_gradient_averager(
             self.grad_averager = self._make_gradient_averager(
                 reuse_grad_buffers=reuse_grad_buffers,
                 reuse_grad_buffers=reuse_grad_buffers,
-                grad_rank_averager=grad_rank_averager,
-                compression=grad_compression,
-                **grad_averager_opts or {},
+                grad_averager=grad_averager,
+                averaged_grads=averaged_grads
             )
             )
         else:
         else:
             self.grad_averager = None
             self.grad_averager = None
@@ -307,13 +301,9 @@ class Optimizer(torch.optim.Optimizer):
             **kwargs,
             **kwargs,
         )
         )
 
 
-    def _make_gradient_averager(self, grad_rank_averager, **kwargs) -> GradientAverager:
+    def _make_gradient_averager(self, grad_averager, **kwargs) -> GradientAverager:
         assert hasattr(self, "state_averager"), "must initialize state averager first"
         assert hasattr(self, "state_averager"), "must initialize state averager first"
-        if grad_rank_averager == "power_ef":
-            grad_averager_type = PowerEFGradientAverager
-        else:
-            grad_averager_type = GradientAverager
-        grad_averager = grad_averager_type(
+        grad_averager = grad_averager(
             dht=self.dht,
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             parameters=self.state_averager.main_parameters,
@@ -426,7 +416,7 @@ class Optimizer(torch.optim.Optimizer):
             self._maybe_schedule_state_averaging()
             self._maybe_schedule_state_averaging()
 
 
         else:
         else:
-            # use_local_updates=True: update parameters on every step independently of other peers
+            # grad_averager=None: update parameters on every step independently of other peers
             if not self.auxiliary:
             if not self.auxiliary:
                 if grad_scaler is not None:
                 if grad_scaler is not None:
                     with grad_scaler.running_global_step():
                     with grad_scaler.running_global_step():

+ 21 - 21
hivemind/optim/power_ef_averager.py

@@ -58,7 +58,7 @@ class PowerEFGradientAverager(GradientAverager):
         client_mode: bool = None,
         client_mode: bool = None,
         warn: bool = True,
         warn: bool = True,
         min_comprasion_ratio: float = 0.5,
         min_comprasion_ratio: float = 0.5,
-        grad_extra_tensors: Sequence[torch.Tensor] = (),
+        averaged_grads: Optional[Sequence[torch.Tensor]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
         self.rank = averager_rank
         self.rank = averager_rank
@@ -71,17 +71,19 @@ class PowerEFGradientAverager(GradientAverager):
                 self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio
                 self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio
             )
             )
         )
         )
-        self._gradient_rests = list(torch.zeros_like(grad, device="cpu") for grad in self._grads_from_parameters())
+        self._gradient_residual = list(torch.zeros_like(grad, device="cpu") for grad in self._grads_from_parameters())
         self._qs = list(
         self._qs = list(
             torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu")
             torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu")
             for idx, grad in enumerate(self._grads_from_parameters())
             for idx, grad in enumerate(self._grads_from_parameters())
             if idx not in self._uncompressed_gradients
             if idx not in self._uncompressed_gradients
         )
         )
-        for tensor in self._qs + self._gradient_rests:
+        for tensor in self._qs + self._gradient_residual:
             if tensor is not None:
             if tensor is not None:
                 assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
                 assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
                 tensor.share_memory_()
                 tensor.share_memory_()
 
 
+        self.all_reduce_phases = (b".phase1", b".phase2")
+
         super().__init__(
         super().__init__(
             self.parameters,
             self.parameters,
             dht=dht,
             dht=dht,
@@ -90,7 +92,7 @@ class PowerEFGradientAverager(GradientAverager):
             accumulate_grads_on=accumulate_grads_on,
             accumulate_grads_on=accumulate_grads_on,
             client_mode=client_mode,
             client_mode=client_mode,
             warn=warn,
             warn=warn,
-            grad_extra_tensors=grad_extra_tensors,
+            averaged_grads=averaged_grads,
             **kwargs,
             **kwargs,
         )
         )
 
 
@@ -98,17 +100,15 @@ class PowerEFGradientAverager(GradientAverager):
     def _register_allreduce_group(self, group_info: GroupInfo):
     def _register_allreduce_group(self, group_info: GroupInfo):
         """registers a given all-reduce runner to listen for incoming connections"""
         """registers a given all-reduce runner to listen for incoming connections"""
         try:
         try:
-            self._running_groups[group_info.group_id + b".phase1"] = asyncio.Future()
-            self._running_groups[group_info.group_id + b".phase2"] = asyncio.Future()
+            for phase in self.all_reduce_phases:
+                self._running_groups[group_info.group_id + phase] = asyncio.Future()
             self._pending_groups_registered.set()
             self._pending_groups_registered.set()
             yield
             yield
         finally:
         finally:
-            maybe_future = self._running_groups.pop(group_info.group_id + b".phase1", None)
-            if maybe_future and not maybe_future.done():
-                logger.warning(f"All-reduce group {group_info.group_id + b'.phase1'} did not finish.")
-            maybe_future = self._running_groups.pop(group_info.group_id + b".phase2", None)
-            if maybe_future and not maybe_future.done():
-                logger.warning(f"All-reduce group {group_info.group_id + b'.phase2'} did not finish.")
+            for phase in self.all_reduce_phases:
+                maybe_future = self._running_groups.pop(group_info.group_id + phase, None)
+                if maybe_future and not maybe_future.done():
+                    logger.warning(f"All-reduce group {group_info.group_id + phase} did not finish.")
             self._pending_groups_registered.set()
             self._pending_groups_registered.set()
 
 
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
@@ -126,7 +126,7 @@ class PowerEFGradientAverager(GradientAverager):
             )
             )
 
 
             async with enter_asynchronously(self.get_tensors()) as averaged_grads:
             async with enter_asynchronously(self.get_tensors()) as averaged_grads:
-                cs = [rest for idx, rest in enumerate(self._gradient_rests) if idx not in self._uncompressed_gradients]
+                cs = [rest for idx, rest in enumerate(self._gradient_residual) if idx not in self._uncompressed_gradients]
                 ps = [
                 ps = [
                     torch.zeros((grad.size(0), self.rank), device="cpu")
                     torch.zeros((grad.size(0), self.rank), device="cpu")
                     for idx, grad in enumerate(averaged_grads)
                     for idx, grad in enumerate(averaged_grads)
@@ -135,13 +135,13 @@ class PowerEFGradientAverager(GradientAverager):
                 for p, q, rest in zip(ps, self._qs, cs):
                 for p, q, rest in zip(ps, self._qs, cs):
                     torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
                     torch.matmul(rest.reshape(-1, q.size(0)), q, out=p)
                 first_all_reduced = ps + [
                 first_all_reduced = ps + [
-                    rest for idx, rest in enumerate(self._gradient_rests) if idx in self._uncompressed_gradients
+                    rest for idx, rest in enumerate(self._gradient_residual) if idx in self._uncompressed_gradients
                 ]
                 ]
                 allreduce1 = AllReduceRunner(
                 allreduce1 = AllReduceRunner(
                     p2p=self._p2p,
                     p2p=self._p2p,
                     servicer_type=type(self),
                     servicer_type=type(self),
                     prefix=self.prefix,
                     prefix=self.prefix,
-                    group_id=group_info.group_id + b".phase1",
+                    group_id=group_info.group_id + self.all_reduce_phases[0],
                     tensors=first_all_reduced,
                     tensors=first_all_reduced,
                     ordered_peer_ids=group_info.peer_ids,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
@@ -149,7 +149,7 @@ class PowerEFGradientAverager(GradientAverager):
                     modes=modes,
                     modes=modes,
                     **kwargs,
                     **kwargs,
                 )
                 )
-                self._running_groups[group_info.group_id + b".phase1"].set_result(allreduce1)
+                self._running_groups[group_info.group_id + self.all_reduce_phases[0]].set_result(allreduce1)
 
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                     async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce1):
                     async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce1):
@@ -171,7 +171,7 @@ class PowerEFGradientAverager(GradientAverager):
                     p2p=self._p2p,
                     p2p=self._p2p,
                     servicer_type=type(self),
                     servicer_type=type(self),
                     prefix=self.prefix,
                     prefix=self.prefix,
-                    group_id=group_info.group_id + b".phase2",
+                    group_id=group_info.group_id + self.all_reduce_phases[1],
                     tensors=self._qs,
                     tensors=self._qs,
                     ordered_peer_ids=group_info.peer_ids,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
@@ -179,7 +179,7 @@ class PowerEFGradientAverager(GradientAverager):
                     modes=modes,
                     modes=modes,
                     **kwargs,
                     **kwargs,
                 )
                 )
-                self._running_groups[group_info.group_id + b".phase2"].set_result(allreduce2)
+                self._running_groups[group_info.group_id + self.all_reduce_phases[1]].set_result(allreduce2)
 
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                     async for tensor, update in azip(as_aiter(*self._qs), allreduce2):
                     async for tensor, update in azip(as_aiter(*self._qs), allreduce2):
@@ -196,7 +196,7 @@ class PowerEFGradientAverager(GradientAverager):
                     new_c = torch.matmul(p, q.t())
                     new_c = torch.matmul(p, q.t())
                     c.copy_(new_c.reshape(c.size()))
                     c.copy_(new_c.reshape(c.size()))
 
 
-                for rest, grad in zip(self._gradient_rests, averaged_grads):
+                for rest, grad in zip(self._gradient_residual, averaged_grads):
                     torch.add(grad, rest, out=grad)
                     torch.add(grad, rest, out=grad)
 
 
                 return allreduce1.gathered
                 return allreduce1.gathered
@@ -212,8 +212,8 @@ class PowerEFGradientAverager(GradientAverager):
         # divide locally accumulated gradients by the number of times they were accumulated
         # divide locally accumulated gradients by the number of times they were accumulated
         grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
         grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
         with self.get_tensors() as averaged_grads:
         with self.get_tensors() as averaged_grads:
-            for grad_acc, averaged_grad, rest in zip(self._grad_accumulators(), averaged_grads, self._gradient_rests):
-                torch.sub(grad_acc * grad_scale, averaged_grad, out=rest)
+            for grad_acc, averaged_grad, rest in zip(self._grad_accumulators(), averaged_grads, self._gradient_residual):
+                rest.copy_(grad_acc, non_blocking=False).mul_(grad_scale).sub_(averaged_grad)
 
 
 
 
 @torch.jit.script
 @torch.jit.script

+ 17 - 2
tests/test_optimizer.py

@@ -2,6 +2,7 @@ import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import time
 import time
 from functools import partial
 from functools import partial
+from typing import Optional
 
 
 import numpy as np
 import numpy as np
 import pytest
 import pytest
@@ -11,7 +12,8 @@ import torch.nn.functional as F
 
 
 import hivemind
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
+from hivemind.optim.power_ef_averager import PowerEFGradientAverager
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.optim.state_averager import TrainingStateAverager
@@ -286,7 +288,15 @@ def test_progress_tracker():
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
+@pytest.mark.parametrize(
+    "grad_averager",
+    [
+        (GradientAverager.get_factory(),),
+        (PowerEFGradientAverager.get_factory(averager_rank=1),)
+    ],
+)
 def test_optimizer(
 def test_optimizer(
+    grad_averager: GradientAveragerFactory,
     num_peers: int = 1,
     num_peers: int = 1,
     num_clients: int = 0,
     num_clients: int = 0,
     target_batch_size: int = 32,
     target_batch_size: int = 32,
@@ -305,7 +315,11 @@ def test_optimizer(
 
 
     def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
     def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
         nonlocal optimizer
         nonlocal optimizer
-        model = nn.Linear(5, 1)
+        model = nn.Sequential(
+            nn.Linear(5, 5),
+            nn.ReLU(),
+            nn.Linear(5, 1),
+        )
 
 
         assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
         assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
 
 
@@ -326,6 +340,7 @@ def test_optimizer(
             delay_optimizer_step=delay_optimizer_step,
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
             average_state_every=average_state_every,
             client_mode=client_mode,
             client_mode=client_mode,
+            grad_averager=GradientAverager,
             verbose=False,
             verbose=False,
         )
         )
         optimizer.load_state_from_peers()
         optimizer.load_state_from_peers()