Browse Source

Faster beam search through DHT sub-keys (#107)

* Added support for dictionary-like DHT keys 
* New beam search based on dictionary new prefixes
* DHTNode now uses a more careful caching policy on store. If a value was rejected, the node will request new value to update its cache
* Add tests for dictionary value types (storage, protocol, node)
* LocalStorage is moved to a separate file and generalized for new value types
* Fixed minor bug where DHTProtocol.store claimed to return None but didn't
* More tests
justheuristic 4 years ago
parent
commit
9f9c4ac96b

+ 5 - 0
docs/modules/dht.rst

@@ -10,6 +10,11 @@ Here's a high level scheme of how these components interact with one another:
    :width: 640
    :width: 640
    :align: center
    :align: center
 
 
+
+**Note:** hivemind.DHT is currently being updated to improve beam search latency
+(see `issue 92 <https://github.com/learning-at-home/hivemind/issues>`__). New functionality will be documented
+here by 2020.10.15 23:59:59 AOE (ping justheuristic for details).
+
 DHT and DHTNode
 DHT and DHTNode
 ###############
 ###############
 
 

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.4'
+__version__ = '0.8.5'

+ 146 - 13
hivemind/dht/__init__.py

@@ -14,11 +14,12 @@ The code is organized as follows:
 """
 """
 import asyncio
 import asyncio
 import ctypes
 import ctypes
+import heapq
 import multiprocessing as mp
 import multiprocessing as mp
 import warnings
 import warnings
 from collections import deque, OrderedDict
 from collections import deque, OrderedDict
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable
+from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque, Set
 
 
 import uvloop
 import uvloop
 
 
@@ -67,7 +68,6 @@ class DHT(mp.Process):
     This beam search explores one additional dimension per step and finds k best experts from across the DHT
     This beam search explores one additional dimension per step and finds k best experts from across the DHT
     in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts.
     in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts.
     """
     """
-
     UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
     UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
     #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
     #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 
 
@@ -147,7 +147,9 @@ class DHT(mp.Process):
             expiration_time = get_dht_time()
             expiration_time = get_dht_time()
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         response = await node.get_many(uids, expiration_time, num_workers=num_workers)
         response = await node.get_many(uids, expiration_time, num_workers=num_workers)
-        future.set_result([RemoteExpert(**expert_data) if maybe_expiration_time else None
+        # TODO expert_data['expert'] -> namedtuple with meaningful field names
+        future.set_result([RemoteExpert(*expert_data['expert'][0])
+                           if maybe_expiration_time else None and expert_data['expert'][1] is not None
                            for uid, (expert_data, maybe_expiration_time) in response.items()])
                            for uid, (expert_data, maybe_expiration_time) in response.items()])
 
 
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
@@ -169,18 +171,148 @@ class DHT(mp.Process):
     async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
     async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         expiration_time = get_dht_time() + self.expiration
         expiration_time = get_dht_time() + self.expiration
-
-        data_to_store = {}
-        for uid in uids:
+        unique_entries: Set[Tuple[str, str]] = set()
+        #                 prefix---v next_dim     uid  endpoint
+        data_to_store: List[Tuple[str, str, List[str, Endpoint]]] = []
+        for uid in uids:  # first k entries are expert uids themselves
+            data_to_store.append((uid, "expert", [uid, endpoint]))
+        for uid in uids:  # and then, add all prefixes
             uid_parts = uid.split(self.UID_DELIMITER)
             uid_parts = uid.split(self.UID_DELIMITER)
-            for i in range(len(uid_parts)):
+            for i in range(len(uid_parts) - 1):
                 uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
                 uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
-                data_to_store[uid_prefix_i] = {'uid': uid, 'endpoint': endpoint}
+                if (uid_prefix_i, uid_parts[i + 1]) in unique_entries:
+                    continue
+                unique_entries.add((uid_prefix_i, uid_parts[i + 1]))
+                data_to_store.append((uid_prefix_i, uid_parts[i + 1], [uid, endpoint]))
+
+        keys, subkeys, values = map(list, zip(*data_to_store))
+        store_ok = await node.store_many(keys, values, expiration_time, subkeys=subkeys, num_workers=num_workers)
+        if future is not None:
+            future.set_result([store_ok[key, subkey] for key, subkey in zip(keys, subkeys)])
+
+    def find_best_experts(self, prefix: str, grid_scores: Sequence[Sequence[float]], beam_size: int, *,
+                          return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]:
+        """
+        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
+
+        :param prefix: common prefix for all expert uids in grid
+        :param grid_scores: scores predicted for each dimension in the grid,
+        :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
+        :param beam_size: how many best experts should beam search return
+         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
+         Please note that any queries that fall outside the budget will still be performed in background and cached
+         for subsequent iterations as long as DHTNode.cache_locally is True
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :param kwargs: extra keyword parameters passed to DHTNode.get_many
+        :returns: a list that contains *up to* k_best RemoteExpert instances
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_find_best_experts', [], dict(prefix=prefix, grid_scores=list(map(tuple, grid_scores)),
+                                                       beam_size=beam_size, future=_future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _find_best_experts(
+            self, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
+            max_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]:
+        max_workers: Optional[int] = max_workers or self.max_workers or beam_size
+
+        # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
+        beam: List[Tuple[float, str, Dict[str, List[str, Endpoint]]]] = await self._get_initial_beam(
+            node, prefix, beam_size, grid_scores[0], num_workers=min(beam_size, max_workers))
+        if not beam:
+            logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
+            return []
+
+        for dim_index in range(1, len(grid_scores) - 1):
+            # select beam_size best suffixes from current beam
+            dim_scores = grid_scores[dim_index]
+            best_active_pairs: List[Tuple[float, str]] = heapq.nlargest(beam_size, (
+                (prefix_score + dim_scores[int(suffix_i)], f"{prefix}{self.UID_DELIMITER}{suffix_i}")
+                for prefix_score, prefix, suffixes in beam for suffix_i in suffixes.keys()
+                # TODO get rid of str.isdecimal
+                if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)))
+
+            # search DHT for next step suffixes
+            _, best_uid_prefixes = zip(*best_active_pairs)
+            # TODO Tuple[Dict[str, List[str, Endpoint]], DHTExpiration] -> namedtuple
+            dht_responses: Dict[str, Tuple[Dict[str, List[str, Endpoint]], DHTExpiration]] = await node.get_many(
+                keys=best_uid_prefixes, num_workers=min(len(best_uid_prefixes), max_workers), **kwargs)
+            if all(expiration is None for key, (_, expiration) in dht_responses.items()):
+                logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim {dim_index})")
+                break
+            beam = [(prefix_score, prefix, dht_responses[prefix][0])  # add suffix dict if it is found
+                    for prefix_score, prefix in best_active_pairs if dht_responses[prefix][1] is not None]
+
+        # select best experts from the final beam
+        dim_scores = grid_scores[-1]
+        final_best_pairs: List[Tuple[float, str, Endpoint]] = heapq.nlargest(beam_size, (
+            (prefix_score + dim_scores[int(suffix_i)], uid, endpoint)
+            for prefix_score, prefix, suffixes in beam for suffix_i, ((uid, endpoint), _) in suffixes.items()
+            if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)
+        ))
+        best_experts = [RemoteExpert(uid, endpoint) for score, uid, endpoint in final_best_pairs]
+        if future is not None:
+            future.set_result(best_experts)
+        return best_experts
+
+    def batch_find_best_experts(self, prefix: str, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, *,
+                                return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]:
+        """
+        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
+
+        :param prefix: common prefix for all expert uids in grid
+        :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
+        :type batch_grid_scores: model scores for each example and each grid dimension,  list of arrays of shape (batch_size, grid_size[i])
+        :param beam_size: how many best experts should beam search return
+         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
+         Please note that any queries that fall outside the budget will still be performed in background and cached
+         for subsequent iterations as long as DHTNode.cache_locally is True
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :param kwargs: extra keyword parameters passed to DHTNode.get_many
+        :returns: a list that contains *up to* k_best RemoteExpert instances
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_batch_find_best_experts', [], dict(prefix=prefix, batch_grid_scores=batch_grid_scores,
+                                                             beam_size=beam_size, future=_future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _batch_find_best_experts(
+            self, node: DHTNode, prefix: str, batch_grid_scores: Sequence[Sequence[Tuple[float]]], beam_size: int,
+            max_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[List[RemoteExpert]]:
+
+        batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))]
+        coros = [self._find_best_experts(node, prefix, grid_scores, beam_size, max_workers, **kwargs) for grid_scores in batch_grid_scores]
 
 
-        store_keys, store_values = zip(*data_to_store.items())
-        store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers)
+        best_experts_batch = await asyncio.gather(*coros)
         if future is not None:
         if future is not None:
