test_dht_experts.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import asyncio
  2. import random
  3. import time
  4. import numpy as np
  5. import pytest
  6. import hivemind
  7. from hivemind.dht import DHTNode
  8. from hivemind.moe.client.beam_search import MoEBeamSearcher
  9. from hivemind.moe.server import declare_experts, get_experts
  10. from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
  11. from hivemind.p2p import PeerInfo
  12. @pytest.mark.forked
  13. def test_store_get_experts(n_peers=10):
  14. peers = [hivemind.DHT(start=True)]
  15. initial_peers = peers[0].get_visible_maddrs()
  16. peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
  17. first_peer = random.choice(peers)
  18. other_peer = random.choice(peers)
  19. expert_uids = [f"my_expert.{i}" for i in range(50)]
  20. batch_size = 10
  21. for batch_start in range(0, len(expert_uids), batch_size):
  22. declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], first_peer.peer_id)
  23. found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
  24. assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
  25. assert all(res is None for res in found[-2:]), "Found non-existing experts"
  26. other_expert = "my_other_expert.1337"
  27. declare_experts(other_peer, [other_expert], other_peer.peer_id)
  28. first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
  29. assert isinstance(first_found, hivemind.RemoteExpert)
  30. assert first_found.server_peer_info.peer_id == other_peer.peer_id
  31. assert first_notfound is None
  32. # test graceful shutdown
  33. first_peer.shutdown()
  34. other_peer.shutdown()
  35. time.sleep(1.0)
  36. remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
  37. remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
  38. assert all(declare_experts(remaining_peer1, ["new_expert.1"], remaining_peer1.peer_id))
  39. assert get_experts(remaining_peer2, ["new_expert.1"])[0].server_peer_info.peer_id == remaining_peer1.peer_id
  40. @pytest.mark.forked
  41. def test_beam_search(
  42. n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
  43. ):
  44. dht_instances = [hivemind.DHT(start=True)]
  45. initial_peers = dht_instances[0].get_visible_maddrs()
  46. dht_instances += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
  47. real_experts = sorted(
  48. {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
  49. )
  50. for batch_start in range(0, len(real_experts), batch_size):
  51. dht = random.choice(dht_instances)
  52. declare_experts(
  53. dht,
  54. real_experts[batch_start : batch_start + batch_size],
  55. peer_id=dht.peer_id,
  56. )
  57. neighbors = sum(
  58. [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], []
  59. )
  60. you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
  61. beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
  62. for i in range(10):
  63. topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
  64. assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
  65. assert len(topk_experts) == beam_size
  66. for i in range(10):
  67. batch_experts = beam_search.batch_find_best_experts(
  68. [np.random.randn(batch_size, dim) for dim in grid_dims], beam_size=beam_size
  69. )
  70. assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
  71. assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
  72. assert all(len(experts) == beam_size for experts in batch_experts)
  73. @pytest.mark.forked
  74. def test_dht_single_node():
  75. node = hivemind.DHT(start=True)
  76. beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
  77. assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], node.peer_id).values())
  78. assert len(declare_experts(node, ["ffn.1", "ffn.2"], node.peer_id)) == 4
  79. assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], node.peer_id)) == 7
  80. for expert in get_experts(node, ["expert.3", "expert.2"]):
  81. assert expert.server_peer_info.peer_id == node.peer_id
  82. assert all(declare_experts(node, ["expert.5", "expert.2"], node.peer_id).values())
  83. found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
  84. assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
  85. successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
  86. assert len(successors["e.1.2."]) == 2
  87. peer_info = PeerInfo(node.peer_id, [a.decapsulate("/p2p/" + a.get("p2p")) for a in node.get_visible_maddrs()])
  88. assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", peer_info)
  89. assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", peer_info)
  90. assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", peer_info)
  91. assert successors["e.4.5."] == {}
  92. initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
  93. assert len(initial_beam) == 3
  94. assert initial_beam[0][:2] == (2.0, "expert.1.")
  95. assert initial_beam[1][:2] == (1.0, "expert.2.")
  96. assert initial_beam[2][:2] == (0.0, "expert.3.")
  97. with pytest.raises(AssertionError):
  98. beam_search = MoEBeamSearcher(node, "expert.1.ffn", (2, 2))
  99. with pytest.raises(AssertionError):
  100. beam_search.get_active_successors(["e.1.2.", "e.2", "e.4.5."])
  101. def test_uid_patterns():
  102. valid_experts = [
  103. "expert.1",
  104. "expert.0",
  105. "expert.0.0.1",
  106. "expert.1337",
  107. "ffn.12.34.56.78.90",
  108. "transformer.3.2.1.0",
  109. "transformer_encoder.2",
  110. "transformer::encoder.2",
  111. "T®@nsf0rmE®🤗.321",
  112. "🤗.321",
  113. "0.1.2",
  114. "00.1.2",
  115. "7070.3.2.1.0",
  116. "block2.1.23",
  117. "LAYER.1.0.1",
  118. ]
  119. valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
  120. valid_prefixes.extend([f"{uid}." for uid in valid_experts])
  121. valid_prefixes.extend([split_uid(uid)[0] for uid in valid_experts])
  122. for uid in valid_experts:
  123. assert is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
  124. for pfx in valid_prefixes:
  125. assert is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
  126. invalid = [
  127. "",
  128. ".",
  129. "expert.-1",
  130. "xxx.a",
  131. "expert.1x",
  132. "expert_ffn.1.abc1",
  133. "some.123.01",
  134. "expert.123.01",
  135. "e1",
  136. "e..1",
  137. "e",
  138. "e.1.2.3..4",
  139. "ffn.1..1",
  140. ".123",
  141. ".1.2.3.",
  142. ".expert",
  143. "transformer.encoder.2",
  144. "T®@nsf0rmE®.🤗.321",
  145. "layer::123",
  146. "expert.0.1.2.suffix",
  147. "0.1.2.suffix",
  148. "expert.1 something",
  149. "expert.1\n",
  150. "expert.1\n2",
  151. "expert.1 ",
  152. "expert.1\nexpert.2",
  153. "'expert.1'",
  154. '"expert.1"',
  155. ]
  156. invalid_experts = invalid + valid_prefixes + ["0", "123456"]
  157. invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
  158. for uid in invalid_experts:
  159. assert not is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
  160. for pfx in invalid_prefixes:
  161. assert not is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
  162. @pytest.mark.forked
  163. @pytest.mark.asyncio
  164. async def test_negative_caching(n_peers=10):
  165. dht_kwargs = {"cache_locally": False}
  166. peers = [hivemind.DHT(start=True, **dht_kwargs)]
  167. initial_peers = peers[0].get_visible_maddrs()
  168. peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
  169. writer_peer = random.choice(peers)
  170. assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], writer_peer.peer_id).values())
  171. neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
  172. neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)
  173. beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix="ffn.", grid_size=(10, 10, 10), negative_caching=True)
  174. # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
  175. assert len(beam_search.get_initial_beam(scores=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], beam_size=3)) == 2
  176. node = await DHTNode.create(initial_peers=neighbors)
  177. fetched = await asyncio.gather(*(node.get(f"ffn.{i}.") for i in range(10)))
  178. for i in range(6):
  179. assert fetched[i] is not None, f"node should have cached ffn.{i}."
  180. for i in range(6, len(fetched)):
  181. assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
  182. await node.shutdown()
  183. neg_caching_peer.shutdown()
  184. for peer in peers:
  185. peer.shutdown()