Эх сурвалжийг харах

Faster beam search: part1/2 (#97)

* [WIP] hivemind.dht.DHT.get_best_experts

* bump

* unused variable

* implement early stopping event for get_many

* implement early stopping event for get_many

* cancel wait_for_termination

* add __contains__ check to LocalStorage

* doc: add single call guarantee

* LocalStorage: add strict typing

* rollback

* add LocalStorage.top

* rollback

* add strict types and builtins

* LocalStorage: add builtins and freeze

* remove_outdated is now private. Rationale: it doesn't always remove outdated elements, better not cause confusion

* caching rev.2 - now it works ;)

* sphinx formatting fix

* add test for caching policy

* pep8

* use separate tester process

* edge case: frozen LocalStorage

* move LocalStorage tests to a separate file

* test for LocalStorage.top

* add tests for new localstorage functionality

* separate tests for DHTNode and experts-related stuff

* separate tests for DHTNode and experts-related stuff

* typo

* add option to reuse pending get requests

* == None -> is None

* WIP that breaks tests

* WIP that breaks tests

* split get_many_by_id into sub-functions, add tests

* remove debugprint

* circleci u mad?

* circleci u mad?

* restore caching

* rm debugprint

* better expiration_time_threshold in DHTNode reuse

* TEMPORARY: fix broken circleci cache

* recache?

* WIP: override cache?

* WIP: override cache?

* WIP: override cache?

* WIP: override cache?

* pep

* add cmd

* rollback auto refactor

* dev1

* bugfix: do not finish_search if there are concurrent workers

* bugfix: do not finish_search if there are concurrent workers

* bugfix: do not finish_search if there are concurrent workers

* remove find_best_experts for now

* we can probably still finish search for a query if it has no concurrent workers

* update benchmark_dht

* rollback changes

* rollback changes

* rollback changes

* unuser import

* auto refactor typo

* typo

* update benchmarks

* fix broken sphinx url

* misc renames

* misc renames

* Update docs/user/contributing.md

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

* Update docs/user/contributing.md

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

* address review by mryab

* address review by mryab

* address review by mryab

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 5 жил өмнө
parent
commit
9c1e14aca1

+ 1 - 0
.circleci/config.yml

@@ -9,6 +9,7 @@ jobs:
     steps:
     steps:
       - checkout
       - checkout
       - python/load-cache
       - python/load-cache
+      - run: pip uninstall -y pytest codecov  # temporary override for broken cache
       - run: pip install codecov pytest tqdm scikit-learn
       - run: pip install codecov pytest tqdm scikit-learn
       - python/install-deps
       - python/install-deps
       - python/save-cache
       - python/save-cache

+ 12 - 11
docs/user/contributing.md

@@ -1,4 +1,4 @@
-## Contributing
+## Developer zone
 
 
 #### Collaborating best practices:
 #### Collaborating best practices:
 Hivemind is still in the early stage of development, we expect only a handful of collaborators with individual roles.
 Hivemind is still in the early stage of development, we expect only a handful of collaborators with individual roles.
