فهرست منبع

Use namedtuples for DHT values (#110)

* add test for beam search

* add tests for find_best_experts and batch_find_best_experts

* storage: use named tuples

* switch to namedtuples

* update storage tests

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 سال پیش
والد
کامیت
a59fa709cc

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.6'
+__version__ = '0.8.7'

+ 14 - 12
hivemind/dht/__init__.py

@@ -19,6 +19,7 @@ import multiprocessing as mp
 import warnings
 import warnings
 from collections import deque, OrderedDict
 from collections import deque, OrderedDict
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
+from itertools import chain
 from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque, Set
 from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque, Set
 
 
 import uvloop
 import uvloop
@@ -148,9 +149,9 @@ class DHT(mp.Process):
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         response = await node.get_many(uids, expiration_time, num_workers=num_workers)
         response = await node.get_many(uids, expiration_time, num_workers=num_workers)
         # TODO expert_data['expert'] -> namedtuple with meaningful field names
         # TODO expert_data['expert'] -> namedtuple with meaningful field names
-        future.set_result([RemoteExpert(*expert_data['expert'][0])
-                           if maybe_expiration_time else None and expert_data['expert'][1] is not None
-                           for uid, (expert_data, maybe_expiration_time) in response.items()])
+        future.set_result([RemoteExpert(*expert_data.value['expert'].value)
+                           if expert_data is not None and 'expert' in expert_data.value else None
+                           for uid, expert_data in response.items()])
 
 
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
         """
         """
@@ -222,6 +223,7 @@ class DHT(mp.Process):
         if not beam:
         if not beam:
             logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
             logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
             return []
             return []
+        # TODO warn user if indices are out of range on the _last_ level! (rationale: beam search may return <k results)
 
 
         for dim_index in range(1, len(grid_scores) - 1):
         for dim_index in range(1, len(grid_scores) - 1):
             # select beam_size best suffixes from current beam
             # select beam_size best suffixes from current beam
@@ -245,11 +247,12 @@ class DHT(mp.Process):
 
 
         # select best experts from the final beam
         # select best experts from the final beam
         dim_scores = grid_scores[-1]
         dim_scores = grid_scores[-1]
-        final_best_pairs: List[Tuple[float, str, Endpoint]] = heapq.nlargest(beam_size, (
+        # TODO use heap to harness all results, get rid of five-line expression
+        final_best_pairs: List[Tuple[float, str, Endpoint]] = heapq.nlargest(beam_size, chain((
             (prefix_score + dim_scores[int(suffix_i)], uid, endpoint)
             (prefix_score + dim_scores[int(suffix_i)], uid, endpoint)
             for prefix_score, prefix, suffixes in beam for suffix_i, ((uid, endpoint), _) in suffixes.items()
             for prefix_score, prefix, suffixes in beam for suffix_i, ((uid, endpoint), _) in suffixes.items()
             if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)
             if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)
-        ))
+        ), ((score, *suffixes['expert']) for score, _, suffixes in beam if 'expert' in suffixes)))
         best_experts = [RemoteExpert(uid, endpoint) for score, uid, endpoint in final_best_pairs]
         best_experts = [RemoteExpert(uid, endpoint) for score, uid, endpoint in final_best_pairs]
         if future is not None:
         if future is not None:
             future.set_result(best_experts)
             future.set_result(best_experts)
@@ -305,9 +308,9 @@ class DHT(mp.Process):
             # await the next best prefix to be fetched
             # await the next best prefix to be fetched
             pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
             pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
             try:
             try:
-                maybe_prefix_data, maybe_expiration_time = await pending_task
-                if maybe_expiration_time is not None:
-                    beam.append((scores[pending_best_index], pending_best_prefix, maybe_prefix_data))
+                maybe_prefix_data = await pending_task
+                if maybe_prefix_data is not None:
+                    beam.append((scores[pending_best_index], pending_best_prefix, maybe_prefix_data.value))
             except asyncio.CancelledError:
             except asyncio.CancelledError:
                 for _, pending_task in pending_tasks:
                 for _, pending_task in pending_tasks:
                     pending_task.cancel()
                     pending_task.cancel()
@@ -328,7 +331,7 @@ class DHT(mp.Process):
         :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
         :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
             The keys in the returned dict are ordered same as in uid_prefixes.
             The keys in the returned dict are ordered same as in uid_prefixes.
         """
         """
