Kaynağa Gözat

Faster beam search: part1/2 (#97)

* [WIP] hivemind.dht.DHT.get_best_experts

* bump

* unused variable

* implement early stopping event for get_many

* implement early stopping event for get_many

* cancel wait_for_termination

* add __contains__ check to LocalStorage

* doc: add single call guarantee

* LocalStorage: add strict typing

* rollback

* add LocalStorage.top

* rollback

* add strict types and builtins

* LocalStorage: add builtins and freeze

* remove_outdated is now private. Rationale: it doesn't always remove outdated elements, better not cause confusion

* caching rev.2 - now it works ;)

* sphinx formatting fix

* add test for caching policy

* pep8

* use separate tester process

* edge case: frozen LocalStorage

* move LocalStorage tests to a separate file

* test for LocalStorage.top

* add tests for new localstorage functionality

* separate tests for DHTNode and experts-related stuff

* separate tests for DHTNode and experts-related stuff

* typo

* add option to reuse pending get requests

* == None -> is None

* WIP that breaks tests

* WIP that breaks tests

* split get_many_by_id into sub-functions, add tests

* remove debugprint

* circleci u mad?

* circleci u mad?

* restore caching

* rm debugprint

* better expiration_time_threshold in DHTNode reuse

* TEMPORARY: fix broken circleci cache

* recache?

* WIP: override cache?

* WIP: override cache?

* WIP: override cache?

* WIP: override cache?

* pep

* add cmd

* rollback auto refactor

* dev1

* bugfix: do not finish_search if there are concurrent workers

* bugfix: do not finish_search if there are concurrent workers

* bugfix: do not finish_search if there are concurrent workers

* remove find_best_experts for now

* we can probably still finish search for a query if it has no concurrent workers

* update benchmark_dht

* rollback changes

* rollback changes

* rollback changes

* unuser import

* auto refactor typo

* typo

* update benchmarks

* fix broken sphinx url

* misc renames

* misc renames

* Update docs/user/contributing.md

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

* Update docs/user/contributing.md

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

* address review by mryab

* address review by mryab

* address review by mryab

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 5 yıl önce
ebeveyn
işleme
9c1e14aca1

+ 1 - 0
.circleci/config.yml

@@ -9,6 +9,7 @@ jobs:
     steps:
       - checkout
       - python/load-cache
+      - run: pip uninstall -y pytest codecov  # temporary override for broken cache
       - run: pip install codecov pytest tqdm scikit-learn
       - python/install-deps
       - python/save-cache

+ 12 - 11
docs/user/contributing.md

@@ -1,4 +1,4 @@
-## Contributing
+## Developer zone
 
 #### Collaborating best practices:
 Hivemind is still in the early stage of development, we expect only a handful of collaborators with individual roles.
@@ -19,7 +19,7 @@ Hivemind is still in the early stage of development, we expect only a handful of
    * If you face any challenges or want feedback, please submit a [draft](https://github.blog/2019-02-14-introducing-draft-pull-requests/) pull request.
 
 
-#### Contributor's manual
+#### Developer quickstart
 
 First, install hivemind in the development mode, preferably with python 3.8 on linux/mac_OS.
 ```
@@ -98,23 +98,24 @@ to measure performance impact of changes to hivemind.dht. It spawns a DHT with `
 then chooses one peer that will declare `num_experts` total experts in batches of `expert_batch_size`.
 Then, another peer will consecutively get all peers and check if they are there.
 
-Here's a run with 1024 participants on the same machine that was used benchmark_throughput:
+Here's a run with 1024 participants on the same machine that was used for benchmark_throughput:
 
+`python benchmark_dht.py --num_peers 1024 --num_experts 16384 --expert_batch_size 64 --expiration 99999 --increase_file_limit`
 <details style="margin-top:-24px; margin-bottom: 16px;">
   <summary>Console outputs</summary>
   
   ```sh
 Increasing file limit - soft 1024=>32768, hard 1048576=>32768
 Creating peers...
-100%|███████████████████████████████████████████████████| 1024/1024 [01:51<00:00,  9.22it/s]
+100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [01:45<00:00,  9.74it/s]
 Sampled 16384 unique ids (after deduplication)
 Storing peers to dht in batches of 64...
-100%|█████████████████████████████████████████████████████| 256/256 [13:00<00:00,  3.05s/it]
-Store success rate: 100.0% (48904 / 48904)
-Mean store time: 0.015967, Total: 780.85
-100%|█████████████████████████████████████████████████████| 256/256 [02:01<00:00,  2.11it/s]
-Get success rate: 100.0 (16383 / 16384)
-Mean get time: 0.00740, Total: 121.29011
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [12:07<00:00,  2.84s/it]
+Store success rate: 100.0% (48920 / 48920)
+Mean store time: 0.01487, Total: 727.46
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [01:48<00:00,  2.35it/s]
+Get success rate: 100.0 (16384 / 16384)
+Mean get time: 0.00664, Total: 108.73952
 Node survival rate: 100.000%
   ```
 </details>
@@ -125,6 +126,6 @@ If one wants to account for these factors, one must introduce them manually by c
   
 
 #### Tips & tricks
-* You can find a wealth of pytorch debugging tricks at [their contributing page](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md).
+* You can find a wealth of pytorch debugging tricks at [their contributing page](https://tinyurl.com/pytorch-contributing).
 * Hivemind is optimized for development in pycharm CE 2019.3 or newer.
   * When working on tests, please mark "tests" as sources root.

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.2'
+__version__ = '0.8.3'

+ 4 - 2
hivemind/dht/__init__.py

@@ -25,7 +25,9 @@ import uvloop
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time
-from hivemind.utils import MPFuture, Endpoint
+from hivemind.utils import MPFuture, Endpoint, get_logger
+
+logger = get_logger(__name__)
 
 
 class DHT(mp.Process):
@@ -155,7 +157,7 @@ class DHT(mp.Process):
         :param uids: a list of expert ids to update
         :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
         :param wait: if True, awaits for declaration to finish, otherwise runs in background
-        :param timeout: waits for the procedure to finish, None means wait indeninitely
+        :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
         :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
         """
         assert not isinstance(uids, str), "Please send a list / tuple of expert uids."

+ 239 - 77
hivemind/dht/node.py

@@ -1,15 +1,21 @@
 from __future__ import annotations
 
 import asyncio
+
 import random
-from collections import namedtuple
-from typing import Optional, Tuple, List, Dict, Collection, Union, Set
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any, Iterable
+from sortedcontainers import SortedList
+from functools import partial
 from warnings import warn
 
-from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue
+from hivemind.dht.protocol import DHTProtocol, LocalStorage
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue
 from hivemind.dht.traverse import traverse_dht
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
+
+logger = get_logger(__name__)
 
 
 class DHTNode:
@@ -45,8 +51,10 @@ class DHTNode:
 
     """
     # fmt:off
-    node_id: DHTID; port: int; num_replicas: int; cache_locally: bool; cache_nearest: int; num_workers: int
-    refresh_timeout: float; protocol: DHTProtocol
+    node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
+    refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
+    cache_refresh_available: asyncio.Event; cache_refresh_queue: LocalStorage
+    reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_IntermediateResult]]
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     # fmt:on
 
