test_averaging.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import asyncio
  2. import random
  3. import numpy as np
  4. import torch
  5. import pytest
  6. import hivemind
  7. from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
  8. from hivemind.client.averaging.load_balancing import load_balance_peers
  9. from hivemind.client.averaging.key_manager import GroupKeyManager
  10. from hivemind.utils import Endpoint
  11. @pytest.mark.forked
  12. @pytest.mark.asyncio
  13. async def test_key_manager():
  14. key_manager = GroupKeyManager(hivemind.DHT(start=True), endpoint='localhvost',
  15. prefix='test_averaging', initial_group_bits='10110',
  16. target_group_size=2)
  17. t = hivemind.get_dht_time()
  18. key = key_manager.current_key
  19. await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 60)
  20. await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61)
  21. q1 = await key_manager.get_averagers(key, only_active=True)
  22. await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 66)
  23. q2 = await key_manager.get_averagers(key, only_active=True)
  24. await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61, looking_for_group=False)
  25. q3 = await key_manager.get_averagers(key, only_active=True)
  26. q4 = await key_manager.get_averagers(key, only_active=False)
  27. q5 = await key_manager.get_averagers('nonexistent_key.0b0101', only_active=False)
  28. assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
  29. assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
  30. assert len(q3) == 1 and ('localhvost', t + 66) in q3
  31. assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
  32. assert len(q5) == 0
  33. @pytest.mark.forked
  34. def test_allreduce_once():
  35. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  36. tensors1 = [torch.randn(123), torch.zeros(3)]
  37. tensors2 = [torch.rand(123), torch.ones(3)]
  38. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  39. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  40. reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
  41. averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
  42. prefix='mygroup', listen_on='127.0.0.1:*',
  43. start=True)
  44. for tensors in [tensors1, tensors2, tensors3, tensors4]]
  45. futures = []
  46. for averager in averagers:
  47. futures.append(averager.step(wait=False))
  48. for future in futures:
  49. result = future.result()
  50. for averager in averagers:
  51. assert averager.endpoint in result
  52. for averager in averagers:
  53. with averager.get_tensors() as averaged_tensors:
  54. for ref, our in zip(reference, averaged_tensors):
  55. assert torch.allclose(ref, our, atol=1e-6)
  56. for averager in averagers:
  57. averager.shutdown()
  58. dht.shutdown()
  59. def compute_mean_std(averagers, unbiased=True):
  60. results = []
  61. for averager in averagers:
  62. with averager.get_tensors() as tensors:
  63. results.append([tensor.clone() for tensor in tensors])
  64. results_stacked_per_tensor = list(map(torch.stack, zip(*results)))
  65. means = [stack.mean(dim=0) for stack in results_stacked_per_tensor]
  66. stds = [stack.std(dim=0, unbiased=unbiased) for stack in results_stacked_per_tensor]
  67. return means, stds
  68. @pytest.mark.forked
  69. def test_allreduce_grid():
  70. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  71. averagers = [hivemind.DecentralizedAverager(
  72. averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
  73. prefix='mygroup', initial_group_bits=bin(i // 2)[2:].rjust(2, '0'), start=True)
  74. for i in range(8)]
  75. [means0], [stds0] = compute_mean_std(averagers)
  76. assert not torch.allclose(stds0, torch.zeros_like(stds0))
  77. prev_means, prev_stds = means0, stds0
  78. for i in range(5):
  79. step_futures = [averager.step(wait=False) for averager in averagers]
  80. groups = [future.result() for future in step_futures]
  81. [means], [stds] = compute_mean_std(averagers)
  82. assert torch.allclose(means, prev_means, atol=1e-6, rtol=0)
  83. assert all(len(group) == 2 for group in groups)
  84. if i <= 2:
  85. assert torch.all(torch.le(stds, prev_stds))
  86. else:
  87. assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
  88. for averager in averagers:
  89. averager.shutdown()
  90. dht.shutdown()
  91. @pytest.mark.forked
  92. def test_allgather():
  93. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  94. averagers = [hivemind.DecentralizedAverager(torch.ones(1), dht=dht, target_group_size=4, averaging_expiration=15,
  95. prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
  96. start=True)
  97. for _ in range(8)]
  98. futures = []
  99. for i, averager in enumerate(averagers):
  100. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo='bar')))
  101. assert len(set(repr(sorted(future.result())) for future in futures)) == 2
  102. reference_metadata = {averager.endpoint: dict(batch_size=123 + i, foo='bar')
  103. for i, averager in enumerate(averagers)}
  104. for future in futures:
  105. gathered = future.result()
  106. assert len(gathered) == 4
  107. for endpoint in gathered:
  108. assert gathered[endpoint] == reference_metadata[endpoint]
  109. for averager in averagers:
  110. averager.shutdown()
  111. dht.shutdown()
  112. @pytest.mark.forked
  113. @pytest.mark.asyncio
  114. async def test_allreduce_protocol():
  115. """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
  116. peers = "alice", "bob", "carol"
  117. tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
  118. for i, peer in enumerate(peers)}
  119. group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
  120. allreduce_protocols = [AllReduceProtocol(
  121. group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
  122. ordered_group_endpoints=peers, part_sizes=(150, 200, 67))
  123. for peer in peers]
  124. async def _accumulate(sender: Endpoint, recipient: Endpoint):
  125. sender_allreduce = allreduce_protocols[peers.index(sender)]
  126. recipient_allreduce = allreduce_protocols[peers.index(recipient)]
  127. averaged_part = await recipient_allreduce.accumulate_part(
  128. source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
  129. sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
  130. await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
  131. if sender != recipient})
  132. reference_tensors = [
  133. sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
  134. for i in range(len(tensors_by_peer[peers[0]]))
  135. ]
  136. for peer, allreduce in zip(peers, allreduce_protocols):
  137. assert allreduce.future.done()
  138. averaged_tensors = await allreduce
  139. assert len(averaged_tensors) == len(reference_tensors)
  140. assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
  141. for our, ref in zip(averaged_tensors, reference_tensors))
  142. @pytest.mark.forked
  143. def test_partitioning():
  144. for _ in range(100):
  145. tensors = []
  146. for _ in range(random.randint(1, 5)):
  147. ndim = random.randint(0, 4)
  148. shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
  149. make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
  150. tensors.append(make_tensor(shape))
  151. total_size = sum(map(torch.Tensor.numel, tensors))
  152. if total_size == 0:
  153. continue
  154. num_chunks = random.randint(1, min(100, sum(x.numel() for x in tensors)))
  155. part_sizes = load_balance_peers(total_size, [None] * num_chunks)
  156. chunks = split_into_parts(tensors, part_sizes)
  157. assert len(chunks) == num_chunks
  158. shapes = [tensor.shape for tensor in tensors]
  159. restored = restore_from_parts(chunks, shapes)
  160. assert len(restored) == len(tensors)
  161. assert all(new.shape == old.shape for new, old in zip(restored, tensors))
  162. assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
  163. def get_cost(vector_size, partitions, throughputs):
  164. return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
  165. for i in range(len(partitions)))
  166. def check_optimality(vector_size, throughputs, ref_partitions):
  167. partitions = list(load_balance_peers(vector_size, throughputs))
  168. assert get_cost(vector_size, partitions, throughputs) <= get_cost(vector_size, ref_partitions, throughputs)
  169. @pytest.mark.forked
  170. def test_load_balancing():
  171. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  172. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  173. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  174. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  175. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  176. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  177. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  178. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  179. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  180. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  181. assert load_balance_peers(100, (None, None)) == (50, 50)
  182. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  183. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  184. with pytest.raises(AssertionError):
  185. load_balance_peers(100, (0, 0, 0))
  186. for i in range(10):
  187. vector_size = np.random.randint(1, 1024 ** 3)
  188. num_peers = np.random.randint(1, 256)
  189. scale = 1e-9 + np.random.rand() * 1e5
  190. throughputs = np.random.rand(num_peers) * scale + 1e-6
  191. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  192. assignment = load_balance_peers(vector_size, throughputs, min_size)
  193. assert np.sum(assignment) == vector_size
  194. assert np.min(assignment) >= 0
  195. @pytest.mark.forked
  196. def test_too_few_peers():
  197. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  198. averagers = [hivemind.DecentralizedAverager(
  199. averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
  200. averaging_expiration=1, request_timeout=0.5,
  201. prefix='mygroup', initial_group_bits=bin(i)[2:].rjust(3, '0'), start=True)
  202. for i in range(4)]
  203. step_futures = [averager.step(wait=False) for averager in averagers]
  204. for future in step_futures:
  205. assert len(future.result()) == 2
  206. for averager in averagers:
  207. averager.shutdown()
  208. dht.shutdown()
  209. @pytest.mark.forked
  210. def test_overcrowded():
  211. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  212. averagers = [hivemind.DecentralizedAverager(
  213. averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
  214. averaging_expiration=1, request_timeout=0.5,
  215. prefix='mygroup', initial_group_bits='', start=True)
  216. for _ in range(32)]
  217. for t in range(5):
  218. step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
  219. assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
  220. for averager in averagers:
  221. averager.shutdown()
  222. dht.shutdown()
  223. @pytest.mark.forked
  224. def test_load_state_from_peers():
  225. num_calls = 0
  226. super_metadata = dict(x=123)
  227. super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
  228. class TestAverager(hivemind.DecentralizedAverager):
  229. def get_current_state(self):
  230. """
  231. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  232. :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
  233. """
  234. nonlocal num_calls, super_metadata, super_tensors
  235. num_calls += 1
  236. return self.serializer.dumps(super_metadata), super_tensors
  237. dht_root = hivemind.DHT(start=True)
  238. initial_peers = [f'{hivemind.LOCALHOST}:{dht_root.port}']
  239. dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
  240. averager1 = TestAverager([torch.randn(3), torch.rand(5)],
  241. dht=dht1, start=True,
  242. prefix='demo-run', target_group_size=2)
  243. dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
  244. dht2.get('demo-run.all_averagers')
  245. averager2 = TestAverager([torch.randn(3), torch.rand(5)],
  246. dht=dht2, start=True,
  247. prefix='demo-run', target_group_size=2)
  248. assert num_calls == 0
  249. got_metadata, got_tensors = averager2.load_state_from_peers()
  250. assert num_calls == 1
  251. assert got_metadata == super_metadata
  252. assert all(map(torch.allclose, got_tensors, super_tensors))
  253. super_metadata['y'] = 123
  254. super_tensors[1][2] = 9
  255. assert num_calls == 1
  256. assert got_metadata != super_metadata
  257. assert not all(map(torch.allclose, got_tensors, super_tensors))
  258. got_metadata, got_tensors = averager2.load_state_from_peers()
  259. assert num_calls == 2
  260. assert got_metadata == super_metadata
  261. assert all(map(torch.allclose, got_tensors, super_tensors))
  262. @pytest.mark.forked
  263. def test_getset_bits():
  264. dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
  265. averager = hivemind.DecentralizedAverager([torch.randn(3)], dht=dht, start=True,
  266. prefix='test_prefix', target_group_size=2)
  267. averager.set_group_bits('00101011101010')
  268. assert averager.get_group_bits() == '00101011101010'