Browse Source

fixes and reworks

Artem Chumachenko 3 years ago
parent
commit
33c2afa1c2

+ 1 - 0
docs/user/quickstart.md

@@ -58,6 +58,7 @@ opt = hivemind.Optimizer(
     batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
     batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
     target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch 
     target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch 
     optimizer=opt,            # wrap the SGD optimizer defined above
     optimizer=opt,            # wrap the SGD optimizer defined above
+    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
     matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
     matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
     averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
     verbose=True              # print logs incessently

+ 2 - 2
hivemind/averaging/averager.py

@@ -509,14 +509,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def _register_allreduce_group(self, group_info: GroupInfo):
     def _register_allreduce_group(self, group_info: GroupInfo):
-        """registers a given all-reduce runner to listen for incoming connections"""
+        """Register a given group all-reduce for one or more all-reduce rounds"""
         try:
         try:
             self._running_groups[group_info.group_id] = asyncio.Future()
             self._running_groups[group_info.group_id] = asyncio.Future()
             self._pending_groups_registered.set()
             self._pending_groups_registered.set()
             yield
             yield
         finally:
         finally:
             maybe_future = self._running_groups.pop(group_info.group_id, None)
             maybe_future = self._running_groups.pop(group_info.group_id, None)
-            if maybe_future and not maybe_future.done():
+            if maybe_future is not None and not maybe_future.done():
                 logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
                 logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
             self._pending_groups_registered.set()
             self._pending_groups_registered.set()
 
 

+ 1 - 11
hivemind/optim/grad_averager.py

@@ -11,10 +11,6 @@ from hivemind.utils import DHTExpiration, get_dht_time, get_logger
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-TGradientAverager = TypeVar("TGradientAverager", bound="GradientAverager")
-GradientAveragerFactory = Callable[[Type[TGradientAverager], Any], TGradientAverager]
-
-
 class GradientAverager(DecentralizedAverager):
 class GradientAverager(DecentralizedAverager):
     """
     """
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
@@ -40,6 +36,7 @@ class GradientAverager(DecentralizedAverager):
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param average_grads: if provided, it will be used as a set of averagable gradients
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
 
 
 
 
@@ -230,10 +227,3 @@ class GradientAverager(DecentralizedAverager):
     def notify_used_averaged_gradients(self):
     def notify_used_averaged_gradients(self):
         """Notify averager that the results of a previous averaging round are accounted for"""
         """Notify averager that the results of a previous averaging round are accounted for"""
         self._new_averaged_grads = False
         self._new_averaged_grads = False
-
-    @classmethod
-    def get_factory(cls, **kwargs1) -> GradientAveragerFactory:
-        def _factory(**kwargs2):
-            return cls(**kwargs1, **kwargs2)
-
-        return _factory

+ 13 - 11
hivemind/optim/optimizer.py

@@ -11,7 +11,7 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
+from hivemind.optim.grad_averager import GradientAverager
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
@@ -35,7 +35,7 @@ class Optimizer(torch.optim.Optimizer):
 
 
     By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
     By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
     There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
     There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
-    or even fully asynchronous (grad_averager=None).
+    or even fully asynchronous (use_local_updates=True).
 
 
     :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
     :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
 
 
@@ -140,10 +140,15 @@ class Optimizer(torch.optim.Optimizer):
       hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
       hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
       gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
       gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
 
 
+    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
+      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
+      Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
+
     :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
     :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
 
 
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param grad_averager: if provided, creates gradient averager with required averaging strategy
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
     :param load_state_compression: compression strategy for loading state from peers, default = no compression
     :param load_state_compression: compression strategy for loading state from peers, default = no compression
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
@@ -180,10 +185,11 @@ class Optimizer(torch.optim.Optimizer):
         delay_grad_averaging: bool = False,
         delay_grad_averaging: bool = False,
         delay_state_averaging: bool = True,
         delay_state_averaging: bool = True,
         average_state_every: int = 1,
         average_state_every: int = 1,
+        use_local_updates: bool = False,
         client_mode: bool = None,
         client_mode: bool = None,
         auxiliary: bool = False,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
         grad_compression: CompressionBase = NoCompression(),
-        grad_averager: Optional[GradientAveragerFactory] = PowerSGDGradientAverager.get_factory(averager_rank=32),
+        grad_averager: Optional[Callable[..., GradientAverager]] = GradientAverager,
         state_averaging_compression: CompressionBase = NoCompression(),
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
         average_opt_statistics: Sequence[str] = (),
@@ -219,13 +225,10 @@ class Optimizer(torch.optim.Optimizer):
                 "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
                 "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
                 "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
                 "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
             )
             )