-            future.set_result([store_ok[key] for key in data_to_store.keys()])
+            future.set_result(best_experts_batch)
+        return best_experts_batch
+
+    async def _get_initial_beam(self, node, prefix: str, beam_size: int, scores: Tuple[float, ...], num_workers: int
+                                ) -> List[Tuple[float, str, Dict[str, List[str]]]]:
+        """ Fetch a list of all active level-one prefixes of a given prefix. Used for beam search """
+        beam: List[Tuple[float, str, Dict[str, List[str, Endpoint]]]] = []  # results will be stored here
+        unattempted_indices: List[int] = sorted(range(len(scores)), key=scores.__getitem__)  # order: worst to best
+        pending_tasks: Deque[Tuple[int, str, asyncio.Task]] = deque()  # up to num_workers concurrent get tasks
+
+        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
+            # dispatch additional tasks
+            while unattempted_indices and len(pending_tasks) < num_workers:
+                next_index = unattempted_indices.pop()  # note: this is best unattempted index because of sort order
+                next_best_prefix = f"{prefix}{self.UID_DELIMITER}{next_index}"
+                pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
+
+            # await the next best prefix to be fetched
+            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
+            try:
+                maybe_prefix_data, maybe_expiration_time = await pending_task
+                if maybe_expiration_time is not None:
+                    beam.append((scores[pending_best_index], pending_best_prefix, maybe_prefix_data))
+            except asyncio.CancelledError:
+                for _, pending_task in pending_tasks:
+                    pending_task.cancel()
+                raise
+        return beam
 
 
     def first_k_active(
     def first_k_active(
             self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None,
             self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None,
@@ -196,6 +328,7 @@ class DHT(mp.Process):
         :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
         :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
             The keys in the returned dict are ordered same as in uid_prefixes.
             The keys in the returned dict are ordered same as in uid_prefixes.
         """
         """
+        logger.warning("first_k_active is deprecated and will be removed in 0.8.6")
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
         self.pipe.send(('_first_k_active', [],
@@ -220,8 +353,8 @@ class DHT(mp.Process):
             response = await pending_tasks.popleft()
             response = await pending_tasks.popleft()
             for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
             for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
                 maybe_expert_data, maybe_expiration_time = response[uid_prefix]
                 maybe_expert_data, maybe_expiration_time = response[uid_prefix]
-                if maybe_expiration_time is not None:  # found active peer
-                    found.append((uid_prefix, RemoteExpert(**maybe_expert_data)))
+                if maybe_expiration_time is not None and len(maybe_expert_data) > 0:  # found active peer
+                    found.append((uid_prefix, RemoteExpert(*next(iter(maybe_expert_data.values()))[0])))
                     # if we found enough active experts, finish immediately
                     # if we found enough active experts, finish immediately
                     if len(found) >= k:
                     if len(found) >= k:
                         break
                         break

+ 147 - 110
hivemind/dht/node.py

@@ -1,17 +1,18 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
-
 import random
 import random
 from collections import defaultdict
 from collections import defaultdict
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
-from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any, Iterable
-from sortedcontainers import SortedList
 from functools import partial
 from functools import partial
+from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
 from warnings import warn
 from warnings import warn
 
 
-from hivemind.dht.protocol import DHTProtocol, LocalStorage
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue
+from sortedcontainers import SortedList
+
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
+from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
 from hivemind.dht.traverse import traverse_dht
 from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
 from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
 
 
@@ -39,7 +40,7 @@ class DHTNode:
         nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
         nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
         This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
         This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
 
 
-    Formally, DHTNode follows the following contract:
+    A DHTNode follows the following contract:
 
 
     - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
     - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
       IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
       IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
@@ -53,9 +54,8 @@ class DHTNode:
     # fmt:off
     # fmt:off
     node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
     node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
     refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
-    cache_refresh_available: asyncio.Event; cache_refresh_queue: LocalStorage
-    reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_IntermediateResult]]
-    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
+    cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]]
+    cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
     # fmt:on
     # fmt:on
 
 
     @classmethod
     @classmethod
@@ -64,8 +64,8 @@ class DHTNode:
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
-            reuse_get_requests: bool = True, num_workers: int = 1, listen: bool = True,
-            listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
+            cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1,
+            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
         """
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
@@ -81,6 +81,7 @@ class DHTNode:
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
         :param cache_locally: if True, caches all values (stored or found) in a node-local cache
         :param cache_locally: if True, caches all values (stored or found) in a node-local cache
+        :param cache_on_store: if True, update cache entries for a key after storing a new item for that key
         :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
         :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
           nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
           nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
         :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
         :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
@@ -96,10 +97,6 @@ class DHTNode:
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
         :param kwargs: extra parameters used in grpc.aio.server
         :param kwargs: extra parameters used in grpc.aio.server
         """
         """
-        if cache_refresh_before_expiry > 0 and not cache_locally:
-            logger.warning("If cache_locally is False, cache_refresh_before_expiry has no effect. To silence this"
-                           " warning, please specify cache_refresh_before_expiry=0")
-
         self = cls(_initialized_with_create=True)
         self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.num_replicas, self.num_workers = num_replicas, num_workers
         self.num_replicas, self.num_workers = num_replicas, num_workers
@@ -110,12 +107,11 @@ class DHTNode:
 
 
         # caching policy
         # caching policy
         self.refresh_timeout = refresh_timeout
         self.refresh_timeout = refresh_timeout
-        self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
+        self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
         self.cache_refresh_before_expiry = cache_refresh_before_expiry
         self.cache_refresh_before_expiry = cache_refresh_before_expiry
-        self.cache_refresh_queue = LocalStorage()
-        self.cache_refresh_available = asyncio.Event()
-        if cache_refresh_before_expiry:
-            asyncio.create_task(self._refresh_stale_cache_entries())
+        self.cache_refresh_queue = CacheRefreshQueue()
+        self.cache_refresh_evt = asyncio.Event()
+        self.cache_refresh_task = None
 
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
                                                  parallel_rpc, cache_size, listen, listen_on, **kwargs)
                                                  parallel_rpc, cache_size, listen, listen_on, **kwargs)
@@ -211,25 +207,27 @@ class DHTNode:
             nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
             nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
         return nearest_nodes_with_endpoints
         return nearest_nodes_with_endpoints
 
 
-    async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, **kwargs) -> bool:
+    async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
+                    subkey: Optional[Subkey] = None, **kwargs) -> bool:
         """
         """
         Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
         Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
-
         :note: store is a simplified interface to store_many, all kwargs are be forwarded there
         :note: store is a simplified interface to store_many, all kwargs are be forwarded there
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         """
-        store_ok = await self.store_many([key], [value], [expiration_time], **kwargs)
-        return store_ok[key]
+        store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)
+        return store_ok[(key, subkey) if subkey is not None else key]
 
 
     async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
     async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
                          expiration_time: Union[DHTExpiration, List[DHTExpiration]],
                          expiration_time: Union[DHTExpiration, List[DHTExpiration]],
+                         subkeys: Optional[Union[Subkey, List[Optional[Subkey]]]] = None,
                          exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
                          exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
         """
         """
-        Traverse DHT to find up to best nodes to store multiple (key, value, expiration_time) pairs.
+        Traverse DHT to find up :num_replicas: to best nodes to store multiple (key, value, expiration_time) pairs.
 
 
         :param keys: arbitrary serializable keys associated with each value
         :param keys: arbitrary serializable keys associated with each value
         :param values: serializable "payload" for each key
         :param values: serializable "payload" for each key
         :param expiration_time: either one expiration time for all keys or individual expiration times (see class doc)
         :param expiration_time: either one expiration time for all keys or individual expiration times (see class doc)
+        :param subkeys: an optional list of same shape as keys. If specified, this
         :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
         :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
         :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
         :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
         :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
         :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
@@ -239,24 +237,23 @@ class DHTNode:
         """
         """
         if isinstance(expiration_time, DHTExpiration):
         if isinstance(expiration_time, DHTExpiration):
             expiration_time = [expiration_time] * len(keys)
             expiration_time = [expiration_time] * len(keys)
-        assert len(keys) == len(values) == len(expiration_time), "Number of keys, values and expiration doesn't match."
+        if subkeys is None or isinstance(subkeys, Subkey):
+            subkeys = [subkeys] * len(keys)
 
 
-        key_ids = list(map(DHTID.generate, keys))
-        id_to_original_key = dict(zip(key_ids, keys))
-        binary_values_by_key_id = {key_id: self.serializer.dumps(value) for key_id, value in zip(key_ids, values)}
-        expiration_by_key_id = {key_id: expiration_time for key_id, expiration_time in zip(key_ids, expiration_time)}
-        unfinished_key_ids = set(key_ids)  # we use this set to ensure that each store request is finished
+        assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
+            "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
 
 
-        store_ok = {key: False for key in keys}  # outputs, updated during search
-        store_finished_events = {key: asyncio.Event() for key in keys}
+        key_id_to_data: DefaultDict[DHTID, List[Tuple[DHTKey, Subkey, DHTValue, DHTExpiration]]] = defaultdict(list)
+        for key, subkey, value, expiration in zip(keys, subkeys, values, expiration_time):
+            key_id_to_data[DHTID.generate(source=key)].append((key, subkey, value, expiration))
 
 
-        if self.cache_locally:
-            for key_id in key_ids:
-                self.protocol.cache.store(key_id, binary_values_by_key_id[key_id], expiration_by_key_id[key_id])
+        unfinished_key_ids = set(key_id_to_data.keys())  # use this set to ensure that each store request is finished
+        store_ok = {(key, subkey): None for key, subkey in zip(keys, subkeys)}  # outputs, updated during search
+        store_finished_events = {(key, subkey): asyncio.Event() for key, subkey in zip(keys, subkeys)}
 
 
         # pre-populate node_to_endpoint
         # pre-populate node_to_endpoint
         node_to_endpoint: Dict[DHTID, Endpoint] = dict()
         node_to_endpoint: Dict[DHTID, Endpoint] = dict()
-        for key_id in key_ids:
+        for key_id in unfinished_key_ids:
             node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
             node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
 
@@ -272,48 +269,73 @@ class DHTNode:
             pending_store_tasks = set()
             pending_store_tasks = set()
             store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
             store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
                                       key=key_id.xor_distance, reverse=True)  # ordered so that .pop() returns nearest
                                       key=key_id.xor_distance, reverse=True)  # ordered so that .pop() returns nearest
