Bläddra i källkod

Use namedtuples for DHT values (#110)

* add test for beam search

* add tests for find_best_experts and batch_find_best_experts

* storage: use named tuples

* switch to namedtuples

* update storage tests

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 år sedan
förälder
incheckning
a59fa709cc

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.6'
+__version__ = '0.8.7'

+ 14 - 12
hivemind/dht/__init__.py

@@ -19,6 +19,7 @@ import multiprocessing as mp
 import warnings
 from collections import deque, OrderedDict
 from concurrent.futures import ThreadPoolExecutor
+from itertools import chain
 from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque, Set
 
 import uvloop
@@ -148,9 +149,9 @@ class DHT(mp.Process):
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         response = await node.get_many(uids, expiration_time, num_workers=num_workers)
         # TODO expert_data['expert'] -> namedtuple with meaningful field names
-        future.set_result([RemoteExpert(*expert_data['expert'][0])
-                           if maybe_expiration_time else None and expert_data['expert'][1] is not None
-                           for uid, (expert_data, maybe_expiration_time) in response.items()])
+        future.set_result([RemoteExpert(*expert_data.value['expert'].value)
+                           if expert_data is not None and 'expert' in expert_data.value else None
+                           for uid, expert_data in response.items()])
 
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
         """
@@ -222,6 +223,7 @@ class DHT(mp.Process):
         if not beam:
             logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
             return []
+        # TODO warn user if indices are out of range on the _last_ level! (rationale: beam search may return <k results)
 
         for dim_index in range(1, len(grid_scores) - 1):
             # select beam_size best suffixes from current beam
@@ -245,11 +247,12 @@ class DHT(mp.Process):
 
         # select best experts from the final beam
         dim_scores = grid_scores[-1]
-        final_best_pairs: List[Tuple[float, str, Endpoint]] = heapq.nlargest(beam_size, (
+        # TODO use heap to harness all results, get rid of five-line expression
+        final_best_pairs: List[Tuple[float, str, Endpoint]] = heapq.nlargest(beam_size, chain((
             (prefix_score + dim_scores[int(suffix_i)], uid, endpoint)
             for prefix_score, prefix, suffixes in beam for suffix_i, ((uid, endpoint), _) in suffixes.items()
             if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)
-        ))
+        ), ((score, *suffixes['expert']) for score, _, suffixes in beam if 'expert' in suffixes)))
         best_experts = [RemoteExpert(uid, endpoint) for score, uid, endpoint in final_best_pairs]
         if future is not None:
             future.set_result(best_experts)
@@ -305,9 +308,9 @@ class DHT(mp.Process):
             # await the next best prefix to be fetched
             pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
             try:
-                maybe_prefix_data, maybe_expiration_time = await pending_task
-                if maybe_expiration_time is not None:
-                    beam.append((scores[pending_best_index], pending_best_prefix, maybe_prefix_data))
+                maybe_prefix_data = await pending_task
+                if maybe_prefix_data is not None:
+                    beam.append((scores[pending_best_index], pending_best_prefix, maybe_prefix_data.value))
             except asyncio.CancelledError:
                 for _, pending_task in pending_tasks:
                     pending_task.cancel()
@@ -328,7 +331,7 @@ class DHT(mp.Process):
         :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
             The keys in the returned dict are ordered same as in uid_prefixes.
         """
-        logger.warning("first_k_active is deprecated and will be removed in 0.8.7")
+        logger.warning("first_k_active is deprecated and will be removed in 0.8.8")
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
@@ -352,9 +355,8 @@ class DHT(mp.Process):
             # parse task results in chronological order, launch additional tasks on demand
             response = await pending_tasks.popleft()
             for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
-                maybe_expert_data, maybe_expiration_time = response[uid_prefix]
-                if maybe_expiration_time is not None and len(maybe_expert_data) > 0:  # found active peer
-                    found.append((uid_prefix, RemoteExpert(*next(iter(maybe_expert_data.values()))[0])))
+                if response[uid_prefix] is not None and len(response[uid_prefix].value) > 0:  # found active peer
+                    found.append((uid_prefix, RemoteExpert(*next(iter(response[uid_prefix].value.values()))[0])))
                     # if we found enough active experts, finish immediately
                     if len(found) >= k:
                         break

+ 38 - 35
hivemind/dht/node.py

@@ -12,7 +12,7 @@ from sortedcontainers import SortedList
 
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
-from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
+from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue, ValueWithExpiration
 from hivemind.dht.traverse import traverse_dht
 from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
 
@@ -189,7 +189,7 @@ class DHTNode:
                 return {query: ([], False) for query in queries}
 
             output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
-            for query, (_, _, peers) in response.items():
+            for query, (_, peers) in response.items():
                 node_to_endpoint.update(peers)
                 output[query] = tuple(peers.keys()), False  # False means "do not interrupt search"
             return output
@@ -237,8 +237,8 @@ class DHTNode:
         """
         if isinstance(expiration_time, DHTExpiration):
             expiration_time = [expiration_time] * len(keys)
-        if subkeys is None or isinstance(subkeys, Subkey):
-            subkeys = [subkeys] * len(keys)
+        if subkeys is None:
+            subkeys = [None] * len(keys)
 
         assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
             "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
@@ -336,13 +336,13 @@ class DHTNode:
                 self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
             self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
 
-    async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
+    async def get(self, key: DHTKey, latest=False, **kwargs) -> Optional[ValueWithExpiration[DHTValue]]:
         """
-        Search for a key across DHT and return either first or latest entry.
+        Search for a key across DHT and return either first or latest entry (if found).
         :param key: same key as in node.store(...)
         :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
         :param kwargs: parameters forwarded to get_many_by_id
-        :returns: (value, expiration time); if value was not found, returns (None, None)
+        :returns: (value, expiration time); if value was not found, returns None
         """
         if latest:
             kwargs["sufficient_expiration_time"] = float('inf')
@@ -350,8 +350,8 @@ class DHTNode:
         return result[key]
 
     async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
-                       **kwargs) -> Dict[DHTKey, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
-                                                       Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
+                       **kwargs) -> Dict[DHTKey, Union[Optional[ValueWithExpiration[DHTValue]],
+                                                       Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
         """
         Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
 
@@ -372,8 +372,8 @@ class DHTNode:
     async def get_many_by_id(
             self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
             num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
-            _is_refresh=False) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
-                                                    Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
+            _is_refresh=False) -> Dict[DHTID, Union[Optional[ValueWithExpiration[DHTValue]],
+                                                    Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
         """
         Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
 
@@ -408,9 +408,9 @@ class DHTNode:
 
         # stage 1: check for value in this node's local storage and cache
         for key_id in key_ids:
-            search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
+            search_results[key_id].add_candidate(self.protocol.storage.get(key_id), source_node_id=self.node_id)
             if not _is_refresh:
-                search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
+                search_results[key_id].add_candidate(self.protocol.cache.get(key_id), source_node_id=self.node_id)
 
         # stage 2: traverse the DHT to get the remaining keys from remote peers
         unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
@@ -427,9 +427,9 @@ class DHTNode:
                 return {query: ([], False) for query in queries}
 
             output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
-            for key_id, (maybe_value_bytes, maybe_expiration_time, peers) in response.items():
+            for key_id, (maybe_value_with_expiration, peers) in response.items():
                 node_to_endpoint.update(peers)
-                search_results[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer)
+                search_results[key_id].add_candidate(maybe_value_with_expiration, source_node_id=peer)
                 output[key_id] = tuple(peers.keys()), search_results[key_id].finished
                 # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
             return output
@@ -457,12 +457,12 @@ class DHTNode:
                 raise e
 
     def _reuse_finished_search_result(self, finished: _SearchState):
+        search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
         expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
         concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
-        # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
+        # note: concurrent_requests is sorted in the order of descending sufficient_expiration_time
         while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
-            concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
-                                                  source_node_id=finished.source_node_id)
+            concurrent_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
             concurrent_requests[-1].finish_search()
             concurrent_requests.pop(-1)
 
@@ -477,8 +477,8 @@ class DHTNode:
         if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
             self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
             logger.debug("Spawned cache refresh task.")
-        previous_earliest_item: Tuple[DHTID, Any, DHTExpiration] = self.cache_refresh_queue.top()
-        if previous_earliest_item is None or refresh_time < previous_earliest_item[-1]:
+        earliest_key, earliest_item = self.cache_refresh_queue.top()
+        if earliest_item is None or refresh_time < earliest_item.expiration_time:
             self.cache_refresh_evt.set()  # if we new element is now earliest, notify the cache queue
         self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
 
@@ -488,7 +488,7 @@ class DHTNode:
             while len(self.cache_refresh_queue) == 0:
                 await self.cache_refresh_evt.wait()
                 self.cache_refresh_evt.clear()
-            key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
+            key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
 
             try:
                 # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
@@ -505,12 +505,14 @@ class DHTNode:
                 max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
                 del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
                 while self.cache_refresh_queue:
-                    key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
+                    key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
                     if nearest_refresh_time > current_time:
                         break
                     del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
                     keys_to_refresh.add(key_id)
-                    max_expiration_time = max(max_expiration_time, self.protocol.cache.get(key_id)[1] or current_time)
+                    cached_item = self.protocol.cache.get(key_id)
+                    if cached_item is not None and cached_item.expiration_time > max_expiration_time:
+                        max_expiration_time = cached_item.expiration_time
 
                 # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
                 sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
@@ -520,9 +522,10 @@ class DHTNode:
                           node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
         """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
         if search.found_something:
-            previous_expiration_time = max(self.protocol.storage.get(search.key_id)[1] or -float('inf'),
-                                           self.protocol.cache.get(search.key_id)[1] or -float('inf'))
-            if search.expiration_time > previous_expiration_time:  # if this value has better expiration
+            _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float('inf'))
+            _, cache_expiration_time = self.protocol.cache.get(search.key_id) or (None, -float('inf'))
+
+            if search.expiration_time > max(storage_expiration_time, cache_expiration_time):
                 if self.cache_locally or _is_refresh:
                     self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time)
                 if self.cache_nearest:
@@ -559,12 +562,12 @@ class _SearchState:
     binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
     expiration_time: Optional[DHTExpiration] = None  # best expiration time so far
     source_node_id: Optional[DHTID] = None  # node that gave us the value
-    future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
+    future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
     serializer: type(SerializerBase) = MSGPackSerializer
 
-    def add_candidate(self, binary_value: Optional[BinaryDHTValue], expiration_time: Optional[DHTExpiration],
-                      source_node_id: Optional[DHTID]):
-        if not self.finished and (expiration_time or -float('inf')) > (self.expiration_time or -float('inf')):
+    def add_candidate(self, candidate: Optional[ValueWithExpiration[BinaryDHTValue]], source_node_id: Optional[DHTID]):
+        binary_value, expiration_time = candidate or (None, -float('inf'))
+        if not self.finished and expiration_time > (self.expiration_time or -float('inf')):
             self.binary_value, self.expiration_time, self.source_node_id = binary_value, expiration_time, source_node_id
             if self.expiration_time >= self.sufficient_expiration_time:
                 self.finish_search()
@@ -577,13 +580,13 @@ class _SearchState:
         if self.future.done():
             return  # either user cancelled our search or someone sent it before us. Nothing more to do here.
         elif not self.found_something:
-            self.future.set_result((None, None))
+            self.future.set_result(None)
         elif isinstance(self.binary_value, BinaryDHTValue):