@@ -55,8 +63,9 @@ class DHTNode:
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
-            num_workers: int = 1, cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
-            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
+            cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
+            reuse_get_requests: bool = True, num_workers: int = 1, listen: bool = True,
+            listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
@@ -71,11 +80,15 @@ class DHTNode:
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
-        :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
         :param cache_locally: if True, caches all values (stored or found) in a node-local cache
         :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
           nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
         :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
+        :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
+          if they are accessed this many seconds before expiration time.
+        :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
+          all concurrent get requests for the same key will reuse the procedure that is currently in progress
+        :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
           if False, this node will refuse any incoming request, effectively being only a "client"
         :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
@@ -83,11 +96,26 @@ class DHTNode:
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
         :param kwargs: extra parameters used in grpc.aio.server
         """
+        if cache_refresh_before_expiry > 0 and not cache_locally:
+            logger.warning("If cache_locally is False, cache_refresh_before_expiry has no effect. To silence this"
+                           " warning, please specify cache_refresh_before_expiry=0")
+
         self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.num_replicas, self.num_workers = num_replicas, num_workers
-        self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
+        self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
+
+        self.reuse_get_requests = reuse_get_requests
+        self.pending_get_requests = defaultdict(partial(SortedList, key=lambda _res: - _res.sufficient_expiration_time))
+
+        # caching policy
         self.refresh_timeout = refresh_timeout
+        self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
+        self.cache_refresh_before_expiry = cache_refresh_before_expiry
+        self.cache_refresh_queue = LocalStorage()
+        self.cache_refresh_available = asyncio.Event()
+        if cache_refresh_before_expiry:
+            asyncio.create_task(self._refresh_stale_cache_entries())
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
                                                  parallel_rpc, cache_size, listen, listen_on, **kwargs)
@@ -129,7 +157,9 @@ class DHTNode:
 
     async def shutdown(self, timeout=None):
         """ Process existing requests, close all connections and stop the server """
-        await self.protocol.shutdown(timeout)
+        self.is_alive = False
+        if self.protocol.server:
+            await self.protocol.shutdown(timeout)
 
     async def find_nearest_nodes(
             self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
@@ -157,15 +187,15 @@ class DHTNode:
                 node_to_endpoint.update(
                     self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id))
 
-        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
+        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
             response = await self.protocol.call_find(node_to_endpoint[peer], queries)
             if not response:
                 return {query: ([], False) for query in queries}
 
-            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
+            output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
             for query, (_, _, peers) in response.items():
                 node_to_endpoint.update(peers)
-                output[query] = list(peers.keys()), False  # False means "do not interrupt search"
+                output[query] = tuple(peers.keys()), False  # False means "do not interrupt search"
             return output
 
         nearest_nodes_per_query, visited_nodes = await traverse_dht(
@@ -289,7 +319,7 @@ class DHTNode:
         Search for a key across DHT and return either first or latest entry.
         :param key: same key as in node.store(...)
         :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
-        :param kwargs: parameters forwarded to get_many
+        :param kwargs: parameters forwarded to get_many_by_id
         :returns: (value, expiration time); if value was not found, returns (None, None)
         """
         if latest:
@@ -297,100 +327,190 @@ class DHTNode:
         result = await self.get_many([key])
         return result[key]
 