+            [original_key, *_], current_subkeys, current_values, current_expirations = zip(*key_id_to_data[key_id])
+            binary_values: List[bytes] = list(map(self.protocol.serializer.dumps, current_values))
 
 
             while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
             while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
-                # spawn enough tasks to cover all replicas
                 while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
                 while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
                     node_id: DHTID = store_candidates.pop()  # nearest untried candidate
                     node_id: DHTID = store_candidates.pop()  # nearest untried candidate
+
                     if node_id == self.node_id:
                     if node_id == self.node_id:
-                        self.protocol.storage.store(key_id, binary_values_by_key_id[key_id],
-                                                    expiration_by_key_id[key_id])
-                        store_ok[id_to_original_key[key_id]] = True
                         num_successful_stores += 1
                         num_successful_stores += 1
-                        if not await_all_replicas:
-                            store_finished_events[id_to_original_key[key_id]].set()
-
+                        for subkey, value, expiration_time in zip(current_subkeys, binary_values, current_expirations):
+                            store_ok[original_key, subkey] = self.protocol.storage.store(
+                                key_id, value, expiration_time, subkey=subkey)
+                            if not await_all_replicas:
+                                store_finished_events[original_key, subkey].set()
                     else:
                     else:
                         pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
                         pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
-                            node_to_endpoint[node_id], [key_id], [binary_values_by_key_id[key_id]],
-                            [expiration_by_key_id[key_id]])))
+                            node_to_endpoint[node_id], keys=[key_id] * len(current_values), values=binary_values,
+                            expiration_time=current_expirations, subkeys=current_subkeys)))
 
 
                 # await nearest task. If it fails, dispatch more on the next iteration
                 # await nearest task. If it fails, dispatch more on the next iteration
                 if pending_store_tasks:
                 if pending_store_tasks:
                     finished_store_tasks, pending_store_tasks = await asyncio.wait(
                     finished_store_tasks, pending_store_tasks = await asyncio.wait(
                         pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
                         pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
                     for task in finished_store_tasks:
                     for task in finished_store_tasks:
-                        if task.result()[0]:  # if store succeeded
-                            store_ok[id_to_original_key[key_id]] = True
+                        if task.result() is not None:
                             num_successful_stores += 1
                             num_successful_stores += 1
-                            if not await_all_replicas:
-                                store_finished_events[id_to_original_key[key_id]].set()
+                            for subkey, store_status in zip(current_subkeys, task.result()):
+                                store_ok[original_key, subkey] = store_status
+                                if not await_all_replicas:
+                                    store_finished_events[original_key, subkey].set()
 
 
-            store_finished_events[id_to_original_key[key_id]].set()
+            if self.cache_on_store:
+                self._update_cache_on_store(key_id, current_subkeys, binary_values, current_expirations,
+                                            store_ok=[store_ok[original_key, subkey] for subkey in current_subkeys])
+
+            for subkey, value_bytes, expiration in zip(current_subkeys, binary_values, current_expirations):
+                store_finished_events[original_key, subkey].set()
 
 
         store_task = asyncio.create_task(self.find_nearest_nodes(
         store_task = asyncio.create_task(self.find_nearest_nodes(
-            queries=set(key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
+            queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
             found_callback=on_found, exclude_self=exclude_self, **kwargs))
             found_callback=on_found, exclude_self=exclude_self, **kwargs))
         try:
         try:
             await asyncio.wait([evt.wait() for evt in store_finished_events.values()])  # wait for items to be stored
             await asyncio.wait([evt.wait() for evt in store_finished_events.values()])  # wait for items to be stored
             assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
             assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
-            return store_ok
+            return {(key, subkey) if subkey else key: status or False for (key, subkey), status in store_ok.items()}
         except asyncio.CancelledError as e:
         except asyncio.CancelledError as e:
             store_task.cancel()
             store_task.cancel()
             raise e
             raise e
 
 
+    def _update_cache_on_store(self, key_id: DHTID, subkeys: List[Subkey], binary_values: List[bytes],
+                               expirations: List[DHTExpiration], store_ok: List[bool]):
+        """ Update local cache after finishing a store for one key (with perhaps several subkeys) """
+        store_succeeded = any(store_ok)
+        is_dictionary = any(subkey is not None for subkey in subkeys)
+        if store_succeeded and not is_dictionary:  # stored a new regular value, cache it!
+            stored_value_bytes, stored_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
+            self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration)
+        elif not store_succeeded and not is_dictionary:  # store rejected, check if local cache is also obsolete
+            rejected_value, rejected_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
+            self.protocol.cache.store(key_id, rejected_value, rejected_expiration)  # can still be better than cache
+            if (self.protocol.cache.get(key_id)[1] or float("inf")) <= rejected_expiration:  # cache would be rejected
+                self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
+        else:  # stored a dictionary (or failed to store), either way, there can be other keys and we should update
+            for subkey, stored_value_bytes, expiration_time in zip(subkeys, binary_values, expirations):
+                self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
+            self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
+
     async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
     async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
         """
         """
         Search for a key across DHT and return either first or latest entry.
         Search for a key across DHT and return either first or latest entry.
@@ -350,8 +372,8 @@ class DHTNode:
     async def get_many_by_id(
     async def get_many_by_id(
             self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
             self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
             num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
             num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
-            _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
-                                                      Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
+            _is_refresh=False) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
+                                                    Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
         """
         """
         Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
         Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
 
 
@@ -364,17 +386,17 @@ class DHTNode:
         :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
         :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
          The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
          The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
          Note: canceling a future will stop search for the corresponding key
          Note: canceling a future will stop search for the corresponding key
-        :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh
+        :param _is_refresh: internal flag, set to True by an internal cache refresher (if enabled)
         :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
         :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
         :note: in order to check if get returned a value, please check (expiration_time is None)
         :note: in order to check if get returned a value, please check (expiration_time is None)
         """
         """
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
         num_workers = num_workers if num_workers is not None else self.num_workers
         num_workers = num_workers if num_workers is not None else self.num_workers
-        search_results: Dict[DHTID, _IntermediateResult] = {key_id: _IntermediateResult(
-            key_id, sufficient_expiration_time, serializer=self.serializer) for key_id in key_ids}
+        search_results: Dict[DHTID, _SearchState] = {key_id: _SearchState(
+            key_id, sufficient_expiration_time, serializer=self.protocol.serializer) for key_id in key_ids}
 
 
-        if _refresh_cache:
+        if not _is_refresh:  # if we're already refreshing cache, there's no need to trigger subsequent refreshes
             for key_id in key_ids:
             for key_id in key_ids:
                 search_results[key_id].add_done_callback(self._trigger_cache_refresh)
                 search_results[key_id].add_done_callback(self._trigger_cache_refresh)
 
 
@@ -387,7 +409,8 @@ class DHTNode:
         # stage 1: check for value in this node's local storage and cache
         # stage 1: check for value in this node's local storage and cache
         for key_id in key_ids:
         for key_id in key_ids:
             search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
             search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
-            search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
+            if not _is_refresh:
+                search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
 
 
         # stage 2: traverse the DHT to get the remaining keys from remote peers
         # stage 2: traverse the DHT to get the remaining keys from remote peers
         unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
         unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
@@ -414,7 +437,7 @@ class DHTNode:
         # V-- this function will be called exactly once when traverse_dht finishes search for a given key
         # V-- this function will be called exactly once when traverse_dht finishes search for a given key
         async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
         async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
             search_results[key_id].finish_search()  # finish search whether or we found something
             search_results[key_id].finish_search()  # finish search whether or we found something
-            self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint)
+            self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh)
 
 
         asyncio.create_task(traverse_dht(
         asyncio.create_task(traverse_dht(
             queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
             queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
@@ -433,9 +456,9 @@ class DHTNode:
                     search_result.future.cancel()
                     search_result.future.cancel()
                 raise e
                 raise e
 
 
-    def _reuse_finished_search_result(self, finished: _IntermediateResult):
+    def _reuse_finished_search_result(self, finished: _SearchState):
         expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
         expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
-        concurrent_requests: SortedList[_IntermediateResult] = self.pending_get_requests[finished.key_id]
+        concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
         # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
         # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
         while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
         while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
             concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
             concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
@@ -443,66 +466,72 @@ class DHTNode:
             concurrent_requests[-1].finish_search()
             concurrent_requests[-1].finish_search()
             concurrent_requests.pop(-1)
             concurrent_requests.pop(-1)
 
 
-    def _trigger_cache_refresh(self, result: _IntermediateResult):
+    def _trigger_cache_refresh(self, search: _SearchState):
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
-        if result.found_something and result.source_node_id == self.node_id:
-            with self.protocol.cache.freeze():  # do not clear outdated cache for now...
-                if self.cache_refresh_before_expiry and result.key_id in self.protocol.cache:
-                    previous_earliest_item: Tuple[DHTID, BinaryDHTValue, DHTExpiration] = self.cache_refresh_queue.top()
-                    self.cache_refresh_queue.store(result.key_id, result.binary_value, result.expiration_time)
-                    if previous_earliest_item is None or result.expiration_time < previous_earliest_item[-1]:
-                        self.cache_refresh_available.set()  # if we new element is now earliest, notify the cache queue
+        if search.found_something and search.source_node_id == self.node_id:
+            if self.cache_refresh_before_expiry and search.key_id in self.protocol.cache:
+                self._schedule_for_refresh(search.key_id, search.expiration_time - self.cache_refresh_before_expiry)
+
+    def _schedule_for_refresh(self, key_id: DHTID, refresh_time: DHTExpiration):
+        """ Add key to a refresh queue, refresh at :refresh_time: or later """
+        if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
+            self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
+            logger.debug("Spawned cache refresh task.")
+        previous_earliest_item: Tuple[DHTID, Any, DHTExpiration] = self.cache_refresh_queue.top()
+        if previous_earliest_item is None or refresh_time < previous_earliest_item[-1]:
+            self.cache_refresh_evt.set()  # if we new element is now earliest, notify the cache queue
+        self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
 
 
     async def _refresh_stale_cache_entries(self):
     async def _refresh_stale_cache_entries(self):
         """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
         """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
         while self.is_alive:
         while self.is_alive:
-            with self.cache_refresh_queue.freeze():
-                while len(self.cache_refresh_queue) == 0:
-                    await self.cache_refresh_available.wait()
-                    self.cache_refresh_available.clear()
-                key_id, _, nearest_expiration = self.cache_refresh_queue.top()
+            while len(self.cache_refresh_queue) == 0:
+                await self.cache_refresh_evt.wait()
+                self.cache_refresh_evt.clear()
+            key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
 
 
             try:
             try:
                 # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
                 # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
-                time_to_wait = nearest_expiration - get_dht_time() - self.cache_refresh_before_expiry
-                await asyncio.wait_for(self.cache_refresh_available.wait(), timeout=time_to_wait)
+                time_to_wait = nearest_refresh_time - get_dht_time()
+                await asyncio.wait_for(self.cache_refresh_evt.wait(), timeout=time_to_wait)
                 # note: the line above will cause TimeoutError when we are ready to refresh cache
                 # note: the line above will cause TimeoutError when we are ready to refresh cache
-                self.cache_refresh_available.clear()  # no timeout error => someone added new entry to queue and ...
+                self.cache_refresh_evt.clear()  # no timeout error => someone added new entry to queue and ...
                 continue  # ... and this element is earlier than nearest_expiration. we should refresh this entry first
                 continue  # ... and this element is earlier than nearest_expiration. we should refresh this entry first
 
 
             except asyncio.TimeoutError:  # caught TimeoutError => it is time to refresh the most recent cached entry
             except asyncio.TimeoutError:  # caught TimeoutError => it is time to refresh the most recent cached entry
                 # step 2: find all keys that we should already refresh and remove them from queue
                 # step 2: find all keys that we should already refresh and remove them from queue
-                with self.cache_refresh_queue.freeze():
-                    keys_to_refresh = {key_id}
+                current_time = get_dht_time()
+                keys_to_refresh = {key_id}
+                max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
+                del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
+                while self.cache_refresh_queue:
+                    key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
+                    if nearest_refresh_time > current_time:
+                        break
                     del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
                     del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
-                    while self.cache_refresh_queue:
-                        key_id, _, nearest_expiration = self.cache_refresh_queue.top()
-                        if nearest_expiration > get_dht_time() + self.cache_refresh_before_expiry:
-                            break
-                        del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
-                        keys_to_refresh.add(key_id)
+                    keys_to_refresh.add(key_id)
+                    max_expiration_time = max(max_expiration_time, self.protocol.cache.get(key_id)[1] or current_time)
 
 
                 # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
                 # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
-                await self.get_many_by_id(
-                    keys_to_refresh, sufficient_expiration_time=nearest_expiration + self.cache_refresh_before_expiry,
-                    _refresh_cache=False)  # if we found value locally, we shouldn't trigger another refresh
+                sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
+                await self.get_many_by_id(keys_to_refresh, sufficient_expiration_time, _is_refresh=True)
 
 
-    def _cache_new_result(self, result: _IntermediateResult, nearest_nodes: List[DHTID],
-                          node_to_endpoint: Dict[DHTID, Endpoint]):
+    def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID],
+                          node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
         """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
         """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
-        if result.found_something:
-            previous_expiration_time = max(self.protocol.storage.get(result.key_id)[1] or -float('inf'),
-                                           self.protocol.cache.get(result.key_id)[1] or -float('inf'))
-            if result.expiration_time > previous_expiration_time:  # if this value has better expiration
-                if self.cache_locally:
-                    self.protocol.cache.store(result.key_id, result.binary_value, result.expiration_time)
+        if search.found_something:
+            previous_expiration_time = max(self.protocol.storage.get(search.key_id)[1] or -float('inf'),
+                                           self.protocol.cache.get(search.key_id)[1] or -float('inf'))
+            if search.expiration_time > previous_expiration_time:  # if this value has better expiration
+                if self.cache_locally or _is_refresh:
+                    self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time)
                 if self.cache_nearest:
                 if self.cache_nearest:
                     num_cached_nodes = 0
                     num_cached_nodes = 0
                     for node_id in nearest_nodes:
                     for node_id in nearest_nodes:
-                        if node_id == result.source_node_id:
+                        if node_id == search.source_node_id:
                             continue
                             continue
                         asyncio.create_task(self.protocol.call_store(
                         asyncio.create_task(self.protocol.call_store(
-                            node_to_endpoint[node_id], [result.key_id], [result.binary_value], [result.expiration_time],
+                            node_to_endpoint[node_id], [search.key_id], [search.binary_value], [search.expiration_time],
                             in_cache=True))
                             in_cache=True))
                         num_cached_nodes += 1
                         num_cached_nodes += 1
                         if num_cached_nodes >= self.cache_nearest:
                         if num_cached_nodes >= self.cache_nearest:
@@ -523,11 +552,11 @@ class DHTNode:
 
 
 
 
 @dataclass(init=True, repr=True, frozen=False, order=False)
 @dataclass(init=True, repr=True, frozen=False, order=False)
-class _IntermediateResult:
+class _SearchState:
     """ A helper class that stores current-best GET results with metadata """
     """ A helper class that stores current-best GET results with metadata """
     key_id: DHTID
     key_id: DHTID
     sufficient_expiration_time: DHTExpiration
     sufficient_expiration_time: DHTExpiration
-    binary_value: Optional[BinaryDHTValue] = None
+    binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
     expiration_time: Optional[DHTExpiration] = None  # best expiration time so far
     expiration_time: Optional[DHTExpiration] = None  # best expiration time so far
     source_node_id: Optional[DHTID] = None  # node that gave us the value
     source_node_id: Optional[DHTID] = None  # node that gave us the value
     future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
     future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
@@ -540,25 +569,33 @@ class _IntermediateResult:
             if self.expiration_time >= self.sufficient_expiration_time:
             if self.expiration_time >= self.sufficient_expiration_time:
                 self.finish_search()
                 self.finish_search()
 
 
-    def add_done_callback(self, callback: Callable[[_IntermediateResult], Any]):
-        """ Add callback that will be called when _IntermediateSearchResult is done (found OR cancelled by user) """
+    def add_done_callback(self, callback: Callable[[_SearchState], Any]):
+        """ Add callback that will be called when _SearchState is done (found OR cancelled by user) """
         self.future.add_done_callback(lambda _future: callback(self))
         self.future.add_done_callback(lambda _future: callback(self))
 
 
     def finish_search(self):
     def finish_search(self):
         if self.future.done():
         if self.future.done():
-            return  # either user cancelled our result or someone sent it before us. Nothing more to do here.
-        deserialized_value = self.serializer.loads(self.binary_value) if self.found_something else None
-        self.future.set_result((deserialized_value, self.expiration_time))
+            return  # either user cancelled our search or someone sent it before us. Nothing more to do here.
+        elif not self.found_something:
+            self.future.set_result((None, None))
+        elif isinstance(self.binary_value, BinaryDHTValue):
+            self.future.set_result((self.serializer.loads(self.binary_value), self.expiration_time))
+        elif isinstance(self.binary_value, DictionaryDHTValue):
+            dict_value = {key: (self.serializer.loads(value), item_expiration_time)
+                          for key, value, item_expiration_time in self.binary_value.items()}
+            self.future.set_result((dict_value, self.expiration_time))
+        else:
+            logger.error(f"Invalid value type: {type(self.binary_value)}")
 
 
     @property
     @property
     def found_something(self) -> bool:
     def found_something(self) -> bool:
-        """ Whether or not we have at least some result, regardless of its expiration time """
+        """ Whether or not we have found at least some value, regardless of its expiration time """
         return self.expiration_time is not None
         return self.expiration_time is not None
 
 
     @property
     @property
     def finished(self) -> bool:
     def finished(self) -> bool:
         return self.future.done()
         return self.future.done()
 
 
-    def __lt__(self, other: _IntermediateResult):
-        """ _IntermediateResult instances will be sorted by their target expiration time """
+    def __lt__(self, other: _SearchState):
+        """ _SearchState instances will be sorted by their target expiration time """
         return self.sufficient_expiration_time < other.sufficient_expiration_time
         return self.sufficient_expiration_time < other.sufficient_expiration_time

+ 75 - 130
hivemind/dht/protocol.py

@@ -2,26 +2,28 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
-import heapq
-from contextlib import contextmanager
-from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
+from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 from warnings import warn
 from warnings import warn
 
 
 import grpc
 import grpc
 import grpc.experimental.aio
 import grpc.experimental.aio
 
 
-from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
+from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
+from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
-from hivemind.utils import Endpoint, get_logger, replace_port
+from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
+NOT_FOUND_VALUE, NOT_FOUND_EXPIRATION, IS_REGULAR_VALUE, IS_DICTIONARY = b'', -float('inf'), '', '___DictionaryDHTValue'
+RESERVED_SUBKEYS = {IS_REGULAR_VALUE, IS_DICTIONARY}
 
 
 
 
 class DHTProtocol(dht_grpc.DHTServicer):
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
-    storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
+    storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
+    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     # fmt:on
     # fmt:on
 
 
     @classmethod
     @classmethod
@@ -44,7 +46,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self = cls(_initialized_with_create=True)
         self = cls(_initialized_with_create=True)
         self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
         self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
         self.wait_timeout, self.channel_options = wait_timeout, channel_options
         self.wait_timeout, self.channel_options = wait_timeout, channel_options
-        self.storage, self.cache = LocalStorage(), LocalStorage(maxsize=cache_size)
+        self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
 
 
@@ -110,30 +112,46 @@ class DHTProtocol(dht_grpc.DHTServicer):
             asyncio.create_task(self.update_routing_table(sender_id, rpc_endpoint))
             asyncio.create_task(self.update_routing_table(sender_id, rpc_endpoint))
         return self.node_info
         return self.node_info
 
 
-    async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
+    async def call_store(self, peer: Endpoint, keys: Sequence[DHTID],
+                         values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
                          expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
                          expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
-                         in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Sequence[bool]:
+                         subkeys: Optional[Union[Subkey, Sequence[Optional[Subkey]]]] = None,
+                         in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Optional[List[bool]]:
         """
         """
         Ask a recipient to store several (key, value : expiration_time) items or update their older value
         Ask a recipient to store several (key, value : expiration_time) items or update their older value
 
 
         :param peer: request this peer to store the data
         :param peer: request this peer to store the data
         :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
         :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
         :param values: a list of N serialized values (bytes) for each respective key
         :param values: a list of N serialized values (bytes) for each respective key
-        :param expiration_time: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
+        :param expiration_time: a list of N expiration timestamps for each respective key-value pair(see get_dht_time())
+        :param subkeys: a list of N optional sub-keys. If None, stores value normally. If not subkey is not None:
+          1) if local storage doesn't have :key:, create a new dictionary {subkey: (value, expiration_time)}
+          2) if local storage already has a dictionary under :key:, try add (subkey, value, exp_time) to that dictionary
+          2) if local storage associates :key: with a normal value with smaller expiration, clear :key: and perform (1)
+          3) finally, if local storage currently associates :key: with a normal value with larger expiration, do nothing
         :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
         :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
         :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
         :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
          until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
          until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
