Explorar o código

Statistics averaging (#229)

* Add statistics averaging feature
* Refactor: move hivemind.client.optim to hivemind.optim
Roman Zhytar %!s(int64=4) %!d(string=hai) anos
pai
achega
8c3bd93e87

+ 1 - 0
hivemind/__init__.py

@@ -2,5 +2,6 @@ from hivemind.client import *
 from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
+from hivemind.optim import *
 
 __version__ = '0.9.7'

+ 1 - 1
hivemind/client/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.moe import RemoteMixtureOfExperts
 from hivemind.client.averaging import DecentralizedAverager
-from hivemind.client.optim import ParameterAveragingOptimizer, DecentralizedSGD, CollaborativeOptimizer
+from hivemind.client.averaging.training import TrainingAverager

+ 9 - 2
hivemind/client/averaging/training.py

@@ -23,6 +23,7 @@ class TrainingAverager(DecentralizedAverager):
     :param opt: a pytorch optimizer to be averaged between peers (complete with model parameters)
     :param average_parameters: whether or not to average model parameters in self.step(...)
     :param average_gradients: whether or not to average model gradients in self.step(...)
+    :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
     :param initialize_optimizer: if True, this will run a speculative optimizer step with
       zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
@@ -30,9 +31,11 @@ class TrainingAverager(DecentralizedAverager):
     :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
     """
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
-                 extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
+                 average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
+                 initialize_optimizer: bool = True, **kwargs):
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
+        self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
         self.lock_averager_step = Lock()
         if initialize_optimizer:
@@ -46,7 +49,7 @@ class TrainingAverager(DecentralizedAverager):
     def step(self, wait: bool = True, **kwargs):
         """ Average optimizer weights and gradients with peers. """
         if not wait:
-            return run_in_background(self.step, wait=False, **kwargs)
+            return run_in_background(self.step, wait=True, **kwargs)
 
         local_tensors = list(self.local_tensors())
         with self.lock_averager_step:
@@ -85,6 +88,10 @@ class TrainingAverager(DecentralizedAverager):
                         yield param.grad
                     elif replace_none:
                         yield torch.zeros_like(param)
+        for stats in self.opt_statistics:
+            for param_group in self.opt.param_groups:
+                for param in param_group['params']:
+                    yield self.opt.state[param][stats]
         yield from iter(self.extra_tensors)
 
     def get_current_state(self):

+ 0 - 2
hivemind/client/optim/__init__.py

@@ -1,2 +0,0 @@
-from hivemind.client.optim.simple import ParameterAveragingOptimizer, DecentralizedSGD
-from hivemind.client.optim.collaborative import CollaborativeOptimizer

+ 4 - 0
hivemind/optim/__init__.py

@@ -0,0 +1,4 @@
+from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.optim.simple import DecentralizedSGD, ParameterAveragingOptimizer

+ 0 - 0
hivemind/client/optim/base.py → hivemind/optim/base.py


+ 2 - 2
hivemind/client/optim/collaborative.py → hivemind/optim/collaborative.py

@@ -8,10 +8,10 @@ import torch
 import numpy as np
 
 from hivemind.dht import DHT
-from hivemind.client.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.client.averaging.training import TrainingAverager
 from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
-from hivemind.client.optim.performance_ema import PerformanceEMA
+from hivemind.optim.performance_ema import PerformanceEMA
 
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)

+ 0 - 0
hivemind/client/optim/performance_ema.py → hivemind/optim/performance_ema.py


+ 1 - 1
hivemind/client/optim/simple.py → hivemind/optim/simple.py

@@ -6,7 +6,7 @@ import torch
 
 from hivemind.dht import DHT
 from hivemind.client.averaging import DecentralizedAverager
-from hivemind.client.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.utils import get_logger, get_dht_time
 
 logger = get_logger(__name__)

+ 46 - 0
tests/test_averaging.py

@@ -377,3 +377,49 @@ def test_getset_bits():
                                               prefix='test_prefix', target_group_size=2)
     averager.set_group_bits('00101011101010')
     assert averager.get_group_bits() == '00101011101010'
+
+
+@pytest.mark.forked
+def test_training_averager(n_steps: int = 10, n_dims: int = 16):
+    torch.manual_seed(42)
+
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    common_kwargs = {'dht': dht, 'start': True, 'listen_on': '127.0.0.1:*',
+                     'prefix': 'demo-run', 'target_group_size': 2}
+
+    x1 = torch.randn(n_dims, requires_grad=True)
+    opt1 = torch.optim.Adam([x1], lr=0.05)
+    averager1 = hivemind.client.TrainingAverager(opt1, average_gradients=True, average_parameters=True,
+                                                 average_opt_statistics=["exp_avg_sq"], **common_kwargs)
+
+    x2 = torch.randn(n_dims, requires_grad=True)
+    opt2 = torch.optim.Adam([x2], lr=0.05)
+    averager2 = hivemind.client.TrainingAverager(opt2, average_gradients=True, average_parameters=True,
+                                                 average_opt_statistics=["exp_avg_sq"], **common_kwargs)
+    a = torch.ones(n_dims)
+
+    for i in range(n_steps):
+        opt1.zero_grad()
+        opt2.zero_grad()
+        (x1 - a).pow(2).sum().backward()
+        (x2 - a).pow(2).sum().backward()
+        opt1.step()
+        opt2.step()
+
+        with torch.no_grad():
+            x_avg = 0.5 * (x1 + x2)
+            grad_avg = 0.5 * (x1.grad + x2.grad)
+            stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
+
+        # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
+        f1 = averager1.step(wait=False)
+        f2 = averager2.step(wait=False)
+        f1.result()
+        f2.result()
+
+        assert torch.allclose(x1, x_avg)
+        assert torch.allclose(x2, x_avg)
+        assert torch.allclose(x1.grad, grad_avg)
+        assert torch.allclose(x2.grad, grad_avg)
+        assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
+        assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)