-            self.future.set_result((self.serializer.loads(self.binary_value), self.expiration_time))
+            self.future.set_result(ValueWithExpiration(self.serializer.loads(self.binary_value), self.expiration_time))
         elif isinstance(self.binary_value, DictionaryDHTValue):
-            dict_value = {key: (self.serializer.loads(value), item_expiration_time)
-                          for key, value, item_expiration_time in self.binary_value.items()}
-            self.future.set_result((dict_value, self.expiration_time))
+            dict_with_subkeys = {key: ValueWithExpiration(self.serializer.loads(value), item_expiration_time)
+                                 for key, (value, item_expiration_time) in self.binary_value.items()}
+            self.future.set_result(ValueWithExpiration(dict_with_subkeys, self.expiration_time))
         else:
             logger.error(f"Invalid value type: {type(self.binary_value)}")
 

+ 38 - 32
hivemind/dht/protocol.py

@@ -9,13 +9,11 @@ import grpc
 import grpc.experimental.aio
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
-from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
+from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue, ValueWithExpiration
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer
 
 logger = get_logger(__name__)
-NOT_FOUND_VALUE, NOT_FOUND_EXPIRATION, IS_REGULAR_VALUE, IS_DICTIONARY = b'', -float('inf'), '', '___DictionaryDHTValue'
-RESERVED_SUBKEYS = {IS_REGULAR_VALUE, IS_DICTIONARY}
 
 
 class DHTProtocol(dht_grpc.DHTServicer):
