test_averaging.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. import asyncio
  2. import random
  3. import numpy as np
  4. import torch
  5. import pytest
  6. import time
  7. import hivemind
  8. from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
  9. from hivemind.client.averaging.load_balancing import load_balance_peers
  10. from hivemind.client.averaging.key_manager import GroupKeyManager
  11. from hivemind.utils import Endpoint
  12. @pytest.mark.forked
  13. @pytest.mark.asyncio
  14. async def test_key_manager():
  15. key_manager = GroupKeyManager(hivemind.DHT(start=True), endpoint='localhvost',
  16. prefix='test_averaging', initial_group_bits='10110',
  17. target_group_size=2)
  18. t = hivemind.get_dht_time()
  19. key = key_manager.current_key
  20. await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 60)
  21. await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61)
  22. q1 = await key_manager.get_averagers(key, only_active=True)
  23. await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 66)
  24. q2 = await key_manager.get_averagers(key, only_active=True)
  25. await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61, looking_for_group=False)
  26. q3 = await key_manager.get_averagers(key, only_active=True)
  27. q4 = await key_manager.get_averagers(key, only_active=False)
  28. q5 = await key_manager.get_averagers('nonexistent_key.0b0101', only_active=False)
  29. assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
  30. assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
  31. assert len(q3) == 1 and ('localhvost', t + 66) in q3
  32. assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
  33. assert len(q5) == 0
  34. @pytest.mark.forked
  35. @pytest.mark.parametrize("n_client_mode_peers", [0, 2])
  36. def test_allreduce_once(n_client_mode_peers):
  37. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  38. n_peers = 4
  39. should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
  40. random.shuffle(should_listen)
  41. tensors1 = [torch.randn(123), torch.zeros(3)]
  42. tensors2 = [torch.rand(123), torch.ones(3)]
  43. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  44. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  45. reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
  46. averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
  47. prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
  48. start=True)
  49. for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
  50. futures = []
  51. for averager in averagers:
  52. futures.append(averager.step(wait=False))
  53. for future in futures:
  54. result = future.result()
  55. for averager in averagers:
  56. assert averager.endpoint in result
  57. for averager in averagers:
  58. with averager.get_tensors() as averaged_tensors:
  59. for ref, our in zip(reference, averaged_tensors):
  60. assert torch.allclose(ref, our, atol=1e-6)
  61. for averager in averagers:
  62. averager.shutdown()
  63. dht.shutdown()
  64. @pytest.mark.forked
  65. def test_allreduce_weighted(n_client_mode_peers: int = 2):
  66. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  67. n_peers = 4
  68. should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
  69. random.shuffle(should_listen)
  70. tensors1 = [torch.randn(123), torch.zeros(3)]
  71. tensors2 = [torch.rand(123), torch.ones(3)]
  72. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  73. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  74. averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
  75. prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
  76. start=True)
  77. for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
  78. weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
  79. reference = [(tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2]
  80. + tensors4[i] * weights[3]) / sum(weights) for i in range(len(tensors1))]
  81. futures = []
  82. for averager, weight in zip(averagers, weights):
  83. futures.append(averager.step(weight=weight, wait=False))
  84. for future in futures:
  85. future.result()
  86. for future, averager in zip(futures, averagers):
  87. with averager.get_tensors() as averaged_tensors:
  88. for ref, our in zip(reference, averaged_tensors):
  89. assert torch.allclose(ref, our, atol=1e-6)
  90. for averager in averagers:
  91. averager.shutdown()
  92. dht.shutdown()
  93. def compute_mean_std(averagers, unbiased=True):
  94. results = []
  95. for averager in averagers:
  96. with averager.get_tensors() as tensors:
  97. results.append([tensor.clone() for tensor in tensors])
  98. results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
  99. means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
  100. stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
  101. return means, stds
  102. @pytest.mark.forked
  103. def test_allreduce_grid():
  104. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  105. averagers = [hivemind.DecentralizedAverager(
  106. averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
  107. prefix='mygroup', initial_group_bits=bin(i // 2)[2:].rjust(2, '0'), start=True)
  108. for i in range(8)]
  109. [means0], [stds0] = compute_mean_std(averagers)
  110. assert not torch.allclose(stds0, torch.zeros_like(stds0))
  111. prev_means, prev_stds = means0, stds0
  112. for i in range(5):
  113. step_futures = [averager.step(wait=False) for averager in averagers]
  114. groups = [future.result() for future in step_futures]
  115. [means], [stds] = compute_mean_std(averagers)
  116. assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
  117. assert all(len(group) == 2 for group in groups)
  118. if i <= 2:
  119. assert torch.all(torch.le(stds, prev_stds))
  120. else:
  121. assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
  122. for averager in averagers:
  123. averager.shutdown()
  124. dht.shutdown()
  125. @pytest.mark.forked
  126. def test_allgather():
  127. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  128. averagers = [hivemind.DecentralizedAverager([torch.ones(1)], dht=dht, target_group_size=4, averaging_expiration=15,
  129. prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
  130. start=True)
  131. for _ in range(8)]
  132. futures = []
  133. for i, averager in enumerate(averagers):
  134. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo='bar')))
  135. assert len(set(repr(sorted(future.result())) for future in futures)) == 2
  136. reference_metadata = {averager.endpoint: dict(batch_size=123 + i, foo='bar')
  137. for i, averager in enumerate(averagers)}
  138. for future in futures:
  139. gathered = future.result()
  140. assert len(gathered) == 4
  141. for endpoint in gathered:
  142. assert gathered[endpoint] == reference_metadata[endpoint]
  143. for averager in averagers:
  144. averager.shutdown()
  145. dht.shutdown()
  146. @pytest.mark.forked
  147. @pytest.mark.asyncio
  148. async def test_allreduce_protocol():
  149. """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
  150. peers = "alice", "bob", "carol", "colab"
  151. tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
  152. for i, peer in enumerate(peers)}
  153. group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
  154. allreduce_protocols = [AllReduceProtocol(
  155. group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
  156. ordered_group_endpoints=peers, part_sizes=(150, 200, 67, 0))
  157. for peer in peers]
  158. async def _accumulate(sender: Endpoint, recipient: Endpoint):
  159. sender_allreduce = allreduce_protocols[peers.index(sender)]
  160. recipient_allreduce = allreduce_protocols[peers.index(recipient)]
  161. averaged_part = await recipient_allreduce.accumulate_part(
  162. source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
  163. sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
  164. await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
  165. if recipient != "colab"})
  166. reference_tensors = [
  167. sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
  168. for i in range(len(tensors_by_peer[peers[0]]))
  169. ]
  170. for peer, allreduce in zip(peers, allreduce_protocols):
  171. assert allreduce.future.done()
  172. averaged_tensors = await allreduce
  173. assert len(averaged_tensors) == len(reference_tensors)
  174. assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
  175. for our, ref in zip(averaged_tensors, reference_tensors))
  176. @pytest.mark.forked
  177. def test_partitioning():
  178. for _ in range(100):
  179. tensors = []
  180. for _ in range(random.randint(1, 5)):
  181. ndim = random.randint(0, 4)
  182. shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
  183. make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
  184. tensors.append(make_tensor(shape))
  185. total_size = sum(map(torch.Tensor.numel, tensors))
  186. if total_size == 0:
  187. continue
  188. num_chunks = random.randint(1, min(100, sum(x.numel() for x in tensors)))
  189. part_sizes = load_balance_peers(total_size, [None] * num_chunks)
  190. chunks = split_into_parts(tensors, part_sizes)
  191. assert len(chunks) == num_chunks
  192. shapes = [tensor.shape for tensor in tensors]
  193. restored = restore_from_parts(chunks, shapes)
  194. assert len(restored) == len(tensors)
  195. assert all(new.shape == old.shape for new, old in zip(restored, tensors))
  196. assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
  197. def get_cost(vector_size, partitions, throughputs):
  198. return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
  199. for i in range(len(partitions)))
  200. def check_optimality(vector_size, throughputs, ref_partitions):
  201. partitions = list(load_balance_peers(vector_size, throughputs))
  202. assert get_cost(vector_size, partitions, throughputs) <= get_cost(vector_size, ref_partitions, throughputs)
  203. @pytest.mark.forked
  204. def test_load_balancing():
  205. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  206. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  207. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  208. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  209. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  210. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  211. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  212. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  213. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  214. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  215. assert load_balance_peers(100, (None, None)) == (50, 50)
  216. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  217. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  218. with pytest.raises(AssertionError):
  219. load_balance_peers(100, (0, 0, 0))
  220. for i in range(10):
  221. vector_size = np.random.randint(1, 1024 ** 3)
  222. num_peers = np.random.randint(1, 256)
  223. scale = 1e-9 + np.random.rand() * 1e5
  224. throughputs = np.random.rand(num_peers) * scale + 1e-6
  225. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  226. assignment = load_balance_peers(vector_size, throughputs, min_size)
  227. assert np.sum(assignment) == vector_size
  228. assert np.min(assignment) >= 0
  229. @pytest.mark.forked
  230. def test_too_few_peers():
  231. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  232. averagers = [hivemind.DecentralizedAverager(
  233. averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
  234. averaging_expiration=1, request_timeout=0.5,
  235. prefix='mygroup', initial_group_bits=bin(i)[2:].rjust(3, '0'), start=True)
  236. for i in range(4)]
  237. step_futures = [averager.step(wait=False) for averager in averagers]
  238. for future in step_futures:
  239. assert len(future.result()) == 2
  240. for averager in averagers:
  241. averager.shutdown()
  242. dht.shutdown()
  243. @pytest.mark.forked
  244. def test_overcrowded(num_peers=16):
  245. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  246. averagers = [hivemind.DecentralizedAverager(
  247. averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
  248. averaging_expiration=1, request_timeout=0.5,
  249. prefix='mygroup', initial_group_bits='', start=True)
  250. for _ in range(num_peers)]
  251. for t in range(5):
  252. step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
  253. assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
  254. for averager in averagers:
  255. averager.shutdown()
  256. dht.shutdown()
  257. @pytest.mark.forked
  258. def test_load_state_from_peers():
  259. num_calls = 0
  260. super_metadata = dict(x=123)
  261. super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
  262. class TestAverager(hivemind.DecentralizedAverager):
  263. def get_current_state(self):
  264. """
  265. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  266. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  267. """
  268. nonlocal num_calls, super_metadata, super_tensors
  269. num_calls += 1
  270. return super_metadata, super_tensors
  271. dht_root = hivemind.DHT(start=True)
  272. initial_peers = [f'{hivemind.LOCALHOST}:{dht_root.port}']
  273. dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
  274. averager1 = TestAverager([torch.randn(3), torch.rand(5)],
  275. dht=dht1, start=True,
  276. prefix='demo-run', target_group_size=2)
  277. dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
  278. dht2.get('demo-run.all_averagers')
  279. averager2 = TestAverager([torch.randn(3), torch.rand(5)],
  280. dht=dht2, start=True,
  281. prefix='demo-run', target_group_size=2)
  282. assert num_calls == 0
  283. got_metadata, got_tensors = averager2.load_state_from_peers()
  284. assert num_calls == 1
  285. assert got_metadata == super_metadata
  286. assert all(map(torch.allclose, got_tensors, super_tensors))
  287. super_metadata['y'] = 123
  288. super_tensors[1][2] = 9
  289. assert num_calls == 1
  290. assert got_metadata != super_metadata
  291. assert not all(map(torch.allclose, got_tensors, super_tensors))
  292. got_metadata, got_tensors = averager2.load_state_from_peers()
  293. assert num_calls == 2
  294. assert got_metadata == super_metadata
  295. assert all(map(torch.allclose, got_tensors, super_tensors))
  296. @pytest.mark.forked
  297. def test_getset_bits():
  298. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  299. averager = hivemind.DecentralizedAverager([torch.randn(3)], dht=dht, start=True,
  300. prefix='test_prefix', target_group_size=2)
  301. averager.set_group_bits('00101011101010')
  302. assert averager.get_group_bits() == '00101011101010'
  303. @pytest.mark.forked
  304. def test_training_averager(n_steps: int = 10, n_dims: int = 16):
  305. torch.manual_seed(42)
  306. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  307. common_kwargs = {'dht': dht, 'start': True, 'listen_on': '127.0.0.1:*',
  308. 'prefix': 'demo-run', 'target_group_size': 2}
  309. x1 = torch.randn(n_dims, requires_grad=True)
  310. opt1 = torch.optim.Adam([x1], lr=0.05)
  311. averager1 = hivemind.client.TrainingAverager(opt1, average_gradients=True, average_parameters=True,
  312. average_opt_statistics=["exp_avg_sq"], **common_kwargs)
  313. x2 = torch.randn(n_dims, requires_grad=True)
  314. opt2 = torch.optim.Adam([x2], lr=0.05)
  315. averager2 = hivemind.client.TrainingAverager(opt2, average_gradients=True, average_parameters=True,
  316. average_opt_statistics=["exp_avg_sq"], **common_kwargs)
  317. a = torch.ones(n_dims)
  318. for i in range(n_steps):
  319. opt1.zero_grad()
  320. opt2.zero_grad()
  321. (x1 - a).pow(2).sum().backward()
  322. (x2 - a).pow(2).sum().backward()
  323. opt1.step()
  324. opt2.step()
  325. with torch.no_grad():
  326. x_avg = 0.5 * (x1 + x2)
  327. grad_avg = 0.5 * (x1.grad + x2.grad)
  328. stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
  329. # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
  330. f1 = averager1.step(wait=False)
  331. f2 = averager2.step(wait=False)
  332. f1.result()
  333. f2.result()
  334. assert torch.allclose(x1, x_avg)
  335. assert torch.allclose(x2, x_avg)
  336. assert torch.allclose(x1.grad, grad_avg)
  337. assert torch.allclose(x2.grad, grad_avg)
  338. assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
  339. assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
  340. @pytest.mark.forked
  341. def test_lr_scheduler(n_steps: int = 100, n_dims: int = 16, time_to_wait: int = 0.3):
  342. torch.manual_seed(42)
  343. dht_root = hivemind.DHT(start=True)
  344. initial_peers = [f"127.0.0.1:{dht_root.port}"]
  345. def primitive_lr_cls(opt):
  346. lmbda = lambda epoch: 0.95
  347. return torch.optim.lr_scheduler.MultiplicativeLR(opt, lmbda, verbose=False)
  348. sgd_kwargs = {'prefix': 'demo-run', 'target_group_size': 2,
  349. 'verbose': True, 'lr': 0.01, 'max_allowed_epoch_difference': 0,
  350. 'total_steps_in_epoch': 40, 'scheduler_cls': primitive_lr_cls,
  351. 'report_progress_expiration': 60}
  352. x1 = torch.randn(n_dims, requires_grad=True)
  353. sgd1 = hivemind.DecentralizedSGD(
  354. [x1],
  355. dht=hivemind.DHT(start=True, initial_peers=initial_peers),
  356. **sgd_kwargs
  357. )
  358. x2 = torch.randn(n_dims, requires_grad=True)
  359. sgd2 = hivemind.DecentralizedSGD(
  360. [x2],
  361. dht=hivemind.DHT(start=True, initial_peers=initial_peers),
  362. **sgd_kwargs
  363. )
  364. target = torch.ones(n_dims)
  365. for i in range(n_steps):
  366. sgd1.zero_grad()
  367. sgd2.zero_grad()
  368. (x1 - target).pow(2).sum().backward()
  369. (x2 - target).pow(2).sum().backward()
  370. sgd1.step()
  371. sgd2.step()
  372. time.sleep(time_to_wait)
  373. assert sgd1.local_epoch == sgd2.local_epoch
  374. assert all([x['lr'] == y['lr'] for x, y in zip(sgd1.opt.param_groups, sgd2.opt.param_groups)])
  375. x3 = torch.randn(n_dims, requires_grad=True)
  376. sgd3 = hivemind.DecentralizedSGD(
  377. [x3],
  378. dht=hivemind.DHT(start=True, initial_peers=initial_peers),
  379. **sgd_kwargs
  380. )
  381. assert sgd3.local_epoch == sgd2.local_epoch
  382. assert all([x['lr'] == y['lr'] for x, y in zip(sgd2.opt.param_groups, sgd3.opt.param_groups)])
  383. sgd1.shutdown()
  384. sgd2.shutdown()
  385. sgd3.shutdown()