test_dht_experts.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import random
  2. import uuid
  3. from itertools import chain
  4. import numpy as np
  5. import hivemind
  6. from hivemind import LOCALHOST
  7. def test_store_get_experts():
  8. peers = [hivemind.DHT(start=True)]
  9. for i in range(10):
  10. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
  11. peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
  12. you: hivemind.dht.DHT = random.choice(peers)
  13. theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
  14. expert_uids = [str(uuid.uuid4()) for _ in range(110)]
  15. batch_size = 10
  16. for batch_start in range(0, len(expert_uids), batch_size):
  17. you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
  18. found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
  19. assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
  20. assert all(res is None for res in found[-2:]), "Found non-existing experts"
  21. that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
  22. theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
  23. you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
  24. assert isinstance(you_found, hivemind.RemoteExpert)
  25. assert you_found.endpoint == f'that_host:{that_guys_port}'
  26. # test first_k_active
  27. assert list(theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10)) == expert_uids[:10]
  28. some_permuted_experts = random.sample(expert_uids, k=32)
  29. assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32)) == some_permuted_experts
  30. assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1)) == some_permuted_experts[:1]
  31. fake_and_real_experts = list(chain(*zip(
  32. [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
  33. assert list(theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9)) == some_permuted_experts[:9]
  34. for peer in peers:
  35. peer.shutdown()
  36. def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
  37. grid_dims=(32, 32, 32)):
  38. dht = []
  39. for i in range(dht_size):
  40. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
  41. dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
  42. real_experts = sorted({
  43. 'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
  44. for _ in range(total_experts)
  45. })
  46. for batch_start in range(0, len(real_experts), batch_size):
  47. random.choice(dht).declare_experts(
  48. real_experts[batch_start: batch_start + batch_size], wait=True,
  49. endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
  50. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
  51. you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
  52. for i in range(50):
  53. topk_experts = you.find_best_experts('expert', [np.random.randn(dim) for dim in grid_dims], beam_size=beam_size)
  54. assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
  55. assert len(topk_experts) == beam_size
  56. for i in range(10):
  57. batch_experts = you.batch_find_best_experts('expert', [np.random.randn(batch_size, dim) for dim in grid_dims],
  58. beam_size=beam_size)
  59. assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
  60. assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
  61. assert all(len(experts) == beam_size for experts in batch_experts)
  62. def test_first_k_active():
  63. node = hivemind.DHT(start=True)
  64. assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
  65. assert all(node.declare_experts(['e.2.1.1'], endpoint=f"{hivemind.LOCALHOST}:1338"))
  66. results = node.first_k_active(['e.0', 'e.1', 'e.2', 'e.3'], k=2)
  67. assert len(results) == 2 and next(iter(results.keys())) == 'e.1'
  68. assert results['e.1'].uid in ('e.1.2.3', 'e.1.2.4') and results['e.1'].endpoint == f"{hivemind.LOCALHOST}:1337"
  69. assert results['e.2'].uid == 'e.2.1.1' and results['e.2'].endpoint == f"{hivemind.LOCALHOST}:1338"
  70. results = node.first_k_active(['e', 'e.1', 'e.1.2', 'e.1.2.3'], k=10)
  71. assert len(results) == 4
  72. assert 'e' in results
  73. for k in ('e.1', 'e.1.2', 'e.1.2.3'):
  74. assert results[k].uid in ('e.1.2.3', 'e.1.2.4') and results[k].endpoint == f"{hivemind.LOCALHOST}:1337"
  75. def test_dht_single_node():
  76. node = hivemind.DHT(start=True)
  77. assert node.first_k_active(['e.3', 'e.2'], k=3) == {}
  78. assert node.get_experts(['e.3', 'e.2']) == [None, None]
  79. assert all(node.declare_experts(['e.1', 'e.2', 'e.3'], f"{hivemind.LOCALHOST}:1337"))
  80. for expert in node.get_experts(['e.3', 'e.2']):
  81. assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
  82. active_found = node.first_k_active(['e.0', 'e.1', 'e.3', 'e.5', 'e.2'], k=2)
  83. assert list(active_found.keys()) == ['e.1', 'e.3']
  84. assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
  85. assert all(node.declare_experts(['e.1', 'e.2', 'e.3'], f"{hivemind.LOCALHOST}:1337"))
  86. assert node.find_best_experts('e', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=4)