-
         :return: list of [True / False] True = stored, False = failed (found newer value or no response)
         :return: list of [True / False] True = stored, False = failed (found newer value or no response)
-         if peer did not respond (e.g. due to timeout or congestion), returns None
+                 if peer did not respond (e.g. due to timeout or congestion), returns None
         """
         """
         if isinstance(expiration_time, DHTExpiration):
         if isinstance(expiration_time, DHTExpiration):
             expiration_time = [expiration_time] * len(keys)
             expiration_time = [expiration_time] * len(keys)
+        if subkeys is None or isinstance(subkeys, Subkey):
+            subkeys = [subkeys] * len(keys)
+
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
-        keys, values, expiration_time, in_cache = map(list, [keys, values, expiration_time, in_cache])
+        keys, subkeys, values, expiration_time, in_cache = map(list, [keys, subkeys, values, expiration_time, in_cache])
+        for i in range(len(keys)):
+            if subkeys[i] is None:  # add default sub-key if not specified
+                subkeys[i] = IS_REGULAR_VALUE if not isinstance(values[i], DictionaryDHTValue) else IS_DICTIONARY
+            if isinstance(values[i], DictionaryDHTValue):
+                assert subkeys[i] == IS_DICTIONARY, "Please do not specify subkey when storing an entire dictionary"
+                values[i] = self.serializer.dumps(values[i])
+
         assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
         assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