-    async def get_many(
-            self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
-            num_workers: Optional[int] = None, beam_size: Optional[int] = None
-    ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]:
+    async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
+                       **kwargs) -> Dict[DHTKey, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
+                                                       Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
         """
+        Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
+
         :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
+        :param sufficient_expiration_time: if the search finds a value that expires after this time,
+            default = time of call, find any value that did not expire by the time of call
+            If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
+        :param kwargs: for full list of parameters, see DHTNode.get_many_by_id
+        :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
+        :note: in order to check if get returned a value, please check if (expiration_time is None)
+        """
+        keys = tuple(keys)
+        key_ids = [DHTID.generate(key) for key in keys]
+        id_to_original_key = dict(zip(key_ids, keys))
+        results_by_id = await self.get_many_by_id(key_ids, sufficient_expiration_time, **kwargs)
+        return {id_to_original_key[key]: result_or_future for key, result_or_future in results_by_id.items()}
+
+    async def get_many_by_id(
+            self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
+            num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
+            _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
+                                                      Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
+        """
+        Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
+
+        :param key_ids: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
         :param sufficient_expiration_time: if the search finds a value that expires after this time,
             default = time of call, find any value that did not expire by the time of call
             If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
         :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
         :param num_workers: override for default num_workers, see traverse_dht num_workers param
-        :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key
+        :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
+         The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
+         Note: canceling a future will stop search for the corresponding key
+        :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh
+        :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
         :note: in order to check if get returned a value, please check (expiration_time is None)
         """
-        key_ids = [DHTID.generate(key) for key in keys]
-        id_to_original_key = dict(zip(key_ids, keys))
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
         num_workers = num_workers if num_workers is not None else self.num_workers
+        search_results: Dict[DHTID, _IntermediateResult] = {key_id: _IntermediateResult(
+            key_id, sufficient_expiration_time, serializer=self.serializer) for key_id in key_ids}
 
-        # search metadata
-        unfinished_key_ids = set(key_ids)  # track key ids for which the search is not terminated
-        node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
+        if _refresh_cache:
+            for key_id in key_ids:
+                search_results[key_id].add_done_callback(self._trigger_cache_refresh)
 
-        SearchResult = namedtuple("SearchResult", ["binary_value", "expiration_time", "source_node_id"])
-        latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
+        # if we have concurrent get request for some of the same keys, subscribe to their results
+        if self.reuse_get_requests:
+            for key_id, search_result in search_results.items():
+                self.pending_get_requests[key_id].add(search_result)
+                search_result.add_done_callback(self._reuse_finished_search_result)
 
-        # stage 1: value can be stored in our local cache
+        # stage 1: check for value in this node's local storage and cache
         for key_id in key_ids:
-            maybe_value, maybe_expiration_time = self.protocol.storage.get(key_id)
-            if maybe_expiration_time is None:
-                maybe_value, maybe_expiration_time = self.protocol.cache.get(key_id)
-            if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
-                latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, self.node_id)
-                if maybe_expiration_time >= sufficient_expiration_time:
-                    unfinished_key_ids.remove(key_id)
-
-        # stage 2: traverse the DHT for any unfinished keys
+            search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
+            search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
+
+        # stage 2: traverse the DHT to get the remaining keys from remote peers
+        unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
+        node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all keys
         for key_id in unfinished_key_ids:
             node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
-        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
+        # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
+        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
             queries = list(queries)
             response = await self.protocol.call_find(node_to_endpoint[peer], queries)
             if not response:
                 return {query: ([], False) for query in queries}
 
-            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
-            for key_id, (maybe_value, maybe_expiration_time, peers) in response.items():
+            output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
+            for key_id, (maybe_value_bytes, maybe_expiration_time, peers) in response.items():
                 node_to_endpoint.update(peers)
-                if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
-                    latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, peer)
-                should_interrupt = (latest_results[key_id].expiration_time >= sufficient_expiration_time)
-                output[key_id] = list(peers.keys()), should_interrupt
+                search_results[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer)
+                output[key_id] = tuple(peers.keys()), search_results[key_id].finished
+                # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
             return output
 
