test_dht_experts.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import random
  2. import numpy as np
  3. import pytest
  4. import asyncio
  5. import multiprocessing as mp
  6. import hivemind
  7. from hivemind import LOCALHOST, UidEndpoint
  8. def test_store_get_experts():
  9. peers = [hivemind.DHT(start=True)]
  10. for i in range(10):
  11. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
  12. peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
  13. you: hivemind.dht.DHT = random.choice(peers)
  14. theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
  15. expert_uids = [f"my_expert.{i}" for i in range(110)]
  16. batch_size = 10
  17. for batch_start in range(0, len(expert_uids), batch_size):
  18. you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
  19. found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
  20. assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
  21. assert all(res is None for res in found[-2:]), "Found non-existing experts"
  22. that_guys_expert, that_guys_port = "my_other_expert.1337", random.randint(1000, 9999)
  23. theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
  24. you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
  25. assert isinstance(you_found, hivemind.RemoteExpert)
  26. assert you_found.endpoint == f'that_host:{that_guys_port}'
  27. for peer in peers:
  28. peer.shutdown()
  29. def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
  30. grid_dims=(32, 32, 32)):
  31. dht = []
  32. for i in range(dht_size):
  33. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
  34. dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
  35. real_experts = sorted({
  36. 'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
  37. for _ in range(total_experts)
  38. })
  39. for batch_start in range(0, len(real_experts), batch_size):
  40. random.choice(dht).declare_experts(
  41. real_experts[batch_start: batch_start + batch_size], wait=True,
  42. endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
  43. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
  44. you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
  45. for i in range(50):
  46. topk_experts = you.find_best_experts('expert.', [np.random.randn(dim) for dim in grid_dims], beam_size=beam_size)
  47. assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
  48. assert len(topk_experts) == beam_size
  49. for i in range(10):
  50. batch_experts = you.batch_find_best_experts('expert.', [np.random.randn(batch_size, dim) for dim in grid_dims],
  51. beam_size=beam_size)
  52. assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
  53. assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
  54. assert all(len(experts) == beam_size for experts in batch_experts)
  55. def test_dht_single_node():
  56. node = hivemind.DHT(start=True, expiration=999)
  57. assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
  58. assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
  59. assert len(node.declare_experts(['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
  60. for expert in node.get_experts(['expert.3', 'expert.2']):
  61. assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
  62. assert all(node.declare_experts(['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
  63. found_experts = node.find_best_experts('expert.', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
  64. assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
  65. successors = node.get_active_successors(['e.1.2.', 'e.2.', 'e.4.5.'])
  66. assert len(successors['e.1.2.']) == 2
  67. assert successors['e.1.2.'][3] == UidEndpoint('e.1.2.3', f'{LOCALHOST}:42')
  68. assert successors['e.1.2.'][5] == UidEndpoint('e.1.2.5', f'{LOCALHOST}:42')
  69. assert len(successors['e.2.']) == 1 and successors['e.2.'][0] == UidEndpoint('e.2.0', f'{LOCALHOST}:42')
  70. assert successors['e.4.5.'] == {}
  71. initial_beam = node.get_initial_beam('expert.', (3, 2, 1, 0, -1, -2, -3), beam_size=3)
  72. assert len(initial_beam) == 3
  73. assert initial_beam[0][:2] == (2.0, 'expert.1.')
  74. assert initial_beam[1][:2] == (1.0, 'expert.2.')
  75. assert initial_beam[2][:2] == (0.0, 'expert.3.')
  76. with pytest.raises(AssertionError):
  77. node.find_best_experts('expert', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
  78. with pytest.raises(AssertionError):
  79. node.find_best_experts('expert.1', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
  80. with pytest.raises(AssertionError):
  81. node.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
  82. with pytest.raises(AssertionError):
  83. node.get_initial_beam('expert', (3, 2, 1, 0, -1, -2, -3), beam_size=3)
  84. def test_uid_patterns():
  85. valid_experts = ["expert.1", "expert.0", "expert.0.0.1", "expert.1337", "ffn.12.34.56.78.90",
  86. "transformer.3.2.1.0", "transformer_encoder.2", "transformer::encoder.2", "T®@nsf0rmE®🤗.321",
  87. "🤗.321", "0.1.2", "00.1.2", "7070.3.2.1.0", "block2.1.23", "LAYER.1.0.1"]
  88. valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
  89. valid_prefixes.extend([f"{uid}." for uid in valid_experts])
  90. valid_prefixes.extend([hivemind.split_uid(uid)[0] for uid in valid_experts])
  91. for uid in valid_experts:
  92. assert hivemind.is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
  93. for pfx in valid_prefixes:
  94. assert hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
  95. invalid = ["", ".", "expert.-1", "xxx.a", "expert.1x", "expert_ffn.1.abc1", "some.123.01", "expert.123.01",
  96. "e1", "e..1", "e", "e.1.2.3..4", "ffn.1..1", ".123", ".1.2.3.", ".expert", "transformer.encoder.2",
  97. "T®@nsf0rmE®.🤗.321", "layer::123", "expert.0.1.2.suffix", "0.1.2.suffix", "expert.1 something",
  98. "expert.1\n", "expert.1\n2", "expert.1 ", "expert.1\nexpert.2", "'expert.1'", '"expert.1"']
  99. invalid_experts = invalid + valid_prefixes + ["0", "123456"]
  100. invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
  101. for uid in invalid_experts:
  102. assert not hivemind.is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
  103. for pfx in invalid_prefixes:
  104. assert not hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
  105. def test_negative_caching():
  106. test_success = mp.Event()
  107. peers = []
  108. for i in range(10):
  109. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
  110. peers.append(hivemind.DHT(initial_peers=neighbors_i, negative_caching=False, cache_locally=False, start=True))
  111. normal_peer, writer_peer = random.sample(peers, 2)
  112. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
  113. neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, negative_caching=True, cache_locally=False, start=True)
  114. assert all(writer_peer.declare_experts(['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
  115. # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
  116. assert len(neg_caching_peer.get_initial_beam(prefix='ffn.', scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
  117. async def _tester():
  118. node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
  119. fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
  120. for i in range(6):
  121. assert fetched[i] is not None, f"node should have cached ffn.{i}."
  122. for i in range(6, len(fetched)):
  123. assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
  124. test_success.set()
  125. proc = mp.Process(target=lambda: asyncio.run(_tester()))
  126. proc.start()
  127. proc.join()
  128. assert test_success.is_set()