-        store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
+        store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), subkeys=subkeys, values=values,
                                              expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
                                              expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
         try:
         try:
             async with self.rpc_semaphore:
             async with self.rpc_semaphore:
@@ -145,7 +163,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
-            return [False] * len(keys)
+            return None
 
 
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
         """ Some node wants us to store this (key, value) pair """
         """ Some node wants us to store this (key, value) pair """
@@ -153,10 +171,19 @@ class DHTProtocol(dht_grpc.DHTServicer):
             asyncio.create_task(self.rpc_ping(request.peer, context))
             asyncio.create_task(self.rpc_ping(request.peer, context))
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
-        for key_bytes, value_bytes, expiration_time, in_cache in zip(
-                request.keys, request.values, request.expiration_time, request.in_cache):
-            local_memory = self.cache if in_cache else self.storage
-            response.store_ok.append(local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time))
+        keys = map(DHTID.from_bytes, request.keys)
+        for key_id, subkey, value_bytes, expiration_time, in_cache in zip(
+                keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
+            storage = self.cache if in_cache else self.storage
+            if subkey == IS_REGULAR_VALUE:  # store normal value without subkeys
+                response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
+            elif subkey == IS_DICTIONARY:   # store an entire dictionary with pre-existing subkeys
+                value_dictionary = self.serializer.loads(value_bytes)
+                assert isinstance(value_dictionary, DictionaryDHTValue)
+                response.store_ok.append(all(storage.store_subkey(key_id, subkey, subvalue, subkey_expiration)
+                                             for subkey, subvalue, subkey_expiration in value_dictionary.items()))
+            else:  # add new entry into an existing dictionary-like value or create a new dictionary with one sub-key
+                response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
         return response
         return response
 
 
     async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \
     async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \
@@ -179,16 +206,19 @@ class DHTProtocol(dht_grpc.DHTServicer):
             if response.peer and response.peer.node_id:
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
-            assert len(response.values) == len(response.expiration_time) == len(response.nearest) == len(keys), \
-                "DHTProtocol: response is not aligned with keys and/or expiration times"
-
-            output = {}  # unpack data without special NOT_FOUND_* values
-            for key, value, expiration_time, nearest in zip(
-                    keys, response.values, response.expiration_time, response.nearest):
-                value = value if value != _NOT_FOUND_VALUE else None
-                expiration_time = expiration_time if expiration_time != _NOT_FOUND_EXPIRATION else None
-                nearest = dict(zip(map(DHTID.from_bytes, nearest.node_ids), nearest.endpoints))
-                output[key] = (value, expiration_time, nearest)
+            assert len(keys) == len(response.results), "DHTProtocol: response is not aligned with keys"
+
+            output = {}  # unpack data depending on its type
+            for key, result in zip(keys, response.results):
+                nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
+                if result.type == dht_pb2.NOT_FOUND:
+                    output[key] = None, None, nearest
+                elif result.type == dht_pb2.FOUND_REGULAR:
+                    output[key] = result.value, result.expiration_time, nearest
+                elif result.type == dht_pb2.FOUND_DICTIONARY:
+                    output[key] = self.serializer.loads(result.value), result.expiration_time, nearest
+                else:
+                    logger.error(f"Unknown result type: {result.type}")
             return output
             return output
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
@@ -201,24 +231,27 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """
         """
         if request.peer:  # if requested, add peer to the routing table
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(request.peer, context))
             asyncio.create_task(self.rpc_ping(request.peer, context))
-
-        response = dht_pb2.FindResponse(values=[], expiration_time=[], nearest=[], peer=self.node_info)
-        for key_id in map(DHTID.from_bytes, request.keys):
+        response = dht_pb2.FindResponse(results=[], peer=self.node_info)
+        for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
             maybe_value, maybe_expiration_time = self.storage.get(key_id)
             maybe_value, maybe_expiration_time = self.storage.get(key_id)
             cached_value, cached_expiration_time = self.cache.get(key_id)
             cached_value, cached_expiration_time = self.cache.get(key_id)
             if (cached_expiration_time or -float('inf')) > (maybe_expiration_time or -float('inf')):
             if (cached_expiration_time or -float('inf')) > (maybe_expiration_time or -float('inf')):
                 maybe_value, maybe_expiration_time = cached_value, cached_expiration_time
                 maybe_value, maybe_expiration_time = cached_value, cached_expiration_time
 
 
