test_averaging.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import random
  2. import numpy as np
  3. import torch
  4. import pytest
  5. import hivemind
  6. from hivemind.client.averaging.allreduce import AveragingMode
  7. from hivemind.client.averaging.load_balancing import load_balance_peers
  8. from hivemind.client.averaging.key_manager import GroupKeyManager
  9. from hivemind.proto.runtime_pb2 import CompressionType
  10. @pytest.mark.forked
  11. @pytest.mark.asyncio
  12. async def test_key_manager():
  13. key_manager = GroupKeyManager(hivemind.DHT(start=True), endpoint='localhvost',
  14. prefix='test_averaging', initial_group_bits='10110',
  15. target_group_size=2)
  16. t = hivemind.get_dht_time()
  17. key = key_manager.current_key
  18. await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 60)
  19. await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61)
  20. q1 = await key_manager.get_averagers(key, only_active=True)
  21. await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 66)
  22. q2 = await key_manager.get_averagers(key, only_active=True)
  23. await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61, looking_for_group=False)
  24. q3 = await key_manager.get_averagers(key, only_active=True)
  25. q4 = await key_manager.get_averagers(key, only_active=False)
  26. q5 = await key_manager.get_averagers('nonexistent_key.0b0101', only_active=False)
  27. assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
  28. assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
  29. assert len(q3) == 1 and ('localhvost', t + 66) in q3
  30. assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
  31. assert len(q5) == 0
  32. def _test_allreduce_once(n_clients, n_aux):
  33. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  34. n_peers = 4
  35. modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
  36. random.shuffle(modes)
  37. tensors1 = [torch.randn(123), torch.zeros(3)]
  38. tensors2 = [torch.rand(123), torch.ones(3)]
  39. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  40. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  41. peer_tensors = [tensors1, tensors2, tensors3, tensors4]
  42. reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
  43. if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
  44. averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
  45. prefix='mygroup', listen=mode != AveragingMode.CLIENT, listen_on='127.0.0.1:*',
  46. auxiliary=mode == AveragingMode.AUX, start=True)
  47. for tensors, mode in zip(peer_tensors, modes)]
  48. futures = []
  49. for averager in averagers:
  50. futures.append(averager.step(wait=False))
  51. for future in futures:
  52. result = future.result()
  53. for averager in averagers:
  54. assert averager.endpoint in result
  55. for averager in averagers:
  56. if averager.mode != AveragingMode.AUX:
  57. with averager.get_tensors() as averaged_tensors:
  58. for ref, our in zip(reference, averaged_tensors):
  59. assert torch.allclose(ref, our, atol=1e-6)
  60. for averager in averagers:
  61. averager.shutdown()
  62. dht.shutdown()
  63. @pytest.mark.forked
  64. @pytest.mark.parametrize("n_clients", [0, 1, 2])
  65. @pytest.mark.parametrize("n_aux", [0, 1, 2])
  66. def test_allreduce_once(n_clients, n_aux):
  67. _test_allreduce_once(n_clients, n_aux)
  68. @pytest.mark.forked
  69. @pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
  70. def test_allreduce_once_edge_cases(n_clients, n_aux):
  71. _test_allreduce_once(n_clients, n_aux)
  72. @pytest.mark.forked
  73. def test_allreduce_weighted(n_client_mode_peers: int = 2):
  74. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  75. n_peers = 4
  76. should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
  77. random.shuffle(should_listen)
  78. tensors1 = [torch.randn(123), torch.zeros(3)]
  79. tensors2 = [torch.rand(123), torch.ones(3)]
  80. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  81. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  82. averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
  83. prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
  84. start=True)
  85. for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
  86. weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
  87. reference = [(tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2]
  88. + tensors4[i] * weights[3]) / sum(weights) for i in range(len(tensors1))]
  89. futures = []
  90. for averager, weight in zip(averagers, weights):
  91. futures.append(averager.step(weight=weight, wait=False))
  92. for future in futures:
  93. future.result()
  94. for future, averager in zip(futures, averagers):
  95. with averager.get_tensors() as averaged_tensors:
  96. for ref, our in zip(reference, averaged_tensors):
  97. assert torch.allclose(ref, our, atol=1e-6)
  98. for averager in averagers:
  99. averager.shutdown()
  100. dht.shutdown()
  101. @pytest.mark.forked
  102. def test_allreduce_compression():
  103. """ this test ensures that compression works correctly when multiple tensors have different compression types """
  104. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  105. tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
  106. tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
  107. results = {}
  108. FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
  109. for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
  110. averager1 = hivemind.DecentralizedAverager([x.clone() for x in tensors1], dht=dht,
  111. compression_type=compression_type_pair, listen=False,
  112. target_group_size=2, prefix='mygroup', start=True)
  113. averager2 = hivemind.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
  114. compression_type=compression_type_pair,
  115. target_group_size=2, prefix='mygroup', start=True)
  116. for future in averager1.step(wait=False), averager2.step(wait=False):
  117. future.result()
  118. with averager1.get_tensors() as averaged_tensors:
  119. results[compression_type_pair] = averaged_tensors
  120. assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
  121. assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
  122. assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
  123. assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
  124. assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
  125. assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
  126. assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
  127. assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
  128. reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
  129. for i in range(2):
  130. assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
  131. assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
  132. def compute_mean_std(averagers, unbiased=True):
  133. results = []
  134. for averager in averagers:
  135. with averager.get_tensors() as tensors:
  136. results.append([tensor.clone() for tensor in tensors])
  137. results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
  138. means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
  139. stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
  140. return means, stds
  141. @pytest.mark.forked
  142. def test_allreduce_grid():
  143. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  144. averagers = [hivemind.DecentralizedAverager(
  145. averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
  146. prefix='mygroup', initial_group_bits=bin(i // 2)[2:].rjust(2, '0'), start=True)
  147. for i in range(8)]
  148. [means0], [stds0] = compute_mean_std(averagers)
  149. assert not torch.allclose(stds0, torch.zeros_like(stds0))
  150. prev_means, prev_stds = means0, stds0
  151. for i in range(5):
  152. step_futures = [averager.step(wait=False) for averager in averagers]
  153. groups = [future.result() for future in step_futures]
  154. [means], [stds] = compute_mean_std(averagers)
  155. assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
  156. assert all(len(group) == 2 for group in groups)
  157. if i <= 2:
  158. assert torch.all(torch.le(stds, prev_stds))
  159. else:
  160. assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
  161. for averager in averagers:
  162. averager.shutdown()
  163. dht.shutdown()
  164. @pytest.mark.forked
  165. def test_allgather():
  166. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  167. averagers = [hivemind.DecentralizedAverager([torch.ones(1)], dht=dht, target_group_size=4, averaging_expiration=15,
  168. prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
  169. start=True)
  170. for _ in range(8)]
  171. futures = []
  172. for i, averager in enumerate(averagers):
  173. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo='bar')))
  174. assert len(set(repr(sorted(future.result())) for future in futures)) == 2
  175. reference_metadata = {averager.endpoint: dict(batch_size=123 + i, foo='bar')
  176. for i, averager in enumerate(averagers)}
  177. for future in futures:
  178. gathered = future.result()
  179. assert len(gathered) == 4
  180. for endpoint in gathered:
  181. assert gathered[endpoint] == reference_metadata[endpoint]
  182. for averager in averagers:
  183. averager.shutdown()
  184. dht.shutdown()
  185. def get_cost(vector_size, partitions, throughputs):
  186. return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
  187. for i in range(len(partitions)))
  188. def check_optimality(vector_size, throughputs, ref_partitions):
  189. partitions = list(load_balance_peers(vector_size, throughputs))
  190. assert get_cost(vector_size, partitions, throughputs) <= get_cost(vector_size, ref_partitions, throughputs)
  191. @pytest.mark.forked
  192. def test_load_balancing():
  193. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  194. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  195. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  196. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  197. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  198. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  199. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  200. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  201. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  202. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  203. assert load_balance_peers(100, (None, None)) == (50, 50)
  204. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  205. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  206. with pytest.raises(AssertionError):
  207. load_balance_peers(100, (0, 0, 0))
  208. for i in range(10):
  209. vector_size = np.random.randint(1, 1024 ** 3)
  210. num_peers = np.random.randint(1, 256)
  211. scale = 1e-9 + np.random.rand() * 1e5
  212. throughputs = np.random.rand(num_peers) * scale + 1e-6
  213. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  214. assignment = load_balance_peers(vector_size, throughputs, min_size)
  215. assert np.sum(assignment) == vector_size
  216. assert np.min(assignment) >= 0
  217. @pytest.mark.forked
  218. def test_too_few_peers():
  219. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  220. averagers = [hivemind.DecentralizedAverager(
  221. averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
  222. averaging_expiration=1, request_timeout=0.5,
  223. prefix='mygroup', initial_group_bits=bin(i)[2:].rjust(3, '0'), start=True)
  224. for i in range(4)]
  225. step_futures = [averager.step(wait=False) for averager in averagers]
  226. for future in step_futures:
  227. assert len(future.result()) == 2
  228. for averager in averagers:
  229. averager.shutdown()
  230. dht.shutdown()
  231. @pytest.mark.forked
  232. def test_overcrowded(num_peers=16):
  233. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  234. averagers = [hivemind.DecentralizedAverager(
  235. averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
  236. averaging_expiration=1, request_timeout=0.5,
  237. prefix='mygroup', initial_group_bits='', start=True)
  238. for _ in range(num_peers)]
  239. for t in range(5):
  240. step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
  241. assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
  242. for averager in averagers:
  243. averager.shutdown()
  244. dht.shutdown()
  245. @pytest.mark.forked
  246. def test_load_state_from_peers():
  247. num_calls = 0
  248. super_metadata = dict(x=123)
  249. super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
  250. class TestAverager(hivemind.DecentralizedAverager):
  251. def get_current_state(self):
  252. """
  253. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  254. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  255. """
  256. nonlocal num_calls, super_metadata, super_tensors
  257. num_calls += 1
  258. return super_metadata, super_tensors
  259. dht_root = hivemind.DHT(start=True)
  260. initial_peers = [f'{hivemind.LOCALHOST}:{dht_root.port}']
  261. dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
  262. averager1 = TestAverager([torch.randn(3), torch.rand(5)],
  263. dht=dht1, start=True,
  264. prefix='demo-run', target_group_size=2)
  265. dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
  266. dht2.get('demo-run.all_averagers')
  267. averager2 = TestAverager([torch.randn(3), torch.rand(5)],
  268. dht=dht2, start=True,
  269. prefix='demo-run', target_group_size=2)
  270. assert num_calls == 0
  271. got_metadata, got_tensors = averager2.load_state_from_peers()
  272. assert num_calls == 1
  273. assert got_metadata == super_metadata
  274. assert all(map(torch.allclose, got_tensors, super_tensors))
  275. super_metadata['y'] = 123
  276. super_tensors[1][2] = 9
  277. assert num_calls == 1
  278. assert got_metadata != super_metadata
  279. assert not all(map(torch.allclose, got_tensors, super_tensors))
  280. got_metadata, got_tensors = averager2.load_state_from_peers()
  281. assert num_calls == 2
  282. assert got_metadata == super_metadata
  283. assert all(map(torch.allclose, got_tensors, super_tensors))
  284. averager1.allow_state_sharing = False
  285. assert averager2.load_state_from_peers() is None
  286. averager1.allow_state_sharing = True
  287. got_metadata, got_tensors = averager2.load_state_from_peers()
  288. assert num_calls == 3
  289. assert got_metadata == super_metadata
  290. @pytest.mark.forked
  291. def test_getset_bits():
  292. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  293. averager = hivemind.DecentralizedAverager([torch.randn(3)], dht=dht, start=True,
  294. prefix='test_prefix', target_group_size=2)
  295. averager.set_group_bits('00101011101010')
  296. assert averager.get_group_bits() == '00101011101010'
  297. @pytest.mark.forked
  298. def test_training_averager(n_steps: int = 10, n_dims: int = 16):
  299. torch.manual_seed(42)
  300. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  301. common_kwargs = {'dht': dht, 'start': True, 'listen_on': '127.0.0.1:*',
  302. 'prefix': 'demo-run', 'target_group_size': 2}
  303. x1 = torch.randn(n_dims, requires_grad=True)
  304. opt1 = torch.optim.Adam([x1], lr=0.05)
  305. averager1 = hivemind.client.TrainingAverager(opt1, average_gradients=True, average_parameters=True,
  306. average_opt_statistics=["exp_avg_sq"], **common_kwargs)
  307. x2 = torch.randn(n_dims, requires_grad=True)
  308. opt2 = torch.optim.Adam([x2], lr=0.05)
  309. averager2 = hivemind.client.TrainingAverager(opt2, average_gradients=True, average_parameters=True,
  310. average_opt_statistics=["exp_avg_sq"], **common_kwargs)
  311. a = torch.ones(n_dims)
  312. for i in range(n_steps):
  313. opt1.zero_grad()
  314. opt2.zero_grad()
  315. (x1 - a).pow(2).sum().backward()
  316. (x2 - a).pow(2).sum().backward()
  317. opt1.step()
  318. opt2.step()
  319. with torch.no_grad():
  320. x_avg = 0.5 * (x1 + x2)
  321. grad_avg = 0.5 * (x1.grad + x2.grad)
  322. stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
  323. # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
  324. f1 = averager1.step(wait=False)
  325. f2 = averager2.step(wait=False)
  326. f1.result()
  327. f2.result()
  328. assert torch.allclose(x1, x_avg)
  329. assert torch.allclose(x2, x_avg)
  330. assert torch.allclose(x1.grad, grad_avg)
  331. assert torch.allclose(x2.grad, grad_avg)
  332. assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
  333. assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)