|
@@ -4,6 +4,7 @@ import random
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import pytest
|
|
|
+import time
|
|
|
import hivemind
|
|
|
from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
|
|
|
from hivemind.client.averaging.load_balancing import load_balance_peers
|
|
@@ -423,3 +424,58 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
|
|
|
assert torch.allclose(x2.grad, grad_avg)
|
|
|
assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
|
|
|
assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.forked
|
|
|
+def test_lr_scheduler(n_steps: int = 100, n_dims: int = 16, time_to_wait: int = 0.3):
|
|
|
+ torch.manual_seed(42)
|
|
|
+
|
|
|
+ dht_root = hivemind.DHT(start=True)
|
|
|
+ initial_peers = [f"127.0.0.1:{dht_root.port}"]
|
|
|
+
|
|
|
+ def primitive_lr_cls(opt):
|
|
|
+ lmbda = lambda epoch: 0.95
|
|
|
+ return torch.optim.lr_scheduler.MultiplicativeLR(opt, lmbda, verbose=False)
|
|
|
+
|
|
|
+ sgd_kwargs = {'prefix': 'demo-run', 'target_group_size': 2,
|
|
|
+ 'verbose': True, 'lr': 0.01, 'max_allowed_epoch_difference': 0,
|
|
|
+ 'total_steps_in_epoch': 40, 'scheduler_cls': primitive_lr_cls,
|
|
|
+ 'report_progress_expiration': 60}
|
|
|
+
|
|
|
+ x1 = torch.randn(n_dims, requires_grad=True)
|
|
|
+ sgd1 = hivemind.DecentralizedSGD(
|
|
|
+ [x1],
|
|
|
+ dht=hivemind.DHT(start=True, initial_peers=initial_peers),
|
|
|
+ **sgd_kwargs
|
|
|
+ )
|
|
|
+ x2 = torch.randn(n_dims, requires_grad=True)
|
|
|
+ sgd2 = hivemind.DecentralizedSGD(
|
|
|
+ [x2],
|
|
|
+ dht=hivemind.DHT(start=True, initial_peers=initial_peers),
|
|
|
+ **sgd_kwargs
|
|
|
+ )
|
|
|
+ target = torch.ones(n_dims)
|
|
|
+
|
|
|
+ for i in range(n_steps):
|
|
|
+ sgd1.zero_grad()
|
|
|
+ sgd2.zero_grad()
|
|
|
+ (x1 - target).pow(2).sum().backward()
|
|
|
+ (x2 - target).pow(2).sum().backward()
|
|
|
+ sgd1.step()
|
|
|
+ sgd2.step()
|
|
|
+ time.sleep(time_to_wait)
|
|
|
+ assert sgd1.local_epoch == sgd2.local_epoch
|
|
|
+ assert all([x['lr'] == y['lr'] for x, y in zip(sgd1.opt.param_groups, sgd2.opt.param_groups)])
|
|
|
+
|
|
|
+ x3 = torch.randn(n_dims, requires_grad=True)
|
|
|
+ sgd3 = hivemind.DecentralizedSGD(
|
|
|
+ [x3],
|
|
|
+ dht=hivemind.DHT(start=True, initial_peers=initial_peers),
|
|
|
+ **sgd_kwargs
|
|
|
+ )
|
|
|
+ assert sgd3.local_epoch == sgd2.local_epoch
|
|
|
+ assert sgd3.local_epoch == sgd1.local_epoch
|
|
|
+ assert all([x['lr'] == y['lr'] for x, y in zip(sgd1.opt.param_groups, sgd2.opt.param_groups)])
|
|
|
+
|
|
|
+ sgd1.shutdown()
|
|
|
+ sgd2.shutdown()
|