-            nearest_neighbors = self.routing_table.get_nearest_neighbors(
-                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id))
-            if nearest_neighbors:
-                peer_ids, endpoints = zip(*nearest_neighbors)
-            else:
-                peer_ids, endpoints = [], []
-
-            response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
-            response.expiration_time.append(maybe_expiration_time if maybe_expiration_time else _NOT_FOUND_EXPIRATION)
-            response.nearest.append(dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)), endpoints=endpoints))
+            if maybe_expiration_time is None:  # value not found
+                item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND)
+            elif isinstance(maybe_value, DictionaryDHTValue):
+                item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_value),
+                                          expiration_time=maybe_value.latest_expiration_time)
+            else:  # found regular value
+                item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_value,
+                                          expiration_time=maybe_expiration_time)
+
+            for node_id, endpoint in self.routing_table.get_nearest_neighbors(
+                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
+                item.nearest_node_ids.append(node_id.to_bytes())
+                item.nearest_endpoints.append(endpoint)
+            response.results.append(item)
         return response
         return response
 
 
     async def update_routing_table(self, node_id: Optional[DHTID], peer_endpoint: Endpoint, responded=True):
     async def update_routing_table(self, node_id: Optional[DHTID], peer_endpoint: Endpoint, responded=True):
@@ -256,91 +289,3 @@ class DHTProtocol(dht_grpc.DHTServicer):
         else:  # we sent outgoing request and peer did not respond
         else:  # we sent outgoing request and peer did not respond
             if node_id is not None and node_id in self.routing_table:
             if node_id is not None and node_id in self.routing_table:
                 del self.routing_table[node_id]
                 del self.routing_table[node_id]
-
-
-_NOT_FOUND_VALUE, _NOT_FOUND_EXPIRATION = b'', -float('inf')  # internal values to represent that a value was not found
-
-
-class LocalStorage:
-    """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration_time) """
-
-    def __init__(self, maxsize: Optional[int] = None):
-        self.cache_size = maxsize or float("inf")
-        self.data: Dict[DHTID, Tuple[BinaryDHTValue, DHTExpiration]] = dict()
-        self.expiration_heap: List[Tuple[DHTExpiration, DHTID]] = []
-        self.key_to_heap: Dict[DHTID, Tuple[DHTExpiration, DHTID]] = dict()
-        self.frozen = False  # if True, do not remove outdated elements
-
-    def _remove_outdated(self):
-        while not self.frozen and self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
-                                                            or len(self.expiration_heap) > self.cache_size):
-            heap_entry = heapq.heappop(self.expiration_heap)
-            key = heap_entry[1]
-            if self.key_to_heap.get(key) == heap_entry:
-                del self.data[key], self.key_to_heap[key]
-
-    def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
-        """
-        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
-        :returns: True if new value was stored, False it was rejected (current value is newer)
-        """
-        if expiration_time < get_dht_time() and not self.frozen:
-            return False
-        self.key_to_heap[key] = (expiration_time, key)
-        heapq.heappush(self.expiration_heap, (expiration_time, key))
-        if key in self.data:
-            if self.data[key][1] < expiration_time:
-                self.data[key] = (value, expiration_time)
-                return True
-            return False
-        self.data[key] = (value, expiration_time)
-        self._remove_outdated()
-        return True
-
-    def get(self, key: DHTID) -> (Optional[BinaryDHTValue], Optional[DHTExpiration]):
-        """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
-        self._remove_outdated()
-        if key in self.data:
-            return self.data[key]
-        return None, None
-
-    def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
-        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
-        self._remove_outdated()
-        return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())
-
-    def top(self) -> Optional[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
-        """ Return the entry with earliest expiration or None if there isn't any """
-        self._remove_outdated()
-        if self.data:
-            top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
-            while self.key_to_heap.get(top_key) != top_entry:
-                heapq.heappop(self.expiration_heap)  # skip leftover "ghost" entries until first real entry
-                top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
-            value, expiration = self.data[top_key]
-            return top_key, value, expiration
-
-    def __contains__(self, key: DHTID):
-        self._remove_outdated()
-        return key in self.data
-
-    def __len__(self):
-        self._remove_outdated()
-        return len(self.data)
-
-    def __delitem__(self, key: DHTID):
-        if key in self.key_to_heap:
-            del self.data[key], self.key_to_heap[key]
-        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
-
-    def __bool__(self):
-        return bool(self.data)
-
-    @contextmanager
-    def freeze(self):
-        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
-        prev_frozen, self.frozen = self.frozen, True
-        try:
-            yield self
-        finally:
-            self.frozen = prev_frozen

+ 1 - 1
hivemind/dht/routing.py

@@ -12,7 +12,7 @@ from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
 
 
 from hivemind.utils import Endpoint, PickleSerializer
 from hivemind.utils import Endpoint, PickleSerializer
 
 
-DHTKey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, float, bytes, bytes  # flavour types
+DHTKey, Subkey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, str, Any, float, bytes, bytes
 get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
 get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
 
 
 
 

+ 161 - 0
hivemind/dht/storage.py