@@ -23,9 +21,11 @@ class DHTProtocol(dht_grpc.DHTServicer):
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
-    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     # fmt:on
 
+    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
+    RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b''
+
     @classmethod
     async def create(
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
@@ -137,17 +137,19 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """
         if isinstance(expiration_time, DHTExpiration):
             expiration_time = [expiration_time] * len(keys)
-        if subkeys is None or isinstance(subkeys, Subkey):
-            subkeys = [subkeys] * len(keys)
+        if subkeys is None:
+            subkeys = [None] * len(keys)
 
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
         keys, subkeys, values, expiration_time, in_cache = map(list, [keys, subkeys, values, expiration_time, in_cache])
         for i in range(len(keys)):
             if subkeys[i] is None:  # add default sub-key if not specified
-                subkeys[i] = IS_REGULAR_VALUE if not isinstance(values[i], DictionaryDHTValue) else IS_DICTIONARY
+                subkeys[i] = self.IS_DICTIONARY if isinstance(values[i], DictionaryDHTValue) else self.IS_REGULAR_VALUE
+            else:
+                subkeys[i] = self.serializer.dumps(subkeys[i])
             if isinstance(values[i], DictionaryDHTValue):
-                assert subkeys[i] == IS_DICTIONARY, "Please do not specify subkey when storing an entire dictionary"
+                assert subkeys[i] == self.IS_DICTIONARY, "Please don't specify subkey when storing an entire dictionary"
                 values[i] = self.serializer.dumps(values[i])
 
         assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
@@ -172,22 +174,23 @@ class DHTProtocol(dht_grpc.DHTServicer):
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         keys = map(DHTID.from_bytes, request.keys)
-        for key_id, subkey, value_bytes, expiration_time, in_cache in zip(
+        for key_id, tag, value_bytes, expiration_time, in_cache in zip(
                 keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
             storage = self.cache if in_cache else self.storage
-            if subkey == IS_REGULAR_VALUE:  # store normal value without subkeys
+            if tag == self.IS_REGULAR_VALUE:  # store normal value without subkeys
                 response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
-            elif subkey == IS_DICTIONARY:   # store an entire dictionary with pre-existing subkeys
+            elif tag == self.IS_DICTIONARY:  # store an entire dictionary with several subkeys
                 value_dictionary = self.serializer.loads(value_bytes)
                 assert isinstance(value_dictionary, DictionaryDHTValue)
-                response.store_ok.append(all(storage.store_subkey(key_id, subkey, subvalue, subkey_expiration)
-                                             for subkey, subvalue, subkey_expiration in value_dictionary.items()))
-            else:  # add new entry into an existing dictionary-like value or create a new dictionary with one sub-key
+                response.store_ok.append(all(storage.store_subkey(key_id, subkey, item.value, item.expiration_time)
+                                             for subkey, item in value_dictionary.items()))
+            else:  # add a new entry into an existing dictionary value or create a new dictionary with one sub-key
+                subkey = self.serializer.loads(tag)
                 response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
         return response
 
-    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \
-            Optional[Dict[DHTID, Tuple[Optional[BinaryDHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]]]:
+    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[
+            Dict[DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, Endpoint]]]]:
         """
         Request keys from a peer. For each key, look for its (value, expiration time) locally and
          k additional peers that are most likely to have this key (ranked by XOR distance)
@@ -211,14 +214,17 @@ class DHTProtocol(dht_grpc.DHTServicer):
             output = {}  # unpack data depending on its type
             for key, result in zip(keys, response.results):
                 nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
+
                 if result.type == dht_pb2.NOT_FOUND:
-                    output[key] = None, None, nearest
+                    output[key] = None, nearest
                 elif result.type == dht_pb2.FOUND_REGULAR:
-                    output[key] = result.value, result.expiration_time, nearest
+                    output[key] = ValueWithExpiration(result.value, result.expiration_time), nearest
                 elif result.type == dht_pb2.FOUND_DICTIONARY:
-                    output[key] = self.serializer.loads(result.value), result.expiration_time, nearest
+                    deserialized_dictionary = self.serializer.loads(result.value)
+                    output[key] = ValueWithExpiration(deserialized_dictionary, result.expiration_time), nearest
                 else:
                     logger.error(f"Unknown result type: {result.type}")
+
             return output
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
@@ -233,22 +239,22 @@ class DHTProtocol(dht_grpc.DHTServicer):
             asyncio.create_task(self.rpc_ping(request.peer, context))
         response = dht_pb2.FindResponse(results=[], peer=self.node_info)
         for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
-            maybe_value, maybe_expiration_time = self.storage.get(key_id)
-            cached_value, cached_expiration_time = self.cache.get(key_id)
-            if (cached_expiration_time or -float('inf')) > (maybe_expiration_time or -float('inf')):
-                maybe_value, maybe_expiration_time = cached_value, cached_expiration_time
+            maybe_item = self.storage.get(key_id)
+            cached_item = self.cache.get(key_id)
+            if cached_item is not None and (maybe_item is None or cached_item.expiration_time > maybe_item.expiration_time):
+                maybe_item = cached_item
 
-            if maybe_expiration_time is None:  # value not found
+            if maybe_item is None:  # value not found
                 item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND)
