test_dht_experts.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import asyncio
  2. import random
  3. import time
  4. import numpy as np
  5. import pytest
  6. import hivemind
  7. from hivemind.dht import DHTNode
  8. from hivemind import LOCALHOST
  9. from hivemind.moe.client.beam_search import MoEBeamSearcher
  10. from hivemind.moe.server import declare_experts, get_experts
  11. from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
  12. @pytest.mark.forked
  13. def test_store_get_experts(n_peers=10):
  14. peers = [hivemind.DHT(start=True)]
  15. initial_peers = peers[0].get_visible_maddrs()
  16. peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
  17. first_peer = random.choice(peers)
  18. other_peer = random.choice(peers)
  19. expert_uids = [f"my_expert.{i}" for i in range(50)]
  20. batch_size = 10
  21. for batch_start in range(0, len(expert_uids), batch_size):
  22. declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], "localhost:1234")
  23. found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
  24. assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
  25. assert all(res is None for res in found[-2:]), "Found non-existing experts"
  26. other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
  27. declare_experts(other_peer, [other_expert], f"that_host:{other_port}")
  28. first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
  29. assert isinstance(first_found, hivemind.RemoteExpert)
  30. assert first_found.endpoint == f"that_host:{other_port}"
  31. # test graceful shutdown
  32. first_peer.shutdown()
  33. other_peer.shutdown()
  34. time.sleep(1.0)
  35. remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
  36. remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
  37. assert all(declare_experts(remaining_peer1, ["new_expert.1"], "dummy"))
  38. assert get_experts(remaining_peer2, ["new_expert.1"])[0].endpoint == "dummy"
  39. @pytest.mark.forked
  40. def test_beam_search(
  41. n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
  42. ):
  43. dht = [hivemind.DHT(start=True)]
  44. initial_peers = dht[0].get_visible_maddrs()
  45. dht += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
  46. real_experts = sorted(
  47. {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
  48. )
  49. for batch_start in range(0, len(real_experts), batch_size):
  50. declare_experts(
  51. random.choice(dht),
  52. real_experts[batch_start : batch_start + batch_size],
  53. wait=True,
  54. endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}",
  55. )
  56. neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
  57. you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
  58. beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
  59. for i in range(10):
  60. topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
  61. assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
  62. assert len(topk_experts) == beam_size
  63. for i in range(10):
  64. batch_experts = beam_search.batch_find_best_experts(
  65. [np.random.randn(batch_size, dim) for dim in grid_dims], beam_size=beam_size
  66. )
  67. assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
  68. assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
  69. assert all(len(experts) == beam_size for experts in batch_experts)
  70. @pytest.mark.forked
  71. def test_dht_single_node():
  72. node = hivemind.DHT(start=True)
  73. beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
  74. assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], f"{hivemind.LOCALHOST}:1337").values())
  75. assert len(declare_experts(node, ["ffn.1", "ffn.2"], endpoint="that_place")) == 4
  76. assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], f"{hivemind.LOCALHOST}:42")) == 7
  77. for expert in get_experts(node, ["expert.3", "expert.2"]):
  78. assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
  79. assert all(declare_experts(node, ["expert.5", "expert.2"], f"{hivemind.LOCALHOST}:1337").values())
  80. found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
  81. assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
  82. successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
  83. assert len(successors["e.1.2."]) == 2
  84. assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", f"{LOCALHOST}:42")
  85. assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", f"{LOCALHOST}:42")
  86. assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", f"{LOCALHOST}:42")
  87. assert successors["e.4.5."] == {}
  88. initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
  89. assert len(initial_beam) == 3
  90. assert initial_beam[0][:2] == (2.0, "expert.1.")
  91. assert initial_beam[1][:2] == (1.0, "expert.2.")
  92. assert initial_beam[2][:2] == (0.0, "expert.3.")
  93. with pytest.raises(AssertionError):
  94. beam_search = MoEBeamSearcher(node, "expert.1.ffn", (2, 2))
  95. with pytest.raises(AssertionError):
  96. beam_search.get_active_successors(["e.1.2.", "e.2", "e.4.5."])
  97. def test_uid_patterns():
  98. valid_experts = [
  99. "expert.1",
  100. "expert.0",
  101. "expert.0.0.1",
  102. "expert.1337",
  103. "ffn.12.34.56.78.90",
  104. "transformer.3.2.1.0",
  105. "transformer_encoder.2",
  106. "transformer::encoder.2",
  107. "T®@nsf0rmE®🤗.321",
  108. "🤗.321",
  109. "0.1.2",
  110. "00.1.2",
  111. "7070.3.2.1.0",
  112. "block2.1.23",
  113. "LAYER.1.0.1",
  114. ]
  115. valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
  116. valid_prefixes.extend([f"{uid}." for uid in valid_experts])
  117. valid_prefixes.extend([split_uid(uid)[0] for uid in valid_experts])
  118. for uid in valid_experts:
  119. assert is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
  120. for pfx in valid_prefixes:
  121. assert is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
  122. invalid = [
  123. "",
  124. ".",
  125. "expert.-1",
  126. "xxx.a",
  127. "expert.1x",
  128. "expert_ffn.1.abc1",
  129. "some.123.01",
  130. "expert.123.01",
  131. "e1",
  132. "e..1",
  133. "e",
  134. "e.1.2.3..4",
  135. "ffn.1..1",
  136. ".123",
  137. ".1.2.3.",
  138. ".expert",
  139. "transformer.encoder.2",
  140. "T®@nsf0rmE®.🤗.321",
  141. "layer::123",
  142. "expert.0.1.2.suffix",
  143. "0.1.2.suffix",
  144. "expert.1 something",
  145. "expert.1\n",
  146. "expert.1\n2",
  147. "expert.1 ",
  148. "expert.1\nexpert.2",
  149. "'expert.1'",
  150. '"expert.1"',
  151. ]
  152. invalid_experts = invalid + valid_prefixes + ["0", "123456"]
  153. invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
  154. for uid in invalid_experts:
  155. assert not is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
  156. for pfx in invalid_prefixes:
  157. assert not is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
  158. @pytest.mark.forked
  159. @pytest.mark.asyncio
  160. async def test_negative_caching(n_peers=10):
  161. dht_kwargs = {"cache_locally": False}
  162. peers = [hivemind.DHT(start=True, **dht_kwargs)]
  163. initial_peers = peers[0].get_visible_maddrs()
  164. peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
  165. writer_peer = random.choice(peers)
  166. assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], "myaddr:1234").values())
  167. neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
  168. neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)
  169. beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix="ffn.", grid_size=(10, 10, 10), negative_caching=True)
  170. # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
  171. assert len(beam_search.get_initial_beam(scores=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], beam_size=3)) == 2
  172. node = await DHTNode.create(initial_peers=neighbors)
  173. fetched = await asyncio.gather(*(node.get(f"ffn.{i}.") for i in range(10)))
  174. for i in range(6):
  175. assert fetched[i] is not None, f"node should have cached ffn.{i}."
  176. for i in range(6, len(fetched)):
  177. assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
  178. await node.shutdown()
  179. neg_caching_peer.shutdown()
  180. for peer in peers:
  181. peer.shutdown()