@@ -0,0 +1,161 @@
+from __future__ import annotations
+import heapq
+from contextlib import contextmanager
+from typing import Generic, Optional, Dict, Tuple, List, Iterator, TypeVar, Union, Any
+
+from hivemind.dht.routing import DHTID, DHTExpiration, get_dht_time, BinaryDHTValue, Subkey
+from hivemind.utils.serializer import MSGPackSerializer
+
+KeyType = TypeVar('KeyType')
+ValueType = TypeVar('ValueType')
+
+
+class TimedStorage(Generic[KeyType, ValueType]):
+    """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
+    frozen = False  # can be set to True. If true, do not remove outdated elements
+
+    def __init__(self, maxsize: Optional[int] = None):
+        self.maxsize = maxsize or float("inf")
+        self.data: Dict[KeyType, Tuple[ValueType, DHTExpiration]] = dict()
+        self.expiration_heap: List[Tuple[DHTExpiration, KeyType]] = []
+        self.key_to_heap: Dict[KeyType, Tuple[DHTExpiration, KeyType]] = dict()
+
+    def _remove_outdated(self):
+        while not self.frozen and self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
+                                                            or len(self.expiration_heap) > self.maxsize):
+            heap_entry = heapq.heappop(self.expiration_heap)
+            key = heap_entry[1]
+            if self.key_to_heap.get(key) == heap_entry:
+                del self.data[key], self.key_to_heap[key]
+
+    def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
+        """
+        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
+        :returns: True if new value was stored, False it was rejected (current value is newer)
+        """
+        if expiration_time < get_dht_time() and not self.frozen:
+            return False
+        self.key_to_heap[key] = (expiration_time, key)
+        heapq.heappush(self.expiration_heap, (expiration_time, key))
+        if key in self.data:
+            if self.data[key][1] < expiration_time:
+                self.data[key] = (value, expiration_time)
+                return True
+            return False
+        self.data[key] = (value, expiration_time)
+        self._remove_outdated()
+        return True
+
+    def get(self, key: KeyType) -> (Optional[ValueType], Optional[DHTExpiration]):
+        """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
+        self._remove_outdated()
+        if key in self.data:
+            return self.data[key]
+        return None, None
+
+    def items(self) -> Iterator[Tuple[KeyType, ValueType, DHTExpiration]]:
+        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
+        self._remove_outdated()
+        return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())
+
+    def top(self) -> Optional[Tuple[KeyType, ValueType, DHTExpiration]]:
+        """ Return the entry with earliest expiration or None if there isn't any """
+        self._remove_outdated()
+        if self.data:
+            top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
+            while self.key_to_heap.get(top_key) != top_entry:
+                heapq.heappop(self.expiration_heap)  # skip leftover "ghost" entries until first real entry
+                top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
+            value, expiration = self.data[top_key]
+            return top_key, value, expiration
+
+    def __contains__(self, key: KeyType):
+        self._remove_outdated()
+        return key in self.data
+
+    def __len__(self):
+        self._remove_outdated()
+        return len(self.data)
+
+    def __delitem__(self, key: KeyType):
+        if key in self.key_to_heap:
+            del self.data[key], self.key_to_heap[key]
+        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+
+    def __bool__(self):
+        return bool(self.data)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.data})"
+
+    @contextmanager
+    def freeze(self):
+        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
+        prev_frozen, self.frozen = self.frozen, True
+        try:
+            yield self
+        finally:
+            self.frozen = prev_frozen
+
+
+@MSGPackSerializer.ext_serializable(0x50)
+class DictionaryDHTValue(TimedStorage[Subkey, BinaryDHTValue]):
+    """ a dictionary-like DHT value type that maps sub-keys to values with individual expirations """
+    latest_expiration_time = float('-inf')
+
+    def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
+        self.latest_expiration_time = max(self.latest_expiration_time, expiration_time)
+        return super().store(key, value, expiration_time)
+
+    def packb(self) -> bytes:
+        """ custom behavior for MSGPackSerializer.dumps """
+        return MSGPackSerializer.dumps([self.maxsize, self.latest_expiration_time, list(map(list, self.items()))])
+
+    @classmethod
+    def unpackb(cls, raw: bytes) -> DictionaryDHTValue:
+        maxsize, latest_expiration_time, items = MSGPackSerializer.loads(raw)
+        with DictionaryDHTValue(maxsize).freeze() as new_dict:
+            for key, value, expiration_time in items:
+                new_dict.store(key, value, expiration_time)
+            new_dict.latest_expiration_time = latest_expiration_time
+            return new_dict
+
+
+class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTValue]]):
+    """ A dictionary-like storage that can store binary values and/or nested dictionaries until expiration """
+    def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration,
+              subkey: Optional[Subkey] = None) -> bool:
+        """
+        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
+        If subkey is not None, adds a subkey-value pair to a dictionary associated with :key: (see store_subkey below)
+        :returns: True if new value was stored, False it was rejected (current value is newer)
+        """
+        if subkey is not None:  # add one sub-key
+            return self.store_subkey(key, subkey, value, expiration_time)
+        else:  # store regular key
+            return super().store(key, value, expiration_time)
+
+    def store_subkey(self, key: DHTID, subkey: Subkey, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
+        """
+        Save a (sub-key, value) into a dictionary associated with a given key.
+         1) if self[key] is empty, create a new dictionary and add sub-key there
+         2) if self[key] is a dictionary (DictionaryDHTValue), store {sub-key: value, expiration} to that storage
+         3) if self[key] is a normal value with smaller expiration time, overwrite it with a dictionary and add sub-key
+        :returns: True if new entry was stored, False it was rejected (current value is newer)
+        """
+        previous_value, previous_expiration_time = self.get(key)
+        if isinstance(previous_value, DictionaryDHTValue):  # already a dictionary, just add new subkey
+            if expiration_time > previous_value.latest_expiration_time:
+                super().store(key, previous_value, expiration_time)  # refresh expiration time
+            return previous_value.store(subkey, value, expiration_time)
+        elif expiration_time > (previous_expiration_time or float('-inf')):  # create new dictionary, add subkey
+            new_storage = DictionaryDHTValue()
+            new_storage.store(subkey, value, expiration_time)
+            return super().store(key, new_storage, new_storage.latest_expiration_time)
+        else:
+            return False
+
+
+class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
+    """ a queue of keys scheduled for refresh in future, used in DHTNode """
+    frozen = True

+ 18 - 12
hivemind/proto/dht.proto

@@ -24,10 +24,11 @@ message NodeInfo {
 message StoreRequest {
 message StoreRequest {
   // three lists of the same length representing dht keys, dht values and expiration
   // three lists of the same length representing dht keys, dht values and expiration
   repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
   repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
-  repeated bytes values = 2;           // binary-encoded value for i-th key
-  repeated double expiration_time = 3; // expirations for i-th key (type = DHTExpiration)
-  repeated bool in_cache = 4;          // if in_cache[i], store i-th key in cache, else store normally
-  NodeInfo peer = 5;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  repeated string subkeys = 2;         // [optional] subkeys for DictionaryDHTValue type, empty string means no subkey
+  repeated bytes values = 3;           // binary-encoded value for i-th key
+  repeated double expiration_time = 4; // expirations for i-th key (type = DHTExpiration)
+  repeated bool in_cache = 5;          // if in_cache[i], store i-th key in cache, else store normally
+  NodeInfo peer = 6;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
 }
 }
 
 
 message StoreResponse {
 message StoreResponse {
@@ -40,16 +41,21 @@ message FindRequest {
   NodeInfo peer = 2;                   // optional, same behavior as in DHT.ping
   NodeInfo peer = 2;                   // optional, same behavior as in DHT.ping
 }
 }
 
 
-message Peers {
-  // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
-  repeated bytes node_ids = 1;         // DHTID serialized with node_id.to_bytes()
-  repeated string endpoints = 2;       // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+enum ResultType {NOT_FOUND = 0; FOUND_REGULAR = 1; FOUND_DICTIONARY = 2;}
+
+message FindResult {
+  ResultType type = 1;                 // NONE |      REGULAR     | DICTIONARY
+  bytes value = 2;                     // n/a  | serialized value | serialized DictionaryDHTValue with serialized fields
+  double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
+
+  // two aligned arrays: DHTIDs and Endpoints for nearest peers (sorted by XOR distance)
+  repeated bytes nearest_node_ids = 4;      // DHTIDs serialized with node_id.to_bytes()
+  repeated string nearest_endpoints = 5;    // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
 }
 }
 
 
+
 message FindResponse {
 message FindResponse {
-  repeated bytes values = 1;           // value for i-th key, b'' means not found locally
-  repeated double expiration_time = 2; // expiration time for i-th value, only valid value is found
-  repeated Peers nearest = 3;          // peers ordered from nearest to farthest based on distance to i-th key
-  NodeInfo peer = 4;                   // respondent's node id, for you to update routing table
+  repeated FindResult results = 1;       // for each item, return value/expiration (if found) and nearest peers
+  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
 }
 }
 
 

+ 0 - 1
hivemind/server/checkpoint_saver.py

@@ -1,5 +1,4 @@
 import threading
 import threading
-import time
 from datetime import datetime
 from datetime import datetime
 from pathlib import Path
 from pathlib import Path
 from shutil import copytree
 from shutil import copytree

+ 43 - 7
hivemind/utils/serializer.py

@@ -1,9 +1,14 @@
 """ A unified interface for several common serialization methods """
 """ A unified interface for several common serialization methods """
 import pickle
 import pickle
 from io import BytesIO
 from io import BytesIO
+from typing import Dict, Any
 
 
 import torch
 import torch
-import umsgpack
+import msgpack
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
 
 
 class SerializerBase:
 class SerializerBase:
     @staticmethod
     @staticmethod
@@ -38,10 +43,41 @@ class PytorchSerializer(SerializerBase):
 
 
 
 
 class MSGPackSerializer(SerializerBase):
 class MSGPackSerializer(SerializerBase):
-    @staticmethod
-    def dumps(obj: object) -> bytes:
-        return umsgpack.dumps(obj, use_bin_type=False, strict_types=True)
+    _ExtTypes: Dict[Any, int] = {}
+    _ExtTypeCodes: Dict[int, Any] = {}
+
+    @classmethod
+    def ext_serializable(cls, type_code: int):
+        assert isinstance(type_code, int), "Please specify a (unique) int type code"
+
+        def wrap(wrapped_type: type):
+            assert callable(getattr(wrapped_type, 'packb', None)) and callable(getattr(wrapped_type, 'unpackb', None)),\
+                f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
+            if type_code in cls._ExtTypeCodes:
+                logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting.")
+            cls._ExtTypeCodes[type_code], cls._ExtTypes[wrapped_type] = wrapped_type, type_code
+            return wrapped_type
+        return wrap
+
+    @classmethod
+    def _encode_ext_types(cls, obj):
+        type_code = cls._ExtTypes.get(type(obj))
+        if type_code is not None:
+            return msgpack.ExtType(type_code, obj.packb())
+        return obj
+
+    @classmethod
+    def _decode_ext_types(cls, type_code: int, data: bytes):
+        if type_code in cls._ExtTypeCodes:
+            return cls._ExtTypeCodes[type_code].unpackb(data)
+        logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
+        return data
+
+    @classmethod
+    def dumps(cls, obj: object) -> bytes:
+        return msgpack.dumps(obj, use_bin_type=True, default=cls._encode_ext_types, strict_types=True)
+
+    @classmethod
+    def loads(cls, buf: bytes) -> object:
+        return msgpack.loads(buf, ext_hook=cls._decode_ext_types, raw=False)
 
 
-    @staticmethod
-    def loads(buf: bytes) -> object:
-        return umsgpack.loads(buf, raw=False)

+ 1 - 1
requirements.txt

@@ -2,7 +2,7 @@ PyYAML
 torch>=1.3.0
 torch>=1.3.0
 numpy>=1.17
 numpy>=1.17
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1
-umsgpack
+msgpack>=0.5.6
 sortedcontainers
 sortedcontainers
 uvloop>=0.14.0
 uvloop>=0.14.0
 grpcio>=1.31
 grpcio>=1.31

+ 1 - 1
tests/test_dht_experts.py

@@ -6,7 +6,7 @@ import hivemind
 from hivemind import LOCALHOST
 from hivemind import LOCALHOST
 
 
 
 
-def test_hivemind_dht():
+def test_store_get_experts():
     peers = [hivemind.DHT(start=True)]
     peers = [hivemind.DHT(start=True)]
     for i in range(10):
     for i in range(10):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]

+ 91 - 8
tests/test_dht_node.py

@@ -11,6 +11,7 @@ from typing import List, Dict
 from hivemind import get_dht_time
 from hivemind import get_dht_time
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.storage import DictionaryDHTValue
 
 
 
 
 def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
 def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
@@ -85,6 +86,24 @@ def test_dht_protocol():
             dummy_port = hivemind.find_open_port()
             dummy_port = hivemind.find_open_port()
             assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
             assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
 
 
+            # store/get a dictionary with sub-keys
+            nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
+            value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
+            assert loop.run_until_complete(protocol.call_store(
+                f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
+                expiration_time=[expiration], subkeys=[subkey1])
+            )
+            assert loop.run_until_complete(protocol.call_store(
+                f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
+                expiration_time=[expiration + 5], subkeys=[subkey2])
+            )
+            recv_dict, recv_expiration, nodes_found = loop.run_until_complete(
+                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
+            assert isinstance(recv_dict, DictionaryDHTValue)
+            assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
+            assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
+            assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
+
             if listen:
             if listen:
                 loop.run_until_complete(protocol.shutdown())
                 loop.run_until_complete(protocol.shutdown())
             print("DHTProtocol test finished successfully!")
             print("DHTProtocol test finished successfully!")
@@ -172,7 +191,8 @@ def test_dht_node():
         # note: we run everything in a separate process to re-initialize all global states from scratch
         # note: we run everything in a separate process to re-initialize all global states from scratch
         # this helps us avoid undesirable side-effects when running multiple tests in sequence
         # this helps us avoid undesirable side-effects when running multiple tests in sequence
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
-        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10))
+        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
+                                                    cache_refresh_before_expiry=False))
 
 
         # test 1: find self
         # test 1: find self
         nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
         nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
@@ -229,19 +249,24 @@ def test_dht_node():
         assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
         assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
 
 
         # test 5: node without peers
         # test 5: node without peers
-        other_node = loop.run_until_complete(DHTNode.create())
-        nearest = loop.run_until_complete(other_node.find_nearest_nodes([dummy]))[dummy]
-        assert len(nearest) == 1 and nearest[other_node.node_id] == f"{LOCALHOST}:{other_node.port}"
-        nearest = loop.run_until_complete(other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
+        detached_node = loop.run_until_complete(DHTNode.create())
+        nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
+        assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
+        nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
         assert len(nearest) == 0
         assert len(nearest) == 0
 
 
         # test 6 store and get value
         # test 6 store and get value
         true_time = get_dht_time() + 1200
         true_time = get_dht_time() + 1200
         assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
         assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
-        for node in [me, other_node]:
-            val, expiration_time = loop.run_until_complete(me.get("mykey"))
-            assert expiration_time == true_time, "Wrong time"
+        that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
+                                                          cache_refresh_before_expiry=False, cache_locally=False))
+
+        for node in [me, that_guy]:
+            val, expiration_time = loop.run_until_complete(node.get("mykey"))
             assert val == ["Value", 10], "Wrong value"
             assert val == ["Value", 10], "Wrong value"