-            elif isinstance(maybe_value, DictionaryDHTValue):
-                item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_value),
-                                          expiration_time=maybe_value.latest_expiration_time)
+            elif isinstance(maybe_item.value, DictionaryDHTValue):
+                item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_item.value),
+                                          expiration_time=maybe_item.expiration_time)
             else:  # found regular value
-                item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_value,
-                                          expiration_time=maybe_expiration_time)
+                item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_item.value,
+                                          expiration_time=maybe_item.expiration_time)
 
             for node_id, endpoint in self.routing_table.get_nearest_neighbors(
-                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
+                    key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
                 item.nearest_node_ids.append(node_id.to_bytes())
                 item.nearest_endpoints.append(endpoint)
             response.results.append(item)
@@ -268,7 +274,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             if node_id not in self.routing_table:
                 # we just met a new node, maybe we know some values that it *should* store
                 data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
-                for key, value, expiration_time in list(self.storage.items()):
+                for key, item in list(self.storage.items()):
                     neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
                     if neighbors:
                         nearest_distance = neighbors[0][0].xor_distance(key)
@@ -276,7 +282,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                         new_node_should_store = node_id.xor_distance(key) < farthest_distance
                         this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
-                        data_to_send.append((key, value, expiration_time))
+                        data_to_send.append((key, item.value, item.expiration_time))
                 if data_to_send:
                     asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
 

+ 1 - 1
hivemind/dht/routing.py

@@ -12,7 +12,7 @@ from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
 
 from hivemind.utils import Endpoint, PickleSerializer
 
-DHTKey, Subkey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, str, Any, float, bytes, bytes
+DHTKey, Subkey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, float, bytes, bytes
 get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
 
 

+ 42 - 31
hivemind/dht/storage.py

@@ -1,13 +1,24 @@
 from __future__ import annotations
 import heapq
 from contextlib import contextmanager
-from typing import Generic, Optional, Dict, Tuple, List, Iterator, TypeVar, Union, Any
+from typing import Generic, Optional, Dict, Tuple, List, Iterator, TypeVar, Union, NamedTuple
 
 from hivemind.dht.routing import DHTID, DHTExpiration, get_dht_time, BinaryDHTValue, Subkey
 from hivemind.utils.serializer import MSGPackSerializer
 
 KeyType = TypeVar('KeyType')
 ValueType = TypeVar('ValueType')
+ROOT = 0
+
+
+class ValueWithExpiration(NamedTuple, Generic[ValueType]):
+    value: ValueType
+    expiration_time: DHTExpiration
+
+
+class HeapEntry(NamedTuple, Generic[KeyType]):
+    expiration_time: DHTExpiration
+    key: KeyType
 
 
 class TimedStorage(Generic[KeyType, ValueType]):
@@ -16,17 +27,16 @@ class TimedStorage(Generic[KeyType, ValueType]):
 
     def __init__(self, maxsize: Optional[int] = None):
         self.maxsize = maxsize or float("inf")
-        self.data: Dict[KeyType, Tuple[ValueType, DHTExpiration]] = dict()
-        self.expiration_heap: List[Tuple[DHTExpiration, KeyType]] = []
-        self.key_to_heap: Dict[KeyType, Tuple[DHTExpiration, KeyType]] = dict()
+        self.data: Dict[KeyType, ValueWithExpiration[ValueType]] = dict()
+        self.expiration_heap: List[HeapEntry[KeyType]] = []
+        self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
 
     def _remove_outdated(self):
-        while not self.frozen and self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
+        while not self.frozen and self.expiration_heap and (self.expiration_heap[ROOT].expiration_time < get_dht_time()
                                                             or len(self.expiration_heap) > self.maxsize):
             heap_entry = heapq.heappop(self.expiration_heap)
-            key = heap_entry[1]
-            if self.key_to_heap.get(key) == heap_entry:
-                del self.data[key], self.key_to_heap[key]
+            if self.key_to_heap.get(heap_entry.key) == heap_entry:
+                del self.data[heap_entry.key], self.key_to_heap[heap_entry.key]
 
     def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
         """
@@ -35,39 +45,39 @@ class TimedStorage(Generic[KeyType, ValueType]):
         """
         if expiration_time < get_dht_time() and not self.frozen:
             return False
-        self.key_to_heap[key] = (expiration_time, key)
-        heapq.heappush(self.expiration_heap, (expiration_time, key))
+        self.key_to_heap[key] = HeapEntry(expiration_time, key)
+        heapq.heappush(self.expiration_heap, self.key_to_heap[key])
         if key in self.data:
-            if self.data[key][1] < expiration_time:
-                self.data[key] = (value, expiration_time)
+            if self.data[key].expiration_time < expiration_time:
+                self.data[key] = ValueWithExpiration(value, expiration_time)
                 return True
             return False
-        self.data[key] = (value, expiration_time)
+        self.data[key] = ValueWithExpiration(value, expiration_time)
         self._remove_outdated()
         return True
 
-    def get(self, key: KeyType) -> (Optional[ValueType], Optional[DHTExpiration]):
+    def get(self, key: KeyType) -> Optional[ValueWithExpiration[ValueType]]:
         """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
         self._remove_outdated()
         if key in self.data:
             return self.data[key]
-        return None, None
+        return None
 
-    def items(self) -> Iterator[Tuple[KeyType, ValueType, DHTExpiration]]:
+    def items(self) -> Iterator[Tuple[KeyType, ValueWithExpiration[ValueType]]]:
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         self._remove_outdated()
-        return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())
+        return ((key, value_and_expiration) for key, value_and_expiration in self.data.items())
 
-    def top(self) -> Optional[Tuple[KeyType, ValueType, DHTExpiration]]:
+    def top(self) -> Tuple[Optional[KeyType], Optional[ValueWithExpiration[ValueType]]]:
         """ Return the entry with earliest expiration or None if there isn't any """
         self._remove_outdated()
         if self.data:
-            top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
-            while self.key_to_heap.get(top_key) != top_entry:
-                heapq.heappop(self.expiration_heap)  # skip leftover "ghost" entries until first real entry
-                top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
-            value, expiration = self.data[top_key]
-            return top_key, value, expiration
+            # skip leftover "ghost" entries until first real entry
+            while self.key_to_heap.get(self.expiration_heap[ROOT].key) != self.expiration_heap[ROOT]:
+                heapq.heappop(self.expiration_heap)
+            top_key = self.expiration_heap[ROOT].key
+            return top_key, self.data[top_key]
+        return None, None
 
     def __contains__(self, key: KeyType):
         self._remove_outdated()
@@ -109,7 +119,8 @@ class DictionaryDHTValue(TimedStorage[Subkey, BinaryDHTValue]):
 
     def packb(self) -> bytes:
         """ custom behavior for MSGPackSerializer.dumps """
-        return MSGPackSerializer.dumps([self.maxsize, self.latest_expiration_time, list(map(list, self.items()))])
+        packed_items = [[key, value, expiration_time] for key, (value, expiration_time) in self.items()]
+        return MSGPackSerializer.dumps([self.maxsize, self.latest_expiration_time, packed_items])
 
     @classmethod
     def unpackb(cls, raw: bytes) -> DictionaryDHTValue:
@@ -143,15 +154,15 @@ class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTVal
          3) if self[key] is a normal value with smaller expiration time, overwrite it with a dictionary and add sub-key
         :returns: True if new entry was stored, False it was rejected (current value is newer)
         """
-        previous_value, previous_expiration_time = self.get(key)
-        if isinstance(previous_value, DictionaryDHTValue):  # already a dictionary, just add new subkey
-            if expiration_time > previous_value.latest_expiration_time:
-                super().store(key, previous_value, expiration_time)  # refresh expiration time
-            return previous_value.store(subkey, value, expiration_time)
-        elif expiration_time > (previous_expiration_time or float('-inf')):  # create new dictionary, add subkey
+        previous_value, previous_expiration_time = self.get(key) or (b'', -float('inf'))
+        if isinstance(previous_value, BinaryDHTValue) and expiration_time > previous_expiration_time:
             new_storage = DictionaryDHTValue()
             new_storage.store(subkey, value, expiration_time)
             return super().store(key, new_storage, new_storage.latest_expiration_time)
+        elif isinstance(previous_value, DictionaryDHTValue):
+            if expiration_time > previous_value.latest_expiration_time:
+                super().store(key, previous_value, expiration_time)  # refresh expiration time
+            return previous_value.store(subkey, value, expiration_time)
         else:
             return False
 

+ 2 - 2
hivemind/proto/dht.proto

@@ -24,8 +24,8 @@ message NodeInfo {
 message StoreRequest {
   // three lists of the same length representing dht keys, dht values and expiration
   repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
-  repeated string subkeys = 2;         // [optional] subkeys for DictionaryDHTValue type, empty string means no subkey
-  repeated bytes values = 3;           // binary-encoded value for i-th key
+  repeated bytes subkeys = 2;          // serialized subkeys for DictionaryDHTValue type. None means no subkey
+  repeated bytes values = 3;           // serialized value for i-th key
   repeated double expiration_time = 4; // expirations for i-th key (type = DHTExpiration)
   repeated bool in_cache = 5;          // if in_cache[i], store i-th key in cache, else store normally
   NodeInfo peer = 6;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping

+ 41 - 7
tests/test_dht_experts.py

@@ -1,6 +1,7 @@
 import random
 import uuid
 from itertools import chain
+import numpy as np
 
 import hivemind
 from hivemind import LOCALHOST
@@ -44,6 +45,38 @@ def test_store_get_experts():
         peer.shutdown()
 
 
+def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
+                     grid_dims=(32, 32, 32)):
+    dht = []
+    for i in range(dht_size):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
+        dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
+
+    real_experts = sorted({
+        'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
+        for _ in range(total_experts)
+    })
+    for batch_start in range(0, len(real_experts), batch_size):
+        random.choice(dht).declare_experts(
+            real_experts[batch_start: batch_start + batch_size], wait=True,
+            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
+
+    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
+    you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
+
+    for i in range(50):
+        topk_experts = you.find_best_experts('expert', [np.random.randn(dim) for dim in grid_dims], beam_size=beam_size)
+        assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
+        assert len(topk_experts) == beam_size
+
+    for i in range(10):
+        batch_experts = you.batch_find_best_experts('expert', [np.random.randn(batch_size, dim) for dim in grid_dims],
+                                                    beam_size=beam_size)
+        assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
+        assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
+        assert all(len(experts) == beam_size for experts in batch_experts)
+
+
 def test_first_k_active():
     node = hivemind.DHT(start=True)
     assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
@@ -63,14 +96,15 @@ def test_first_k_active():
 
 def test_dht_single_node():
     node = hivemind.DHT(start=True)
-    assert node.first_k_active(['e3', 'e2'], k=3) == {}
-    assert node.get_experts(['e3', 'e2']) == [None, None]
+    assert node.first_k_active(['e.3', 'e.2'], k=3) == {}
+    assert node.get_experts(['e.3', 'e.2']) == [None, None]
 
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
-    for expert in node.get_experts(['e3', 'e2']):
+    assert all(node.declare_experts(['e.1', 'e.2', 'e.3'], f"{hivemind.LOCALHOST}:1337"))
+    for expert in node.get_experts(['e.3', 'e.2']):
         assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
-    active_found = node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2)
-    assert list(active_found.keys()) == ['e1', 'e3']
+    active_found = node.first_k_active(['e.0', 'e.1', 'e.3', 'e.5', 'e.2'], k=2)
+    assert list(active_found.keys()) == ['e.1', 'e.3']
     assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
 
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
+    assert all(node.declare_experts(['e.1', 'e.2', 'e.3'], f"{hivemind.LOCALHOST}:1337"))
+    assert node.find_best_experts('e', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=4)

+ 9 - 9
tests/test_dht_node.py

@@ -62,7 +62,7 @@ def test_dht_protocol():
             assert all(store_ok), "DHT rejected a trivial store"
 
             # peer 1 must know about peer 2
-            recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+            (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
                 protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
             recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
             (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
@@ -75,9 +75,9 @@ def test_dht_protocol():
 
             # peer 2 must know about peer 1, but not have a *random* nonexistent value
             dummy_key = DHTID.generate()
-            recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete(
+            empty_item, nodes_found_2 = loop.run_until_complete(
                 protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
-            assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values"
+            assert empty_item is None, "Non-existent keys shouldn't have values"
             (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
             assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
                 f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
@@ -97,7 +97,7 @@ def test_dht_protocol():
                 f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
                 expiration_time=[expiration + 5], subkeys=[subkey2])
             )
-            recv_dict, recv_expiration, nodes_found = loop.run_until_complete(
+            (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
                 protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
             assert isinstance(recv_dict, DictionaryDHTValue)
             assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
@@ -134,14 +134,14 @@ def test_empty_table():
 
         key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
 
-        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+        empty_item, nodes_found = loop.run_until_complete(
             protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
-        assert recv_value_bytes is None and recv_expiration is None and len(nodes_found) == 0
+        assert empty_item is None and len(nodes_found) == 0
         assert all(loop.run_until_complete(protocol.call_store(
             f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
         )), "peer rejected store"
 
-        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
             protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
         recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
         assert len(nodes_found) == 0
@@ -266,7 +266,7 @@ def test_dht_node():
             assert val == ["Value", 10], "Wrong value"
             assert expiration_time == true_time, f"Wrong time"
 
-        assert loop.run_until_complete(detached_node.get("mykey")) == (None, None)
+        assert loop.run_until_complete(detached_node.get("mykey")) is None
 
         # test 7: bulk store and bulk get
         keys = 'foo', 'bar', 'baz', 'zzz'
@@ -429,7 +429,7 @@ def test_dhtnode_reuse_get():
 
         assert (await futures1['k1'])[0] == 123
         assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
-        assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) == (None, None)
+        assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
         test_success.set()
 
     proc = mp.Process(target=lambda: asyncio.run(_tester()))

+ 8 - 8
tests/test_dht_storage.py

@@ -16,13 +16,13 @@ def test_get_expired():
     d = DHTLocalStorage()
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
     time.sleep(0.5)
-    assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
+    assert d.get(DHTID.generate("key")) is None, "Expired value must be deleted"
     print("Test get expired passed")
 
 
 def test_get_empty():
     d = DHTLocalStorage()
-    assert d.get(DHTID.generate(source="key")) == (None, None), "DHTLocalStorage returned non-existent value"
+    assert d.get(DHTID.generate(source="key")) is None, "DHTLocalStorage returned non-existent value"
     print("Test get expired passed")
 
 
@@ -41,7 +41,7 @@ def test_maxsize_cache():
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
     assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
-    assert d.get(DHTID.generate("key1"))[0] is None, "Value with less exp time, must be deleted"
+    assert d.get(DHTID.generate("key1")) is None, "Value with less exp time, must be deleted"
 
 
 def test_localstorage_top():
@@ -49,17 +49,17 @@ def test_localstorage_top():
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
     d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
-    assert d.top()[:2] == (DHTID.generate("key1"), b"val1")
+    assert d.top()[0] == DHTID.generate("key1") and d.top()[1].value == b"val1"
 
     d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3)
-    assert d.top()[:2] == (DHTID.generate("key2"), b"val2")
+    assert d.top()[0] == DHTID.generate("key2") and d.top()[1].value == b"val2"
 
     del d[DHTID.generate('key2')]
-    assert d.top()[:2] == (DHTID.generate("key1"), b"val1_new")
+    assert d.top()[0] == DHTID.generate("key1") and d.top()[1].value == b"val1_new"
     d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5)
     d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6)  # key4 will push out key1 due to maxsize
 
-    assert d.top()[:2] == (DHTID.generate("key3"), b"val3")
+    assert d.top()[0] == DHTID.generate("key3") and d.top()[1].value == b"val3"
 
 
 def test_localstorage_nested():
@@ -71,7 +71,7 @@ def test_localstorage_nested():
     d2.store('subkey3', b'value3', time + 1)
 
     assert d2.latest_expiration_time == time + 3
-    for subkey, subvalue, subexpiration in d2.items():
+    for subkey, (subvalue, subexpiration) in d2.items():
         assert d1.store_subkey(DHTID.generate('foo'), subkey, subvalue, subexpiration)
     assert d1.store(DHTID.generate('bar'), b'456', time + 2)
     assert d1.get(DHTID.generate('foo'))[0].data == d2.data