-        if grad_averager is None:
+        if use_local_updates:
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
-
-        params = list(params) if params is not None else optimizer.param_groups
-        if all(isinstance(p, torch.Tensor) for p in params):
-            params = (dict(params=params),)
+            assert grad_averager is None, "if local_updates is True, provided gradient_averager will not be used"
 
 
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
@@ -244,7 +247,6 @@ class Optimizer(torch.optim.Optimizer):
         self.tracker = self._make_progress_tracker(
         self.tracker = self._make_progress_tracker(
             target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
             target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
         )
         )
-        averaged_grads = None
         self.state_averager = self._make_state_averager(
         self.state_averager = self._make_state_averager(
             optimizer=optimizer,
             optimizer=optimizer,
             params=params,
             params=params,
@@ -257,9 +259,9 @@ class Optimizer(torch.optim.Optimizer):
             extra_tensors=extra_tensors,
             extra_tensors=extra_tensors,
             **averager_opts or {},
             **averager_opts or {},
         )
         )
-        if grad_averager:
+        if grad_averager is not None and not use_local_updates:
             self.grad_averager = self._make_gradient_averager(
             self.grad_averager = self._make_gradient_averager(
-                reuse_grad_buffers=reuse_grad_buffers, grad_averager=grad_averager, averaged_grads=averaged_grads
+                reuse_grad_buffers=reuse_grad_buffers, grad_averager=grad_averager
             )
             )
         else:
         else:
             self.grad_averager = None
             self.grad_averager = None

+ 84 - 60
hivemind/optim/power_sgd_averager.py

@@ -8,7 +8,6 @@ from typing import Any, Iterable, Optional, Sequence
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-import hivemind
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
@@ -37,6 +36,7 @@ from hivemind.utils.asyncio import (
     switch_to_uvloop,
     switch_to_uvloop,
 )
 )
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
+from hivemind.utils.math import orthogonalize_
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 
 
@@ -47,39 +47,80 @@ logger = get_logger(__name__)
 
 
 
 
 class PowerSGDGradientAverager(GradientAverager):
 class PowerSGDGradientAverager(GradientAverager):
+    """
+    A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
+    For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
+    Put simply, this method approximates large gradient tensors (m,n) with a product of two  
+    smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
+    
+    As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
+    High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
+    Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
+    
+    To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
+    if some part of the gradient is "lost in compression", it will be added to the next iteration.
+    This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
+    and (b) if devices stay alive only for one step, training with small rank may converge slower.
+    This is because error feedback takes multiple step to kick in.
+    
+    Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
+    If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
+    normally, i.e. without powerSGD. See min_compression_ratio for details.
+    
+    :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
+     matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
+    
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param averager_rank: compress gradient tensors
+    :param min_comprasion_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor, otherwise aggregate tensors as is
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    """
     def __init__(
     def __init__(
         self,
         self,
         parameters: Iterable[torch.nn.Parameter],
         parameters: Iterable[torch.nn.Parameter],
         averager_rank: int,
         averager_rank: int,
         *,
         *,
-        dht: hivemind.DHT,
+        dht: DHT,
         prefix: str,
         prefix: str,
         reuse_grad_buffers: bool = False,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         client_mode: bool = None,
         warn: bool = True,
         warn: bool = True,
-        min_comprasion_ratio: float = 0.5,
+        min_compression_ratio: float = 0.5,
         averaged_grads: Optional[Sequence[torch.Tensor]] = None,
         averaged_grads: Optional[Sequence[torch.Tensor]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
         self.rank = averager_rank
         self.rank = averager_rank
         self.parameters = tuple(parameters)
         self.parameters = tuple(parameters)
-        self._uncompressed_gradients = set(
+        self._uncompressed_gradients_indexes = set(
             i
             i
             for i, grad in enumerate(self._grads_from_parameters())
             for i, grad in enumerate(self._grads_from_parameters())
             if len(tuple(grad.size())) == 1
             if len(tuple(grad.size())) == 1
             or (
             or (
-                self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) > 1 - min_comprasion_ratio
-            )
+                1 - self.rank * (grad.size(0) + np.prod(grad.size()[1:])) / np.prod(grad.size()) < min_compression_ratio
+            ) # compute how much parameters can we left via factorization
         )
         )
-        self._ms = list(torch.zeros_like(grad, device="cpu").share_memory_() for grad in self._grads_from_parameters())
-        self._qs = list(
-            torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu").share_memory_()
+        self._ms = [
+            torch.zeros_like(grad, device="cpu").share_memory_() 
             for idx, grad in enumerate(self._grads_from_parameters())
             for idx, grad in enumerate(self._grads_from_parameters())
-            if idx not in self._uncompressed_gradients
-        )
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+        self._qs = [
+            torch.rand((np.prod(grad.size()[1:]), self.rank), device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
 
 
-        self.all_reduce_phases = (b".phase1", b".phase2")
+        self.all_reduce_phases = (b".phase_p", b".phase_q")
 
 
         super().__init__(
         super().__init__(
             self.parameters,
             self.parameters,
@@ -123,99 +164,93 @@ class PowerSGDGradientAverager(GradientAverager):
             )
             )
 
 
             async with enter_asynchronously(self.get_tensors()) as averaged_grads:
             async with enter_asynchronously(self.get_tensors()) as averaged_grads:
-                for grad, m in zip(averaged_grads, self._ms):
+                # make this two pairs list for better mapping between m buffers and gradients
+                averaged_grads_via_sgd = [
+                    grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes
+                ]
+                for grad, m in zip(averaged_grads_via_sgd, self._ms):
                     m.add_(grad.to(m.device))
                     m.add_(grad.to(m.device))
 
 
-                averaged_sgd_ms = [m for idx, m in enumerate(self._ms) if idx not in self._uncompressed_gradients]
-                averaged_sgd_grad = [
-                    grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients
-                ]
                 ps = [
                 ps = [
                     torch.zeros((grad.size(0), self.rank), device="cpu")
                     torch.zeros((grad.size(0), self.rank), device="cpu")
-                    for idx, grad in enumerate(averaged_grads)
-                    if idx not in self._uncompressed_gradients
+                    for idx, grad in enumerate(averaged_grad_via_sgd)
                 ]
                 ]
-                for p, q, m in zip(ps, self._qs, averaged_sgd_ms):
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    # we use reshape for all matrixes because sgd works only with 2d tensors
                     torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
                     torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
-                first_all_reduced = ps + [m for idx, m in enumerate(self._ms) if idx in self._uncompressed_gradients]
-                allreduce1 = AllReduceRunner(
+
+                allreduce_p_phase = AllReduceRunner(
                     p2p=self._p2p,
                     p2p=self._p2p,
                     servicer_type=type(self),
                     servicer_type=type(self),
                     prefix=self.prefix,
                     prefix=self.prefix,
                     group_id=group_info.group_id + self.all_reduce_phases[0],
                     group_id=group_info.group_id + self.all_reduce_phases[0],
-                    tensors=first_all_reduced,
+                    tensors=ps,
                     ordered_peer_ids=group_info.peer_ids,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
                     gathered=user_gathered,
                     gathered=user_gathered,
                     modes=modes,
                     modes=modes,
                     **kwargs,
                     **kwargs,
                 )
                 )
-                self._running_groups[group_info.group_id + self.all_reduce_phases[0]].set_result(allreduce1)
+                self._running_groups[group_info.group_id + self.all_reduce_phases[0]].set_result(allreduce_p_phase)
 
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                    async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce1):
+                    async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce_p_phase):
                         # all-reduce is performed asynchronously while iterating
                         # all-reduce is performed asynchronously while iterating
                         tensor.add_(update, alpha=self._averaging_alpha)
                         tensor.add_(update, alpha=self._averaging_alpha)
                 else:
                 else:
-                    async for _ in allreduce1:  # trigger all-reduce by iterating
+                    async for _ in allreduce_p_phase:  # trigger all-reduce by iterating
                         raise ValueError("aux peers should not receive averaged tensors")
                         raise ValueError("aux peers should not receive averaged tensors")
 
 
-                # orth ps
                 for p in ps:
                 for p in ps:
-                    orthogonalize(p)
+                    orthogonalize_(p)
 
 
-                # compute qs
-                for p, q, m in zip(ps, self._qs, averaged_sgd_ms):
+                for p, q, m in zip(ps, self._qs, self._ms):
                     torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
                     torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
 
 
-                allreduce2 = AllReduceRunner(
+                averaged_grad_wo_sgd = [
+                    grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
+                ]
+
+                allreduce_q_phase = AllReduceRunner(
                     p2p=self._p2p,
                     p2p=self._p2p,
                     servicer_type=type(self),
                     servicer_type=type(self),
                     prefix=self.prefix,
                     prefix=self.prefix,
                     group_id=group_info.group_id + self.all_reduce_phases[1],
                     group_id=group_info.group_id + self.all_reduce_phases[1],
-                    tensors=self._qs,
+                    tensors=self._qs + averaged_grad_wo_sgd,
                     ordered_peer_ids=group_info.peer_ids,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
                     gathered=user_gathered,
                     gathered=user_gathered,
                     modes=modes,
                     modes=modes,
                     **kwargs,
                     **kwargs,
                 )
                 )
-                self._running_groups[group_info.group_id + self.all_reduce_phases[1]].set_result(allreduce2)
+                self._running_groups[group_info.group_id + self.all_reduce_phases[1]].set_result(allreduce_q_phase)
 
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                    async for tensor, update in azip(as_aiter(*self._qs), allreduce2):
+                    async for tensor, update in azip(as_aiter(*self._qs), allreduce_q_phase):
                         # all-reduce is performed asynchronously while iterating
                         # all-reduce is performed asynchronously while iterating
                         tensor.add_(update, alpha=self._averaging_alpha)
                         tensor.add_(update, alpha=self._averaging_alpha)
                         self.last_updated = get_dht_time()
                         self.last_updated = get_dht_time()
                         self._state_updated.set()
                         self._state_updated.set()
                 else:
                 else:
-                    async for _ in allreduce2:  # trigger all-reduce by iterating
+                    async for _ in allreduce_q_phase:  # trigger all-reduce by iterating
                         raise ValueError("aux peers should not receive averaged tensors")
                         raise ValueError("aux peers should not receive averaged tensors")
 
 
-                # recompute grads
-                for p, q, m, grad in zip(ps, self._qs, averaged_sgd_ms, averaged_sgd_grad):
-                    new_m = torch.matmul(p, q.t())
-                    m.sub_(new_m.reshape(m.size()))
-                    grad.copy_(new_m.reshape(grad.size()))
-
-                for idx, (m, grad) in enumerate(zip(self._ms, averaged_grads)):
-                    if idx in self._uncompressed_gradients:
-                        grad.copy_(m)
-                        m.data[...] = 0
+                for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grad_via_sgd):
+                    new_m = torch.matmul(p, q.t()).reshape(m.size())
+                    m.sub_(new_m)
+                    grad.copy_(new_m)
 
 
                 return allreduce1.gathered
                 return allreduce1.gathered
         except BaseException as e:
         except BaseException as e:
             logger.exception(e)
             logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
-        finally:
-            pass
 
 
     def get_current_state(self):
     def get_current_state(self):
         with torch.no_grad(), self.lock_averaged_tensors:
         with torch.no_grad(), self.lock_averaged_tensors:
-            grad_averager_buffers = list(q for q in self._qs)
+            grad_averager_buffers = [q for q in self._qs]
             grad_averager_buffers_infos = [
             grad_averager_buffers_infos = [
                 CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
                 CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
-                for buffer, key in zip(grad_averager_buffers, range(len(grad_averager_buffers)))
+                for buffer, key in zip(grad_averager_buffers, enumerate(grad_averager_buffers))
             ]
             ]
 
 
         metadata = dict(group_bits=self.get_group_bits())
         metadata = dict(group_bits=self.get_group_bits())
@@ -236,14 +271,3 @@ class PowerSGDGradientAverager(GradientAverager):
         with torch.no_grad(), self.lock_averaged_tensors:
         with torch.no_grad(), self.lock_averaged_tensors:
             for local_q, loaded_q in zip(self._qs, flat_tensors):
             for local_q, loaded_q in zip(self._qs, flat_tensors):
                 local_q.copy_(loaded_q, non_blocking=True)
                 local_q.copy_(loaded_q, non_blocking=True)
-
-
-@torch.jit.script
-def orthogonalize(matrix, eps=torch.tensor(1e-8)):
-    n, m = matrix.shape
-    for i in range(m):
-        col = matrix[:, i : i + 1]
-        col /= torch.sqrt(torch.sum(col**2)) + eps
-        if i + 1 < m:
-            rest = matrix[:, i + 1 :]
-            rest -= torch.sum(col * rest, dim=0) * col

+ 1 - 0
hivemind/utils/__init__.py

@@ -2,6 +2,7 @@ from hivemind.utils.asyncio import *
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.math import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.networking import *

+ 14 - 0
hivemind/utils/math.py

@@ -0,0 +1,14 @@
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script      
+def orthogonalize_(matrix, eps: float = 1e-8):
+    """Orthogonalize a 2d tensor in-place over the last dimension"""
+    n, m = matrix.shape
+    for i in range(m):
+        col = matrix[:, i]
+        F.normalize(col, dim=0, eps=eps, out=col)
+        if i + 1 < m:
+            rest = matrix[:, i + 1 :]
+            rest.addmm_(col[:, None], (col @ rest)[None, :], alpha=-1)

+ 4 - 4
tests/test_optimizer.py

@@ -2,7 +2,7 @@ import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import time
 import time
 from functools import partial
 from functools import partial
-from typing import Optional
+from typing import Callable, Optional
 
 
 import numpy as np
 import numpy as np
 import pytest
 import pytest
@@ -12,7 +12,7 @@ import torch.nn.functional as F
 
 
 import hivemind
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
+from hivemind.optim.grad_averager import GradientAverager
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.progress_tracker import ProgressTracker
@@ -294,10 +294,10 @@ def test_progress_tracker():
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "grad_averager",
     "grad_averager",
-    [GradientAverager.get_factory(), PowerSGDGradientAverager.get_factory(averager_rank=1)],
+    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
 )
 )
 def test_optimizer(
 def test_optimizer(
-    grad_averager: GradientAveragerFactory,
+    grad_averager: Optional[Callable[..., GradientAverager]],
     num_peers: int = 1,
     num_peers: int = 1,
     num_clients: int = 0,
     num_clients: int = 0,
     target_batch_size: int = 32,
     target_batch_size: int = 32,