Quellcode durchsuchen

style and fixes

Artem Chumachenko vor 3 Jahren
Ursprung
Commit
b36f5643a8

+ 1 - 1
hivemind/averaging/averager.py

@@ -549,7 +549,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     **kwargs,
                 )
                 self._running_groups[group_info.group_id].set_result(allreduce)
-                # ^--- maybe this can be extracted into a method that checks if register_... context is active.
+                # TODO maybe this can be extracted into a method that checks if register_... context is active.
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                     iter_results = allreduce.run()

+ 7 - 3
hivemind/optim/grad_averager.py

@@ -1,16 +1,20 @@
 import contextlib
-from typing import Any, Callable, Iterable, Iterator, Optional, Sequence, Type, TypeVar, Union
+from typing import Any, Callable, Iterable, Iterator, Optional, Sequence, TypeVar
 
 import torch
 
-import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
+from hivemind.dht import DHT
 from hivemind.utils import DHTExpiration, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 
 
+TGradientAverager = TypeVar("TGradientAverager", bound="GradientAverager")
+GradientAveragerFactory = Callable[..., TGradientAverager]
+
+
 class GradientAverager(DecentralizedAverager):
     """
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
@@ -70,7 +74,7 @@ class GradientAverager(DecentralizedAverager):
         self,
         parameters: Iterable[torch.nn.Parameter],
         *,
-        dht: hivemind.DHT,
+        dht: DHT,
         prefix: str,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,

+ 10 - 8
hivemind/optim/optimizer.py

@@ -11,7 +11,7 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
@@ -189,7 +189,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode: bool = None,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
-        grad_averager: Optional[Callable[..., GradientAverager]] = GradientAverager,
+        grad_averager_factory: Optional[GradientAveragerFactory] = GradientAverager,
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
@@ -228,7 +228,9 @@ class Optimizer(torch.optim.Optimizer):
         if use_local_updates:
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
-            assert grad_averager is None, "if local_updates is True, provided gradient_averager will not be used"
+            assert (
+                grad_averager_factory is None
+            ), "if local_updates is True, provided gradient_averager will not be used"
 
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
@@ -259,9 +261,9 @@ class Optimizer(torch.optim.Optimizer):
             extra_tensors=extra_tensors,
             **averager_opts or {},
         )
-        if grad_averager is not None and not use_local_updates:
+        if grad_averager_factory is not None and not use_local_updates:
             self.grad_averager = self._make_gradient_averager(
-                reuse_grad_buffers=reuse_grad_buffers, grad_averager=grad_averager
+                reuse_grad_buffers=reuse_grad_buffers, grad_averager_factory=grad_averager_factory
             )
         else:
             self.grad_averager = None
@@ -294,9 +296,9 @@ class Optimizer(torch.optim.Optimizer):
             **kwargs,
         )
 
-    def _make_gradient_averager(self, grad_averager, **kwargs) -> GradientAverager:
+    def _make_gradient_averager(self, grad_averager_factory, **kwargs) -> GradientAverager:
         assert hasattr(self, "state_averager"), "must initialize state averager first"
-        grad_averager = grad_averager(
+        grad_averager = grad_averager_factory(
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
@@ -409,7 +411,7 @@ class Optimizer(torch.optim.Optimizer):
             self._maybe_schedule_state_averaging()
 
         else:
-            # grad_averager=None: update parameters on every step independently of other peers
+            # use_local_updates=True: update parameters on every step independently of other peers
             if not self.auxiliary:
                 if grad_scaler is not None:
                     with grad_scaler.running_global_step():

+ 7 - 14
hivemind/optim/power_sgd_averager.py

@@ -80,14 +80,9 @@ class PowerSGDGradientAverager(GradientAverager):
         self._uncompressed_gradients_indexes = set(
             i
             for i, grad in enumerate(self._grads_from_parameters())
-            if len(tuple(grad.size())) <= 1
-            or (
-                1
-                - self.rank
-                * sum(get_flatten_greedy_dims(grad))
-                / (get_flatten_greedy_dims(grad)[0] * get_flatten_greedy_dims(grad)[1])
-                < min_compression_ratio
-            )  # compute how much parameters can we left via factorization
+            if grad.ndim <= 1
+            or (1 - self.rank * sum(get_flatten_greedy_dims(grad)) / grad.numel()) < min_compression_ratio
+            # compute how much parameters can we left via factorization
         )
         self._ms = [
             torch.zeros_like(grad, device="cpu").share_memory_()
@@ -144,7 +139,6 @@ class PowerSGDGradientAverager(GradientAverager):
             )
 
             async with enter_asynchronously(self.get_tensors()) as averaged_grads:
-                # make this two pairs list for better mapping between m buffers and gradients
                 averaged_grads_via_sgd = [
                     grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes
                 ]
@@ -156,7 +150,7 @@ class PowerSGDGradientAverager(GradientAverager):
                     for idx, grad in enumerate(averaged_grad_via_sgd)
                 ]
                 for p, q, m in zip(ps, self._qs, self._ms):
-                    # we use reshape for all matrixes because sgd works only with 2d tensors
+                    # we use reshape for all matrixes because PowerSGD works only with 2d tensors
                     torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
 
                 allreduce_p_phase = AllReduceRunner(
@@ -206,13 +200,12 @@ class PowerSGDGradientAverager(GradientAverager):
                 self._running_groups[group_info.group_id + self.all_reduce_phases[1]].set_result(allreduce_q_phase)
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                    async for tensor, update in azip(as_aiter(*self._qs), allreduce_q_phase):
-                        # all-reduce is performed asynchronously while iterating
+                    async for tensor, update in azip(as_aiter(*(self._qs + averaged_grad_wo_sgd)), allreduce_q_phase):
                         tensor.add_(update, alpha=self._averaging_alpha)
                         self.last_updated = get_dht_time()
                         self._state_updated.set()
                 else:
-                    async for _ in allreduce_q_phase:  # trigger all-reduce by iterating
+                    async for _ in allreduce_q_phase:
                         raise ValueError("aux peers should not receive averaged tensors")
 
                 for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grad_via_sgd):
@@ -245,7 +238,7 @@ class PowerSGDGradientAverager(GradientAverager):
         logger.info("Starting loading gradient averager buffers from peers")
 
         if len(flat_tensors) != len(self._qs):
-            logger.error("Failed to load state from peer, received parameters, extras or metadata")
+            logger.error("Failed to load state from peer, received invalid parameters, extras or metadata")
             return
 
         with torch.no_grad(), self.lock_averaged_tensors:

+ 0 - 1
hivemind/utils/__init__.py

@@ -2,7 +2,6 @@ from hivemind.utils.asyncio import *
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.math import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *

+ 4 - 4
tests/test_optimizer.py

@@ -12,7 +12,7 @@ import torch.nn.functional as F
 
 import hivemind
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import ProgressTracker
@@ -293,11 +293,11 @@ def test_progress_tracker():
 
 @pytest.mark.forked
 @pytest.mark.parametrize(
-    "grad_averager",
+    "grad_averager_factory",
     [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
 )
 def test_optimizer(
-    grad_averager: Optional[Callable[..., GradientAverager]],
+    grad_averager_factory: GradientAveragerFactory,
     num_peers: int = 1,
     num_clients: int = 0,
     target_batch_size: int = 32,
@@ -341,7 +341,7 @@ def test_optimizer(
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
             client_mode=client_mode,
-            grad_averager=grad_averager,
+            grad_averager_factory=grad_averager_factory,
             verbose=False,
         )
         optimizer.load_state_from_peers()