xtinkt 4 лет назад
Родитель
Сommit
9e67069c3f
2 измененных файлов с 57 добавлено и 1 удалено
  1. 1 1
      hivemind/optim/averaged.py
  2. 56 0
      tests/test_averaging.py

+ 1 - 1
hivemind/optim/averaged.py

@@ -117,7 +117,7 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
         return step_result
 
     def _sync_if_needed(self):
-        if not self.is_synchronized:
+        if not self.is_synchronized():
             logger.warning("Peer is out of sync.")
             self.load_states_from_peers(**kwargs)
             return

+ 56 - 0
tests/test_averaging.py

@@ -4,6 +4,7 @@ import random
 import numpy as np
 import torch
 import pytest
+import time
 import hivemind
 from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
 from hivemind.client.averaging.load_balancing import load_balance_peers
@@ -423,3 +424,58 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
         assert torch.allclose(x2.grad, grad_avg)
         assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
+
+
+@pytest.mark.forked
+def test_lr_scheduler(n_steps: int = 100, n_dims: int = 16, time_to_wait: int = 0.3):
+    torch.manual_seed(42)
+
+    dht_root = hivemind.DHT(start=True)
+    initial_peers = [f"127.0.0.1:{dht_root.port}"]
+
+    def primitive_lr_cls(opt):
+        lmbda = lambda epoch: 0.95
+        return torch.optim.lr_scheduler.MultiplicativeLR(opt, lmbda, verbose=False)
+
+    sgd_kwargs = {'prefix': 'demo-run', 'target_group_size': 2,
+                  'verbose': True, 'lr': 0.01, 'max_allowed_epoch_difference': 0,
+                  'total_steps_in_epoch': 40, 'scheduler_cls': primitive_lr_cls,
+                  'report_progress_expiration': 60}
+
+    x1 = torch.randn(n_dims, requires_grad=True)
+    sgd1 = hivemind.DecentralizedSGD(
+        [x1],
+        dht=hivemind.DHT(start=True, initial_peers=initial_peers),
+        **sgd_kwargs
+    )
+    x2 = torch.randn(n_dims, requires_grad=True)
+    sgd2 = hivemind.DecentralizedSGD(
+        [x2],
+        dht=hivemind.DHT(start=True, initial_peers=initial_peers),
+        **sgd_kwargs
+    )
+    target = torch.ones(n_dims)
+
+    for i in range(n_steps):
+        sgd1.zero_grad()
+        sgd2.zero_grad()
+        (x1 - target).pow(2).sum().backward()
+        (x2 - target).pow(2).sum().backward()
+        sgd1.step()
+        sgd2.step()
+        time.sleep(time_to_wait)
+    assert sgd1.local_epoch == sgd2.local_epoch
+    assert all([x['lr'] == y['lr'] for x, y in zip(sgd1.opt.param_groups, sgd2.opt.param_groups)])
+
+    x3 = torch.randn(n_dims, requires_grad=True)
+    sgd3 = hivemind.DecentralizedSGD(
+        [x3],
+        dht=hivemind.DHT(start=True, initial_peers=initial_peers),
+        **sgd_kwargs
+    )
+    assert sgd3.local_epoch == sgd2.local_epoch
+    assert sgd3.local_epoch == sgd1.local_epoch
+    assert all([x['lr'] == y['lr'] for x, y in zip(sgd1.opt.param_groups, sgd2.opt.param_groups)])
+
+    sgd1.shutdown()
+    sgd2.shutdown()