+            assert expiration_time == true_time, f"Wrong time"
+
+        assert loop.run_until_complete(detached_node.get("mykey")) == (None, None)
 
 
         # test 7: bulk store and bulk get
         # test 7: bulk store and bulk get
         keys = 'foo', 'bar', 'baz', 'zzz'
         keys = 'foo', 'bar', 'baz', 'zzz'
@@ -252,6 +277,31 @@ def test_dht_node():
         for key, value in zip(keys, values):
         for key, value in zip(keys, values):
             assert key in response and response[key][0] == value
             assert key in response and response[key][0] == value
 
 
+        # test 8: store dictionaries as values (with sub-keys)
+        upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
+        now = get_dht_time()
+        assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
+        assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
+        for node in [that_guy, me]:
+            value, time = loop.run_until_complete(node.get(upper_key))
+            assert isinstance(value, dict) and time == now + 20
+            assert value[subkey1] == (123, now + 10)
+            assert value[subkey2] == (456, now + 20)
+            assert len(value) == 2
+
+        assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
+        assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
+        assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
+        loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
+
+        for node in [that_guy, me]:
+            value, time = loop.run_until_complete(node.get(upper_key))
+            assert isinstance(value, dict) and time == now + 50, (value, time)
+            assert value[subkey1] == (123, now + 10)
+            assert value[subkey2] == (567, now + 30)
+            assert value[subkey3] == (890, now + 50)
+            assert len(value) == 3
+
         test_success.set()
         test_success.set()
 
 
     tester = mp.Process(target=_tester, daemon=True)
     tester = mp.Process(target=_tester, daemon=True)
@@ -262,6 +312,39 @@ def test_dht_node():
         proc.terminate()
         proc.terminate()
 
 
 
 
+def test_dhtnode_replicas():
+    dht_size = 20
+    initial_peers = 3
+    num_replicas = random.randint(1, 20)
+    test_success = mp.Event()
+
+    async def _tester():
+        peers = []
+        for i in range(dht_size):
+            neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
+            peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
+
+        you = random.choice(peers)
+        assert await you.store('key1', 'foo', get_dht_time() + 999)
+
+        actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
+        assert num_replicas == actual_key1_replicas
+
+        assert await you.store('key2', 'bar', get_dht_time() + 999)
+        total_size = sum(len(peer.protocol.storage) for peer in peers)
+        actual_key2_replicas = total_size - actual_key1_replicas
+        assert num_replicas == actual_key2_replicas
+
+        assert await you.store('key2', 'baz', get_dht_time() + 1000)
+        assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
+        test_success.set()
+
+    proc = mp.Process(target=lambda: asyncio.run(_tester()))
+    proc.start()
+    proc.join()
+    assert test_success.is_set()
+
+
 def test_dhtnode_caching(T=0.05):
 def test_dhtnode_caching(T=0.05):
     test_success = mp.Event()
     test_success = mp.Event()
 
 

+ 60 - 10
tests/test_dht_storage.py

@@ -1,18 +1,19 @@
 import time
 import time
 
 
-from hivemind import DHTID, get_dht_time
-from hivemind.dht.protocol import LocalStorage
+from hivemind.dht.routing import get_dht_time
+from hivemind.dht.storage import DHTLocalStorage, DHTID, DictionaryDHTValue
+from hivemind.utils.serializer import MSGPackSerializer
 
 
 
 
 def test_store():
 def test_store():
-    d = LocalStorage()
+    d = DHTLocalStorage()
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
     assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
     assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
     print("Test store passed")
     print("Test store passed")
 
 
 
 
 def test_get_expired():
 def test_get_expired():
-    d = LocalStorage()
+    d = DHTLocalStorage()
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
     time.sleep(0.5)
     time.sleep(0.5)
     assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
     assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
@@ -20,13 +21,13 @@ def test_get_expired():
 
 
 
 
 def test_get_empty():
 def test_get_empty():
-    d = LocalStorage()
-    assert d.get(DHTID.generate(source="key")) == (None, None), "LocalStorage returned non-existent value"
+    d = DHTLocalStorage()
+    assert d.get(DHTID.generate(source="key")) == (None, None), "DHTLocalStorage returned non-existent value"
     print("Test get expired passed")
     print("Test get expired passed")
 
 
 
 
 def test_change_expiration_time():
 def test_change_expiration_time():
-    d = LocalStorage()
+    d = DHTLocalStorage()
     d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
     assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
     assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
     d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
     d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
@@ -36,7 +37,7 @@ def test_change_expiration_time():
 
 
 
 
 def test_maxsize_cache():
 def test_maxsize_cache():
-    d = LocalStorage(maxsize=1)
+    d = DHTLocalStorage(maxsize=1)
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
     assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
     assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
@@ -44,7 +45,7 @@ def test_maxsize_cache():
 
 
 
 
 def test_localstorage_top():
 def test_localstorage_top():
-    d = LocalStorage(maxsize=3)
+    d = DHTLocalStorage(maxsize=3)
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
     d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
     d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
@@ -61,8 +62,40 @@ def test_localstorage_top():
     assert d.top()[:2] == (DHTID.generate("key3"), b"val3")
     assert d.top()[:2] == (DHTID.generate("key3"), b"val3")
 
 
 
 
+def test_localstorage_nested():
+    time = get_dht_time()
+    d1 = DHTLocalStorage()
+    d2 = DictionaryDHTValue()
+    d2.store('subkey1', b'value1', time + 2)
+    d2.store('subkey2', b'value2', time + 3)
+    d2.store('subkey3', b'value3', time + 1)
+
+    assert d2.latest_expiration_time == time + 3
+    for subkey, subvalue, subexpiration in d2.items():
+        assert d1.store_subkey(DHTID.generate('foo'), subkey, subvalue, subexpiration)
+    assert d1.store(DHTID.generate('bar'), b'456', time + 2)
+    assert d1.get(DHTID.generate('foo'))[0].data == d2.data
+    assert d1.get(DHTID.generate('foo'))[1] == d2.latest_expiration_time
+    assert d1.get(DHTID.generate('foo'))[0].get('subkey1') == (b'value1', time + 2)
+    assert len(d1.get(DHTID.generate('foo'))[0]) == 3
+    assert d1.store_subkey(DHTID.generate('foo'), 'subkey4', b'value4', time + 4)
+    assert len(d1.get(DHTID.generate('foo'))[0]) == 4
+
+    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA', time + 1) is False  # prev has better expiration
+    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA', time + 3)  # new value has better expiration
+    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyB', b'valueB', time + 4)  # new value has better expiration
+    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA+', time + 5)  # overwrite subkeyA under key bar
+    assert all(subkey in d1.get(DHTID.generate('bar'))[0] for subkey in ('subkeyA', 'subkeyB'))
+    assert len(d1.get(DHTID.generate('bar'))[0]) == 2 and d1.get(DHTID.generate('bar'))[1] == time + 5
+
+    assert d1.store(DHTID.generate('foo'), b'nothing', time + 3.5) is False  # previous value has better expiration
+    assert d1.get(DHTID.generate('foo'))[0].get('subkey2') == (b'value2', time + 3)
+    assert d1.store(DHTID.generate('foo'), b'nothing', time + 5) is True  # new value has better expiraiton
+    assert d1.get(DHTID.generate('foo')) == (b'nothing', time + 5)  # value should be replaced
+
+
 def test_localstorage_freeze():
 def test_localstorage_freeze():
-    d = LocalStorage(maxsize=2)
+    d = DHTLocalStorage(maxsize=2)
 
 
     with d.freeze():
     with d.freeze():
         d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01)
         d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01)
@@ -77,3 +110,20 @@ def test_localstorage_freeze():
         d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 3)  # key3 will push key1 out due to maxsize
         d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 3)  # key3 will push key1 out due to maxsize
         assert DHTID.generate("key1") in d
         assert DHTID.generate("key1") in d
     assert DHTID.generate("key1") not in d
     assert DHTID.generate("key1") not in d
+
+
+def test_localstorage_serialize():
+    d1 = DictionaryDHTValue()
+    d2 = DictionaryDHTValue()
+
+    now = get_dht_time()
+    d1.store('key1', b'ololo', now + 1)
+    d2.store('key2', b'pysh', now + 1)
+    d2.store('key3', b'pyshpysh', now + 2)
+
+    data = MSGPackSerializer.dumps([d1, d2, 123321])
+    assert isinstance(data, bytes)
+    new_d1, new_d2, new_value = MSGPackSerializer.loads(data)
+    assert isinstance(new_d1, DictionaryDHTValue) and isinstance(new_d2, DictionaryDHTValue) and new_value == 123321
+    assert 'key1' in new_d1 and len(new_d1) == 1
+    assert 'key1' not in new_d2 and len(new_d2) == 2 and new_d2.get('key3') == (b'pyshpysh', now + 2)

+ 0 - 2
tests/test_moe.py

@@ -15,8 +15,6 @@ def test_moe():
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
                            num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
                            num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
-        # declare expert uids. Server *should* declare them by itself, but it takes time.
-        assert all(dht.declare_experts(all_expert_uids, endpoint=server_endpoint))
 
 
         dmoe = hivemind.RemoteMixtureOfExperts(
         dmoe = hivemind.RemoteMixtureOfExperts(
             in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn')
             in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn')