test_routing.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import random
  2. import heapq
  3. import operator
  4. from itertools import chain, zip_longest
  5. from hivemind import LOCALHOST
  6. from hivemind.dht.routing import RoutingTable, DHTID
  7. from hivemind.utils.serializer import PickleSerializer
  8. def test_ids_basic():
  9. # basic functionality tests
  10. for i in range(100):
  11. id1, id2 = DHTID.generate(), DHTID.generate()
  12. assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX
  13. assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2, id2) == 0
  14. assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2)
  15. assert len(PickleSerializer.dumps(id1)) - len(PickleSerializer.dumps(int(id1))) < 40
  16. assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes(id2.to_bytes()) == id2
  17. def test_ids_depth():
  18. for i in range(100):
  19. ids = [random.randint(0, 4096) for i in range(random.randint(1, 256))]
  20. ours = DHTID.longest_common_prefix_length(*map(DHTID, ids))
  21. ids_bitstr = [
  22. "".join(bin(bite)[2:].rjust(8, '0') for bite in uid.to_bytes(20, 'big'))
  23. for uid in ids
  24. ]
  25. reference = len(shared_prefix(*ids_bitstr))
  26. assert reference == ours, f"ours {ours} != reference {reference}, ids: {ids}"
  27. def test_routing_table_basic():
  28. node_id = DHTID.generate()
  29. routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
  30. for phony_neighbor_port in random.sample(range(10000), 100):
  31. phony_id = DHTID.generate()
  32. routing_table.add_or_update_node(phony_id, f'{LOCALHOST}:{phony_neighbor_port}')
  33. assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
  34. assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
  35. for bucket in routing_table.buckets:
  36. assert len(bucket.replacement_nodes) == 0, "There should be no replacement nodes in a table with 100 entries"
  37. assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets)
  38. def test_routing_table_parameters():
  39. for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
  40. (20, 5, 45, 65),
  41. (50, 5, 35, 45),
  42. (20, 10, 650, 800),
  43. (20, 1, 7, 15),
  44. ]:
  45. node_id = DHTID.generate()
  46. routing_table = RoutingTable(node_id, bucket_size=bucket_size, depth_modulo=modulo)
  47. for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
  48. routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
  49. for bucket in routing_table.buckets:
  50. assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_addr) <= bucket.size
  51. assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
  52. f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}")
  53. def test_routing_table_search():
  54. for table_size, lower_active, upper_active in [
  55. (10, 10, 10), (10_000, 800, 1100)
  56. ]:
  57. node_id = DHTID.generate()
  58. routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
  59. num_added = 0
  60. total_nodes = 0
  61. for phony_neighbor_port in random.sample(range(1_000_000), table_size):
  62. routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
  63. new_total = sum(len(bucket.nodes_to_addr) for bucket in routing_table.buckets)
  64. num_added += new_total > total_nodes
  65. total_nodes = new_total
  66. num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
  67. all_active_neighbors = list(chain(
  68. *(bucket.nodes_to_addr.keys() for bucket in routing_table.buckets)
  69. ))
  70. assert lower_active <= len(all_active_neighbors) <= upper_active
  71. assert len(all_active_neighbors) == num_added
  72. assert num_added + num_replacements == table_size
  73. # random queries
  74. for i in range(500):
  75. k = random.randint(1, 100)
  76. query_id = DHTID.generate()
  77. exclude = query_id if random.random() < 0.5 else None
  78. our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
  79. reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
  80. assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
  81. assert all(our_addr == routing_table[our_node] for our_node, our_addr in zip(our_knn, our_addrs))
  82. # queries from table
  83. for i in range(500):
  84. k = random.randint(1, 100)
  85. query_id = random.choice(all_active_neighbors)
  86. our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
  87. reference_knn = heapq.nsmallest(
  88. k + 1, all_active_neighbors,
  89. key=lambda uid: query_id.xor_distance(uid))
  90. if query_id in reference_knn:
  91. reference_knn.remove(query_id)
  92. assert len(our_knn) == len(reference_knn)
  93. assert all(query_id.xor_distance(our) == query_id.xor_distance(ref)
  94. for our, ref in zip_longest(our_knn, reference_knn))
  95. assert routing_table.get_nearest_neighbors(query_id, k=k, exclude=None)[0][0] == query_id
  96. def shared_prefix(*strings: str):
  97. for i in range(min(map(len, strings))):
  98. if len(set(map(operator.itemgetter(i), strings))) != 1:
  99. return strings[0][:i]
  100. return min(strings, key=len)