Artem Chumachenko 3 年 前
コミット
8e20eb9cfd
3 ファイル変更57 行追加42 行削除
  1. 6 0
      hivemind/optim/grad_averager.py
  2. 33 22
      hivemind/optim/power_sgd_averager.py
  3. 18 20
      tests/test_optimizer.py

+ 6 - 0
hivemind/optim/grad_averager.py

@@ -105,6 +105,12 @@ class GradientAverager(DecentralizedAverager):
                 averaged_grads = tuple(
                     grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
                 )
+            else:
+                if all(
+                    params_grad.size() == grad.size()
+                    for param_grad, grad in zip(self._grads_from_parameters(), averaged_grad)
+                ):
+                    raise ValueError("Averaged gradients doesn't have same shape as gradients from parameters")
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:

+ 33 - 22
hivemind/optim/power_sgd_averager.py

@@ -1,6 +1,7 @@
 import asyncio
 import contextlib
 import multiprocessing as mp
+from enum import Enum
 from typing import Any, Iterable, Optional, Sequence
 
 import torch
@@ -21,6 +22,11 @@ GatheredData = Any
 logger = get_logger(__name__)
 
 
+class AllReducePhases(Enum):
+    PHASE_P = 1
+    PHASE_Q = 2
+
+
 class PowerSGDGradientAverager(GradientAverager):
     """
     A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
@@ -97,8 +103,6 @@ class PowerSGDGradientAverager(GradientAverager):
             if idx not in self._uncompressed_gradients_indexes
         ]
 
-        self.all_reduce_phases = (b".phase_p", b".phase_q")
-
         super().__init__(
             self.parameters,
             dht=dht,
@@ -107,23 +111,23 @@ class PowerSGDGradientAverager(GradientAverager):
             accumulate_grads_on=accumulate_grads_on,
             client_mode=client_mode,
             warn=warn,
-            averaged_grads=None,
+            averaged_grads=averaged_grads,
             **kwargs,
         )
 
     @contextlib.contextmanager
     def _register_allreduce_group(self, group_info: GroupInfo):
-        """registers a given all-reduce runner to listen for incoming connections"""
+        """Register a given group for one or more all-reduce rounds"""
         try:
-            for phase in self.all_reduce_phases:
-                self._running_groups[group_info.group_id + phase] = asyncio.Future()
+            for phase in list(AllReducePhases):
+                self._running_groups[group_info.group_id + phase.name.encode()] = asyncio.Future()
             self._pending_groups_registered.set()
             yield
         finally:
-            for phase in self.all_reduce_phases:
-                maybe_future = self._running_groups.pop(group_info.group_id + phase, None)
+            for phase in list(AllReducePhases):
+                maybe_future = self._running_groups.pop(group_info.group_id + phase.name.encode(), None)
                 if maybe_future and not maybe_future.done():
-                    logger.warning(f"All-reduce group {group_info.group_id + phase} did not finish.")
+                    logger.warning(f"All-reduce group {group_info.group_id + phase.name.encode()} did not finish.")
             self._pending_groups_registered.set()
 
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
@@ -149,17 +153,17 @@ class PowerSGDGradientAverager(GradientAverager):
 
                 ps = [
                     torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
-                    for idx, grad in enumerate(averaged_grad_via_sgd)
+                    for idx, grad in enumerate(averaged_grads_via_sgd)
                 ]
                 for p, q, m in zip(ps, self._qs, self._ms):
                     # we use reshape for all matrixes because PowerSGD works only with 2d tensors
                     torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
 
-                allreduce_p_phase = AllReduceRunner(
+                allreduce_phase_p = AllReduceRunner(
                     p2p=self._p2p,
                     servicer_type=type(self),
                     prefix=self.prefix,
-                    group_id=group_info.group_id + self.all_reduce_phases[0],
+                    group_id=group_info.group_id + AllReducePhases.PHASE_P.name.encode(),
                     tensors=ps,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
@@ -167,14 +171,14 @@ class PowerSGDGradientAverager(GradientAverager):
                     modes=modes,
                     **kwargs,
                 )
-                self._running_groups[group_info.group_id + self.all_reduce_phases[0]].set_result(allreduce_p_phase)
+                self._running_groups[group_info.group_id + AllReducePhases.PHASE_P.name.encode()].set_result(allreduce_phase_p)
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                    async for tensor, update in azip(as_aiter(*first_all_reduced), allreduce_p_phase):
+                    async for tensor, update in azip(as_aiter(*ps), allreduce_phase_p):
                         # all-reduce is performed asynchronously while iterating
                         tensor.add_(update, alpha=self._averaging_alpha)
                 else:
-                    async for _ in allreduce_p_phase:  # trigger all-reduce by iterating
+                    async for _ in allreduce_phase_p:  # trigger all-reduce by iterating
                         raise ValueError("aux peers should not receive averaged tensors")
 
                 for p in ps:
@@ -187,11 +191,11 @@ class PowerSGDGradientAverager(GradientAverager):
                     grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
                 ]
 
-                allreduce_q_phase = AllReduceRunner(
+                allreduce_phase_q = AllReduceRunner(
                     p2p=self._p2p,
                     servicer_type=type(self),
                     prefix=self.prefix,
-                    group_id=group_info.group_id + self.all_reduce_phases[1],
+                    group_id=group_info.group_id + AllReducePhases.PHASE_Q.name.encode(),
                     tensors=self._qs + averaged_grad_wo_sgd,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
@@ -199,28 +203,31 @@ class PowerSGDGradientAverager(GradientAverager):
                     modes=modes,
                     **kwargs,
                 )
-                self._running_groups[group_info.group_id + self.all_reduce_phases[1]].set_result(allreduce_q_phase)
+                self._running_groups[group_info.group_id + AllReducePhases.PHASE_Q.name.encode()].set_result(allreduce_phase_q)
 
                 if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                    async for tensor, update in azip(as_aiter(*(self._qs + averaged_grad_wo_sgd)), allreduce_q_phase):
+                    async for tensor, update in azip(as_aiter(*(self._qs + averaged_grad_wo_sgd)), allreduce_phase_q):
                         tensor.add_(update, alpha=self._averaging_alpha)
                         self.last_updated = get_dht_time()
                         self._state_updated.set()
                 else:
-                    async for _ in allreduce_q_phase:
+                    async for _ in allreduce_phase_q:
                         raise ValueError("aux peers should not receive averaged tensors")
 
-                for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grad_via_sgd):
+                for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grads_via_sgd):
                     new_m = torch.matmul(p, q.t()).reshape(m.size())
                     m.sub_(new_m)
                     grad.copy_(new_m)
 
-                return allreduce1.gathered
+                return allreduce_phase_p.gathered
         except BaseException as e:
             logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
     def get_current_state(self):
+        """
+        Get current gradient averager state and when requested by a newbie peer.
+        """
         with torch.no_grad(), self.lock_averaged_tensors:
             grad_averager_buffers = [q for q in self._qs]
             grad_averager_buffers_infos = [
@@ -232,6 +239,10 @@ class PowerSGDGradientAverager(GradientAverager):
         return metadata, grad_averager_buffers, grad_averager_buffers_infos
 
     def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update gradient averager buffers.
+        :returns: whether or the averager succeeded in loading parameters
+        """
         loaded_state = super().load_state_from_peers(**kwargs)
         if loaded_state is None:
             return

+ 18 - 20
tests/test_optimizer.py

@@ -1,6 +1,7 @@
 import ctypes
 import multiprocessing as mp
 import time
+import sys
 from functools import partial
 from typing import Callable, Optional
 
@@ -21,16 +22,22 @@ from hivemind.utils.crypto import RSAPrivateKey
 
 
 @pytest.mark.forked
-def test_grad_averager():
+@pytest.mark.parametrize(
+    "grad_averager_factory",
+    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
+)
+def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
+    parameter_shape = (5, 5)
+
     dht1 = hivemind.DHT(start=True)
-    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager1 = GradientAverager(
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager1 = grad_averager_factory(
         model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
     )
 
     dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
-    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager2 = GradientAverager(
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager2 = grad_averager_factory(
         model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
     )
 
@@ -40,12 +47,12 @@ def test_grad_averager():
     for i in range(10):
         time.sleep(0.1)
         if i % 3 == 0:
-            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1 = F.mse_loss(model1.w, torch.ones(parameter_shape))
             loss1.backward()
             averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
             model1.zero_grad()
         else:
-            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2 = F.mse_loss(model2.w, -torch.ones(parameter_shape))
             loss2.backward()
             averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
             # note: we do not call zero grad here because reuse_grad_buffers=True
@@ -53,11 +60,11 @@ def test_grad_averager():
     assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
     peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
     assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
-    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    ref_grads1 = torch.full(parameter_shape, -2 / np.prod(parameter_shape) * averager1.local_times_accumulated)
     assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
 
     assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
-    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    ref_grads2 = torch.full(parameter_shape, 2 / np.prod(parameter_shape) * averager2.local_times_accumulated)
     assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
 
     averager1.step(control=control1, wait=False)
@@ -292,12 +299,7 @@ def test_progress_tracker():
 
 
 @pytest.mark.forked
-@pytest.mark.parametrize(
-    "grad_averager_factory",
-    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
-)
 def test_optimizer(
-    grad_averager_factory: GradientAveragerFactory,
     num_peers: int = 1,
     num_clients: int = 0,
     target_batch_size: int = 32,
@@ -316,11 +318,7 @@ def test_optimizer(
 
     def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
         nonlocal optimizer
-        model = nn.Sequential(
-            nn.Linear(5, 5),
-            nn.ReLU(),
-            nn.Linear(5, 1),
-        )
+        model = nn.Sequential(nn.Linear(5, 1))
 
         assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
 
@@ -341,7 +339,7 @@ def test_optimizer(
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
             client_mode=client_mode,
-            grad_averager_factory=grad_averager_factory,
+            grad_averager_factory=GradientAverager,
             verbose=False,
         )
         optimizer.load_state_from_peers()