瀏覽代碼

Fix shape validation in GradientAverager (#481)

Max Ryabinin 3 年之前
父節點
當前提交
dac8940c32

+ 4 - 4
hivemind/optim/grad_averager.py

@@ -106,11 +106,11 @@ class GradientAverager(DecentralizedAverager):
                     grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
                 )
             else:
-                if all(
-                    params_grad.size() == grad.size()
-                    for param_grad, grad in zip(self._grads_from_parameters(), averaged_grad)
+                if any(
+                    param_grad.size() != grad.size()
+                    for param_grad, grad in zip(self._grads_from_parameters(), averaged_grads)
                 ):
-                    raise ValueError("Averaged gradients doesn't have same shape as gradients from parameters")
+                    raise ValueError("Averaged gradients don't have same shape as gradients from parameters")
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:

+ 0 - 1
hivemind/optim/optimizer.py

@@ -13,7 +13,6 @@ from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
     LRSchedulerBase,

+ 1 - 2
hivemind/optim/power_sgd_averager.py

@@ -1,6 +1,5 @@
 import asyncio
 import contextlib
-import multiprocessing as mp
 from enum import Enum
 from typing import Any, Iterable, Optional, Sequence
 
@@ -9,7 +8,7 @@ import torch
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
-from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
+from hivemind.averaging.matchmaking import MatchmakingException
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.dht import DHT
 from hivemind.optim.grad_averager import GradientAverager

+ 1 - 1
hivemind/optim/state_averager.py

@@ -14,7 +14,7 @@ from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
 

+ 22 - 0
tests/test_optimizer.py

@@ -83,6 +83,28 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
     assert not torch.allclose(model2.w.grad, ref_average)
 
 
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "grad_averager_factory",
+    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
+)
+def test_grad_averager_wrong_shape(grad_averager_factory: GradientAveragerFactory):
+    parameter_shape = (5, 5)
+    model = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    dht = hivemind.DHT(start=True)
+
+    with pytest.raises(ValueError):
+        grad_averager_factory(
+            model.parameters(),
+            dht=dht,
+            prefix="test_fail",
+            target_group_size=2,
+            reuse_grad_buffers=False,
+            start=True,
+            averaged_grads=[torch.zeros(parameter_shape + (1,))],
+        )
+
+
 @pytest.mark.forked
 @pytest.mark.parametrize(
     "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",