test_averaging.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import asyncio
  2. import random
  3. import time
  4. from itertools import product
  5. import torch
  6. import pytest
  7. import hivemind
  8. from hivemind.client.allreduce import GroupAllReduce, split_into_parts, restore_from_parts
  9. from hivemind.utils import LOCALHOST
  10. @pytest.mark.forked
  11. @pytest.mark.asyncio
  12. async def test_allreduce_direct():
  13. # WARNING! this test uses an early interface that will change by the time DecentralizedAverager is finished
  14. dht = hivemind.DHT(start=True)
  15. tensors1 = [torch.randn(123), torch.zeros(3)]
  16. tensors2 = [torch.rand(123), torch.ones(3)]
  17. tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
  18. reference = [(tensors1[i] + tensors2[i] + tensors3[i]) / 3 for i in range(len(tensors1))]
  19. averager1 = hivemind.DecentralizedAverager(tensors1, dht=dht, start=True, max_size=3, timeout=5)
  20. averager2 = hivemind.DecentralizedAverager(tensors2, dht=dht, start=True, max_size=3, timeout=5)
  21. averager3 = hivemind.DecentralizedAverager(tensors3, dht=dht, start=True, max_size=3, timeout=5)
  22. future1 = averager1.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager1.port}",
  23. leader_endpoint=None, return_future=True)
  24. time.sleep(0.1)
  25. future2 = averager2.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager2.port}",
  26. leader_endpoint=f"{LOCALHOST}:{averager1.port}",
  27. return_future=True)
  28. future3 = averager3.group_allreduce(my_endpoint=f"{LOCALHOST}:{averager3.port}",
  29. leader_endpoint=f"{LOCALHOST}:{averager1.port}",
  30. return_future=True)
  31. for future in future1, future2, future3:
  32. for ref, our in zip(reference, await future):
  33. assert torch.allclose(ref, our)
  34. @pytest.mark.forked
  35. @pytest.mark.asyncio
  36. async def test_allreduce_protocol():
  37. """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
  38. peers = "alice", "bob", "carol"
  39. expiration_offsets = 4, 0, 1
  40. tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
  41. for i, peer in enumerate(peers)}
  42. alice, bob, carol = allreduce_protocols = [
  43. GroupAllReduce(endpoint=peer, expiration=hivemind.get_dht_time() + offset, tensors=tensors_by_peer[peer])
  44. for peer, offset in zip(peers, expiration_offsets)]
  45. bob.start_new_group()
  46. bob.add_peer_to_group(alice.info.endpoint)
  47. alice.join_group(bob, bob.group_id)
  48. bob.add_peer_to_group(carol.info.endpoint)
  49. carol.join_group(carol, bob.group_id)
  50. bob.leader_begin_allreduce()
  51. ordered_group_endpoints = await bob.assembled_group
  52. assert len(ordered_group_endpoints) == len(peers)
  53. carol.follower_begin_allreduce(ordered_group_endpoints)
  54. alice.follower_begin_allreduce(ordered_group_endpoints)
  55. chunks_by_peer = {protocol.info.endpoint: {
  56. peer: part for peer, part in zip(peers, split_into_parts(protocol.local_tensors, len(ordered_group_endpoints)))
  57. } for protocol in allreduce_protocols}
  58. all_pairs = list(product(allreduce_protocols, peers))
  59. random.shuffle(all_pairs)
  60. await asyncio.gather(*(
  61. peer_allreduce.accumulate(source_peer, chunks_by_peer[source_peer][peer_allreduce.info.endpoint])
  62. for peer_allreduce, source_peer in all_pairs))
  63. averaged_parts = await asyncio.gather(*(protocol.averaged_part for protocol in allreduce_protocols))
  64. tensor_shapes = [tensor.shape for tensor in alice.local_tensors]
  65. averaged_tensors = restore_from_parts(averaged_parts, tensor_shapes)
  66. reference_tensors = [
  67. sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
  68. for i in range(len(tensors_by_peer[peers[0]]))
  69. ]
  70. assert len(averaged_tensors) == len(reference_tensors)
  71. assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
  72. for our, ref in zip(averaged_tensors, reference_tensors))
  73. @pytest.mark.forked
  74. def test_chunks():
  75. for i in range(100):
  76. tensors = []
  77. for i in range(random.randint(1, 5)):
  78. ndim = random.randint(0, 4)
  79. shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
  80. make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
  81. tensors.append(make_tensor(shape))
  82. total_size = sum(map(torch.Tensor.numel, tensors))
  83. if total_size == 0:
  84. continue
  85. num_chunks = random.randint(1, min(1000, sum(x.numel() for x in tensors)))
  86. chunks = split_into_parts(tensors, group_size=num_chunks)
  87. assert len(chunks) == num_chunks
  88. shapes = [tensor.shape for tensor in tensors]
  89. restored = restore_from_parts(chunks, shapes)
  90. assert len(restored) == len(tensors)
  91. assert all(new.shape == old.shape for new, old in zip(restored, tensors))
  92. assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))