@@ -19,7 +19,7 @@ Hivemind is still in the early stage of development, we expect only a handful of
    * If you face any challenges or want feedback, please submit a [draft](https://github.blog/2019-02-14-introducing-draft-pull-requests/) pull request.
    * If you face any challenges or want feedback, please submit a [draft](https://github.blog/2019-02-14-introducing-draft-pull-requests/) pull request.
 
 
 
 
-#### Contributor's manual
+#### Developer quickstart
 
 
 First, install hivemind in the development mode, preferably with python 3.8 on linux/mac_OS.
 First, install hivemind in the development mode, preferably with python 3.8 on linux/mac_OS.
 ```
 ```
@@ -98,23 +98,24 @@ to measure performance impact of changes to hivemind.dht. It spawns a DHT with `
 then chooses one peer that will declare `num_experts` total experts in batches of `expert_batch_size`.
 then chooses one peer that will declare `num_experts` total experts in batches of `expert_batch_size`.
 Then, another peer will consecutively get all peers and check if they are there.
 Then, another peer will consecutively get all peers and check if they are there.
 
 
-Here's a run with 1024 participants on the same machine that was used benchmark_throughput:
+Here's a run with 1024 participants on the same machine that was used for benchmark_throughput:
 
 
+`python benchmark_dht.py --num_peers 1024 --num_experts 16384 --expert_batch_size 64 --expiration 99999 --increase_file_limit`
 <details style="margin-top:-24px; margin-bottom: 16px;">
 <details style="margin-top:-24px; margin-bottom: 16px;">
   <summary>Console outputs</summary>
   <summary>Console outputs</summary>
   
   
   ```sh
   ```sh
 Increasing file limit - soft 1024=>32768, hard 1048576=>32768
 Increasing file limit - soft 1024=>32768, hard 1048576=>32768
 Creating peers...
 Creating peers...
-100%|███████████████████████████████████████████████████| 1024/1024 [01:51<00:00,  9.22it/s]
+100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [01:45<00:00,  9.74it/s]
 Sampled 16384 unique ids (after deduplication)
 Sampled 16384 unique ids (after deduplication)
 Storing peers to dht in batches of 64...
 Storing peers to dht in batches of 64...
-100%|█████████████████████████████████████████████████████| 256/256 [13:00<00:00,  3.05s/it]
-Store success rate: 100.0% (48904 / 48904)
-Mean store time: 0.015967, Total: 780.85
-100%|█████████████████████████████████████████████████████| 256/256 [02:01<00:00,  2.11it/s]
-Get success rate: 100.0 (16383 / 16384)
-Mean get time: 0.00740, Total: 121.29011
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [12:07<00:00,  2.84s/it]
+Store success rate: 100.0% (48920 / 48920)
+Mean store time: 0.01487, Total: 727.46
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [01:48<00:00,  2.35it/s]
+Get success rate: 100.0 (16384 / 16384)
+Mean get time: 0.00664, Total: 108.73952
 Node survival rate: 100.000%
 Node survival rate: 100.000%
   ```
   ```
 </details>
 </details>
@@ -125,6 +126,6 @@ If one wants to account for these factors, one must introduce them manually by c
   
   
 
 
 #### Tips & tricks
 #### Tips & tricks
-* You can find a wealth of pytorch debugging tricks at [their contributing page](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md).
+* You can find a wealth of pytorch debugging tricks at [their contributing page](https://tinyurl.com/pytorch-contributing).
 * Hivemind is optimized for development in pycharm CE 2019.3 or newer.
 * Hivemind is optimized for development in pycharm CE 2019.3 or newer.
   * When working on tests, please mark "tests" as sources root.
   * When working on tests, please mark "tests" as sources root.

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.2'
+__version__ = '0.8.3'

+ 4 - 2
hivemind/dht/__init__.py

@@ -25,7 +25,9 @@ import uvloop
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time
 from hivemind.dht.routing import get_dht_time
-from hivemind.utils import MPFuture, Endpoint
+from hivemind.utils import MPFuture, Endpoint, get_logger
+
+logger = get_logger(__name__)
 
 
 
 
 class DHT(mp.Process):
 class DHT(mp.Process):
@@ -155,7 +157,7 @@ class DHT(mp.Process):
         :param uids: a list of expert ids to update
         :param uids: a list of expert ids to update
         :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
         :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
         :param wait: if True, awaits for declaration to finish, otherwise runs in background
         :param wait: if True, awaits for declaration to finish, otherwise runs in background
-        :param timeout: waits for the procedure to finish, None means wait indeninitely
+        :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
         :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
         :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
         """
         """
         assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
         assert not isinstance(uids, str), "Please send a list / tuple of expert uids."

+ 239 - 77
hivemind/dht/node.py

@@ -1,15 +1,21 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
+
 import random
 import random
-from collections import namedtuple
-from typing import Optional, Tuple, List, Dict, Collection, Union, Set
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any, Iterable
+from sortedcontainers import SortedList
+from functools import partial
 from warnings import warn
 from warnings import warn
 
 
-from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue
+from hivemind.dht.protocol import DHTProtocol, LocalStorage
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue
 from hivemind.dht.traverse import traverse_dht
 from hivemind.dht.traverse import traverse_dht
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
+
+logger = get_logger(__name__)
 
 
 
 
 class DHTNode:
 class DHTNode:
@@ -45,8 +51,10 @@ class DHTNode:
 
 
     """
     """
     # fmt:off
     # fmt:off
-    node_id: DHTID; port: int; num_replicas: int; cache_locally: bool; cache_nearest: int; num_workers: int
-    refresh_timeout: float; protocol: DHTProtocol
+    node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
+    refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
+    cache_refresh_available: asyncio.Event; cache_refresh_queue: LocalStorage
+    reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_IntermediateResult]]
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     # fmt:on
     # fmt:on
 
 
@@ -55,8 +63,9 @@ class DHTNode:
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
-            num_workers: int = 1, cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
-            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
+            cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
+            reuse_get_requests: bool = True, num_workers: int = 1, listen: bool = True,
+            listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
         """
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
@@ -71,11 +80,15 @@ class DHTNode:
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
-        :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
         :param cache_locally: if True, caches all values (stored or found) in a node-local cache
         :param cache_locally: if True, caches all values (stored or found) in a node-local cache
         :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
         :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
           nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
           nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
         :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
         :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
+        :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
+          if they are accessed this many seconds before expiration time.
+        :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
+          all concurrent get requests for the same key will reuse the procedure that is currently in progress
+        :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
           if False, this node will refuse any incoming request, effectively being only a "client"
           if False, this node will refuse any incoming request, effectively being only a "client"
         :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
         :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -83,11 +96,26 @@ class DHTNode:
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
         :param kwargs: extra parameters used in grpc.aio.server
         :param kwargs: extra parameters used in grpc.aio.server
         """
         """
+        if cache_refresh_before_expiry > 0 and not cache_locally:
+            logger.warning("If cache_locally is False, cache_refresh_before_expiry has no effect. To silence this"
+                           " warning, please specify cache_refresh_before_expiry=0")
+
         self = cls(_initialized_with_create=True)
         self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.num_replicas, self.num_workers = num_replicas, num_workers
         self.num_replicas, self.num_workers = num_replicas, num_workers
-        self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
+        self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
+
+        self.reuse_get_requests = reuse_get_requests
+        self.pending_get_requests = defaultdict(partial(SortedList, key=lambda _res: - _res.sufficient_expiration_time))
+
+        # caching policy
         self.refresh_timeout = refresh_timeout
         self.refresh_timeout = refresh_timeout
+        self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
+        self.cache_refresh_before_expiry = cache_refresh_before_expiry
+        self.cache_refresh_queue = LocalStorage()
+        self.cache_refresh_available = asyncio.Event()
+        if cache_refresh_before_expiry:
+            asyncio.create_task(self._refresh_stale_cache_entries())
 
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
                                                  parallel_rpc, cache_size, listen, listen_on, **kwargs)
                                                  parallel_rpc, cache_size, listen, listen_on, **kwargs)
@@ -129,7 +157,9 @@ class DHTNode:
 
 
     async def shutdown(self, timeout=None):
     async def shutdown(self, timeout=None):
         """ Process existing requests, close all connections and stop the server """
         """ Process existing requests, close all connections and stop the server """
-        await self.protocol.shutdown(timeout)
+        self.is_alive = False
+        if self.protocol.server:
+            await self.protocol.shutdown(timeout)
 
 
     async def find_nearest_nodes(
     async def find_nearest_nodes(
             self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
             self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
@@ -157,15 +187,15 @@ class DHTNode:
                 node_to_endpoint.update(
                 node_to_endpoint.update(
                     self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id))
                     self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id))
 
 
-        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
+        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
             response = await self.protocol.call_find(node_to_endpoint[peer], queries)
             response = await self.protocol.call_find(node_to_endpoint[peer], queries)
             if not response:
             if not response:
                 return {query: ([], False) for query in queries}
                 return {query: ([], False) for query in queries}
 
 
-            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
+            output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
             for query, (_, _, peers) in response.items():
             for query, (_, _, peers) in response.items():
                 node_to_endpoint.update(peers)
                 node_to_endpoint.update(peers)
-                output[query] = list(peers.keys()), False  # False means "do not interrupt search"
+                output[query] = tuple(peers.keys()), False  # False means "do not interrupt search"
             return output
             return output
 
 
         nearest_nodes_per_query, visited_nodes = await traverse_dht(
         nearest_nodes_per_query, visited_nodes = await traverse_dht(
@@ -289,7 +319,7 @@ class DHTNode:
         Search for a key across DHT and return either first or latest entry.
         Search for a key across DHT and return either first or latest entry.
         :param key: same key as in node.store(...)
         :param key: same key as in node.store(...)
         :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
         :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
-        :param kwargs: parameters forwarded to get_many
+        :param kwargs: parameters forwarded to get_many_by_id
         :returns: (value, expiration time); if value was not found, returns (None, None)
         :returns: (value, expiration time); if value was not found, returns (None, None)
         """
         """
         if latest:
         if latest:
@@ -297,100 +327,190 @@ class DHTNode:
         result = await self.get_many([key])
         result = await self.get_many([key])
         return result[key]
         return result[key]
 
 
-    async def get_many(
-            self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
-            num_workers: Optional[int] = None, beam_size: Optional[int] = None
-    ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]:
+    async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
+                       **kwargs) -> Dict[DHTKey, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
+                                                       Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
         """
         """
+        Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
+
         :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
         :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
+        :param sufficient_expiration_time: if the search finds a value that expires after this time,
+            default = time of call, find any value that did not expire by the time of call
+            If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
+        :param kwargs: for full list of parameters, see DHTNode.get_many_by_id
+        :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
+        :note: in order to check if get returned a value, please check if (expiration_time is None)
+        """
+        keys = tuple(keys)
+        key_ids = [DHTID.generate(key) for key in keys]
+        id_to_original_key = dict(zip(key_ids, keys))
+        results_by_id = await self.get_many_by_id(key_ids, sufficient_expiration_time, **kwargs)
+        return {id_to_original_key[key]: result_or_future for key, result_or_future in results_by_id.items()}
+
+    async def get_many_by_id(
+            self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
+            num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
+            _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
+                                                      Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
+        """
+        Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
+
+        :param key_ids: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
         :param sufficient_expiration_time: if the search finds a value that expires after this time,
         :param sufficient_expiration_time: if the search finds a value that expires after this time,
             default = time of call, find any value that did not expire by the time of call
             default = time of call, find any value that did not expire by the time of call
             If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
             If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
         :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
         :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
         :param num_workers: override for default num_workers, see traverse_dht num_workers param
         :param num_workers: override for default num_workers, see traverse_dht num_workers param
-        :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key
+        :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
+         The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
+         Note: canceling a future will stop search for the corresponding key
+        :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh
+        :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
         :note: in order to check if get returned a value, please check (expiration_time is None)
         :note: in order to check if get returned a value, please check (expiration_time is None)
         """
         """
-        key_ids = [DHTID.generate(key) for key in keys]
-        id_to_original_key = dict(zip(key_ids, keys))
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
         num_workers = num_workers if num_workers is not None else self.num_workers
         num_workers = num_workers if num_workers is not None else self.num_workers
+        search_results: Dict[DHTID, _IntermediateResult] = {key_id: _IntermediateResult(
+            key_id, sufficient_expiration_time, serializer=self.serializer) for key_id in key_ids}
 
 
-        # search metadata
-        unfinished_key_ids = set(key_ids)  # track key ids for which the search is not terminated
-        node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
+        if _refresh_cache:
+            for key_id in key_ids:
+                search_results[key_id].add_done_callback(self._trigger_cache_refresh)
 
 
-        SearchResult = namedtuple("SearchResult", ["binary_value", "expiration_time", "source_node_id"])
-        latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
+        # if we have concurrent get request for some of the same keys, subscribe to their results
+        if self.reuse_get_requests:
+            for key_id, search_result in search_results.items():
+                self.pending_get_requests[key_id].add(search_result)
+                search_result.add_done_callback(self._reuse_finished_search_result)
 
 
-        # stage 1: value can be stored in our local cache
+        # stage 1: check for value in this node's local storage and cache
         for key_id in key_ids:
         for key_id in key_ids:
-            maybe_value, maybe_expiration_time = self.protocol.storage.get(key_id)
-            if maybe_expiration_time is None:
-                maybe_value, maybe_expiration_time = self.protocol.cache.get(key_id)
-            if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
-                latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, self.node_id)
-                if maybe_expiration_time >= sufficient_expiration_time:
-                    unfinished_key_ids.remove(key_id)
-
-        # stage 2: traverse the DHT for any unfinished keys
+            search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
+            search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
+
+        # stage 2: traverse the DHT to get the remaining keys from remote peers
+        unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
+        node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all keys
         for key_id in unfinished_key_ids:
         for key_id in unfinished_key_ids:
             node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
             node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
 
-        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
+        # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
+        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
             queries = list(queries)
             queries = list(queries)
             response = await self.protocol.call_find(node_to_endpoint[peer], queries)
             response = await self.protocol.call_find(node_to_endpoint[peer], queries)
             if not response:
             if not response:
                 return {query: ([], False) for query in queries}
                 return {query: ([], False) for query in queries}
 
 
-            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
-            for key_id, (maybe_value, maybe_expiration_time, peers) in response.items():
+            output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
+            for key_id, (maybe_value_bytes, maybe_expiration_time, peers) in response.items():
                 node_to_endpoint.update(peers)
                 node_to_endpoint.update(peers)
-                if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
-                    latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, peer)
-                should_interrupt = (latest_results[key_id].expiration_time >= sufficient_expiration_time)
-                output[key_id] = list(peers.keys()), should_interrupt
+                search_results[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer)
+                output[key_id] = tuple(peers.keys()), search_results[key_id].finished
+                # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
             return output
             return output
 
 
-        nearest_nodes_per_query, visited_nodes = await traverse_dht(
+        # V-- this function will be called exactly once when traverse_dht finishes search for a given key
+        async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
+            search_results[key_id].finish_search()  # finish search whether or we found something
+            self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint)
+
+        asyncio.create_task(traverse_dht(
             queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
             queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
             beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
             beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
-            get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids})
-
-        # stage 3: cache any new results depending on caching parameters
-        for key_id, nearest_nodes in nearest_nodes_per_query.items():
-            latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[key_id]
-            should_cache = latest_expiration_time >= sufficient_expiration_time  # if we found a newer value, cache it
-            if should_cache and self.cache_locally:
-                self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration_time)
-
-            if should_cache and self.cache_nearest:
-                num_cached_nodes = 0
-                for node_id in nearest_nodes:
-                    if node_id == latest_node_id:
-                        continue
-                    asyncio.create_task(self.protocol.call_store(
-                        node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration_time],
-                        in_cache=True))
-                    num_cached_nodes += 1
-                    if num_cached_nodes >= self.cache_nearest:
-                        break
-
-        # stage 4: deserialize data and assemble function output
-        find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {}
-        for key_id, (latest_value_bytes, latest_expiration_time, _) in latest_results.items():
-            if latest_expiration_time != -float('inf'):
-                latest_value = self.serializer.loads(latest_value_bytes)
-                find_result[id_to_original_key[key_id]] = (latest_value, latest_expiration_time)
-            else:
-                find_result[id_to_original_key[key_id]] = None, None
-        return find_result
+            get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
+            found_callback=found_callback, await_all_tasks=False))
+
+        if return_futures:
+            return {key_id: search_result.future for key_id, search_result in search_results.items()}
+        else:
+            try:
+                # note: this should be first time when we await something, there's no need to "try" the entire function
+                return {key_id: await search_result.future for key_id, search_result in search_results.items()}
+            except asyncio.CancelledError as e:  # terminate remaining tasks ASAP
+                for key_id, search_result in search_results.items():
+                    search_result.future.cancel()
+                raise e
+
+    def _reuse_finished_search_result(self, finished: _IntermediateResult):
+        expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
+        concurrent_requests: SortedList[_IntermediateResult] = self.pending_get_requests[finished.key_id]
+        # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
+        while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
+            concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
+                                                  source_node_id=finished.source_node_id)
+            concurrent_requests[-1].finish_search()
+            concurrent_requests.pop(-1)
+
+    def _trigger_cache_refresh(self, result: _IntermediateResult):
+        """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
+        if result.found_something and result.source_node_id == self.node_id:
+            with self.protocol.cache.freeze():  # do not clear outdated cache for now...
+                if self.cache_refresh_before_expiry and result.key_id in self.protocol.cache:
+                    previous_earliest_item: Tuple[DHTID, BinaryDHTValue, DHTExpiration] = self.cache_refresh_queue.top()
+                    self.cache_refresh_queue.store(result.key_id, result.binary_value, result.expiration_time)
+                    if previous_earliest_item is None or result.expiration_time < previous_earliest_item[-1]:
+                        self.cache_refresh_available.set()  # if we new element is now earliest, notify the cache queue
+
+    async def _refresh_stale_cache_entries(self):
+        """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
+        while self.is_alive:
+            with self.cache_refresh_queue.freeze():
+                while len(self.cache_refresh_queue) == 0:
+                    await self.cache_refresh_available.wait()
+                    self.cache_refresh_available.clear()
+                key_id, _, nearest_expiration = self.cache_refresh_queue.top()
+
+            try:
+                # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
+                time_to_wait = nearest_expiration - get_dht_time() - self.cache_refresh_before_expiry
+                await asyncio.wait_for(self.cache_refresh_available.wait(), timeout=time_to_wait)
+                # note: the line above will cause TimeoutError when we are ready to refresh cache
+                self.cache_refresh_available.clear()  # no timeout error => someone added new entry to queue and ...
+                continue  # ... and this element is earlier than nearest_expiration. we should refresh this entry first
+
+            except asyncio.TimeoutError:  # caught TimeoutError => it is time to refresh the most recent cached entry
+                # step 2: find all keys that we should already refresh and remove them from queue
+                with self.cache_refresh_queue.freeze():
+                    keys_to_refresh = {key_id}
+                    del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
+                    while self.cache_refresh_queue:
+                        key_id, _, nearest_expiration = self.cache_refresh_queue.top()
+                        if nearest_expiration > get_dht_time() + self.cache_refresh_before_expiry:
+                            break
+                        del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
+                        keys_to_refresh.add(key_id)
+
+                # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
+                await self.get_many_by_id(
+                    keys_to_refresh, sufficient_expiration_time=nearest_expiration + self.cache_refresh_before_expiry,
+                    _refresh_cache=False)  # if we found value locally, we shouldn't trigger another refresh
+
+    def _cache_new_result(self, result: _IntermediateResult, nearest_nodes: List[DHTID],
+                          node_to_endpoint: Dict[DHTID, Endpoint]):
+        """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
+        if result.found_something:
+            previous_expiration_time = max(self.protocol.storage.get(result.key_id)[1] or -float('inf'),
+                                           self.protocol.cache.get(result.key_id)[1] or -float('inf'))
+            if result.expiration_time > previous_expiration_time:  # if this value has better expiration
+                if self.cache_locally:
+                    self.protocol.cache.store(result.key_id, result.binary_value, result.expiration_time)
+                if self.cache_nearest:
+                    num_cached_nodes = 0
+                    for node_id in nearest_nodes:
+                        if node_id == result.source_node_id:
+                            continue
+                        asyncio.create_task(self.protocol.call_store(
+                            node_to_endpoint[node_id], [result.key_id], [result.binary_value], [result.expiration_time],
+                            in_cache=True))
+                        num_cached_nodes += 1
+                        if num_cached_nodes >= self.cache_nearest:
+                            break
 
 
     async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
     async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
         """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
         """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
-        while period is not None:  # if None run once, otherwise run forever
+        while self.is_alive and period is not None:  # if None run once, otherwise run forever
             refresh_time = get_dht_time()
             refresh_time = get_dht_time()
             staleness_threshold = refresh_time - period
             staleness_threshold = refresh_time - period
             stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
             stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
@@ -400,3 +520,45 @@ class DHTNode:
                 await self.find_nearest_nodes(refresh_id)
                 await self.find_nearest_nodes(refresh_id)
 
 
             await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))
             await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))
+
+
+@dataclass(init=True, repr=True, frozen=False, order=False)
+class _IntermediateResult:
+    """ A helper class that stores current-best GET results with metadata """
+    key_id: DHTID
+    sufficient_expiration_time: DHTExpiration
+    binary_value: Optional[BinaryDHTValue] = None
+    expiration_time: Optional[DHTExpiration] = None  # best expiration time so far
+    source_node_id: Optional[DHTID] = None  # node that gave us the value
+    future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
+    serializer: type(SerializerBase) = MSGPackSerializer
+
+    def add_candidate(self, binary_value: Optional[BinaryDHTValue], expiration_time: Optional[DHTExpiration],
+                      source_node_id: Optional[DHTID]):
+        if not self.finished and (expiration_time or -float('inf')) > (self.expiration_time or -float('inf')):
+            self.binary_value, self.expiration_time, self.source_node_id = binary_value, expiration_time, source_node_id
+            if self.expiration_time >= self.sufficient_expiration_time:
+                self.finish_search()
+
+    def add_done_callback(self, callback: Callable[[_IntermediateResult], Any]):
+        """ Add callback that will be called when _IntermediateSearchResult is done (found OR cancelled by user) """
+        self.future.add_done_callback(lambda _future: callback(self))
+
+    def finish_search(self):
+        if self.future.done():
+            return  # either user cancelled our result or someone sent it before us. Nothing more to do here.
+        deserialized_value = self.serializer.loads(self.binary_value) if self.found_something else None
+        self.future.set_result((deserialized_value, self.expiration_time))
+
+    @property
+    def found_something(self) -> bool:
+        """ Whether or not we have at least some result, regardless of its expiration time """
+        return self.expiration_time is not None
+
+    @property
+    def finished(self) -> bool:
+        return self.future.done()
+
+    def __lt__(self, other: _IntermediateResult):
+        """ _IntermediateResult instances will be sorted by their target expiration time """
+        return self.sufficient_expiration_time < other.sufficient_expiration_time

+ 50 - 12
hivemind/dht/protocol.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import heapq
 import heapq
+from contextlib import contextmanager
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from warnings import warn
 from warnings import warn
 
 
@@ -265,16 +266,17 @@ class LocalStorage:
 
 
     def __init__(self, maxsize: Optional[int] = None):
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
         self.cache_size = maxsize or float("inf")
-        self.data = dict()
-        self.expiration_heap = []
-        self.key_to_heap = dict()
-
-    def remove_outdated(self):
-        while self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
-                                        or len(self.expiration_heap) > self.cache_size):
+        self.data: Dict[DHTID, Tuple[BinaryDHTValue, DHTExpiration]] = dict()
+        self.expiration_heap: List[Tuple[DHTExpiration, DHTID]] = []
+        self.key_to_heap: Dict[DHTID, Tuple[DHTExpiration, DHTID]] = dict()
+        self.frozen = False  # if True, do not remove outdated elements
+
+    def _remove_outdated(self):
+        while not self.frozen and self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
+                                                            or len(self.expiration_heap) > self.cache_size):
             heap_entry = heapq.heappop(self.expiration_heap)
             heap_entry = heapq.heappop(self.expiration_heap)
             key = heap_entry[1]
             key = heap_entry[1]
-            if self.key_to_heap[key] == heap_entry:
+            if self.key_to_heap.get(key) == heap_entry:
                 del self.data[key], self.key_to_heap[key]
                 del self.data[key], self.key_to_heap[key]
 
 
     def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
     def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
@@ -282,7 +284,7 @@ class LocalStorage:
         Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
         Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
         :returns: True if new value was stored, False it was rejected (current value is newer)
         :returns: True if new value was stored, False it was rejected (current value is newer)
         """
         """
-        if expiration_time < get_dht_time():
+        if expiration_time < get_dht_time() and not self.frozen:
             return False
             return False
         self.key_to_heap[key] = (expiration_time, key)
         self.key_to_heap[key] = (expiration_time, key)
         heapq.heappush(self.expiration_heap, (expiration_time, key))
         heapq.heappush(self.expiration_heap, (expiration_time, key))
@@ -292,17 +294,53 @@ class LocalStorage:
                 return True
                 return True
             return False
             return False
         self.data[key] = (value, expiration_time)
         self.data[key] = (value, expiration_time)
-        self.remove_outdated()
+        self._remove_outdated()
         return True
         return True
 
 
     def get(self, key: DHTID) -> (Optional[BinaryDHTValue], Optional[DHTExpiration]):
     def get(self, key: DHTID) -> (Optional[BinaryDHTValue], Optional[DHTExpiration]):
         """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
         """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
-        self.remove_outdated()
+        self._remove_outdated()
         if key in self.data:
         if key in self.data:
             return self.data[key]
             return self.data[key]
         return None, None
         return None, None
 
 
     def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
     def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
-        self.remove_outdated()
+        self._remove_outdated()
         return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())
         return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())
+
+    def top(self) -> Optional[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
+        """ Return the entry with earliest expiration or None if there isn't any """
+        self._remove_outdated()
+        if self.data:
+            top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
+            while self.key_to_heap.get(top_key) != top_entry:
+                heapq.heappop(self.expiration_heap)  # skip leftover "ghost" entries until first real entry
+                top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
+            value, expiration = self.data[top_key]
+            return top_key, value, expiration
+
+    def __contains__(self, key: DHTID):
+        self._remove_outdated()
+        return key in self.data
+
+    def __len__(self):
+        self._remove_outdated()
+        return len(self.data)
+
+    def __delitem__(self, key: DHTID):
+        if key in self.key_to_heap:
+            del self.data[key], self.key_to_heap[key]
+        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+
+    def __bool__(self):
+        return bool(self.data)
+
+    @contextmanager
+    def freeze(self):
+        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
+        prev_frozen, self.frozen = self.frozen, True
+        try:
+            yield self
+        finally:
+            self.frozen = prev_frozen

+ 16 - 11
hivemind/dht/traverse.py

@@ -11,7 +11,7 @@ ROOT = 0  # alias for heap root
 
 
 async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], beam_size: int,
 async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], beam_size: int,
                               get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
                               get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
-                              visited_nodes: Collection[DHTID] = ()) -> Tuple[List[DHTID], Set[DHTID]]:
+                              visited_nodes: Collection[DHTID] = ()) -> Tuple[Tuple[DHTID], Set[DHTID]]:
     """
     """
     Traverse the DHT graph using get_neighbors function, find :beam_size: nearest nodes according to DHTID.xor_distance.
     Traverse the DHT graph using get_neighbors function, find :beam_size: nearest nodes according to DHTID.xor_distance.
 
 
@@ -64,7 +64,7 @@ async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID],
 
 
 async def traverse_dht(
 async def traverse_dht(
         queries: Collection[DHTID], initial_nodes: List[DHTID], beam_size: int, num_workers: int, queries_per_call: int,
         queries: Collection[DHTID], initial_nodes: List[DHTID], beam_size: int, num_workers: int, queries_per_call: int,
-        get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[List[DHTID], bool]]]],
+        get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[Tuple[DHTID], bool]]]],
         found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
         found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
         await_all_tasks: bool = True, visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
         await_all_tasks: bool = True, visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
 ) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
 ) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
@@ -90,8 +90,9 @@ async def traverse_dht(
         The search terminates iff each query is either stopped via should_stop or finds beam_size nearest nodes.
         The search terminates iff each query is either stopped via should_stop or finds beam_size nearest nodes.
 
 
     :param found_callback: if specified, call this callback for each finished query the moment it finishes or is stopped
     :param found_callback: if specified, call this callback for each finished query the moment it finishes or is stopped
-        More specifically, run asyncio.create_task(found_found_callback(query, nearest_to_query, visited_for_query))
+        More specifically, run asyncio.create_task(found_callback(query, nearest_to_query, visited_for_query))
         Using this callback allows one to process results faster before traverse_dht is finishes for all queries.
         Using this callback allows one to process results faster before traverse_dht is finishes for all queries.
+        It is guaranteed that found_callback will be called exactly once on each query in queries.
 
 
     :param await_all_tasks: if True, wait for all tasks to finish before returning, otherwise returns after finding
     :param await_all_tasks: if True, wait for all tasks to finish before returning, otherwise returns after finding
         nearest neighbors and finishes the remaining tasks (callbacks and queries to known-but-unvisited nodes)
         nearest neighbors and finishes the remaining tasks (callbacks and queries to known-but-unvisited nodes)
@@ -133,10 +134,14 @@ async def traverse_dht(
 
 
     def heuristic_priority(heap_query: DHTID):
     def heuristic_priority(heap_query: DHTID):
         """ Workers prioritize expanding nodes that lead to under-explored queries (by other workers) """
         """ Workers prioritize expanding nodes that lead to under-explored queries (by other workers) """
-        if len(candidate_nodes[heap_query]) == 0:
-            return float('inf'), float('inf')
-        else:  # prefer candidates in heaps with least number of concurrent workers, break ties by distance to query
+        if has_candidates(heap_query):
+            # prefer candidates in heaps with least number of concurrent workers, break ties by distance to query
             return active_workers[heap_query], candidate_nodes[heap_query][ROOT][0]
             return active_workers[heap_query], candidate_nodes[heap_query][ROOT][0]
+        return float('inf'), float('inf')  # try not to explore vertices with no candidates
+
+    def has_candidates(query: DHTID):
+        """ Whether this query's heap contains at least one candidate node that can be explored """
+        return candidate_nodes[query] and candidate_nodes[query][ROOT][0] <= upper_bound(query)
 
 
     def upper_bound(query: DHTID):
     def upper_bound(query: DHTID):
         """ Any node that is farther from query than upper_bound(query) will not be added to heaps """
         """ Any node that is farther from query than upper_bound(query) will not be added to heaps """
@@ -156,7 +161,8 @@ async def traverse_dht(
             # select the heap based on priority
             # select the heap based on priority
             chosen_query: DHTID = min(unfinished_queries, key=heuristic_priority)
             chosen_query: DHTID = min(unfinished_queries, key=heuristic_priority)
 
 
-            if len(candidate_nodes[chosen_query]) == 0:  # if there are no peers to explore...
+            # if there are no peers to explore...
+            if not has_candidates(chosen_query):
                 other_workers_pending = active_workers.most_common(1)[0][1] > 0
                 other_workers_pending = active_workers.most_common(1)[0][1] > 0
                 if other_workers_pending:  # ... wait for other workers (if any) or add more peers
                 if other_workers_pending:  # ... wait for other workers (if any) or add more peers
                     heap_updated_event.clear()
                     heap_updated_event.clear()
@@ -169,10 +175,9 @@ async def traverse_dht(
 
 
             # select vertex to be explored
             # select vertex to be explored
             chosen_distance_to_query, chosen_peer = heapq.heappop(candidate_nodes[chosen_query])
             chosen_distance_to_query, chosen_peer = heapq.heappop(candidate_nodes[chosen_query])
-            if chosen_peer in visited_nodes[chosen_query]:
-                continue
-            if chosen_distance_to_query > upper_bound(chosen_query):
-                finish_search(chosen_query)
+            if chosen_peer in visited_nodes[chosen_query] or chosen_distance_to_query > upper_bound(chosen_query):
+                if chosen_distance_to_query > upper_bound(chosen_query) and active_workers[chosen_query] == 0:
+                    finish_search(chosen_query)
                 continue
                 continue
 
 
             # find additional queries to pack in the same request
             # find additional queries to pack in the same request

+ 1 - 0
requirements.txt

@@ -3,6 +3,7 @@ torch>=1.3.0
 numpy>=1.17
 numpy>=1.17
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1
 umsgpack
 umsgpack
+sortedcontainers
 uvloop>=0.14.0
 uvloop>=0.14.0
 grpcio>=1.31
 grpcio>=1.31
 grpcio-tools>=1.30.0
 grpcio-tools>=1.30.0

+ 76 - 0
tests/test_dht_experts.py

@@ -0,0 +1,76 @@
+import random
+import uuid
+from itertools import chain
+
+import hivemind
+from hivemind import LOCALHOST
+
+
+def test_hivemind_dht():
+    peers = [hivemind.DHT(start=True)]
+    for i in range(10):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
+
+    you: hivemind.dht.DHT = random.choice(peers)
+    theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
+
+    expert_uids = [str(uuid.uuid4()) for _ in range(110)]
+    batch_size = 10
+    for batch_start in range(0, len(expert_uids), batch_size):
+        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
+
+    found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
+    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
+    assert all(res is None for res in found[-2:]), "Found non-existing experts"
+
+    that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
+    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
+    you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
+    assert isinstance(you_found, hivemind.RemoteExpert)
+    assert you_found.endpoint == f'that_host:{that_guys_port}'
+
+    # test first_k_active
+    assert list(theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10)) == expert_uids[:10]
+
+    some_permuted_experts = random.sample(expert_uids, k=32)
+    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32)) == some_permuted_experts
+    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1)) == some_permuted_experts[:1]
+    fake_and_real_experts = list(chain(*zip(
+        [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
+    assert list(theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9)) == some_permuted_experts[:9]
+
+    for peer in peers:
+        peer.shutdown()
+
+
+def test_first_k_active():
+    node = hivemind.DHT(start=True)
+    assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
+    assert all(node.declare_experts(['e.2.1.1'], endpoint=f"{hivemind.LOCALHOST}:1338"))
+
+    results = node.first_k_active(['e.0', 'e.1', 'e.2', 'e.3'], k=2)
+    assert len(results) == 2 and next(iter(results.keys())) == 'e.1'
+    assert results['e.1'].uid in ('e.1.2.3', 'e.1.2.4') and results['e.1'].endpoint == f"{hivemind.LOCALHOST}:1337"
+    assert results['e.2'].uid == 'e.2.1.1' and results['e.2'].endpoint == f"{hivemind.LOCALHOST}:1338"
+
+    results = node.first_k_active(['e', 'e.1', 'e.1.2', 'e.1.2.3'], k=10)
+    assert len(results) == 4
+    assert 'e' in results
+    for k in ('e.1', 'e.1.2', 'e.1.2.3'):
+        assert results[k].uid in ('e.1.2.3', 'e.1.2.4') and results[k].endpoint == f"{hivemind.LOCALHOST}:1337"
+
+
+def test_dht_single_node():
+    node = hivemind.DHT(start=True)
+    assert node.first_k_active(['e3', 'e2'], k=3) == {}
+    assert node.get_experts(['e3', 'e2']) == [None, None]
+
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
+    for expert in node.get_experts(['e3', 'e2']):
+        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+    active_found = node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2)
+    assert list(active_found.keys()) == ['e1', 'e3']
+    assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
+
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))

+ 80 - 100
tests/test_dht.py → tests/test_dht_node.py

@@ -1,10 +1,7 @@
-import time
 import asyncio
 import asyncio
 import multiprocessing as mp
 import multiprocessing as mp
 import random
 import random
 import heapq
 import heapq
-import uuid
-from itertools import chain
 from typing import Optional
 from typing import Optional
 import numpy as np
 import numpy as np
 
 
@@ -13,7 +10,7 @@ from typing import List, Dict
 
 
 from hivemind import get_dht_time
 from hivemind import get_dht_time
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
-from hivemind.dht.protocol import LocalStorage
+from hivemind.dht.protocol import DHTProtocol
 
 
 
 
 def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
 def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
@@ -265,111 +262,94 @@ def test_dht_node():
         proc.terminate()
         proc.terminate()
 
 
 
 
-def test_hivemind_dht():
-    peers = [hivemind.DHT(start=True)]
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
-
-    you: hivemind.dht.DHT = random.choice(peers)
-    theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
-
-    expert_uids = [str(uuid.uuid4()) for _ in range(110)]
-    batch_size = 10
-    for batch_start in range(0, len(expert_uids), batch_size):
-        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
-
-    found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
-    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
-    assert all(res is None for res in found[-2:]), "Found non-existing experts"
-
-    that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
-    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
-    you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
-    assert isinstance(you_found, hivemind.RemoteExpert)
-    assert you_found.endpoint == f'that_host:{that_guys_port}'
-
-    # test first_k_active
-    assert list(theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10)) == expert_uids[:10]
-
-    some_permuted_experts = random.sample(expert_uids, k=32)
-    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32)) == some_permuted_experts
-    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1)) == some_permuted_experts[:1]
-    fake_and_real_experts = list(chain(*zip(
-        [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
-    assert list(theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9)) == some_permuted_experts[:9]
-
-    for peer in peers:
-        peer.shutdown()
-
-
-def test_dht_single_node():
-    node = hivemind.DHT(start=True)
-    assert node.first_k_active(['e3', 'e2'], k=3) == {}
-    assert node.get_experts(['e3', 'e2']) == [None, None]
-
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
-    for expert in node.get_experts(['e3', 'e2']):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
-    active_found = node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2)
-    assert list(active_found.keys()) == ['e1', 'e3']
-    assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
-
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
-
-
-def test_first_k_active():
-    node = hivemind.DHT(start=True)
-    assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
-    assert all(node.declare_experts(['e.2.1.1'], endpoint=f"{hivemind.LOCALHOST}:1338"))
+def test_dhtnode_caching(T=0.05):
+    test_success = mp.Event()
 
 
-    results = node.first_k_active(['e.0', 'e.1', 'e.2', 'e.3'], k=2)
-    assert len(results) == 2 and next(iter(results.keys())) == 'e.1'
-    assert results['e.1'].uid in ('e.1.2.3', 'e.1.2.4') and results['e.1'].endpoint == f"{hivemind.LOCALHOST}:1337"
-    assert results['e.2'].uid == 'e.2.1.1' and results['e.2'].endpoint == f"{hivemind.LOCALHOST}:1338"
+    async def _tester():
+        node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
+        node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
+                                              cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
+        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
+        await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
+        await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
+        await node1.get_many(['k', 'k2', 'k3', 'k4'])
+        assert len(node1.protocol.cache) == 3
+        assert len(node1.cache_refresh_queue) == 0
+
+        await node1.get_many(['k', 'k2', 'k3', 'k4'])
+        assert len(node1.cache_refresh_queue) == 3
+
+        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
+        await asyncio.sleep(4 * T)
+        await node1.get('k')
+        await asyncio.sleep(1 * T)
+
+        assert len(node1.protocol.cache) == 3
+        assert len(node1.cache_refresh_queue) == 2
+        await asyncio.sleep(3 * T)
+
+        assert len(node1.cache_refresh_queue) == 1
+
+        await asyncio.sleep(5 * T)
+        assert len(node1.cache_refresh_queue) == 0
+        await asyncio.sleep(5 * T)
+        assert len(node1.cache_refresh_queue) == 0
+
+        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
+        await node1.get('k')
+        await asyncio.sleep(1 * T)
+        assert len(node1.cache_refresh_queue) == 0
+        await node1.get('k')
+        await asyncio.sleep(1 * T)
+        assert len(node1.cache_refresh_queue) == 1
+
+        await asyncio.sleep(5 * T)
+        assert len(node1.cache_refresh_queue) == 0
+
+        await asyncio.gather(node1.shutdown(), node2.shutdown())
+        test_success.set()
 
 
-    results = node.first_k_active(['e', 'e.1', 'e.1.2', 'e.1.2.3'], k=10)
-    assert len(results) == 4
-    assert 'e' in results
-    for k in ('e.1', 'e.1.2', 'e.1.2.3'):
-        assert results[k].uid in ('e.1.2.3', 'e.1.2.4') and results[k].endpoint == f"{hivemind.LOCALHOST}:1337"
+    proc = mp.Process(target=lambda: asyncio.run(_tester()))
+    proc.start()
+    proc.join()
+    assert test_success.is_set()
 
 
 
 
+def test_dhtnode_reuse_get():
+    test_success = mp.Event()
 
 
-def test_store():
-    d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
-    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
-    print("Test store passed")
+    async def _tester():
+        peers = []
+        for i in range(10):
+            neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+            peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
 
 
+        await asyncio.gather(
+            random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
+            random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
+        )
 
 
-def test_get_expired():
-    d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
-    time.sleep(0.5)
-    assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
-    print("Test get expired passed")
+        you = random.choice(peers)
 
 
+        futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
+        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
+        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
 
 
-def test_get_empty():
-    d = LocalStorage()
-    assert d.get(DHTID.generate(source="key")) == (None, None), "LocalStorage returned non-existent value"
-    print("Test get expired passed")
+        futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
+        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
 
 
+        await asyncio.gather(*futures1.values(), *futures2.values())
+        futures3 = await you.get_many(['k3'], return_futures=True)
+        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
+        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
+        assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
 
 
-def test_change_expiration_time():
-    d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
-    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
-    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
-    time.sleep(1)
-    assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
-    print("Test change expiration time passed")
-
+        assert (await futures1['k1'])[0] == 123
+        assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
+        assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) == (None, None)
+        test_success.set()
 
 
-def test_maxsize_cache():
-    d = LocalStorage(maxsize=1)
-    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
-    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
-    assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
-    assert d.get(DHTID.generate("key1"))[0] is None, "Value with less exp time, must be deleted"
+    proc = mp.Process(target=lambda: asyncio.run(_tester()))
+    proc.start()
+    proc.join()
+    assert test_success.is_set()

+ 79 - 0
tests/test_dht_storage.py

@@ -0,0 +1,79 @@
+import time
+
+from hivemind import DHTID, get_dht_time
+from hivemind.dht.protocol import LocalStorage
+
+
+def test_store():
+    d = LocalStorage()
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
+    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
+    print("Test store passed")
+
+
+def test_get_expired():
+    d = LocalStorage()
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
+    time.sleep(0.5)
+    assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
+    print("Test get expired passed")
+
+
+def test_get_empty():
+    d = LocalStorage()
+    assert d.get(DHTID.generate(source="key")) == (None, None), "LocalStorage returned non-existent value"
+    print("Test get expired passed")
+
+
+def test_change_expiration_time():
+    d = LocalStorage()
+    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
+    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
+    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
+    time.sleep(1)
+    assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
+    print("Test change expiration time passed")
+
+
+def test_maxsize_cache():
+    d = LocalStorage(maxsize=1)
+    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
+    assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
+    assert d.get(DHTID.generate("key1"))[0] is None, "Value with less exp time, must be deleted"
+
+
+def test_localstorage_top():
+    d = LocalStorage(maxsize=3)
+    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
+    d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
+    assert d.top()[:2] == (DHTID.generate("key1"), b"val1")
+
+    d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3)
+    assert d.top()[:2] == (DHTID.generate("key2"), b"val2")
+
+    del d[DHTID.generate('key2')]
+    assert d.top()[:2] == (DHTID.generate("key1"), b"val1_new")
+    d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5)
+    d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6)  # key4 will push out key1 due to maxsize
+
+    assert d.top()[:2] == (DHTID.generate("key3"), b"val3")
+
+
+def test_localstorage_freeze():
+    d = LocalStorage(maxsize=2)
+
+    with d.freeze():
+        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01)
+        assert DHTID.generate("key1") in d
+        time.sleep(0.03)
+        assert DHTID.generate("key1") in d
+    assert DHTID.generate("key1") not in d
+
+    with d.freeze():
+        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+        d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
+        d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 3)  # key3 will push key1 out due to maxsize
+        assert DHTID.generate("key1") in d
+    assert DHTID.generate("key1") not in d