test_dht_experts.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import asyncio
  2. import random
  3. import time
  4. import numpy as np
  5. import pytest
  6. import hivemind
  7. from hivemind.dht import DHTNode
  8. from hivemind.moe.client.beam_search import MoEBeamSearcher
  9. from hivemind.moe.expert_uid import ExpertInfo, is_valid_prefix, is_valid_uid, split_uid
  10. from hivemind.moe.server import declare_experts, get_experts
  11. @pytest.mark.forked
  12. def test_store_get_experts(n_peers=10):
  13. peers = [hivemind.DHT(start=True)]
  14. initial_peers = peers[0].get_visible_maddrs()
  15. peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
  16. first_peer = random.choice(peers)
  17. other_peer = random.choice(peers)
  18. expert_uids = [f"my_expert.{i}" for i in range(50)]
  19. batch_size = 10
  20. for batch_start in range(0, len(expert_uids), batch_size):
  21. declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size])
  22. found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
  23. assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
  24. assert all(res is None for res in found[-2:]), "Found non-existing experts"
  25. other_expert = "my_other_expert.1337"
  26. declare_experts(other_peer, [other_expert])
  27. first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
  28. assert isinstance(first_found, hivemind.RemoteExpert)
  29. assert first_found.peer_id == other_peer.peer_id
  30. assert first_notfound is None
  31. # test graceful shutdown
  32. first_peer.shutdown()
  33. other_peer.shutdown()
  34. time.sleep(1.0)
  35. remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
  36. remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
  37. assert all(declare_experts(remaining_peer1, ["new_expert.1"]))
  38. assert get_experts(remaining_peer2, ["new_expert.1"])[0].peer_id == remaining_peer1.peer_id
  39. @pytest.mark.forked
  40. def test_beam_search(
  41. n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
  42. ):
  43. dht_instances = [hivemind.DHT(start=True)]
  44. initial_peers = dht_instances[0].get_visible_maddrs()
  45. dht_instances += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
  46. real_experts = sorted(
  47. {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
  48. )
  49. for batch_start in range(0, len(real_experts), batch_size):
  50. dht = random.choice(dht_instances)
  51. declare_experts(
  52. dht,
  53. real_experts[batch_start : batch_start + batch_size],
  54. )
  55. neighbors = sum(
  56. [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], []
  57. )
  58. you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
  59. beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
  60. for i in range(10):
  61. topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
  62. assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
  63. assert len(topk_experts) == beam_size
  64. for i in range(10):
  65. batch_experts = beam_search.batch_find_best_experts(
  66. [np.random.randn(batch_size, dim) for dim in grid_dims], beam_size=beam_size
  67. )
  68. assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
  69. assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
  70. assert all(len(experts) == beam_size for experts in batch_experts)
  71. @pytest.mark.forked
  72. def test_dht_single_node():
  73. node = hivemind.DHT(start=True)
  74. beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
  75. assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"]).values())
  76. assert len(declare_experts(node, ["ffn.1", "ffn.2"])) == 4
  77. assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"])) == 7
  78. for expert in get_experts(node, ["expert.3", "expert.2"]):
  79. assert expert.peer_id == node.peer_id
  80. assert all(declare_experts(node, ["expert.5", "expert.2"]).values())
  81. found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
  82. assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
  83. successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
  84. assert len(successors["e.1.2."]) == 2
  85. assert successors["e.1.2."][3] == ExpertInfo("e.1.2.3", node.peer_id)
  86. assert successors["e.1.2."][5] == ExpertInfo("e.1.2.5", node.peer_id)
  87. assert len(successors["e.2."]) == 1 and successors["e.2."][0] == ExpertInfo("e.2.0", node.peer_id)
  88. assert successors["e.4.5."] == {}
  89. initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
  90. assert len(initial_beam) == 3
  91. assert initial_beam[0][:2] == (2.0, "expert.1.")
  92. assert initial_beam[1][:2] == (1.0, "expert.2.")
  93. assert initial_beam[2][:2] == (0.0, "expert.3.")
  94. with pytest.raises(AssertionError):
  95. beam_search = MoEBeamSearcher(node, "expert.1.ffn", (2, 2))
  96. with pytest.raises(AssertionError):
  97. beam_search.get_active_successors(["e.1.2.", "e.2", "e.4.5."])
  98. def test_uid_patterns():
  99. valid_experts = [
  100. "expert.1",
  101. "expert.0",
  102. "expert.0.0.1",
  103. "expert.1337",
  104. "ffn.12.34.56.78.90",
  105. "transformer.3.2.1.0",
  106. "transformer_encoder.2",
  107. "transformer::encoder.2",
  108. "T®@nsf0rmE®🤗.321",
  109. "🤗.321",
  110. "0.1.2",
  111. "00.1.2",
  112. "7070.3.2.1.0",
  113. "block2.1.23",
  114. "LAYER.1.0.1",
  115. ]
  116. valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
  117. valid_prefixes.extend([f"{uid}." for uid in valid_experts])
  118. valid_prefixes.extend([split_uid(uid)[0] for uid in valid_experts])
  119. for uid in valid_experts:
  120. assert is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
  121. for pfx in valid_prefixes:
  122. assert is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
  123. invalid = [
  124. "",
  125. ".",
  126. "expert.-1",
  127. "xxx.a",
  128. "expert.1x",
  129. "expert_ffn.1.abc1",
  130. "some.123.01",
  131. "expert.123.01",
  132. "e1",
  133. "e..1",
  134. "e",
  135. "e.1.2.3..4",
  136. "ffn.1..1",
  137. ".123",
  138. ".1.2.3.",
  139. ".expert",
  140. "transformer.encoder.2",
  141. "T®@nsf0rmE®.🤗.321",
  142. "layer::123",
  143. "expert.0.1.2.suffix",
  144. "0.1.2.suffix",
  145. "expert.1 something",
  146. "expert.1\n",
  147. "expert.1\n2",
  148. "expert.1 ",
  149. "expert.1\nexpert.2",
  150. "'expert.1'",
  151. '"expert.1"',
  152. ]
  153. invalid_experts = invalid + valid_prefixes + ["0", "123456"]
  154. invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
  155. for uid in invalid_experts:
  156. assert not is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
  157. for pfx in invalid_prefixes:
  158. assert not is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
  159. @pytest.mark.forked
  160. @pytest.mark.asyncio
  161. async def test_negative_caching(n_peers=10):
  162. dht_kwargs = {"cache_locally": False}
  163. peers = [hivemind.DHT(start=True, **dht_kwargs)]
  164. initial_peers = peers[0].get_visible_maddrs()
  165. peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
  166. writer_peer = random.choice(peers)
  167. assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"]).values())
  168. neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
  169. neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)
  170. beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix="ffn.", grid_size=(10, 10, 10), negative_caching=True)
  171. # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
  172. assert len(beam_search.get_initial_beam(scores=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], beam_size=3)) == 2
  173. node = await DHTNode.create(initial_peers=neighbors)
  174. fetched = await asyncio.gather(*(node.get(f"ffn.{i}.") for i in range(10)))
  175. for i in range(6):
  176. assert fetched[i] is not None, f"node should have cached ffn.{i}."
  177. for i in range(6, len(fetched)):
  178. assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
  179. await node.shutdown()
  180. neg_caching_peer.shutdown()
  181. for peer in peers:
  182. peer.shutdown()