-        nearest_nodes_per_query, visited_nodes = await traverse_dht(
+        # V-- this function will be called exactly once when traverse_dht finishes search for a given key
+        async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
+            search_results[key_id].finish_search()  # finish search whether or we found something
+            self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint)
+
+        asyncio.create_task(traverse_dht(
             queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
             beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
-            get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids})
-
-        # stage 3: cache any new results depending on caching parameters
-        for key_id, nearest_nodes in nearest_nodes_per_query.items():
-            latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[key_id]
-            should_cache = latest_expiration_time >= sufficient_expiration_time  # if we found a newer value, cache it
-            if should_cache and self.cache_locally:
-                self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration_time)
-
-            if should_cache and self.cache_nearest:
-                num_cached_nodes = 0
-                for node_id in nearest_nodes:
-                    if node_id == latest_node_id:
-                        continue
-                    asyncio.create_task(self.protocol.call_store(
-                        node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration_time],
-                        in_cache=True))
-                    num_cached_nodes += 1
-                    if num_cached_nodes >= self.cache_nearest:
-                        break
-
-        # stage 4: deserialize data and assemble function output
-        find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {}
-        for key_id, (latest_value_bytes, latest_expiration_time, _) in latest_results.items():
-            if latest_expiration_time != -float('inf'):
-                latest_value = self.serializer.loads(latest_value_bytes)
-                find_result[id_to_original_key[key_id]] = (latest_value, latest_expiration_time)
-            else:
-                find_result[id_to_original_key[key_id]] = None, None
-        return find_result
+            get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
+            found_callback=found_callback, await_all_tasks=False))
+
+        if return_futures:
+            return {key_id: search_result.future for key_id, search_result in search_results.items()}
+        else:
+            try:
+                # note: this should be first time when we await something, there's no need to "try" the entire function
+                return {key_id: await search_result.future for key_id, search_result in search_results.items()}
+            except asyncio.CancelledError as e:  # terminate remaining tasks ASAP
+                for key_id, search_result in search_results.items():
+                    search_result.future.cancel()
+                raise e
+
+    def _reuse_finished_search_result(self, finished: _IntermediateResult):
+        expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
+        concurrent_requests: SortedList[_IntermediateResult] = self.pending_get_requests[finished.key_id]
+        # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
+        while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
+            concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
+                                                  source_node_id=finished.source_node_id)
+            concurrent_requests[-1].finish_search()
+            concurrent_requests.pop(-1)
+
+    def _trigger_cache_refresh(self, result: _IntermediateResult):
+        """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
+        if result.found_something and result.source_node_id == self.node_id:
+            with self.protocol.cache.freeze():  # do not clear outdated cache for now...
+                if self.cache_refresh_before_expiry and result.key_id in self.protocol.cache:
+                    previous_earliest_item: Tuple[DHTID, BinaryDHTValue, DHTExpiration] = self.cache_refresh_queue.top()
+                    self.cache_refresh_queue.store(result.key_id, result.binary_value, result.expiration_time)
+                    if previous_earliest_item is None or result.expiration_time < previous_earliest_item[-1]:
+                        self.cache_refresh_available.set()  # if we new element is now earliest, notify the cache queue
+
+    async def _refresh_stale_cache_entries(self):
+        """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
+        while self.is_alive:
+            with self.cache_refresh_queue.freeze():
+                while len(self.cache_refresh_queue) == 0:
+                    await self.cache_refresh_available.wait()
+                    self.cache_refresh_available.clear()
+                key_id, _, nearest_expiration = self.cache_refresh_queue.top()
+
+            try:
+                # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
+                time_to_wait = nearest_expiration - get_dht_time() - self.cache_refresh_before_expiry
+                await asyncio.wait_for(self.cache_refresh_available.wait(), timeout=time_to_wait)
+                # note: the line above will cause TimeoutError when we are ready to refresh cache
+                self.cache_refresh_available.clear()  # no timeout error => someone added new entry to queue and ...
+                continue  # ... and this element is earlier than nearest_expiration. we should refresh this entry first
+
+            except asyncio.TimeoutError:  # caught TimeoutError => it is time to refresh the most recent cached entry
+                # step 2: find all keys that we should already refresh and remove them from queue
+                with self.cache_refresh_queue.freeze():
+                    keys_to_refresh = {key_id}
+                    del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
+                    while self.cache_refresh_queue:
+                        key_id, _, nearest_expiration = self.cache_refresh_queue.top()
+                        if nearest_expiration > get_dht_time() + self.cache_refresh_before_expiry:
+                            break
+                        del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
+                        keys_to_refresh.add(key_id)
+
+                # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
+                await self.get_many_by_id(
+                    keys_to_refresh, sufficient_expiration_time=nearest_expiration + self.cache_refresh_before_expiry,
+                    _refresh_cache=False)  # if we found value locally, we shouldn't trigger another refresh
+
+    def _cache_new_result(self, result: _IntermediateResult, nearest_nodes: List[DHTID],
+                          node_to_endpoint: Dict[DHTID, Endpoint]):
+        """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
+        if result.found_something:
+            previous_expiration_time = max(self.protocol.storage.get(result.key_id)[1] or -float('inf'),
+                                           self.protocol.cache.get(result.key_id)[1] or -float('inf'))
+            if result.expiration_time > previous_expiration_time:  # if this value has better expiration
+                if self.cache_locally:
+                    self.protocol.cache.store(result.key_id, result.binary_value, result.expiration_time)
+                if self.cache_nearest:
+                    num_cached_nodes = 0
+                    for node_id in nearest_nodes:
+                        if node_id == result.source_node_id:
+                            continue
+                        asyncio.create_task(self.protocol.call_store(
+                            node_to_endpoint[node_id], [result.key_id], [result.binary_value], [result.expiration_time],
+                            in_cache=True))
+                        num_cached_nodes += 1
+                        if num_cached_nodes >= self.cache_nearest:
+                            break
 
     async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
         """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
-        while period is not None:  # if None run once, otherwise run forever
+        while self.is_alive and period is not None:  # if None run once, otherwise run forever
             refresh_time = get_dht_time()
             staleness_threshold = refresh_time - period
             stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
@@ -400,3 +520,45 @@ class DHTNode:
                 await self.find_nearest_nodes(refresh_id)
 
             await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))
