ソースを参照

Fix issue 71: DHT.declare_experts hangs without peers (#72)

* add test for single-node dht (reported by unconst)

* minor: support visited_ids=None

* minor: support visited_ids=None

* handle exclude_self better

* typo

* finalize, add self to routing table

* rollback

* rollback

* rollback

* wip

* finalize
justheuristic 5 年 前
コミット
c3cbbdb8c5
3 ファイル変更48 行追加24 行削除
  1. 37 21
      hivemind/dht/node.py
  2. 3 3
      hivemind/dht/traverse.py
  3. 8 0
      tests/test_dht.py

+ 37 - 21
hivemind/dht/node.py

@@ -168,18 +168,18 @@ class DHTNode:
                 output[query] = list(peers.keys()), False  # False means "do not interrupt search"
             return output
 
-        nearest_nodes, visited_nodes = await traverse_dht(
+        nearest_nodes_per_query, visited_nodes = await traverse_dht(
             queries, initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers,
             queries_per_call=int(len(queries) ** 0.5), get_neighbors=get_neighbors,
             visited_nodes={query: {self.node_id} for query in queries}, **kwargs)
 
-        nearest_nodes_per_query = {}
-        for query, nearest_nodes in nearest_nodes.items():
+        nearest_nodes_with_endpoints = {}
+        for query, nearest_nodes in nearest_nodes_per_query.items():
             if not exclude_self:
                 nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
                 node_to_endpoint[self.node_id] = f"{LOCALHOST}:{self.port}"
-            nearest_nodes_per_query[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
-        return nearest_nodes_per_query
+            nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
+        return nearest_nodes_with_endpoints
 
     async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, **kwargs) -> bool:
         """
@@ -233,28 +233,44 @@ class DHTNode:
             """ This will be called once per key when find_nearest_nodes is done for a particular node """
             # note: we use callbacks instead of returned values to call store immediately without waiting for stragglers
             assert key_id in unfinished_key_ids, "Internal error: traverse_dht finished the same query twice"
+            assert self.node_id not in nearest_nodes
             unfinished_key_ids.remove(key_id)
 
-            # ensure k nodes and (optionally) exclude self
-            nearest_nodes = [node_id for node_id in nearest_nodes if (not exclude_self or node_id != self.node_id)]
-            store_args = [key_id], [binary_values_by_key_id[key_id]], [expiration_by_key_id[key_id]]
-            store_tasks = {asyncio.create_task(self.protocol.call_store(node_to_endpoint[nearest_node_id], *store_args))
-                           for nearest_node_id in nearest_nodes[:self.num_replicas]}
-            backup_nodes = nearest_nodes[self.num_replicas:]  # used in case previous nodes didn't respond
-
-            # parse responses and issue additional stores if someone fails
-            while store_tasks:
-                finished_store_tasks, store_tasks = await asyncio.wait(store_tasks, return_when=asyncio.FIRST_COMPLETED)
-                for task in finished_store_tasks:
-                    if task.result()[0]:  # if store succeeded
+            # ensure k nodes stored the value, optionally include self.node_id as a candidate
+            num_successful_stores = 0
+            pending_store_tasks = set()
+            store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
+                                      key=key_id.xor_distance, reverse=True)  # ordered so that .pop() returns nearest
+
+            while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
+                # spawn enough tasks to cover all replicas
+                while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
+                    node_id: DHTID = store_candidates.pop()  # nearest untried candidate
+                    if node_id == self.node_id:
+                        self.protocol.storage.store(key_id, binary_values_by_key_id[key_id],
+                                                    expiration_by_key_id[key_id])
                         store_ok[id_to_original_key[key_id]] = True
+                        num_successful_stores += 1
                         if not await_all_replicas:
                             store_finished_events[id_to_original_key[key_id]].set()
-                    elif backup_nodes:
-                        store_tasks.add(asyncio.create_task(
-                            self.protocol.call_store(node_to_endpoint[backup_nodes.pop(0)], *store_args)))
 
-                store_finished_events[id_to_original_key[key_id]].set()
+                    else:
+                        pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
+                            node_to_endpoint[node_id], [key_id], [binary_values_by_key_id[key_id]],
+                            [expiration_by_key_id[key_id]])))
+
+                # await nearest task. If it fails, dispatch more on the next iteration
+                if pending_store_tasks:
+                    finished_store_tasks, pending_store_tasks = await asyncio.wait(
+                        pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
+                    for task in finished_store_tasks:
+                        if task.result()[0]:  # if store succeeded
+                            store_ok[id_to_original_key[key_id]] = True
+                            num_successful_stores += 1
+                            if not await_all_replicas:
+                                store_finished_events[id_to_original_key[key_id]].set()
+
+            store_finished_events[id_to_original_key[key_id]].set()
 
         asyncio.create_task(self.find_nearest_nodes(
             queries=set(key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,

+ 3 - 3
hivemind/dht/traverse.py

@@ -105,13 +105,13 @@ async def traverse_dht(
         visited nodes: { query -> a set of all nodes that received requests for a given query }
     """
     if len(queries) == 0:
-        return {}, dict(visited_nodes)
+        return {}, dict(visited_nodes or {})
 
     unfinished_queries = set(queries)  # all queries that haven't triggered finish_search yet
     candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: unvisited nodes, ordered nearest-to-farthest
     nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: top-k nearest nodes, farthest-to-nearest
     known_nodes: Dict[DHTID, Set[DHTID]] = {}  # all nodes ever added to the heap (for deduplication)
-    visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes)  # where we requested get_neighbors for a given query
+    visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes or {})  # nodes that were chosen for get_neighbors call
     pending_tasks = set()  # all active tasks (get_neighbors and found_callback)
     active_workers = Counter({q: 0 for q in queries})  # count workers that search for this query
 
@@ -149,7 +149,7 @@ async def traverse_dht(
             search_finished_event.set()
         if found_callback:
             nearest_neighbors = [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
-            pending_tasks.add(asyncio.create_task(found_callback(query, nearest_neighbors, set(visited_nodes))))
+            pending_tasks.add(asyncio.create_task(found_callback(query, nearest_neighbors, set(visited_nodes[query]))))
 
     async def worker():
         while unfinished_queries:

+ 8 - 0
tests/test_dht.py

@@ -300,6 +300,14 @@ def test_hivemind_dht():
         peer.shutdown()
 
 
+def test_dht_single_node():
+    node = hivemind.DHT(start=True)
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], hivemind.LOCALHOST, 1337))
+    for expert in node.get_experts(['e3', 'e2']):
+        assert expert.host == hivemind.LOCALHOST and expert.port == 1337
+    assert node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2) == ['e1', 'e3']
+
+
 def test_store():
     d = LocalStorage()
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)