Переглянути джерело

Statistics averaging (#229)

* Add statistics averaging feature
* Refactor: move hivemind.client.optim to hivemind.optim
Roman Zhytar 4 роки тому
батько
коміт
8c3bd93e87

+ 1 - 0
hivemind/__init__.py

@@ -2,5 +2,6 @@ from hivemind.client import *
 from hivemind.dht import *
 from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
+from hivemind.optim import *
 
 
 __version__ = '0.9.7'
 __version__ = '0.9.7'

+ 1 - 1
hivemind/client/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.moe import RemoteMixtureOfExperts
 from hivemind.client.moe import RemoteMixtureOfExperts
 from hivemind.client.averaging import DecentralizedAverager
 from hivemind.client.averaging import DecentralizedAverager
-from hivemind.client.optim import ParameterAveragingOptimizer, DecentralizedSGD, CollaborativeOptimizer
+from hivemind.client.averaging.training import TrainingAverager

+ 9 - 2
hivemind/client/averaging/training.py

@@ -23,6 +23,7 @@ class TrainingAverager(DecentralizedAverager):
     :param opt: a pytorch optimizer to be averaged between peers (complete with model parameters)
     :param opt: a pytorch optimizer to be averaged between peers (complete with model parameters)
     :param average_parameters: whether or not to average model parameters in self.step(...)
     :param average_parameters: whether or not to average model parameters in self.step(...)
     :param average_gradients: whether or not to average model gradients in self.step(...)
     :param average_gradients: whether or not to average model gradients in self.step(...)
+    :param average_opt_statistics: if specified, average optimizer statistics with corresponding names in statedict
     :param initialize_optimizer: if True, this will run a speculative optimizer step with
     :param initialize_optimizer: if True, this will run a speculative optimizer step with
       zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
       zero gradients to initialize all tensors. If False, please initialize the optimizer state manually.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
@@ -30,9 +31,11 @@ class TrainingAverager(DecentralizedAverager):
     :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
     :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
     """
     """
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
-                 extra_tensors: Sequence[torch.Tensor] = (), initialize_optimizer: bool = True, **kwargs):
+                 average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
+                 initialize_optimizer: bool = True, **kwargs):
 
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
+        self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
         self.lock_averager_step = Lock()
         self.lock_averager_step = Lock()
         if initialize_optimizer:
         if initialize_optimizer:
@@ -46,7 +49,7 @@ class TrainingAverager(DecentralizedAverager):
     def step(self, wait: bool = True, **kwargs):
     def step(self, wait: bool = True, **kwargs):
         """ Average optimizer weights and gradients with peers. """
         """ Average optimizer weights and gradients with peers. """
         if not wait:
         if not wait:
-            return run_in_background(self.step, wait=False, **kwargs)
+            return run_in_background(self.step, wait=True, **kwargs)
 
 
         local_tensors = list(self.local_tensors())
         local_tensors = list(self.local_tensors())
         with self.lock_averager_step:
         with self.lock_averager_step:
@@ -85,6 +88,10 @@ class TrainingAverager(DecentralizedAverager):
                         yield param.grad
                         yield param.grad
                     elif replace_none:
                     elif replace_none:
                         yield torch.zeros_like(param)
                         yield torch.zeros_like(param)
+        for stats in self.opt_statistics:
+            for param_group in self.opt.param_groups:
+                for param in param_group['params']:
+                    yield self.opt.state[param][stats]
         yield from iter(self.extra_tensors)
         yield from iter(self.extra_tensors)
 
 
     def get_current_state(self):
     def get_current_state(self):

+ 0 - 2
hivemind/client/optim/__init__.py

@@ -1,2 +0,0 @@
-from hivemind.client.optim.simple import ParameterAveragingOptimizer, DecentralizedSGD
-from hivemind.client.optim.collaborative import CollaborativeOptimizer

+ 4 - 0
hivemind/optim/__init__.py

@@ -0,0 +1,4 @@
+from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.optim.simple import DecentralizedSGD, ParameterAveragingOptimizer

+ 0 - 0
hivemind/client/optim/base.py → hivemind/optim/base.py


+ 2 - 2
hivemind/client/optim/collaborative.py → hivemind/optim/collaborative.py

@@ -8,10 +8,10 @@ import torch
 import numpy as np
 import numpy as np
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.client.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.client.averaging.training import TrainingAverager
 from hivemind.client.averaging.training import TrainingAverager
 from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
 from hivemind.utils import get_logger, get_dht_time, ValueWithExpiration
-from hivemind.client.optim.performance_ema import PerformanceEMA
+from hivemind.optim.performance_ema import PerformanceEMA
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)

+ 0 - 0
hivemind/client/optim/performance_ema.py → hivemind/optim/performance_ema.py


+ 1 - 1
hivemind/client/optim/simple.py → hivemind/optim/simple.py

@@ -6,7 +6,7 @@ import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.client.averaging import DecentralizedAverager
 from hivemind.client.averaging import DecentralizedAverager
-from hivemind.client.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.utils import get_logger, get_dht_time
 from hivemind.utils import get_logger, get_dht_time
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 46 - 0
tests/test_averaging.py

@@ -377,3 +377,49 @@ def test_getset_bits():
                                               prefix='test_prefix', target_group_size=2)
                                               prefix='test_prefix', target_group_size=2)
     averager.set_group_bits('00101011101010')
     averager.set_group_bits('00101011101010')
     assert averager.get_group_bits() == '00101011101010'
     assert averager.get_group_bits() == '00101011101010'
+
+
+@pytest.mark.forked
+def test_training_averager(n_steps: int = 10, n_dims: int = 16):
+    torch.manual_seed(42)
+
+    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    common_kwargs = {'dht': dht, 'start': True, 'listen_on': '127.0.0.1:*',
+                     'prefix': 'demo-run', 'target_group_size': 2}
+
+    x1 = torch.randn(n_dims, requires_grad=True)
+    opt1 = torch.optim.Adam([x1], lr=0.05)
+    averager1 = hivemind.client.TrainingAverager(opt1, average_gradients=True, average_parameters=True,
+                                                 average_opt_statistics=["exp_avg_sq"], **common_kwargs)
+
+    x2 = torch.randn(n_dims, requires_grad=True)
+    opt2 = torch.optim.Adam([x2], lr=0.05)
+    averager2 = hivemind.client.TrainingAverager(opt2, average_gradients=True, average_parameters=True,
+                                                 average_opt_statistics=["exp_avg_sq"], **common_kwargs)
+    a = torch.ones(n_dims)
+
+    for i in range(n_steps):
+        opt1.zero_grad()
+        opt2.zero_grad()
+        (x1 - a).pow(2).sum().backward()
+        (x2 - a).pow(2).sum().backward()
+        opt1.step()
+        opt2.step()
+
+        with torch.no_grad():
+            x_avg = 0.5 * (x1 + x2)
+            grad_avg = 0.5 * (x1.grad + x2.grad)
+            stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
+
+        # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
+        f1 = averager1.step(wait=False)
+        f2 = averager2.step(wait=False)
+        f1.result()
+        f2.result()
+
+        assert torch.allclose(x1, x_avg)
+        assert torch.allclose(x2, x_avg)
+        assert torch.allclose(x1.grad, grad_avg)
+        assert torch.allclose(x2.grad, grad_avg)
+        assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
+        assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)