+
+
+@dataclass(init=True, repr=True, frozen=False, order=False)
+class _IntermediateResult:
+    """ A helper class that stores current-best GET results with metadata """
+    key_id: DHTID
+    sufficient_expiration_time: DHTExpiration
+    binary_value: Optional[BinaryDHTValue] = None
+    expiration_time: Optional[DHTExpiration] = None  # best expiration time so far
+    source_node_id: Optional[DHTID] = None  # node that gave us the value
+    future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
+    serializer: type(SerializerBase) = MSGPackSerializer
+
+    def add_candidate(self, binary_value: Optional[BinaryDHTValue], expiration_time: Optional[DHTExpiration],
+                      source_node_id: Optional[DHTID]):
+        if not self.finished and (expiration_time or -float('inf')) > (self.expiration_time or -float('inf')):
+            self.binary_value, self.expiration_time, self.source_node_id = binary_value, expiration_time, source_node_id
+            if self.expiration_time >= self.sufficient_expiration_time:
+                self.finish_search()
+
+    def add_done_callback(self, callback: Callable[[_IntermediateResult], Any]):
+        """ Add callback that will be called when _IntermediateSearchResult is done (found OR cancelled by user) """
+        self.future.add_done_callback(lambda _future: callback(self))
+
+    def finish_search(self):
+        if self.future.done():
+            return  # either user cancelled our result or someone sent it before us. Nothing more to do here.
+        deserialized_value = self.serializer.loads(self.binary_value) if self.found_something else None
+        self.future.set_result((deserialized_value, self.expiration_time))
+
+    @property
+    def found_something(self) -> bool:
+        """ Whether or not we have at least some result, regardless of its expiration time """
+        return self.expiration_time is not None
+
+    @property
+    def finished(self) -> bool:
+        return self.future.done()
+
+    def __lt__(self, other: _IntermediateResult):
+        """ _IntermediateResult instances will be sorted by their target expiration time """
+        return self.sufficient_expiration_time < other.sufficient_expiration_time

+ 50 - 12
hivemind/dht/protocol.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 
 import asyncio
 import heapq
+from contextlib import contextmanager
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from warnings import warn
 
@@ -265,16 +266,17 @@ class LocalStorage:
 
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
-        self.data = dict()
-        self.expiration_heap = []
-        self.key_to_heap = dict()
-
-    def remove_outdated(self):
-        while self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
-                                        or len(self.expiration_heap) > self.cache_size):
+        self.data: Dict[DHTID, Tuple[BinaryDHTValue, DHTExpiration]] = dict()
+        self.expiration_heap: List[Tuple[DHTExpiration, DHTID]] = []
+        self.key_to_heap: Dict[DHTID, Tuple[DHTExpiration, DHTID]] = dict()
+        self.frozen = False  # if True, do not remove outdated elements
+
+    def _remove_outdated(self):
+        while not self.frozen and self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
+                                                            or len(self.expiration_heap) > self.cache_size):
             heap_entry = heapq.heappop(self.expiration_heap)
             key = heap_entry[1]
-            if self.key_to_heap[key] == heap_entry:
+            if self.key_to_heap.get(key) == heap_entry:
                 del self.data[key], self.key_to_heap[key]
 
     def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
@@ -282,7 +284,7 @@ class LocalStorage:
         Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
         :returns: True if new value was stored, False it was rejected (current value is newer)
         """
-        if expiration_time < get_dht_time():
+        if expiration_time < get_dht_time() and not self.frozen:
             return False
         self.key_to_heap[key] = (expiration_time, key)
         heapq.heappush(self.expiration_heap, (expiration_time, key))
@@ -292,17 +294,53 @@ class LocalStorage:
                 return True
             return False
         self.data[key] = (value, expiration_time)
-        self.remove_outdated()
+        self._remove_outdated()
         return True
 
     def get(self, key: DHTID) -> (Optional[BinaryDHTValue], Optional[DHTExpiration]):
         """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
-        self.remove_outdated()
+        self._remove_outdated()
         if key in self.data:
             return self.data[key]
         return None, None
 
     def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
-        self.remove_outdated()
+        self._remove_outdated()
         return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())
+
+    def top(self) -> Optional[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
+        """ Return the entry with earliest expiration or None if there isn't any """
+        self._remove_outdated()
+        if self.data:
+            top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
+            while self.key_to_heap.get(top_key) != top_entry:
+                heapq.heappop(self.expiration_heap)  # skip leftover "ghost" entries until first real entry
+                top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
+            value, expiration = self.data[top_key]
+            return top_key, value, expiration
+
+    def __contains__(self, key: DHTID):
+        self._remove_outdated()
+        return key in self.data
+
+    def __len__(self):
+        self._remove_outdated()
+        return len(self.data)
+
+    def __delitem__(self, key: DHTID):
+        if key in self.key_to_heap:
+            del self.data[key], self.key_to_heap[key]
+        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+
+    def __bool__(self):
+        return bool(self.data)
+
+    @contextmanager
+    def freeze(self):
+        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
+        prev_frozen, self.frozen = self.frozen, True
+        try:
+            yield self
+        finally:
+            self.frozen = prev_frozen

+ 16 - 11
hivemind/dht/traverse.py

@@ -11,7 +11,7 @@ ROOT = 0  # alias for heap root
 
 async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], beam_size: int,
                               get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