-        logger.warning("first_k_active is deprecated and will be removed in 0.8.7")
+        logger.warning("first_k_active is deprecated and will be removed in 0.8.8")
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
         self.pipe.send(('_first_k_active', [],
@@ -352,9 +355,8 @@ class DHT(mp.Process):
             # parse task results in chronological order, launch additional tasks on demand
             # parse task results in chronological order, launch additional tasks on demand
             response = await pending_tasks.popleft()
             response = await pending_tasks.popleft()
             for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
             for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
-                maybe_expert_data, maybe_expiration_time = response[uid_prefix]
-                if maybe_expiration_time is not None and len(maybe_expert_data) > 0:  # found active peer
-                    found.append((uid_prefix, RemoteExpert(*next(iter(maybe_expert_data.values()))[0])))
+                if response[uid_prefix] is not None and len(response[uid_prefix].value) > 0:  # found active peer
+                    found.append((uid_prefix, RemoteExpert(*next(iter(response[uid_prefix].value.values()))[0])))
                     # if we found enough active experts, finish immediately
                     # if we found enough active experts, finish immediately
                     if len(found) >= k:
                     if len(found) >= k:
                         break
                         break

+ 38 - 35
hivemind/dht/node.py

@@ -12,7 +12,7 @@ from sortedcontainers import SortedList
 
 
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
-from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
+from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue, ValueWithExpiration
 from hivemind.dht.traverse import traverse_dht
 from hivemind.dht.traverse import traverse_dht
 from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
 from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
 
 
@@ -189,7 +189,7 @@ class DHTNode:
                 return {query: ([], False) for query in queries}
                 return {query: ([], False) for query in queries}
 
 
             output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
             output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
-            for query, (_, _, peers) in response.items():
+            for query, (_, peers) in response.items():
                 node_to_endpoint.update(peers)
                 node_to_endpoint.update(peers)
                 output[query] = tuple(peers.keys()), False  # False means "do not interrupt search"
                 output[query] = tuple(peers.keys()), False  # False means "do not interrupt search"
             return output
             return output
@@ -237,8 +237,8 @@ class DHTNode:
         """
         """
         if isinstance(expiration_time, DHTExpiration):
         if isinstance(expiration_time, DHTExpiration):
             expiration_time = [expiration_time] * len(keys)
             expiration_time = [expiration_time] * len(keys)
-        if subkeys is None or isinstance(subkeys, Subkey):
-            subkeys = [subkeys] * len(keys)
+        if subkeys is None:
+            subkeys = [None] * len(keys)
 
 
         assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
         assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
             "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
             "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
@@ -336,13 +336,13 @@ class DHTNode:
                 self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
                 self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
             self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
             self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
 
 
-    async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
+    async def get(self, key: DHTKey, latest=False, **kwargs) -> Optional[ValueWithExpiration[DHTValue]]:
         """
         """
-        Search for a key across DHT and return either first or latest entry.
+        Search for a key across DHT and return either first or latest entry (if found).
         :param key: same key as in node.store(...)
         :param key: same key as in node.store(...)
         :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
         :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
         :param kwargs: parameters forwarded to get_many_by_id
         :param kwargs: parameters forwarded to get_many_by_id
-        :returns: (value, expiration time); if value was not found, returns (None, None)
+        :returns: (value, expiration time); if value was not found, returns None
         """
         """
         if latest:
         if latest:
             kwargs["sufficient_expiration_time"] = float('inf')
             kwargs["sufficient_expiration_time"] = float('inf')
@@ -350,8 +350,8 @@ class DHTNode:
         return result[key]
         return result[key]
 
 
     async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
     async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
-                       **kwargs) -> Dict[DHTKey, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
-                                                       Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
+                       **kwargs) -> Dict[DHTKey, Union[Optional[ValueWithExpiration[DHTValue]],
+                                                       Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
         """
         """
         Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
         Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
 
 
@@ -372,8 +372,8 @@ class DHTNode:
     async def get_many_by_id(
     async def get_many_by_id(
             self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
             self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
             num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
             num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
-            _is_refresh=False) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
-                                                    Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
+            _is_refresh=False) -> Dict[DHTID, Union[Optional[ValueWithExpiration[DHTValue]],
+                                                    Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
         """
         """
         Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
         Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
 
 
@@ -408,9 +408,9 @@ class DHTNode:
 
 
         # stage 1: check for value in this node's local storage and cache
         # stage 1: check for value in this node's local storage and cache
         for key_id in key_ids:
         for key_id in key_ids:
-            search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
+            search_results[key_id].add_candidate(self.protocol.storage.get(key_id), source_node_id=self.node_id)
             if not _is_refresh:
             if not _is_refresh:
-                search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
+                search_results[key_id].add_candidate(self.protocol.cache.get(key_id), source_node_id=self.node_id)
 
 
         # stage 2: traverse the DHT to get the remaining keys from remote peers
         # stage 2: traverse the DHT to get the remaining keys from remote peers
         unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
         unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
@@ -427,9 +427,9 @@ class DHTNode:
                 return {query: ([], False) for query in queries}
                 return {query: ([], False) for query in queries}
 
 
             output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
             output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
-            for key_id, (maybe_value_bytes, maybe_expiration_time, peers) in response.items():
+            for key_id, (maybe_value_with_expiration, peers) in response.items():
                 node_to_endpoint.update(peers)
                 node_to_endpoint.update(peers)
-                search_results[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer)
+                search_results[key_id].add_candidate(maybe_value_with_expiration, source_node_id=peer)
                 output[key_id] = tuple(peers.keys()), search_results[key_id].finished
                 output[key_id] = tuple(peers.keys()), search_results[key_id].finished
                 # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
                 # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
             return output
             return output
@@ -457,12 +457,12 @@ class DHTNode:
                 raise e
                 raise e
 
 
     def _reuse_finished_search_result(self, finished: _SearchState):
     def _reuse_finished_search_result(self, finished: _SearchState):
+        search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
         expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
         expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
         concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
         concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
-        # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
+        # note: concurrent_requests is sorted in the order of descending sufficient_expiration_time
         while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
         while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
-            concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
-                                                  source_node_id=finished.source_node_id)
+            concurrent_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
             concurrent_requests[-1].finish_search()
             concurrent_requests[-1].finish_search()
             concurrent_requests.pop(-1)
             concurrent_requests.pop(-1)
 
 
@@ -477,8 +477,8 @@ class DHTNode:
         if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
         if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
             self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
             self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
             logger.debug("Spawned cache refresh task.")
             logger.debug("Spawned cache refresh task.")
-        previous_earliest_item: Tuple[DHTID, Any, DHTExpiration] = self.cache_refresh_queue.top()
-        if previous_earliest_item is None or refresh_time < previous_earliest_item[-1]:
+        earliest_key, earliest_item = self.cache_refresh_queue.top()
+        if earliest_item is None or refresh_time < earliest_item.expiration_time:
             self.cache_refresh_evt.set()  # if we new element is now earliest, notify the cache queue
             self.cache_refresh_evt.set()  # if we new element is now earliest, notify the cache queue
         self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
         self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
 
 
@@ -488,7 +488,7 @@ class DHTNode:
             while len(self.cache_refresh_queue) == 0:
             while len(self.cache_refresh_queue) == 0:
                 await self.cache_refresh_evt.wait()
                 await self.cache_refresh_evt.wait()
                 self.cache_refresh_evt.clear()
                 self.cache_refresh_evt.clear()
-            key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
+            key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
 
 
             try:
             try:
                 # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
                 # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
@@ -505,12 +505,14 @@ class DHTNode:
                 max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
                 max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
                 del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
                 del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
                 while self.cache_refresh_queue:
                 while self.cache_refresh_queue:
-                    key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
+                    key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
                     if nearest_refresh_time > current_time:
                     if nearest_refresh_time > current_time:
                         break
                         break
                     del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
                     del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
                     keys_to_refresh.add(key_id)
                     keys_to_refresh.add(key_id)
-                    max_expiration_time = max(max_expiration_time, self.protocol.cache.get(key_id)[1] or current_time)
+                    cached_item = self.protocol.cache.get(key_id)
+                    if cached_item is not None and cached_item.expiration_time > max_expiration_time:
+                        max_expiration_time = cached_item.expiration_time
 
 
                 # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
                 # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
                 sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
                 sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
@@ -520,9 +522,10 @@ class DHTNode:
                           node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
                           node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
         """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
         """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
         if search.found_something:
         if search.found_something:
-            previous_expiration_time = max(self.protocol.storage.get(search.key_id)[1] or -float('inf'),
-                                           self.protocol.cache.get(search.key_id)[1] or -float('inf'))
-            if search.expiration_time > previous_expiration_time:  # if this value has better expiration
+            _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float('inf'))
+            _, cache_expiration_time = self.protocol.cache.get(search.key_id) or (None, -float('inf'))
+
+            if search.expiration_time > max(storage_expiration_time, cache_expiration_time):
                 if self.cache_locally or _is_refresh:
                 if self.cache_locally or _is_refresh:
                     self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time)
                     self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time)
                 if self.cache_nearest:
                 if self.cache_nearest:
@@ -559,12 +562,12 @@ class _SearchState:
     binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
     binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
     expiration_time: Optional[DHTExpiration] = None  # best expiration time so far
     expiration_time: Optional[DHTExpiration] = None  # best expiration time so far
     source_node_id: Optional[DHTID] = None  # node that gave us the value
     source_node_id: Optional[DHTID] = None  # node that gave us the value
-    future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
+    future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
     serializer: type(SerializerBase) = MSGPackSerializer
     serializer: type(SerializerBase) = MSGPackSerializer
 
 
-    def add_candidate(self, binary_value: Optional[BinaryDHTValue], expiration_time: Optional[DHTExpiration],
-                      source_node_id: Optional[DHTID]):
-        if not self.finished and (expiration_time or -float('inf')) > (self.expiration_time or -float('inf')):
+    def add_candidate(self, candidate: Optional[ValueWithExpiration[BinaryDHTValue]], source_node_id: Optional[DHTID]):
+        binary_value, expiration_time = candidate or (None, -float('inf'))
+        if not self.finished and expiration_time > (self.expiration_time or -float('inf')):
             self.binary_value, self.expiration_time, self.source_node_id = binary_value, expiration_time, source_node_id
             self.binary_value, self.expiration_time, self.source_node_id = binary_value, expiration_time, source_node_id
             if self.expiration_time >= self.sufficient_expiration_time:
             if self.expiration_time >= self.sufficient_expiration_time:
                 self.finish_search()
                 self.finish_search()
@@ -577,13 +580,13 @@ class _SearchState:
         if self.future.done():
         if self.future.done():
             return  # either user cancelled our search or someone sent it before us. Nothing more to do here.
             return  # either user cancelled our search or someone sent it before us. Nothing more to do here.
         elif not self.found_something:
         elif not self.found_something:
-            self.future.set_result((None, None))
+            self.future.set_result(None)
         elif isinstance(self.binary_value, BinaryDHTValue):
         elif isinstance(self.binary_value, BinaryDHTValue):
-            self.future.set_result((self.serializer.loads(self.binary_value), self.expiration_time))
+            self.future.set_result(ValueWithExpiration(self.serializer.loads(self.binary_value), self.expiration_time))
         elif isinstance(self.binary_value, DictionaryDHTValue):
         elif isinstance(self.binary_value, DictionaryDHTValue):
-            dict_value = {key: (self.serializer.loads(value), item_expiration_time)
-                          for key, value, item_expiration_time in self.binary_value.items()}
-            self.future.set_result((dict_value, self.expiration_time))
+            dict_with_subkeys = {key: ValueWithExpiration(self.serializer.loads(value), item_expiration_time)
+                                 for key, (value, item_expiration_time) in self.binary_value.items()}
+            self.future.set_result(ValueWithExpiration(dict_with_subkeys, self.expiration_time))
         else:
         else:
             logger.error(f"Invalid value type: {type(self.binary_value)}")
             logger.error(f"Invalid value type: {type(self.binary_value)}")
 
 

+ 38 - 32
hivemind/dht/protocol.py

@@ -9,13 +9,11 @@ import grpc
 import grpc.experimental.aio
 import grpc.experimental.aio
 
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
-from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
+from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue, ValueWithExpiration
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
-NOT_FOUND_VALUE, NOT_FOUND_EXPIRATION, IS_REGULAR_VALUE, IS_DICTIONARY = b'', -float('inf'), '', '___DictionaryDHTValue'
-RESERVED_SUBKEYS = {IS_REGULAR_VALUE, IS_DICTIONARY}
 
 
 
 
 class DHTProtocol(dht_grpc.DHTServicer):
 class DHTProtocol(dht_grpc.DHTServicer):
@@ -23,9 +21,11 @@ class DHTProtocol(dht_grpc.DHTServicer):
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
-    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     # fmt:on
     # fmt:on
 
 
+    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
+    RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b''
+
     @classmethod
     @classmethod
     async def create(
     async def create(
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
@@ -137,17 +137,19 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """
         """
         if isinstance(expiration_time, DHTExpiration):
         if isinstance(expiration_time, DHTExpiration):
             expiration_time = [expiration_time] * len(keys)
             expiration_time = [expiration_time] * len(keys)
-        if subkeys is None or isinstance(subkeys, Subkey):
-            subkeys = [subkeys] * len(keys)
+        if subkeys is None:
+            subkeys = [None] * len(keys)
 
 
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
         keys, subkeys, values, expiration_time, in_cache = map(list, [keys, subkeys, values, expiration_time, in_cache])
         keys, subkeys, values, expiration_time, in_cache = map(list, [keys, subkeys, values, expiration_time, in_cache])
         for i in range(len(keys)):
         for i in range(len(keys)):
             if subkeys[i] is None:  # add default sub-key if not specified
             if subkeys[i] is None:  # add default sub-key if not specified
-                subkeys[i] = IS_REGULAR_VALUE if not isinstance(values[i], DictionaryDHTValue) else IS_DICTIONARY
+                subkeys[i] = self.IS_DICTIONARY if isinstance(values[i], DictionaryDHTValue) else self.IS_REGULAR_VALUE
+            else:
+                subkeys[i] = self.serializer.dumps(subkeys[i])
             if isinstance(values[i], DictionaryDHTValue):
             if isinstance(values[i], DictionaryDHTValue):
-                assert subkeys[i] == IS_DICTIONARY, "Please do not specify subkey when storing an entire dictionary"
+                assert subkeys[i] == self.IS_DICTIONARY, "Please don't specify subkey when storing an entire dictionary"
                 values[i] = self.serializer.dumps(values[i])
                 values[i] = self.serializer.dumps(values[i])
 
 
         assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
         assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
@@ -172,22 +174,23 @@ class DHTProtocol(dht_grpc.DHTServicer):
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         keys = map(DHTID.from_bytes, request.keys)
         keys = map(DHTID.from_bytes, request.keys)
-        for key_id, subkey, value_bytes, expiration_time, in_cache in zip(
+        for key_id, tag, value_bytes, expiration_time, in_cache in zip(
                 keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
                 keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
             storage = self.cache if in_cache else self.storage
             storage = self.cache if in_cache else self.storage
-            if subkey == IS_REGULAR_VALUE:  # store normal value without subkeys
+            if tag == self.IS_REGULAR_VALUE:  # store normal value without subkeys
                 response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
                 response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
-            elif subkey == IS_DICTIONARY:   # store an entire dictionary with pre-existing subkeys
+            elif tag == self.IS_DICTIONARY:  # store an entire dictionary with several subkeys
                 value_dictionary = self.serializer.loads(value_bytes)
                 value_dictionary = self.serializer.loads(value_bytes)
                 assert isinstance(value_dictionary, DictionaryDHTValue)
                 assert isinstance(value_dictionary, DictionaryDHTValue)
-                response.store_ok.append(all(storage.store_subkey(key_id, subkey, subvalue, subkey_expiration)
-                                             for subkey, subvalue, subkey_expiration in value_dictionary.items()))
-            else:  # add new entry into an existing dictionary-like value or create a new dictionary with one sub-key
+                response.store_ok.append(all(storage.store_subkey(key_id, subkey, item.value, item.expiration_time)
+                                             for subkey, item in value_dictionary.items()))
+            else:  # add a new entry into an existing dictionary value or create a new dictionary with one sub-key
+                subkey = self.serializer.loads(tag)
                 response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
                 response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
         return response
         return response
 
 
-    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \
-            Optional[Dict[DHTID, Tuple[Optional[BinaryDHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]]]:
+    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[
+            Dict[DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, Endpoint]]]]:
         """
         """
         Request keys from a peer. For each key, look for its (value, expiration time) locally and
         Request keys from a peer. For each key, look for its (value, expiration time) locally and
          k additional peers that are most likely to have this key (ranked by XOR distance)
          k additional peers that are most likely to have this key (ranked by XOR distance)
@@ -211,14 +214,17 @@ class DHTProtocol(dht_grpc.DHTServicer):
             output = {}  # unpack data depending on its type
             output = {}  # unpack data depending on its type
             for key, result in zip(keys, response.results):
             for key, result in zip(keys, response.results):
                 nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
                 nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
+
                 if result.type == dht_pb2.NOT_FOUND:
                 if result.type == dht_pb2.NOT_FOUND:
-                    output[key] = None, None, nearest
+                    output[key] = None, nearest
                 elif result.type == dht_pb2.FOUND_REGULAR:
                 elif result.type == dht_pb2.FOUND_REGULAR:
-                    output[key] = result.value, result.expiration_time, nearest
+                    output[key] = ValueWithExpiration(result.value, result.expiration_time), nearest
                 elif result.type == dht_pb2.FOUND_DICTIONARY:
                 elif result.type == dht_pb2.FOUND_DICTIONARY:
-                    output[key] = self.serializer.loads(result.value), result.expiration_time, nearest
+                    deserialized_dictionary = self.serializer.loads(result.value)
+                    output[key] = ValueWithExpiration(deserialized_dictionary, result.expiration_time), nearest
                 else:
                 else:
                     logger.error(f"Unknown result type: {result.type}")
                     logger.error(f"Unknown result type: {result.type}")
+
             return output
             return output
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
@@ -233,22 +239,22 @@ class DHTProtocol(dht_grpc.DHTServicer):
             asyncio.create_task(self.rpc_ping(request.peer, context))
             asyncio.create_task(self.rpc_ping(request.peer, context))
         response = dht_pb2.FindResponse(results=[], peer=self.node_info)
         response = dht_pb2.FindResponse(results=[], peer=self.node_info)
         for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
         for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
-            maybe_value, maybe_expiration_time = self.storage.get(key_id)
-            cached_value, cached_expiration_time = self.cache.get(key_id)
-            if (cached_expiration_time or -float('inf')) > (maybe_expiration_time or -float('inf')):
-                maybe_value, maybe_expiration_time = cached_value, cached_expiration_time
+            maybe_item = self.storage.get(key_id)
+            cached_item = self.cache.get(key_id)
+            if cached_item is not None and (maybe_item is None or cached_item.expiration_time > maybe_item.expiration_time):
+                maybe_item = cached_item
 
 
-            if maybe_expiration_time is None:  # value not found
+            if maybe_item is None:  # value not found
                 item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND)
                 item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND)
-            elif isinstance(maybe_value, DictionaryDHTValue):
-                item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_value),
-                                          expiration_time=maybe_value.latest_expiration_time)
+            elif isinstance(maybe_item.value, DictionaryDHTValue):
+                item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_item.value),
+                                          expiration_time=maybe_item.expiration_time)
             else:  # found regular value
             else:  # found regular value
