test_averaging.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import asyncio
  2. import random
  3. import numpy as np
  4. import torch
  5. import pytest
  6. import hivemind
  7. from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
  8. from hivemind.client.averaging.load_balancing import load_balance_peers
  9. from hivemind.utils import Endpoint
  10. @pytest.mark.forked
  11. def test_getset_averagers():
  12. dht = hivemind.DHT(start=True)
  13. t = hivemind.get_dht_time()
  14. dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 60)
  15. dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', expiration_time=t + 61)
  16. q1 = dht.get_averagers('bucket.0b10110', only_active=True)
  17. dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 66)
  18. q2 = dht.get_averagers('bucket.0b10110', only_active=True)
  19. dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', looking_for_group=False,
  20. expiration_time=t + 61)
  21. q3 = dht.get_averagers('bucket.0b10110', only_active=True)
  22. q4 = dht.get_averagers('bucket.0b10110', only_active=False)
  23. assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
  24. assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
  25. assert len(q3) == 1 and ('localhvost', t + 66) in q3
  26. assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
  27. @pytest.mark.forked
  28. def test_allreduce_once():
  29. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  30. tensors1 = [torch.randn(123), torch.zeros(3)]
  31. tensors2 = [torch.rand(123), torch.ones(3)]
  32. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  33. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  34. reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
  35. averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
  36. prefix='mygroup', initial_group_bits='0110', listen_on='127.0.0.1:*',
  37. start=True)
  38. for tensors in [tensors1, tensors2, tensors3, tensors4]]
  39. futures = []
  40. for averager in averagers:
  41. futures.append(averager.step(wait=False))
  42. for future in futures:
  43. result = future.result()
  44. for averager in averagers:
  45. assert averager.endpoint in result
  46. for averager in averagers:
  47. with averager.get_tensors() as averaged_tensors:
  48. for ref, our in zip(reference, averaged_tensors):
  49. assert torch.allclose(ref, our, atol=1e-6)
  50. @pytest.mark.forked
  51. def test_allgather():
  52. dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
  53. averagers = [hivemind.DecentralizedAverager(torch.ones(1), dht=dht, target_group_size=4, averaging_expiration=15,
  54. prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
  55. start=True)
  56. for _ in range(8)]
  57. futures = []
  58. for i, averager in enumerate(averagers):
  59. futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo='bar')))
  60. assert len(set(repr(sorted(future.result())) for future in futures)) == 2
  61. reference_metadata = {averager.endpoint: dict(batch_size=123 + i, foo='bar')
  62. for i, averager in enumerate(averagers)}
  63. for future in futures:
  64. gathered = future.result()
  65. assert len(gathered) == 4
  66. for endpoint in gathered:
  67. assert gathered[endpoint] == reference_metadata[endpoint]
  68. @pytest.mark.forked
  69. @pytest.mark.asyncio
  70. async def test_allreduce_protocol():
  71. """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
  72. peers = "alice", "bob", "carol"
  73. tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
  74. for i, peer in enumerate(peers)}
  75. group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
  76. allreduce_protocols = [AllReduceProtocol(
  77. group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
  78. ordered_group_endpoints=peers, part_sizes=(150, 200, 67))
  79. for peer in peers]
  80. async def _accumulate(sender: Endpoint, recipient: Endpoint):
  81. sender_allreduce = allreduce_protocols[peers.index(sender)]
  82. recipient_allreduce = allreduce_protocols[peers.index(recipient)]
  83. averaged_part = await recipient_allreduce.accumulate_part(
  84. source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
  85. sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
  86. await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
  87. if sender != recipient})
  88. reference_tensors = [
  89. sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
  90. for i in range(len(tensors_by_peer[peers[0]]))
  91. ]
  92. for peer, allreduce in zip(peers, allreduce_protocols):
  93. assert allreduce.future.done()
  94. averaged_tensors = await allreduce
  95. assert len(averaged_tensors) == len(reference_tensors)
  96. assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
  97. for our, ref in zip(averaged_tensors, reference_tensors))
  98. @pytest.mark.forked
  99. def test_partitioning():
  100. for _ in range(100):
  101. tensors = []
  102. for _ in range(random.randint(1, 5)):
  103. ndim = random.randint(0, 4)
  104. shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
  105. make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
  106. tensors.append(make_tensor(shape))
  107. total_size = sum(map(torch.Tensor.numel, tensors))
  108. if total_size == 0:
  109. continue
  110. num_chunks = random.randint(1, min(1000, sum(x.numel() for x in tensors)))
  111. part_sizes = load_balance_peers(total_size, [None] * num_chunks)
  112. chunks = split_into_parts(tensors, part_sizes)
  113. assert len(chunks) == num_chunks
  114. shapes = [tensor.shape for tensor in tensors]
  115. restored = restore_from_parts(chunks, shapes)
  116. assert len(restored) == len(tensors)
  117. assert all(new.shape == old.shape for new, old in zip(restored, tensors))
  118. assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
  119. def get_cost(vector_size, partitions, throughputs):
  120. return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
  121. for i in range(len(partitions)))
  122. def check_optimality(vector_size, throughputs, ref_partitions):
  123. partitions = list(load_balance_peers(vector_size, throughputs))
  124. assert get_cost(vector_size, partitions, throughputs) <= get_cost(vector_size, ref_partitions, throughputs)
  125. @pytest.mark.forked
  126. def test_load_balancing():
  127. check_optimality(60, np.array([0.25, 0.25, 0.25, 0.25]), [15, 15, 15, 15])
  128. check_optimality(1024, np.array([0.3, 0.5, 0.9]), [0, 255, 769])
  129. check_optimality(60, np.array([0.44, 0.33, 0.22]), [42, 18, 0])
  130. check_optimality(60, np.array([0.55, 0.44, 0.40]), [35, 16, 9])
  131. check_optimality(1024 * 1024, np.array([0.3, 0.5, 0.9, 0.6]), [0, 169327, 602629, 276620])
  132. check_optimality(1024 * 1024, np.array([0.0, 0.5, 0.0, 0.6]), [0, 428963, 0, 619613])
  133. assert load_balance_peers(60, np.array([0.55, 0.44, 0.40]), min_size=10) == (41, 19, 0)
  134. assert load_balance_peers(60, np.array([0.32, 0.55, 0.44]), min_size=10) == (0, 40, 20)
  135. assert load_balance_peers(2, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 1)
  136. assert load_balance_peers(1, np.array([0.55, 0.20, 0.44]), min_size=10) == (1, 0, 0)
  137. assert load_balance_peers(100, (None, None)) == (50, 50)
  138. assert load_balance_peers(100, (None, None, None, None, None)) == (20, 20, 20, 20, 20)
  139. assert load_balance_peers(100, (0, 0, 0, None, None)) == (0, 0, 0, 50, 50)
  140. with pytest.raises(AssertionError):
  141. load_balance_peers(100, (0, 0, 0))
  142. for i in range(10):
  143. vector_size = np.random.randint(1, 1024 ** 3)
  144. num_peers = np.random.randint(1, 256)
  145. scale = 1e-9 + np.random.rand() * 1e5
  146. throughputs = np.random.rand(num_peers) * scale + 1e-6
  147. min_size = np.random.choice([0, np.random.randint(0, vector_size // 10)])
  148. assignment = load_balance_peers(vector_size, throughputs, min_size)
  149. assert np.sum(assignment) == vector_size
  150. assert np.min(assignment) >= 0