-                              visited_nodes: Collection[DHTID] = ()) -> Tuple[List[DHTID], Set[DHTID]]:
+                              visited_nodes: Collection[DHTID] = ()) -> Tuple[Tuple[DHTID], Set[DHTID]]:
     """
     Traverse the DHT graph using get_neighbors function, find :beam_size: nearest nodes according to DHTID.xor_distance.
 
@@ -64,7 +64,7 @@ async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID],
 
 async def traverse_dht(
         queries: Collection[DHTID], initial_nodes: List[DHTID], beam_size: int, num_workers: int, queries_per_call: int,
-        get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[List[DHTID], bool]]]],
+        get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[Tuple[DHTID], bool]]]],
         found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
         await_all_tasks: bool = True, visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
 ) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
@@ -90,8 +90,9 @@ async def traverse_dht(
         The search terminates iff each query is either stopped via should_stop or finds beam_size nearest nodes.
 
     :param found_callback: if specified, call this callback for each finished query the moment it finishes or is stopped
-        More specifically, run asyncio.create_task(found_found_callback(query, nearest_to_query, visited_for_query))
+        More specifically, run asyncio.create_task(found_callback(query, nearest_to_query, visited_for_query))
         Using this callback allows one to process results faster before traverse_dht is finishes for all queries.
+        It is guaranteed that found_callback will be called exactly once on each query in queries.
 
     :param await_all_tasks: if True, wait for all tasks to finish before returning, otherwise returns after finding
         nearest neighbors and finishes the remaining tasks (callbacks and queries to known-but-unvisited nodes)
@@ -133,10 +134,14 @@ async def traverse_dht(
 
     def heuristic_priority(heap_query: DHTID):
         """ Workers prioritize expanding nodes that lead to under-explored queries (by other workers) """
-        if len(candidate_nodes[heap_query]) == 0:
-            return float('inf'), float('inf')
-        else:  # prefer candidates in heaps with least number of concurrent workers, break ties by distance to query
+        if has_candidates(heap_query):
+            # prefer candidates in heaps with least number of concurrent workers, break ties by distance to query
             return active_workers[heap_query], candidate_nodes[heap_query][ROOT][0]
+        return float('inf'), float('inf')  # try not to explore vertices with no candidates
+
+    def has_candidates(query: DHTID):
+        """ Whether this query's heap contains at least one candidate node that can be explored """
+        return candidate_nodes[query] and candidate_nodes[query][ROOT][0] <= upper_bound(query)
 
     def upper_bound(query: DHTID):
         """ Any node that is farther from query than upper_bound(query) will not be added to heaps """
@@ -156,7 +161,8 @@ async def traverse_dht(
             # select the heap based on priority
             chosen_query: DHTID = min(unfinished_queries, key=heuristic_priority)
 
-            if len(candidate_nodes[chosen_query]) == 0:  # if there are no peers to explore...
+            # if there are no peers to explore...
+            if not has_candidates(chosen_query):
                 other_workers_pending = active_workers.most_common(1)[0][1] > 0
                 if other_workers_pending:  # ... wait for other workers (if any) or add more peers
                     heap_updated_event.clear()
@@ -169,10 +175,9 @@ async def traverse_dht(
 
             # select vertex to be explored
             chosen_distance_to_query, chosen_peer = heapq.heappop(candidate_nodes[chosen_query])
-            if chosen_peer in visited_nodes[chosen_query]:
-                continue
-            if chosen_distance_to_query > upper_bound(chosen_query):
-                finish_search(chosen_query)
+            if chosen_peer in visited_nodes[chosen_query] or chosen_distance_to_query > upper_bound(chosen_query):
+                if chosen_distance_to_query > upper_bound(chosen_query) and active_workers[chosen_query] == 0:
+                    finish_search(chosen_query)
                 continue
 
             # find additional queries to pack in the same request

+ 1 - 0
requirements.txt

@@ -3,6 +3,7 @@ torch>=1.3.0
 numpy>=1.17
 prefetch_generator>=1.0.1
 umsgpack
+sortedcontainers
 uvloop>=0.14.0
 grpcio>=1.31
 grpcio-tools>=1.30.0

+ 76 - 0
tests/test_dht_experts.py

@@ -0,0 +1,76 @@
+import random
+import uuid
+from itertools import chain
+
+import hivemind
+from hivemind import LOCALHOST
+
+
+def test_hivemind_dht():
+    peers = [hivemind.DHT(start=True)]
+    for i in range(10):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
+
+    you: hivemind.dht.DHT = random.choice(peers)
+    theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
+
+    expert_uids = [str(uuid.uuid4()) for _ in range(110)]
+    batch_size = 10
+    for batch_start in range(0, len(expert_uids), batch_size):
+        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
+
+    found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
+    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
+    assert all(res is None for res in found[-2:]), "Found non-existing experts"
+
+    that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
+    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
+    you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
+    assert isinstance(you_found, hivemind.RemoteExpert)
+    assert you_found.endpoint == f'that_host:{that_guys_port}'
+
+    # test first_k_active
+    assert list(theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10)) == expert_uids[:10]
+
+    some_permuted_experts = random.sample(expert_uids, k=32)
+    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32)) == some_permuted_experts
+    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1)) == some_permuted_experts[:1]
+    fake_and_real_experts = list(chain(*zip(
+        [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
+    assert list(theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9)) == some_permuted_experts[:9]
+
+    for peer in peers:
+        peer.shutdown()
+
+
+def test_first_k_active():
+    node = hivemind.DHT(start=True)
+    assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
+    assert all(node.declare_experts(['e.2.1.1'], endpoint=f"{hivemind.LOCALHOST}:1338"))
+
+    results = node.first_k_active(['e.0', 'e.1', 'e.2', 'e.3'], k=2)
+    assert len(results) == 2 and next(iter(results.keys())) == 'e.1'
+    assert results['e.1'].uid in ('e.1.2.3', 'e.1.2.4') and results['e.1'].endpoint == f"{hivemind.LOCALHOST}:1337"
+    assert results['e.2'].uid == 'e.2.1.1' and results['e.2'].endpoint == f"{hivemind.LOCALHOST}:1338"
+
+    results = node.first_k_active(['e', 'e.1', 'e.1.2', 'e.1.2.3'], k=10)
+    assert len(results) == 4
+    assert 'e' in results
+    for k in ('e.1', 'e.1.2', 'e.1.2.3'):
+        assert results[k].uid in ('e.1.2.3', 'e.1.2.4') and results[k].endpoint == f"{hivemind.LOCALHOST}:1337"
+
+
+def test_dht_single_node():
+    node = hivemind.DHT(start=True)
+    assert node.first_k_active(['e3', 'e2'], k=3) == {}
+    assert node.get_experts(['e3', 'e2']) == [None, None]
+
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
+    for expert in node.get_experts(['e3', 'e2']):
+        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+    active_found = node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2)
+    assert list(active_found.keys()) == ['e1', 'e3']
+    assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
+
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))