-                item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_value,
-                                          expiration_time=maybe_expiration_time)
+                item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_item.value,
+                                          expiration_time=maybe_item.expiration_time)
 
 
             for node_id, endpoint in self.routing_table.get_nearest_neighbors(
             for node_id, endpoint in self.routing_table.get_nearest_neighbors(
-                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
+                    key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
                 item.nearest_node_ids.append(node_id.to_bytes())
                 item.nearest_node_ids.append(node_id.to_bytes())
                 item.nearest_endpoints.append(endpoint)
                 item.nearest_endpoints.append(endpoint)
             response.results.append(item)
             response.results.append(item)
@@ -268,7 +274,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             if node_id not in self.routing_table:
             if node_id not in self.routing_table:
                 # we just met a new node, maybe we know some values that it *should* store
                 # we just met a new node, maybe we know some values that it *should* store
                 data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
                 data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
-                for key, value, expiration_time in list(self.storage.items()):
+                for key, item in list(self.storage.items()):
                     neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
                     neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
                     if neighbors:
                     if neighbors:
                         nearest_distance = neighbors[0][0].xor_distance(key)
                         nearest_distance = neighbors[0][0].xor_distance(key)
@@ -276,7 +282,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                         new_node_should_store = node_id.xor_distance(key) < farthest_distance
                         new_node_should_store = node_id.xor_distance(key) < farthest_distance
                         this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
                         this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
-                        data_to_send.append((key, value, expiration_time))
+                        data_to_send.append((key, item.value, item.expiration_time))
                 if data_to_send:
                 if data_to_send:
                     asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
                     asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
 
 

+ 1 - 1
hivemind/dht/routing.py

@@ -12,7 +12,7 @@ from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
 
 
 from hivemind.utils import Endpoint, PickleSerializer
 from hivemind.utils import Endpoint, PickleSerializer
 
 
-DHTKey, Subkey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, str, Any, float, bytes, bytes
+DHTKey, Subkey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, float, bytes, bytes
 get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
 get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
 
 
 
 

+ 42 - 31
hivemind/dht/storage.py

@@ -1,13 +1,24 @@
 from __future__ import annotations
 from __future__ import annotations
 import heapq
 import heapq
 from contextlib import contextmanager
 from contextlib import contextmanager
-from typing import Generic, Optional, Dict, Tuple, List, Iterator, TypeVar, Union, Any
+from typing import Generic, Optional, Dict, Tuple, List, Iterator, TypeVar, Union, NamedTuple
 
 
 from hivemind.dht.routing import DHTID, DHTExpiration, get_dht_time, BinaryDHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTExpiration, get_dht_time, BinaryDHTValue, Subkey
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
 
 
 KeyType = TypeVar('KeyType')
 KeyType = TypeVar('KeyType')
 ValueType = TypeVar('ValueType')
 ValueType = TypeVar('ValueType')
+ROOT = 0
+
+
+class ValueWithExpiration(NamedTuple, Generic[ValueType]):
+    value: ValueType
+    expiration_time: DHTExpiration
+
+
+class HeapEntry(NamedTuple, Generic[KeyType]):
+    expiration_time: DHTExpiration
+    key: KeyType
 
 
 
 
 class TimedStorage(Generic[KeyType, ValueType]):
 class TimedStorage(Generic[KeyType, ValueType]):
@@ -16,17 +27,16 @@ class TimedStorage(Generic[KeyType, ValueType]):
 
 
     def __init__(self, maxsize: Optional[int] = None):
     def __init__(self, maxsize: Optional[int] = None):
         self.maxsize = maxsize or float("inf")
         self.maxsize = maxsize or float("inf")
-        self.data: Dict[KeyType, Tuple[ValueType, DHTExpiration]] = dict()
-        self.expiration_heap: List[Tuple[DHTExpiration, KeyType]] = []
-        self.key_to_heap: Dict[KeyType, Tuple[DHTExpiration, KeyType]] = dict()
+        self.data: Dict[KeyType, ValueWithExpiration[ValueType]] = dict()
+        self.expiration_heap: List[HeapEntry[KeyType]] = []
+        self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
 
 
     def _remove_outdated(self):
     def _remove_outdated(self):
-        while not self.frozen and self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
+        while not self.frozen and self.expiration_heap and (self.expiration_heap[ROOT].expiration_time < get_dht_time()
                                                             or len(self.expiration_heap) > self.maxsize):
                                                             or len(self.expiration_heap) > self.maxsize):
             heap_entry = heapq.heappop(self.expiration_heap)
             heap_entry = heapq.heappop(self.expiration_heap)
-            key = heap_entry[1]
-            if self.key_to_heap.get(key) == heap_entry:
-                del self.data[key], self.key_to_heap[key]
+            if self.key_to_heap.get(heap_entry.key) == heap_entry:
+                del self.data[heap_entry.key], self.key_to_heap[heap_entry.key]
 
 
     def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
     def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
         """
         """
@@ -35,39 +45,39 @@ class TimedStorage(Generic[KeyType, ValueType]):
         """
         """
         if expiration_time < get_dht_time() and not self.frozen:
         if expiration_time < get_dht_time() and not self.frozen:
             return False
             return False
-        self.key_to_heap[key] = (expiration_time, key)
-        heapq.heappush(self.expiration_heap, (expiration_time, key))
+        self.key_to_heap[key] = HeapEntry(expiration_time, key)
+        heapq.heappush(self.expiration_heap, self.key_to_heap[key])
         if key in self.data:
         if key in self.data:
-            if self.data[key][1] < expiration_time:
-                self.data[key] = (value, expiration_time)
+            if self.data[key].expiration_time < expiration_time:
+                self.data[key] = ValueWithExpiration(value, expiration_time)
                 return True
                 return True
             return False
             return False
-        self.data[key] = (value, expiration_time)
+        self.data[key] = ValueWithExpiration(value, expiration_time)
         self._remove_outdated()
         self._remove_outdated()
         return True
         return True
 
 
-    def get(self, key: KeyType) -> (Optional[ValueType], Optional[DHTExpiration]):
+    def get(self, key: KeyType) -> Optional[ValueWithExpiration[ValueType]]:
         """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
         """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
         self._remove_outdated()
         self._remove_outdated()
         if key in self.data:
         if key in self.data:
             return self.data[key]
             return self.data[key]
-        return None, None
+        return None
 
 
-    def items(self) -> Iterator[Tuple[KeyType, ValueType, DHTExpiration]]:
+    def items(self) -> Iterator[Tuple[KeyType, ValueWithExpiration[ValueType]]]:
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         self._remove_outdated()
         self._remove_outdated()
-        return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())
+        return ((key, value_and_expiration) for key, value_and_expiration in self.data.items())
 
 
-    def top(self) -> Optional[Tuple[KeyType, ValueType, DHTExpiration]]:
+    def top(self) -> Tuple[Optional[KeyType], Optional[ValueWithExpiration[ValueType]]]:
         """ Return the entry with earliest expiration or None if there isn't any """
         """ Return the entry with earliest expiration or None if there isn't any """
         self._remove_outdated()
         self._remove_outdated()
         if self.data:
         if self.data:
