test_averaging.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import asyncio
  2. import random
  3. import time
  4. import torch
  5. import pytest
  6. import hivemind
  7. from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
  8. from hivemind.utils import Endpoint
  9. @pytest.mark.forked
  10. def test_getset_averagers():
  11. dht = hivemind.DHT(start=True)
  12. t = hivemind.get_dht_time()
  13. dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 60)
  14. dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', expiration_time=t + 61)
  15. q1 = dht.get_averagers('bucket.0b10110', only_active=True)
  16. dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 66)
  17. q2 = dht.get_averagers('bucket.0b10110', only_active=True)
  18. dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', looking_for_group=False,
  19. expiration_time=t + 61)
  20. q3 = dht.get_averagers('bucket.0b10110', only_active=True)
  21. q4 = dht.get_averagers('bucket.0b10110', only_active=False)
  22. assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
  23. assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
  24. assert len(q3) == 1 and ('localhvost', t + 66) in q3
  25. assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
  26. @pytest.mark.forked
  27. def test_allreduce_once():
  28. dht = hivemind.DHT(start=True)
  29. tensors1 = [torch.randn(123), torch.zeros(3)]
  30. tensors2 = [torch.rand(123), torch.ones(3)]
  31. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  32. tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
  33. reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
  34. averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
  35. prefix='mygroup', initial_group_bits='0110', listen_on='127.0.0.1:*',
  36. start=True)
  37. for tensors in [tensors1, tensors2, tensors3, tensors4]]
  38. futures = []
  39. for averager in averagers:
  40. futures.append(averager.step(wait=False))
  41. for future in futures:
  42. assert future.result() is True
  43. for averager in averagers:
  44. with averager.get_tensors() as averaged_tensors:
  45. for ref, our in zip(reference, averaged_tensors):
  46. assert torch.allclose(ref, our, atol=1e-6)
  47. @pytest.mark.forked
  48. @pytest.mark.asyncio
  49. async def test_allreduce_protocol():
  50. """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
  51. peers = "alice", "bob", "carol"
  52. tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
  53. for i, peer in enumerate(peers)}
  54. group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
  55. allreduce_protocols = [AllReduceProtocol(
  56. group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer], ordered_group_endpoints=peers)
  57. for peer in peers]
  58. async def _accumulate(sender: Endpoint, recipient: Endpoint):
  59. sender_allreduce = allreduce_protocols[peers.index(sender)]
  60. recipient_allreduce = allreduce_protocols[peers.index(recipient)]
  61. averaged_part = await recipient_allreduce.accumulate_part(
  62. source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
  63. sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
  64. await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
  65. if sender != recipient})
  66. reference_tensors = [
  67. sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
  68. for i in range(len(tensors_by_peer[peers[0]]))
  69. ]
  70. for peer, allreduce in zip(peers, allreduce_protocols):
  71. assert allreduce.future.done()
  72. averaged_tensors = await allreduce
  73. assert len(averaged_tensors) == len(reference_tensors)
  74. assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
  75. for our, ref in zip(averaged_tensors, reference_tensors))
  76. @pytest.mark.forked
  77. def test_partitioning():
  78. for _ in range(100):
  79. tensors = []
  80. for _ in range(random.randint(1, 5)):
  81. ndim = random.randint(0, 4)
  82. shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
  83. make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
  84. tensors.append(make_tensor(shape))
  85. total_size = sum(map(torch.Tensor.numel, tensors))
  86. if total_size == 0:
  87. continue
  88. num_chunks = random.randint(1, min(1000, sum(x.numel() for x in tensors)))
  89. chunks = split_into_parts(tensors, group_size=num_chunks)
  90. assert len(chunks) == num_chunks
  91. shapes = [tensor.shape for tensor in tensors]
  92. restored = restore_from_parts(chunks, shapes)
  93. assert len(restored) == len(tensors)
  94. assert all(new.shape == old.shape for new, old in zip(restored, tensors))
  95. assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))