+ 80 - 100
tests/test_dht.py → tests/test_dht_node.py

@@ -1,10 +1,7 @@
-import time
 import asyncio
 import multiprocessing as mp
 import random
 import heapq
-import uuid
-from itertools import chain
 from typing import Optional
 import numpy as np
 
@@ -13,7 +10,7 @@ from typing import List, Dict
 
 from hivemind import get_dht_time
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
-from hivemind.dht.protocol import LocalStorage
+from hivemind.dht.protocol import DHTProtocol
 
 
 def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
@@ -265,111 +262,94 @@ def test_dht_node():
         proc.terminate()
 
 
-def test_hivemind_dht():
-    peers = [hivemind.DHT(start=True)]
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
-
-    you: hivemind.dht.DHT = random.choice(peers)
-    theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
-
-    expert_uids = [str(uuid.uuid4()) for _ in range(110)]
-    batch_size = 10
-    for batch_start in range(0, len(expert_uids), batch_size):
-        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
-
-    found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
-    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
-    assert all(res is None for res in found[-2:]), "Found non-existing experts"
-
-    that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
-    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
-    you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
-    assert isinstance(you_found, hivemind.RemoteExpert)
-    assert you_found.endpoint == f'that_host:{that_guys_port}'
-
-    # test first_k_active
-    assert list(theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10)) == expert_uids[:10]
-
-    some_permuted_experts = random.sample(expert_uids, k=32)
-    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32)) == some_permuted_experts
-    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1)) == some_permuted_experts[:1]
-    fake_and_real_experts = list(chain(*zip(
-        [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
-    assert list(theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9)) == some_permuted_experts[:9]
-
-    for peer in peers:
-        peer.shutdown()
-
-
-def test_dht_single_node():
-    node = hivemind.DHT(start=True)
-    assert node.first_k_active(['e3', 'e2'], k=3) == {}
-    assert node.get_experts(['e3', 'e2']) == [None, None]
-
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
-    for expert in node.get_experts(['e3', 'e2']):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
-    active_found = node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2)
-    assert list(active_found.keys()) == ['e1', 'e3']
-    assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
-
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
-
-
-def test_first_k_active():
-    node = hivemind.DHT(start=True)
-    assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
-    assert all(node.declare_experts(['e.2.1.1'], endpoint=f"{hivemind.LOCALHOST}:1338"))
+def test_dhtnode_caching(T=0.05):
+    test_success = mp.Event()
 
-    results = node.first_k_active(['e.0', 'e.1', 'e.2', 'e.3'], k=2)
-    assert len(results) == 2 and next(iter(results.keys())) == 'e.1'
-    assert results['e.1'].uid in ('e.1.2.3', 'e.1.2.4') and results['e.1'].endpoint == f"{hivemind.LOCALHOST}:1337"
-    assert results['e.2'].uid == 'e.2.1.1' and results['e.2'].endpoint == f"{hivemind.LOCALHOST}:1338"
+    async def _tester():
+        node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
+        node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
+                                              cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
+        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
+        await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
+        await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
+        await node1.get_many(['k', 'k2', 'k3', 'k4'])
+        assert len(node1.protocol.cache) == 3
+        assert len(node1.cache_refresh_queue) == 0
+
+        await node1.get_many(['k', 'k2', 'k3', 'k4'])
+        assert len(node1.cache_refresh_queue) == 3
+
+        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
+        await asyncio.sleep(4 * T)
+        await node1.get('k')
+        await asyncio.sleep(1 * T)
+
+        assert len(node1.protocol.cache) == 3
+        assert len(node1.cache_refresh_queue) == 2
+        await asyncio.sleep(3 * T)
+
+        assert len(node1.cache_refresh_queue) == 1
+
+        await asyncio.sleep(5 * T)
+        assert len(node1.cache_refresh_queue) == 0
+        await asyncio.sleep(5 * T)
+        assert len(node1.cache_refresh_queue) == 0
+
+        await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
+        await node1.get('k')
+        await asyncio.sleep(1 * T)
+        assert len(node1.cache_refresh_queue) == 0
+        await node1.get('k')
+        await asyncio.sleep(1 * T)
+        assert len(node1.cache_refresh_queue) == 1
+
+        await asyncio.sleep(5 * T)
+        assert len(node1.cache_refresh_queue) == 0
+
+        await asyncio.gather(node1.shutdown(), node2.shutdown())
+        test_success.set()
 
-    results = node.first_k_active(['e', 'e.1', 'e.1.2', 'e.1.2.3'], k=10)
-    assert len(results) == 4
-    assert 'e' in results
-    for k in ('e.1', 'e.1.2', 'e.1.2.3'):
-        assert results[k].uid in ('e.1.2.3', 'e.1.2.4') and results[k].endpoint == f"{hivemind.LOCALHOST}:1337"
+    proc = mp.Process(target=lambda: asyncio.run(_tester()))
+    proc.start()
+    proc.join()
+    assert test_success.is_set()
 
 
+def test_dhtnode_reuse_get():
+    test_success = mp.Event()
 
-def test_store():
-    d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
-    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
-    print("Test store passed")
+    async def _tester():
+        peers = []
+        for i in range(10):
+            neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+            peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
 
+        await asyncio.gather(
+            random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
+            random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
+        )
 
-def test_get_expired():
-    d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
-    time.sleep(0.5)
-    assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
-    print("Test get expired passed")
+        you = random.choice(peers)
 
+        futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
+        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
+        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
 
-def test_get_empty():
-    d = LocalStorage()
-    assert d.get(DHTID.generate(source="key")) == (None, None), "LocalStorage returned non-existent value"
-    print("Test get expired passed")
+        futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
+        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
 
+        await asyncio.gather(*futures1.values(), *futures2.values())
+        futures3 = await you.get_many(['k3'], return_futures=True)
+        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
+        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
+        assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
 
-def test_change_expiration_time():
-    d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
-    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
-    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
-    time.sleep(1)
-    assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
-    print("Test change expiration time passed")
-
+        assert (await futures1['k1'])[0] == 123
+        assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
+        assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) == (None, None)
+        test_success.set()
 
-def test_maxsize_cache():
-    d = LocalStorage(maxsize=1)
-    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
-    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
-    assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
-    assert d.get(DHTID.generate("key1"))[0] is None, "Value with less exp time, must be deleted"
+    proc = mp.Process(target=lambda: asyncio.run(_tester()))
+    proc.start()
+    proc.join()
+    assert test_success.is_set()

+ 79 - 0
tests/test_dht_storage.py

@@ -0,0 +1,79 @@
+import time
+
+from hivemind import DHTID, get_dht_time
+from hivemind.dht.protocol import LocalStorage
+
+
+def test_store():
+    d = LocalStorage()
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
+    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
+    print("Test store passed")
+
+
+def test_get_expired():
+    d = LocalStorage()
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
+    time.sleep(0.5)
+    assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
+    print("Test get expired passed")
+
+
+def test_get_empty():
+    d = LocalStorage()
+    assert d.get(DHTID.generate(source="key")) == (None, None), "LocalStorage returned non-existent value"
+    print("Test get expired passed")
+
+
+def test_change_expiration_time():
+    d = LocalStorage()
+    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
+    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
+    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
+    time.sleep(1)
+    assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
+    print("Test change expiration time passed")
+
+
+def test_maxsize_cache():
+    d = LocalStorage(maxsize=1)
+    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
+    assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
+    assert d.get(DHTID.generate("key1"))[0] is None, "Value with less exp time, must be deleted"
+
+
+def test_localstorage_top():
+    d = LocalStorage(maxsize=3)
+    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
+    d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
+    assert d.top()[:2] == (DHTID.generate("key1"), b"val1")
+
+    d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3)
+    assert d.top()[:2] == (DHTID.generate("key2"), b"val2")
+
+    del d[DHTID.generate('key2')]
+    assert d.top()[:2] == (DHTID.generate("key1"), b"val1_new")
+    d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5)
+    d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6)  # key4 will push out key1 due to maxsize
+
+    assert d.top()[:2] == (DHTID.generate("key3"), b"val3")
+
+
+def test_localstorage_freeze():
+    d = LocalStorage(maxsize=2)
+
+    with d.freeze():
+        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01)
+        assert DHTID.generate("key1") in d
+        time.sleep(0.03)
+        assert DHTID.generate("key1") in d
+    assert DHTID.generate("key1") not in d
+
+    with d.freeze():
+        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+        d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
+        d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 3)  # key3 will push key1 out due to maxsize
+        assert DHTID.generate("key1") in d
+    assert DHTID.generate("key1") not in d