-            top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
-            while self.key_to_heap.get(top_key) != top_entry:
-                heapq.heappop(self.expiration_heap)  # skip leftover "ghost" entries until first real entry
-                top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
-            value, expiration = self.data[top_key]
-            return top_key, value, expiration
+            # skip leftover "ghost" entries until first real entry
+            while self.key_to_heap.get(self.expiration_heap[ROOT].key) != self.expiration_heap[ROOT]:
+                heapq.heappop(self.expiration_heap)
+            top_key = self.expiration_heap[ROOT].key
+            return top_key, self.data[top_key]
+        return None, None
 
 
     def __contains__(self, key: KeyType):
     def __contains__(self, key: KeyType):
         self._remove_outdated()
         self._remove_outdated()
@@ -109,7 +119,8 @@ class DictionaryDHTValue(TimedStorage[Subkey, BinaryDHTValue]):
 
 
     def packb(self) -> bytes:
     def packb(self) -> bytes:
         """ custom behavior for MSGPackSerializer.dumps """
         """ custom behavior for MSGPackSerializer.dumps """
-        return MSGPackSerializer.dumps([self.maxsize, self.latest_expiration_time, list(map(list, self.items()))])
+        packed_items = [[key, value, expiration_time] for key, (value, expiration_time) in self.items()]
+        return MSGPackSerializer.dumps([self.maxsize, self.latest_expiration_time, packed_items])
 
 
     @classmethod
     @classmethod
     def unpackb(cls, raw: bytes) -> DictionaryDHTValue:
     def unpackb(cls, raw: bytes) -> DictionaryDHTValue:
@@ -143,15 +154,15 @@ class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTVal
          3) if self[key] is a normal value with smaller expiration time, overwrite it with a dictionary and add sub-key
          3) if self[key] is a normal value with smaller expiration time, overwrite it with a dictionary and add sub-key
         :returns: True if new entry was stored, False it was rejected (current value is newer)
         :returns: True if new entry was stored, False it was rejected (current value is newer)
         """
         """
-        previous_value, previous_expiration_time = self.get(key)
-        if isinstance(previous_value, DictionaryDHTValue):  # already a dictionary, just add new subkey
-            if expiration_time > previous_value.latest_expiration_time:
-                super().store(key, previous_value, expiration_time)  # refresh expiration time
-            return previous_value.store(subkey, value, expiration_time)
-        elif expiration_time > (previous_expiration_time or float('-inf')):  # create new dictionary, add subkey
+        previous_value, previous_expiration_time = self.get(key) or (b'', -float('inf'))
+        if isinstance(previous_value, BinaryDHTValue) and expiration_time > previous_expiration_time:
             new_storage = DictionaryDHTValue()
             new_storage = DictionaryDHTValue()
             new_storage.store(subkey, value, expiration_time)
             new_storage.store(subkey, value, expiration_time)
             return super().store(key, new_storage, new_storage.latest_expiration_time)
             return super().store(key, new_storage, new_storage.latest_expiration_time)
+        elif isinstance(previous_value, DictionaryDHTValue):
+            if expiration_time > previous_value.latest_expiration_time:
+                super().store(key, previous_value, expiration_time)  # refresh expiration time
+            return previous_value.store(subkey, value, expiration_time)
         else:
         else:
             return False
             return False
 
 

+ 2 - 2
hivemind/proto/dht.proto

@@ -24,8 +24,8 @@ message NodeInfo {
 message StoreRequest {
 message StoreRequest {
   // three lists of the same length representing dht keys, dht values and expiration
   // three lists of the same length representing dht keys, dht values and expiration
   repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
   repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
-  repeated string subkeys = 2;         // [optional] subkeys for DictionaryDHTValue type, empty string means no subkey
-  repeated bytes values = 3;           // binary-encoded value for i-th key
+  repeated bytes subkeys = 2;          // serialized subkeys for DictionaryDHTValue type. None means no subkey
+  repeated bytes values = 3;           // serialized value for i-th key
   repeated double expiration_time = 4; // expirations for i-th key (type = DHTExpiration)
   repeated double expiration_time = 4; // expirations for i-th key (type = DHTExpiration)
   repeated bool in_cache = 5;          // if in_cache[i], store i-th key in cache, else store normally
   repeated bool in_cache = 5;          // if in_cache[i], store i-th key in cache, else store normally
   NodeInfo peer = 6;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
   NodeInfo peer = 6;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping

+ 41 - 7
tests/test_dht_experts.py

@@ -1,6 +1,7 @@
 import random
 import random
 import uuid
 import uuid
 from itertools import chain
 from itertools import chain
+import numpy as np
 
 
 import hivemind
 import hivemind
 from hivemind import LOCALHOST
 from hivemind import LOCALHOST
@@ -44,6 +45,38 @@ def test_store_get_experts():
         peer.shutdown()
         peer.shutdown()
 
 
 
 
+def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
+                     grid_dims=(32, 32, 32)):
+    dht = []
+    for i in range(dht_size):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
+        dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
+
+    real_experts = sorted({
+        'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
+        for _ in range(total_experts)
+    })
+    for batch_start in range(0, len(real_experts), batch_size):
+        random.choice(dht).declare_experts(
+            real_experts[batch_start: batch_start + batch_size], wait=True,
+            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
+
+    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
+    you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
+
+    for i in range(50):
+        topk_experts = you.find_best_experts('expert', [np.random.randn(dim) for dim in grid_dims], beam_size=beam_size)
+        assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
+        assert len(topk_experts) == beam_size
+
+    for i in range(10):
+        batch_experts = you.batch_find_best_experts('expert', [np.random.randn(batch_size, dim) for dim in grid_dims],
+                                                    beam_size=beam_size)
+        assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
+        assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
+        assert all(len(experts) == beam_size for experts in batch_experts)
+
+
 def test_first_k_active():
 def test_first_k_active():
     node = hivemind.DHT(start=True)
     node = hivemind.DHT(start=True)
     assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
     assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
@@ -63,14 +96,15 @@ def test_first_k_active():
 
 
 def test_dht_single_node():
 def test_dht_single_node():
     node = hivemind.DHT(start=True)
     node = hivemind.DHT(start=True)
-    assert node.first_k_active(['e3', 'e2'], k=3) == {}
-    assert node.get_experts(['e3', 'e2']) == [None, None]
+    assert node.first_k_active(['e.3', 'e.2'], k=3) == {}
+    assert node.get_experts(['e.3', 'e.2']) == [None, None]
 
 
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
-    for expert in node.get_experts(['e3', 'e2']):
+    assert all(node.declare_experts(['e.1', 'e.2', 'e.3'], f"{hivemind.LOCALHOST}:1337"))
+    for expert in node.get_experts(['e.3', 'e.2']):
         assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
         assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
-    active_found = node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2)
-    assert list(active_found.keys()) == ['e1', 'e3']
+    active_found = node.first_k_active(['e.0', 'e.1', 'e.3', 'e.5', 'e.2'], k=2)
+    assert list(active_found.keys()) == ['e.1', 'e.3']
     assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
     assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
 
 
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
+    assert all(node.declare_experts(['e.1', 'e.2', 'e.3'], f"{hivemind.LOCALHOST}:1337"))
+    assert node.find_best_experts('e', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=4)

+ 9 - 9
tests/test_dht_node.py

@@ -62,7 +62,7 @@ def test_dht_protocol():
             assert all(store_ok), "DHT rejected a trivial store"
             assert all(store_ok), "DHT rejected a trivial store"
 
 
             # peer 1 must know about peer 2
             # peer 1 must know about peer 2
-            recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+            (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
                 protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
                 protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
             recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
             recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
             (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
             (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
@@ -75,9 +75,9 @@ def test_dht_protocol():
 
 
             # peer 2 must know about peer 1, but not have a *random* nonexistent value
             # peer 2 must know about peer 1, but not have a *random* nonexistent value
             dummy_key = DHTID.generate()
             dummy_key = DHTID.generate()
-            recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete(
+            empty_item, nodes_found_2 = loop.run_until_complete(
                 protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
                 protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
-            assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values"
+            assert empty_item is None, "Non-existent keys shouldn't have values"
             (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
             (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
             assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
             assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
                 f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
                 f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
@@ -97,7 +97,7 @@ def test_dht_protocol():
                 f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
                 f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
                 expiration_time=[expiration + 5], subkeys=[subkey2])
                 expiration_time=[expiration + 5], subkeys=[subkey2])
             )
             )
-            recv_dict, recv_expiration, nodes_found = loop.run_until_complete(
+            (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
                 protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
                 protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
             assert isinstance(recv_dict, DictionaryDHTValue)
             assert isinstance(recv_dict, DictionaryDHTValue)
             assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
             assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
@@ -134,14 +134,14 @@ def test_empty_table():
 
 
         key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
         key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
 
 
-        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+        empty_item, nodes_found = loop.run_until_complete(
             protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
             protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
-        assert recv_value_bytes is None and recv_expiration is None and len(nodes_found) == 0
+        assert empty_item is None and len(nodes_found) == 0
         assert all(loop.run_until_complete(protocol.call_store(
         assert all(loop.run_until_complete(protocol.call_store(
             f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
             f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
         )), "peer rejected store"
         )), "peer rejected store"
 
 
-        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
             protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
             protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
         recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
         recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
         assert len(nodes_found) == 0
         assert len(nodes_found) == 0
@@ -266,7 +266,7 @@ def test_dht_node():
             assert val == ["Value", 10], "Wrong value"
             assert val == ["Value", 10], "Wrong value"
             assert expiration_time == true_time, f"Wrong time"
             assert expiration_time == true_time, f"Wrong time"
 
 
-        assert loop.run_until_complete(detached_node.get("mykey")) == (None, None)
+        assert loop.run_until_complete(detached_node.get("mykey")) is None
 
 
         # test 7: bulk store and bulk get
         # test 7: bulk store and bulk get
         keys = 'foo', 'bar', 'baz', 'zzz'
         keys = 'foo', 'bar', 'baz', 'zzz'
@@ -429,7 +429,7 @@ def test_dhtnode_reuse_get():
 
 
         assert (await futures1['k1'])[0] == 123
         assert (await futures1['k1'])[0] == 123
         assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
         assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
-        assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) == (None, None)
+        assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
         test_success.set()
         test_success.set()
 
 
     proc = mp.Process(target=lambda: asyncio.run(_tester()))
     proc = mp.Process(target=lambda: asyncio.run(_tester()))

+ 8 - 8
tests/test_dht_storage.py

@@ -16,13 +16,13 @@ def test_get_expired():
     d = DHTLocalStorage()
     d = DHTLocalStorage()
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
     time.sleep(0.5)
     time.sleep(0.5)
-    assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
+    assert d.get(DHTID.generate("key")) is None, "Expired value must be deleted"
     print("Test get expired passed")
     print("Test get expired passed")
 
 
 
 
 def test_get_empty():
 def test_get_empty():
     d = DHTLocalStorage()
     d = DHTLocalStorage()
-    assert d.get(DHTID.generate(source="key")) == (None, None), "DHTLocalStorage returned non-existent value"
+    assert d.get(DHTID.generate(source="key")) is None, "DHTLocalStorage returned non-existent value"
     print("Test get expired passed")
     print("Test get expired passed")
 
 
 
 
@@ -41,7 +41,7 @@ def test_maxsize_cache():
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
     assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
     assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
-    assert d.get(DHTID.generate("key1"))[0] is None, "Value with less exp time, must be deleted"
+    assert d.get(DHTID.generate("key1")) is None, "Value with less exp time, must be deleted"
 
 
 
 
 def test_localstorage_top():
 def test_localstorage_top():
@@ -49,17 +49,17 @@ def test_localstorage_top():
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
     d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
     d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
-    assert d.top()[:2] == (DHTID.generate("key1"), b"val1")
+    assert d.top()[0] == DHTID.generate("key1") and d.top()[1].value == b"val1"
 
 
     d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3)
     d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3)
-    assert d.top()[:2] == (DHTID.generate("key2"), b"val2")
+    assert d.top()[0] == DHTID.generate("key2") and d.top()[1].value == b"val2"
 
 
     del d[DHTID.generate('key2')]
     del d[DHTID.generate('key2')]
-    assert d.top()[:2] == (DHTID.generate("key1"), b"val1_new")
+    assert d.top()[0] == DHTID.generate("key1") and d.top()[1].value == b"val1_new"
     d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5)
     d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5)
     d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6)  # key4 will push out key1 due to maxsize
     d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6)  # key4 will push out key1 due to maxsize
 
 
-    assert d.top()[:2] == (DHTID.generate("key3"), b"val3")
+    assert d.top()[0] == DHTID.generate("key3") and d.top()[1].value == b"val3"
 
 
 
 
 def test_localstorage_nested():
 def test_localstorage_nested():
@@ -71,7 +71,7 @@ def test_localstorage_nested():
     d2.store('subkey3', b'value3', time + 1)
     d2.store('subkey3', b'value3', time + 1)
 
 
     assert d2.latest_expiration_time == time + 3
     assert d2.latest_expiration_time == time + 3
-    for subkey, subvalue, subexpiration in d2.items():
+    for subkey, (subvalue, subexpiration) in d2.items():
         assert d1.store_subkey(DHTID.generate('foo'), subkey, subvalue, subexpiration)
         assert d1.store_subkey(DHTID.generate('foo'), subkey, subvalue, subexpiration)
     assert d1.store(DHTID.generate('bar'), b'456', time + 2)
     assert d1.store(DHTID.generate('bar'), b'456', time + 2)
     assert d1.get(DHTID.generate('foo'))[0].data == d2.data
     assert d1.get(DHTID.generate('foo'))[0].data == d2.data