Преглед на файлове

Faster beam search through DHT sub-keys (#107)

* Added support for dictionary-like DHT keys 
* New beam search based on dictionary new prefixes
* DHTNode now uses a more careful caching policy on store. If a value was rejected, the node will request new value to update its cache
* Add tests for dictionary value types (storage, protocol, node)
* LocalStorage is moved to a separate file and generalized for new value types
* Fixed minor bug where DHTProtocol.store claimed to return None but didn't
* More tests
justheuristic преди 4 години
родител
ревизия
9f9c4ac96b

+ 5 - 0
docs/modules/dht.rst

@@ -10,6 +10,11 @@ Here's a high level scheme of how these components interact with one another:
    :width: 640
    :align: center
 
+
+**Note:** hivemind.DHT is currently being updated to improve beam search latency
+(see `issue 92 <https://github.com/learning-at-home/hivemind/issues>`__). New functionality will be documented
+here by 2020.10.15 23:59:59 AOE (ping justheuristic for details).
+
 DHT and DHTNode
 ###############
 

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.4'
+__version__ = '0.8.5'

+ 146 - 13
hivemind/dht/__init__.py

@@ -14,11 +14,12 @@ The code is organized as follows:
 """
 import asyncio
 import ctypes
+import heapq
 import multiprocessing as mp
 import warnings
 from collections import deque, OrderedDict
 from concurrent.futures import ThreadPoolExecutor
-from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable
+from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque, Set
 
 import uvloop
 
@@ -67,7 +68,6 @@ class DHT(mp.Process):
     This beam search explores one additional dimension per step and finds k best experts from across the DHT
     in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts.
     """
-
     UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
     #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 
@@ -147,7 +147,9 @@ class DHT(mp.Process):
             expiration_time = get_dht_time()
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         response = await node.get_many(uids, expiration_time, num_workers=num_workers)
-        future.set_result([RemoteExpert(**expert_data) if maybe_expiration_time else None
+        # TODO expert_data['expert'] -> namedtuple with meaningful field names
+        future.set_result([RemoteExpert(*expert_data['expert'][0])
+                           if maybe_expiration_time else None and expert_data['expert'][1] is not None
                            for uid, (expert_data, maybe_expiration_time) in response.items()])
 
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
@@ -169,18 +171,148 @@ class DHT(mp.Process):
     async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         expiration_time = get_dht_time() + self.expiration
-
-        data_to_store = {}
-        for uid in uids:
+        unique_entries: Set[Tuple[str, str]] = set()
+        #                 prefix---v next_dim     uid  endpoint
+        data_to_store: List[Tuple[str, str, List[str, Endpoint]]] = []
+        for uid in uids:  # first k entries are expert uids themselves
+            data_to_store.append((uid, "expert", [uid, endpoint]))
+        for uid in uids:  # and then, add all prefixes
             uid_parts = uid.split(self.UID_DELIMITER)
-            for i in range(len(uid_parts)):
+            for i in range(len(uid_parts) - 1):
                 uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
-                data_to_store[uid_prefix_i] = {'uid': uid, 'endpoint': endpoint}
+                if (uid_prefix_i, uid_parts[i + 1]) in unique_entries:
+                    continue
+                unique_entries.add((uid_prefix_i, uid_parts[i + 1]))
+                data_to_store.append((uid_prefix_i, uid_parts[i + 1], [uid, endpoint]))
+
+        keys, subkeys, values = map(list, zip(*data_to_store))
+        store_ok = await node.store_many(keys, values, expiration_time, subkeys=subkeys, num_workers=num_workers)
+        if future is not None:
+            future.set_result([store_ok[key, subkey] for key, subkey in zip(keys, subkeys)])
+
+    def find_best_experts(self, prefix: str, grid_scores: Sequence[Sequence[float]], beam_size: int, *,
+                          return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]:
+        """
+        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
+
+        :param prefix: common prefix for all expert uids in grid
+        :param grid_scores: scores predicted for each dimension in the grid,
+        :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
+        :param beam_size: how many best experts should beam search return
+         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
+         Please note that any queries that fall outside the budget will still be performed in background and cached
+         for subsequent iterations as long as DHTNode.cache_locally is True
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :param kwargs: extra keyword parameters passed to DHTNode.get_many
+        :returns: a list that contains *up to* k_best RemoteExpert instances
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_find_best_experts', [], dict(prefix=prefix, grid_scores=list(map(tuple, grid_scores)),
+                                                       beam_size=beam_size, future=_future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _find_best_experts(
+            self, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
+            max_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]:
+        max_workers: Optional[int] = max_workers or self.max_workers or beam_size
+
+        # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
+        beam: List[Tuple[float, str, Dict[str, List[str, Endpoint]]]] = await self._get_initial_beam(
+            node, prefix, beam_size, grid_scores[0], num_workers=min(beam_size, max_workers))
+        if not beam:
+            logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
+            return []
+
+        for dim_index in range(1, len(grid_scores) - 1):
+            # select beam_size best suffixes from current beam
+            dim_scores = grid_scores[dim_index]
+            best_active_pairs: List[Tuple[float, str]] = heapq.nlargest(beam_size, (
+                (prefix_score + dim_scores[int(suffix_i)], f"{prefix}{self.UID_DELIMITER}{suffix_i}")
+                for prefix_score, prefix, suffixes in beam for suffix_i in suffixes.keys()
+                # TODO get rid of str.isdecimal
+                if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)))
+
+            # search DHT for next step suffixes
+            _, best_uid_prefixes = zip(*best_active_pairs)
+            # TODO Tuple[Dict[str, List[str, Endpoint]], DHTExpiration] -> namedtuple
+            dht_responses: Dict[str, Tuple[Dict[str, List[str, Endpoint]], DHTExpiration]] = await node.get_many(
+                keys=best_uid_prefixes, num_workers=min(len(best_uid_prefixes), max_workers), **kwargs)
+            if all(expiration is None for key, (_, expiration) in dht_responses.items()):
+                logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim {dim_index})")
+                break
+            beam = [(prefix_score, prefix, dht_responses[prefix][0])  # add suffix dict if it is found
+                    for prefix_score, prefix in best_active_pairs if dht_responses[prefix][1] is not None]
+
+        # select best experts from the final beam
+        dim_scores = grid_scores[-1]
+        final_best_pairs: List[Tuple[float, str, Endpoint]] = heapq.nlargest(beam_size, (
+            (prefix_score + dim_scores[int(suffix_i)], uid, endpoint)
+            for prefix_score, prefix, suffixes in beam for suffix_i, ((uid, endpoint), _) in suffixes.items()
+            if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)
+        ))
+        best_experts = [RemoteExpert(uid, endpoint) for score, uid, endpoint in final_best_pairs]
+        if future is not None:
+            future.set_result(best_experts)
+        return best_experts
+
+    def batch_find_best_experts(self, prefix: str, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, *,
+                                return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]:
+        """
+        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
+
+        :param prefix: common prefix for all expert uids in grid
+        :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
+        :type batch_grid_scores: model scores for each example and each grid dimension,  list of arrays of shape (batch_size, grid_size[i])
+        :param beam_size: how many best experts should beam search return
+         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
+         Please note that any queries that fall outside the budget will still be performed in background and cached
+         for subsequent iterations as long as DHTNode.cache_locally is True
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :param kwargs: extra keyword parameters passed to DHTNode.get_many
+        :returns: a list that contains *up to* k_best RemoteExpert instances
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_batch_find_best_experts', [], dict(prefix=prefix, batch_grid_scores=batch_grid_scores,
+                                                             beam_size=beam_size, future=_future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _batch_find_best_experts(
+            self, node: DHTNode, prefix: str, batch_grid_scores: Sequence[Sequence[Tuple[float]]], beam_size: int,
+            max_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[List[RemoteExpert]]:
+
+        batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))]
+        coros = [self._find_best_experts(node, prefix, grid_scores, beam_size, max_workers, **kwargs) for grid_scores in batch_grid_scores]
 
-        store_keys, store_values = zip(*data_to_store.items())
-        store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers)
+        best_experts_batch = await asyncio.gather(*coros)
         if future is not None:
-            future.set_result([store_ok[key] for key in data_to_store.keys()])
+            future.set_result(best_experts_batch)
+        return best_experts_batch
+
+    async def _get_initial_beam(self, node, prefix: str, beam_size: int, scores: Tuple[float, ...], num_workers: int
+                                ) -> List[Tuple[float, str, Dict[str, List[str]]]]:
+        """ Fetch a list of all active level-one prefixes of a given prefix. Used for beam search """
+        beam: List[Tuple[float, str, Dict[str, List[str, Endpoint]]]] = []  # results will be stored here
+        unattempted_indices: List[int] = sorted(range(len(scores)), key=scores.__getitem__)  # order: worst to best
+        pending_tasks: Deque[Tuple[int, str, asyncio.Task]] = deque()  # up to num_workers concurrent get tasks
+
+        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
+            # dispatch additional tasks
+            while unattempted_indices and len(pending_tasks) < num_workers:
+                next_index = unattempted_indices.pop()  # note: this is best unattempted index because of sort order
+                next_best_prefix = f"{prefix}{self.UID_DELIMITER}{next_index}"
+                pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
+
+            # await the next best prefix to be fetched
+            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
+            try:
+                maybe_prefix_data, maybe_expiration_time = await pending_task
+                if maybe_expiration_time is not None:
+                    beam.append((scores[pending_best_index], pending_best_prefix, maybe_prefix_data))
+            except asyncio.CancelledError:
+                for _, pending_task in pending_tasks:
+                    pending_task.cancel()
+                raise
+        return beam
 
     def first_k_active(
             self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None,
@@ -196,6 +328,7 @@ class DHT(mp.Process):
         :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
             The keys in the returned dict are ordered same as in uid_prefixes.
         """
+        logger.warning("first_k_active is deprecated and will be removed in 0.8.6")
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
@@ -220,8 +353,8 @@ class DHT(mp.Process):
             response = await pending_tasks.popleft()
             for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
                 maybe_expert_data, maybe_expiration_time = response[uid_prefix]
-                if maybe_expiration_time is not None:  # found active peer
-                    found.append((uid_prefix, RemoteExpert(**maybe_expert_data)))
+                if maybe_expiration_time is not None and len(maybe_expert_data) > 0:  # found active peer
+                    found.append((uid_prefix, RemoteExpert(*next(iter(maybe_expert_data.values()))[0])))
                     # if we found enough active experts, finish immediately
                     if len(found) >= k:
                         break

+ 147 - 110
hivemind/dht/node.py

@@ -1,17 +1,18 @@
 from __future__ import annotations
 
 import asyncio
-
 import random
 from collections import defaultdict
 from dataclasses import dataclass, field
-from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any, Iterable
-from sortedcontainers import SortedList
 from functools import partial
+from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
 from warnings import warn
 
-from hivemind.dht.protocol import DHTProtocol, LocalStorage
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue
+from sortedcontainers import SortedList
+
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
+from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
 from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
 
@@ -39,7 +40,7 @@ class DHTNode:
         nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
         This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
 
-    Formally, DHTNode follows the following contract:
+    A DHTNode follows the following contract:
 
     - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
       IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
@@ -53,9 +54,8 @@ class DHTNode:
     # fmt:off
     node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
     refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
-    cache_refresh_available: asyncio.Event; cache_refresh_queue: LocalStorage
-    reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_IntermediateResult]]
-    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
+    cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]]
+    cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
     # fmt:on
 
     @classmethod
@@ -64,8 +64,8 @@ class DHTNode:
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
-            reuse_get_requests: bool = True, num_workers: int = 1, listen: bool = True,
-            listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
+            cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1,
+            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
@@ -81,6 +81,7 @@ class DHTNode:
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
         :param cache_locally: if True, caches all values (stored or found) in a node-local cache
+        :param cache_on_store: if True, update cache entries for a key after storing a new item for that key
         :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
           nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
         :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
@@ -96,10 +97,6 @@ class DHTNode:
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
         :param kwargs: extra parameters used in grpc.aio.server
         """
-        if cache_refresh_before_expiry > 0 and not cache_locally:
-            logger.warning("If cache_locally is False, cache_refresh_before_expiry has no effect. To silence this"
-                           " warning, please specify cache_refresh_before_expiry=0")
-
         self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.num_replicas, self.num_workers = num_replicas, num_workers
@@ -110,12 +107,11 @@ class DHTNode:
 
         # caching policy
         self.refresh_timeout = refresh_timeout
-        self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
+        self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
         self.cache_refresh_before_expiry = cache_refresh_before_expiry
-        self.cache_refresh_queue = LocalStorage()
-        self.cache_refresh_available = asyncio.Event()
-        if cache_refresh_before_expiry:
-            asyncio.create_task(self._refresh_stale_cache_entries())
+        self.cache_refresh_queue = CacheRefreshQueue()
+        self.cache_refresh_evt = asyncio.Event()
+        self.cache_refresh_task = None
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
                                                  parallel_rpc, cache_size, listen, listen_on, **kwargs)
@@ -211,25 +207,27 @@ class DHTNode:
             nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
         return nearest_nodes_with_endpoints
 
-    async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, **kwargs) -> bool:
+    async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
+                    subkey: Optional[Subkey] = None, **kwargs) -> bool:
         """
         Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
-
         :note: store is a simplified interface to store_many, all kwargs are be forwarded there
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
-        store_ok = await self.store_many([key], [value], [expiration_time], **kwargs)
-        return store_ok[key]
+        store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)
+        return store_ok[(key, subkey) if subkey is not None else key]
 
     async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
                          expiration_time: Union[DHTExpiration, List[DHTExpiration]],
+                         subkeys: Optional[Union[Subkey, List[Optional[Subkey]]]] = None,
                          exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
         """
-        Traverse DHT to find up to best nodes to store multiple (key, value, expiration_time) pairs.
+        Traverse DHT to find up :num_replicas: to best nodes to store multiple (key, value, expiration_time) pairs.
 
         :param keys: arbitrary serializable keys associated with each value
         :param values: serializable "payload" for each key
         :param expiration_time: either one expiration time for all keys or individual expiration times (see class doc)
+        :param subkeys: an optional list of same shape as keys. If specified, this
         :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
         :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
         :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
@@ -239,24 +237,23 @@ class DHTNode:
         """
         if isinstance(expiration_time, DHTExpiration):
             expiration_time = [expiration_time] * len(keys)
-        assert len(keys) == len(values) == len(expiration_time), "Number of keys, values and expiration doesn't match."
+        if subkeys is None or isinstance(subkeys, Subkey):
+            subkeys = [subkeys] * len(keys)
 
-        key_ids = list(map(DHTID.generate, keys))
-        id_to_original_key = dict(zip(key_ids, keys))
-        binary_values_by_key_id = {key_id: self.serializer.dumps(value) for key_id, value in zip(key_ids, values)}
-        expiration_by_key_id = {key_id: expiration_time for key_id, expiration_time in zip(key_ids, expiration_time)}
-        unfinished_key_ids = set(key_ids)  # we use this set to ensure that each store request is finished
+        assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
+            "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
 
-        store_ok = {key: False for key in keys}  # outputs, updated during search
-        store_finished_events = {key: asyncio.Event() for key in keys}
+        key_id_to_data: DefaultDict[DHTID, List[Tuple[DHTKey, Subkey, DHTValue, DHTExpiration]]] = defaultdict(list)
+        for key, subkey, value, expiration in zip(keys, subkeys, values, expiration_time):
+            key_id_to_data[DHTID.generate(source=key)].append((key, subkey, value, expiration))
 
-        if self.cache_locally:
-            for key_id in key_ids:
-                self.protocol.cache.store(key_id, binary_values_by_key_id[key_id], expiration_by_key_id[key_id])
+        unfinished_key_ids = set(key_id_to_data.keys())  # use this set to ensure that each store request is finished
+        store_ok = {(key, subkey): None for key, subkey in zip(keys, subkeys)}  # outputs, updated during search
+        store_finished_events = {(key, subkey): asyncio.Event() for key, subkey in zip(keys, subkeys)}
 
         # pre-populate node_to_endpoint
         node_to_endpoint: Dict[DHTID, Endpoint] = dict()
-        for key_id in key_ids:
+        for key_id in unfinished_key_ids:
             node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
@@ -272,48 +269,73 @@ class DHTNode:
             pending_store_tasks = set()
             store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
                                       key=key_id.xor_distance, reverse=True)  # ordered so that .pop() returns nearest
+            [original_key, *_], current_subkeys, current_values, current_expirations = zip(*key_id_to_data[key_id])
+            binary_values: List[bytes] = list(map(self.protocol.serializer.dumps, current_values))
 
             while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
-                # spawn enough tasks to cover all replicas
                 while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
                     node_id: DHTID = store_candidates.pop()  # nearest untried candidate
+
                     if node_id == self.node_id:
-                        self.protocol.storage.store(key_id, binary_values_by_key_id[key_id],
-                                                    expiration_by_key_id[key_id])
-                        store_ok[id_to_original_key[key_id]] = True
                         num_successful_stores += 1
-                        if not await_all_replicas:
-                            store_finished_events[id_to_original_key[key_id]].set()
-
+                        for subkey, value, expiration_time in zip(current_subkeys, binary_values, current_expirations):
+                            store_ok[original_key, subkey] = self.protocol.storage.store(
+                                key_id, value, expiration_time, subkey=subkey)
+                            if not await_all_replicas:
+                                store_finished_events[original_key, subkey].set()
                     else:
                         pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
-                            node_to_endpoint[node_id], [key_id], [binary_values_by_key_id[key_id]],
-                            [expiration_by_key_id[key_id]])))
+                            node_to_endpoint[node_id], keys=[key_id] * len(current_values), values=binary_values,
+                            expiration_time=current_expirations, subkeys=current_subkeys)))
 
                 # await nearest task. If it fails, dispatch more on the next iteration
                 if pending_store_tasks:
                     finished_store_tasks, pending_store_tasks = await asyncio.wait(
                         pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
                     for task in finished_store_tasks:
-                        if task.result()[0]:  # if store succeeded
-                            store_ok[id_to_original_key[key_id]] = True
+                        if task.result() is not None:
                             num_successful_stores += 1
-                            if not await_all_replicas:
-                                store_finished_events[id_to_original_key[key_id]].set()
+                            for subkey, store_status in zip(current_subkeys, task.result()):
+                                store_ok[original_key, subkey] = store_status
+                                if not await_all_replicas:
+                                    store_finished_events[original_key, subkey].set()
 
-            store_finished_events[id_to_original_key[key_id]].set()
+            if self.cache_on_store:
+                self._update_cache_on_store(key_id, current_subkeys, binary_values, current_expirations,
+                                            store_ok=[store_ok[original_key, subkey] for subkey in current_subkeys])
+
+            for subkey, value_bytes, expiration in zip(current_subkeys, binary_values, current_expirations):
+                store_finished_events[original_key, subkey].set()
 
         store_task = asyncio.create_task(self.find_nearest_nodes(
-            queries=set(key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
+            queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
             found_callback=on_found, exclude_self=exclude_self, **kwargs))
         try:
             await asyncio.wait([evt.wait() for evt in store_finished_events.values()])  # wait for items to be stored
             assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
-            return store_ok
+            return {(key, subkey) if subkey else key: status or False for (key, subkey), status in store_ok.items()}
         except asyncio.CancelledError as e:
             store_task.cancel()
             raise e
 
+    def _update_cache_on_store(self, key_id: DHTID, subkeys: List[Subkey], binary_values: List[bytes],
+                               expirations: List[DHTExpiration], store_ok: List[bool]):
+        """ Update local cache after finishing a store for one key (with perhaps several subkeys) """
+        store_succeeded = any(store_ok)
+        is_dictionary = any(subkey is not None for subkey in subkeys)
+        if store_succeeded and not is_dictionary:  # stored a new regular value, cache it!
+            stored_value_bytes, stored_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
+            self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration)
+        elif not store_succeeded and not is_dictionary:  # store rejected, check if local cache is also obsolete
+            rejected_value, rejected_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
+            self.protocol.cache.store(key_id, rejected_value, rejected_expiration)  # can still be better than cache
+            if (self.protocol.cache.get(key_id)[1] or float("inf")) <= rejected_expiration:  # cache would be rejected
+                self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
+        else:  # stored a dictionary (or failed to store), either way, there can be other keys and we should update
+            for subkey, stored_value_bytes, expiration_time in zip(subkeys, binary_values, expirations):
+                self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
+            self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
+
     async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
         """
         Search for a key across DHT and return either first or latest entry.
@@ -350,8 +372,8 @@ class DHTNode:
     async def get_many_by_id(
             self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
             num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
-            _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
-                                                      Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
+            _is_refresh=False) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
+                                                    Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
         """
         Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
 
@@ -364,17 +386,17 @@ class DHTNode:
         :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
          The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
          Note: canceling a future will stop search for the corresponding key
-        :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh
+        :param _is_refresh: internal flag, set to True by an internal cache refresher (if enabled)
         :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
         :note: in order to check if get returned a value, please check (expiration_time is None)
         """
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
         num_workers = num_workers if num_workers is not None else self.num_workers
-        search_results: Dict[DHTID, _IntermediateResult] = {key_id: _IntermediateResult(
-            key_id, sufficient_expiration_time, serializer=self.serializer) for key_id in key_ids}
+        search_results: Dict[DHTID, _SearchState] = {key_id: _SearchState(
+            key_id, sufficient_expiration_time, serializer=self.protocol.serializer) for key_id in key_ids}
 
-        if _refresh_cache:
+        if not _is_refresh:  # if we're already refreshing cache, there's no need to trigger subsequent refreshes
             for key_id in key_ids:
                 search_results[key_id].add_done_callback(self._trigger_cache_refresh)
 
@@ -387,7 +409,8 @@ class DHTNode:
         # stage 1: check for value in this node's local storage and cache
         for key_id in key_ids:
             search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
-            search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
+            if not _is_refresh:
+                search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
 
         # stage 2: traverse the DHT to get the remaining keys from remote peers
         unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
@@ -414,7 +437,7 @@ class DHTNode:
         # V-- this function will be called exactly once when traverse_dht finishes search for a given key
         async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
             search_results[key_id].finish_search()  # finish search whether or we found something
-            self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint)
+            self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh)
 
         asyncio.create_task(traverse_dht(
             queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
@@ -433,9 +456,9 @@ class DHTNode:
                     search_result.future.cancel()
                 raise e
 
-    def _reuse_finished_search_result(self, finished: _IntermediateResult):
+    def _reuse_finished_search_result(self, finished: _SearchState):
         expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
-        concurrent_requests: SortedList[_IntermediateResult] = self.pending_get_requests[finished.key_id]
+        concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
         # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
         while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
             concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
@@ -443,66 +466,72 @@ class DHTNode:
             concurrent_requests[-1].finish_search()
             concurrent_requests.pop(-1)
 
-    def _trigger_cache_refresh(self, result: _IntermediateResult):
+    def _trigger_cache_refresh(self, search: _SearchState):
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
-        if result.found_something and result.source_node_id == self.node_id:
-            with self.protocol.cache.freeze():  # do not clear outdated cache for now...
-                if self.cache_refresh_before_expiry and result.key_id in self.protocol.cache:
-                    previous_earliest_item: Tuple[DHTID, BinaryDHTValue, DHTExpiration] = self.cache_refresh_queue.top()
-                    self.cache_refresh_queue.store(result.key_id, result.binary_value, result.expiration_time)
-                    if previous_earliest_item is None or result.expiration_time < previous_earliest_item[-1]:
-                        self.cache_refresh_available.set()  # if we new element is now earliest, notify the cache queue
+        if search.found_something and search.source_node_id == self.node_id:
+            if self.cache_refresh_before_expiry and search.key_id in self.protocol.cache:
+                self._schedule_for_refresh(search.key_id, search.expiration_time - self.cache_refresh_before_expiry)
+
+    def _schedule_for_refresh(self, key_id: DHTID, refresh_time: DHTExpiration):
+        """ Add key to a refresh queue, refresh at :refresh_time: or later """
+        if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
+            self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
+            logger.debug("Spawned cache refresh task.")
+        previous_earliest_item: Tuple[DHTID, Any, DHTExpiration] = self.cache_refresh_queue.top()
+        if previous_earliest_item is None or refresh_time < previous_earliest_item[-1]:
+            self.cache_refresh_evt.set()  # if we new element is now earliest, notify the cache queue
+        self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
 
     async def _refresh_stale_cache_entries(self):
         """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
         while self.is_alive:
-            with self.cache_refresh_queue.freeze():
-                while len(self.cache_refresh_queue) == 0:
-                    await self.cache_refresh_available.wait()
-                    self.cache_refresh_available.clear()
-                key_id, _, nearest_expiration = self.cache_refresh_queue.top()
+            while len(self.cache_refresh_queue) == 0:
+                await self.cache_refresh_evt.wait()
+                self.cache_refresh_evt.clear()
+            key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
 
             try:
                 # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
-                time_to_wait = nearest_expiration - get_dht_time() - self.cache_refresh_before_expiry
-                await asyncio.wait_for(self.cache_refresh_available.wait(), timeout=time_to_wait)
+                time_to_wait = nearest_refresh_time - get_dht_time()
+                await asyncio.wait_for(self.cache_refresh_evt.wait(), timeout=time_to_wait)
                 # note: the line above will cause TimeoutError when we are ready to refresh cache
-                self.cache_refresh_available.clear()  # no timeout error => someone added new entry to queue and ...
+                self.cache_refresh_evt.clear()  # no timeout error => someone added new entry to queue and ...
                 continue  # ... and this element is earlier than nearest_expiration. we should refresh this entry first
 
             except asyncio.TimeoutError:  # caught TimeoutError => it is time to refresh the most recent cached entry
                 # step 2: find all keys that we should already refresh and remove them from queue
-                with self.cache_refresh_queue.freeze():
-                    keys_to_refresh = {key_id}
+                current_time = get_dht_time()
+                keys_to_refresh = {key_id}
+                max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
+                del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
+                while self.cache_refresh_queue:
+                    key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
+                    if nearest_refresh_time > current_time:
+                        break
                     del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
-                    while self.cache_refresh_queue:
-                        key_id, _, nearest_expiration = self.cache_refresh_queue.top()
-                        if nearest_expiration > get_dht_time() + self.cache_refresh_before_expiry:
-                            break
-                        del self.cache_refresh_queue[key_id]  # we pledge to refresh this key_id in the nearest batch
-                        keys_to_refresh.add(key_id)
+                    keys_to_refresh.add(key_id)
+                    max_expiration_time = max(max_expiration_time, self.protocol.cache.get(key_id)[1] or current_time)
 
                 # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
-                await self.get_many_by_id(
-                    keys_to_refresh, sufficient_expiration_time=nearest_expiration + self.cache_refresh_before_expiry,
-                    _refresh_cache=False)  # if we found value locally, we shouldn't trigger another refresh
+                sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
+                await self.get_many_by_id(keys_to_refresh, sufficient_expiration_time, _is_refresh=True)
 
-    def _cache_new_result(self, result: _IntermediateResult, nearest_nodes: List[DHTID],
-                          node_to_endpoint: Dict[DHTID, Endpoint]):
+    def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID],
+                          node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
         """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
-        if result.found_something:
-            previous_expiration_time = max(self.protocol.storage.get(result.key_id)[1] or -float('inf'),
-                                           self.protocol.cache.get(result.key_id)[1] or -float('inf'))
-            if result.expiration_time > previous_expiration_time:  # if this value has better expiration
-                if self.cache_locally:
-                    self.protocol.cache.store(result.key_id, result.binary_value, result.expiration_time)
+        if search.found_something:
+            previous_expiration_time = max(self.protocol.storage.get(search.key_id)[1] or -float('inf'),
+                                           self.protocol.cache.get(search.key_id)[1] or -float('inf'))
+            if search.expiration_time > previous_expiration_time:  # if this value has better expiration
+                if self.cache_locally or _is_refresh:
+                    self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time)
                 if self.cache_nearest:
                     num_cached_nodes = 0
                     for node_id in nearest_nodes:
-                        if node_id == result.source_node_id:
+                        if node_id == search.source_node_id:
                             continue
                         asyncio.create_task(self.protocol.call_store(
-                            node_to_endpoint[node_id], [result.key_id], [result.binary_value], [result.expiration_time],
+                            node_to_endpoint[node_id], [search.key_id], [search.binary_value], [search.expiration_time],
                             in_cache=True))
                         num_cached_nodes += 1
                         if num_cached_nodes >= self.cache_nearest:
@@ -523,11 +552,11 @@ class DHTNode:
 
 
 @dataclass(init=True, repr=True, frozen=False, order=False)
-class _IntermediateResult:
+class _SearchState:
     """ A helper class that stores current-best GET results with metadata """
     key_id: DHTID
     sufficient_expiration_time: DHTExpiration
-    binary_value: Optional[BinaryDHTValue] = None
+    binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
     expiration_time: Optional[DHTExpiration] = None  # best expiration time so far
     source_node_id: Optional[DHTID] = None  # node that gave us the value
     future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
@@ -540,25 +569,33 @@ class _IntermediateResult:
             if self.expiration_time >= self.sufficient_expiration_time:
                 self.finish_search()
 
-    def add_done_callback(self, callback: Callable[[_IntermediateResult], Any]):
-        """ Add callback that will be called when _IntermediateSearchResult is done (found OR cancelled by user) """
+    def add_done_callback(self, callback: Callable[[_SearchState], Any]):
+        """ Add callback that will be called when _SearchState is done (found OR cancelled by user) """
         self.future.add_done_callback(lambda _future: callback(self))
 
     def finish_search(self):
         if self.future.done():
-            return  # either user cancelled our result or someone sent it before us. Nothing more to do here.
-        deserialized_value = self.serializer.loads(self.binary_value) if self.found_something else None
-        self.future.set_result((deserialized_value, self.expiration_time))
+            return  # either user cancelled our search or someone sent it before us. Nothing more to do here.
+        elif not self.found_something:
+            self.future.set_result((None, None))
+        elif isinstance(self.binary_value, BinaryDHTValue):
+            self.future.set_result((self.serializer.loads(self.binary_value), self.expiration_time))
+        elif isinstance(self.binary_value, DictionaryDHTValue):
+            dict_value = {key: (self.serializer.loads(value), item_expiration_time)
+                          for key, value, item_expiration_time in self.binary_value.items()}
+            self.future.set_result((dict_value, self.expiration_time))
+        else:
+            logger.error(f"Invalid value type: {type(self.binary_value)}")
 
     @property
     def found_something(self) -> bool:
-        """ Whether or not we have at least some result, regardless of its expiration time """
+        """ Whether or not we have found at least some value, regardless of its expiration time """
         return self.expiration_time is not None
 
     @property
     def finished(self) -> bool:
         return self.future.done()
 
-    def __lt__(self, other: _IntermediateResult):
-        """ _IntermediateResult instances will be sorted by their target expiration time """
+    def __lt__(self, other: _SearchState):
+        """ _SearchState instances will be sorted by their target expiration time """
         return self.sufficient_expiration_time < other.sufficient_expiration_time

+ 75 - 130
hivemind/dht/protocol.py

@@ -2,26 +2,28 @@
 from __future__ import annotations
 
 import asyncio
-import heapq
-from contextlib import contextmanager
-from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
+from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 from warnings import warn
 
 import grpc
 import grpc.experimental.aio
 
-from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
+from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
+from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
-from hivemind.utils import Endpoint, get_logger, replace_port
+from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer
 
 logger = get_logger(__name__)
+NOT_FOUND_VALUE, NOT_FOUND_EXPIRATION, IS_REGULAR_VALUE, IS_DICTIONARY = b'', -float('inf'), '', '___DictionaryDHTValue'
+RESERVED_SUBKEYS = {IS_REGULAR_VALUE, IS_DICTIONARY}
 
 
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
-    storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
+    storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
+    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     # fmt:on
 
     @classmethod
@@ -44,7 +46,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self = cls(_initialized_with_create=True)
         self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
         self.wait_timeout, self.channel_options = wait_timeout, channel_options
-        self.storage, self.cache = LocalStorage(), LocalStorage(maxsize=cache_size)
+        self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
 
@@ -110,30 +112,46 @@ class DHTProtocol(dht_grpc.DHTServicer):
             asyncio.create_task(self.update_routing_table(sender_id, rpc_endpoint))
         return self.node_info
 
-    async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
+    async def call_store(self, peer: Endpoint, keys: Sequence[DHTID],
+                         values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
                          expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
-                         in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Sequence[bool]:
+                         subkeys: Optional[Union[Subkey, Sequence[Optional[Subkey]]]] = None,
+                         in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Optional[List[bool]]:
         """
         Ask a recipient to store several (key, value : expiration_time) items or update their older value
 
         :param peer: request this peer to store the data
         :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
         :param values: a list of N serialized values (bytes) for each respective key
-        :param expiration_time: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
+        :param expiration_time: a list of N expiration timestamps for each respective key-value pair(see get_dht_time())
+        :param subkeys: a list of N optional sub-keys. If None, stores value normally. If not subkey is not None:
+          1) if local storage doesn't have :key:, create a new dictionary {subkey: (value, expiration_time)}
+          2) if local storage already has a dictionary under :key:, try add (subkey, value, exp_time) to that dictionary
+          2) if local storage associates :key: with a normal value with smaller expiration, clear :key: and perform (1)
+          3) finally, if local storage currently associates :key: with a normal value with larger expiration, do nothing
         :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
         :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
          until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
-
         :return: list of [True / False] True = stored, False = failed (found newer value or no response)
-         if peer did not respond (e.g. due to timeout or congestion), returns None
+                 if peer did not respond (e.g. due to timeout or congestion), returns None
         """
         if isinstance(expiration_time, DHTExpiration):
             expiration_time = [expiration_time] * len(keys)
+        if subkeys is None or isinstance(subkeys, Subkey):
+            subkeys = [subkeys] * len(keys)
+
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
-        keys, values, expiration_time, in_cache = map(list, [keys, values, expiration_time, in_cache])
+        keys, subkeys, values, expiration_time, in_cache = map(list, [keys, subkeys, values, expiration_time, in_cache])
+        for i in range(len(keys)):
+            if subkeys[i] is None:  # add default sub-key if not specified
+                subkeys[i] = IS_REGULAR_VALUE if not isinstance(values[i], DictionaryDHTValue) else IS_DICTIONARY
+            if isinstance(values[i], DictionaryDHTValue):
+                assert subkeys[i] == IS_DICTIONARY, "Please do not specify subkey when storing an entire dictionary"
+                values[i] = self.serializer.dumps(values[i])
+
         assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
-        store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
+        store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), subkeys=subkeys, values=values,
                                              expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
         try:
             async with self.rpc_semaphore:
@@ -145,7 +163,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
-            return [False] * len(keys)
+            return None
 
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
         """ Some node wants us to store this (key, value) pair """
@@ -153,10 +171,19 @@ class DHTProtocol(dht_grpc.DHTServicer):
             asyncio.create_task(self.rpc_ping(request.peer, context))
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
-        for key_bytes, value_bytes, expiration_time, in_cache in zip(
-                request.keys, request.values, request.expiration_time, request.in_cache):
-            local_memory = self.cache if in_cache else self.storage
-            response.store_ok.append(local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time))
+        keys = map(DHTID.from_bytes, request.keys)
+        for key_id, subkey, value_bytes, expiration_time, in_cache in zip(
+                keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
+            storage = self.cache if in_cache else self.storage
+            if subkey == IS_REGULAR_VALUE:  # store normal value without subkeys
+                response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
+            elif subkey == IS_DICTIONARY:   # store an entire dictionary with pre-existing subkeys
+                value_dictionary = self.serializer.loads(value_bytes)
+                assert isinstance(value_dictionary, DictionaryDHTValue)
+                response.store_ok.append(all(storage.store_subkey(key_id, subkey, subvalue, subkey_expiration)
+                                             for subkey, subvalue, subkey_expiration in value_dictionary.items()))
+            else:  # add new entry into an existing dictionary-like value or create a new dictionary with one sub-key
+                response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
         return response
 
     async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \
@@ -179,16 +206,19 @@ class DHTProtocol(dht_grpc.DHTServicer):
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
-            assert len(response.values) == len(response.expiration_time) == len(response.nearest) == len(keys), \
-                "DHTProtocol: response is not aligned with keys and/or expiration times"
-
-            output = {}  # unpack data without special NOT_FOUND_* values
-            for key, value, expiration_time, nearest in zip(
-                    keys, response.values, response.expiration_time, response.nearest):
-                value = value if value != _NOT_FOUND_VALUE else None
-                expiration_time = expiration_time if expiration_time != _NOT_FOUND_EXPIRATION else None
-                nearest = dict(zip(map(DHTID.from_bytes, nearest.node_ids), nearest.endpoints))
-                output[key] = (value, expiration_time, nearest)
+            assert len(keys) == len(response.results), "DHTProtocol: response is not aligned with keys"
+
+            output = {}  # unpack data depending on its type
+            for key, result in zip(keys, response.results):
+                nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
+                if result.type == dht_pb2.NOT_FOUND:
+                    output[key] = None, None, nearest
+                elif result.type == dht_pb2.FOUND_REGULAR:
+                    output[key] = result.value, result.expiration_time, nearest
+                elif result.type == dht_pb2.FOUND_DICTIONARY:
+                    output[key] = self.serializer.loads(result.value), result.expiration_time, nearest
+                else:
+                    logger.error(f"Unknown result type: {result.type}")
             return output
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
@@ -201,24 +231,27 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(request.peer, context))
-
-        response = dht_pb2.FindResponse(values=[], expiration_time=[], nearest=[], peer=self.node_info)
-        for key_id in map(DHTID.from_bytes, request.keys):
+        response = dht_pb2.FindResponse(results=[], peer=self.node_info)
+        for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
             maybe_value, maybe_expiration_time = self.storage.get(key_id)
             cached_value, cached_expiration_time = self.cache.get(key_id)
             if (cached_expiration_time or -float('inf')) > (maybe_expiration_time or -float('inf')):
                 maybe_value, maybe_expiration_time = cached_value, cached_expiration_time
 
-            nearest_neighbors = self.routing_table.get_nearest_neighbors(
-                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id))
-            if nearest_neighbors:
-                peer_ids, endpoints = zip(*nearest_neighbors)
-            else:
-                peer_ids, endpoints = [], []
-
-            response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
-            response.expiration_time.append(maybe_expiration_time if maybe_expiration_time else _NOT_FOUND_EXPIRATION)
-            response.nearest.append(dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)), endpoints=endpoints))
+            if maybe_expiration_time is None:  # value not found
+                item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND)
+            elif isinstance(maybe_value, DictionaryDHTValue):
+                item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_value),
+                                          expiration_time=maybe_value.latest_expiration_time)
+            else:  # found regular value
+                item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_value,
+                                          expiration_time=maybe_expiration_time)
+
+            for node_id, endpoint in self.routing_table.get_nearest_neighbors(
+                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
+                item.nearest_node_ids.append(node_id.to_bytes())
+                item.nearest_endpoints.append(endpoint)
+            response.results.append(item)
         return response
 
     async def update_routing_table(self, node_id: Optional[DHTID], peer_endpoint: Endpoint, responded=True):
@@ -256,91 +289,3 @@ class DHTProtocol(dht_grpc.DHTServicer):
         else:  # we sent outgoing request and peer did not respond
             if node_id is not None and node_id in self.routing_table:
                 del self.routing_table[node_id]
-
-
-_NOT_FOUND_VALUE, _NOT_FOUND_EXPIRATION = b'', -float('inf')  # internal values to represent that a value was not found
-
-
-class LocalStorage:
-    """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration_time) """
-
-    def __init__(self, maxsize: Optional[int] = None):
-        self.cache_size = maxsize or float("inf")
-        self.data: Dict[DHTID, Tuple[BinaryDHTValue, DHTExpiration]] = dict()
-        self.expiration_heap: List[Tuple[DHTExpiration, DHTID]] = []
-        self.key_to_heap: Dict[DHTID, Tuple[DHTExpiration, DHTID]] = dict()
-        self.frozen = False  # if True, do not remove outdated elements
-
-    def _remove_outdated(self):
-        while not self.frozen and self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
-                                                            or len(self.expiration_heap) > self.cache_size):
-            heap_entry = heapq.heappop(self.expiration_heap)
-            key = heap_entry[1]
-            if self.key_to_heap.get(key) == heap_entry:
-                del self.data[key], self.key_to_heap[key]
-
-    def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
-        """
-        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
-        :returns: True if new value was stored, False it was rejected (current value is newer)
-        """
-        if expiration_time < get_dht_time() and not self.frozen:
-            return False
-        self.key_to_heap[key] = (expiration_time, key)
-        heapq.heappush(self.expiration_heap, (expiration_time, key))
-        if key in self.data:
-            if self.data[key][1] < expiration_time:
-                self.data[key] = (value, expiration_time)
-                return True
-            return False
-        self.data[key] = (value, expiration_time)
-        self._remove_outdated()
-        return True
-
-    def get(self, key: DHTID) -> (Optional[BinaryDHTValue], Optional[DHTExpiration]):
-        """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
-        self._remove_outdated()
-        if key in self.data:
-            return self.data[key]
-        return None, None
-
-    def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
-        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
-        self._remove_outdated()
-        return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())
-
-    def top(self) -> Optional[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
-        """ Return the entry with earliest expiration or None if there isn't any """
-        self._remove_outdated()
-        if self.data:
-            top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
-            while self.key_to_heap.get(top_key) != top_entry:
-                heapq.heappop(self.expiration_heap)  # skip leftover "ghost" entries until first real entry
-                top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
-            value, expiration = self.data[top_key]
-            return top_key, value, expiration
-
-    def __contains__(self, key: DHTID):
-        self._remove_outdated()
-        return key in self.data
-
-    def __len__(self):
-        self._remove_outdated()
-        return len(self.data)
-
-    def __delitem__(self, key: DHTID):
-        if key in self.key_to_heap:
-            del self.data[key], self.key_to_heap[key]
-        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
-
-    def __bool__(self):
-        return bool(self.data)
-
-    @contextmanager
-    def freeze(self):
-        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
-        prev_frozen, self.frozen = self.frozen, True
-        try:
-            yield self
-        finally:
-            self.frozen = prev_frozen

+ 1 - 1
hivemind/dht/routing.py

@@ -12,7 +12,7 @@ from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
 
 from hivemind.utils import Endpoint, PickleSerializer
 
-DHTKey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, float, bytes, bytes  # flavour types
+DHTKey, Subkey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, str, Any, float, bytes, bytes
 get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
 
 

+ 161 - 0
hivemind/dht/storage.py

@@ -0,0 +1,161 @@
+from __future__ import annotations
+import heapq
+from contextlib import contextmanager
+from typing import Generic, Optional, Dict, Tuple, List, Iterator, TypeVar, Union, Any
+
+from hivemind.dht.routing import DHTID, DHTExpiration, get_dht_time, BinaryDHTValue, Subkey
+from hivemind.utils.serializer import MSGPackSerializer
+
+KeyType = TypeVar('KeyType')
+ValueType = TypeVar('ValueType')
+
+
+class TimedStorage(Generic[KeyType, ValueType]):
+    """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
+    frozen = False  # can be set to True. If true, do not remove outdated elements
+
+    def __init__(self, maxsize: Optional[int] = None):
+        self.maxsize = maxsize or float("inf")
+        self.data: Dict[KeyType, Tuple[ValueType, DHTExpiration]] = dict()
+        self.expiration_heap: List[Tuple[DHTExpiration, KeyType]] = []
+        self.key_to_heap: Dict[KeyType, Tuple[DHTExpiration, KeyType]] = dict()
+
+    def _remove_outdated(self):
+        while not self.frozen and self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
+                                                            or len(self.expiration_heap) > self.maxsize):
+            heap_entry = heapq.heappop(self.expiration_heap)
+            key = heap_entry[1]
+            if self.key_to_heap.get(key) == heap_entry:
+                del self.data[key], self.key_to_heap[key]
+
+    def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
+        """
+        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
+        :returns: True if new value was stored, False it was rejected (current value is newer)
+        """
+        if expiration_time < get_dht_time() and not self.frozen:
+            return False
+        self.key_to_heap[key] = (expiration_time, key)
+        heapq.heappush(self.expiration_heap, (expiration_time, key))
+        if key in self.data:
+            if self.data[key][1] < expiration_time:
+                self.data[key] = (value, expiration_time)
+                return True
+            return False
+        self.data[key] = (value, expiration_time)
+        self._remove_outdated()
+        return True
+
+    def get(self, key: KeyType) -> (Optional[ValueType], Optional[DHTExpiration]):
+        """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
+        self._remove_outdated()
+        if key in self.data:
+            return self.data[key]
+        return None, None
+
+    def items(self) -> Iterator[Tuple[KeyType, ValueType, DHTExpiration]]:
+        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
+        self._remove_outdated()
+        return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())
+
+    def top(self) -> Optional[Tuple[KeyType, ValueType, DHTExpiration]]:
+        """ Return the entry with earliest expiration or None if there isn't any """
+        self._remove_outdated()
+        if self.data:
+            top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
+            while self.key_to_heap.get(top_key) != top_entry:
+                heapq.heappop(self.expiration_heap)  # skip leftover "ghost" entries until first real entry
+                top_entry, top_key = self.expiration_heap[0], self.expiration_heap[0][1]
+            value, expiration = self.data[top_key]
+            return top_key, value, expiration
+
+    def __contains__(self, key: KeyType):
+        self._remove_outdated()
+        return key in self.data
+
+    def __len__(self):
+        self._remove_outdated()
+        return len(self.data)
+
+    def __delitem__(self, key: KeyType):
+        if key in self.key_to_heap:
+            del self.data[key], self.key_to_heap[key]
+        # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
+
+    def __bool__(self):
+        return bool(self.data)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({self.data})"
+
+    @contextmanager
+    def freeze(self):
+        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
+        prev_frozen, self.frozen = self.frozen, True
+        try:
+            yield self
+        finally:
+            self.frozen = prev_frozen
+
+
+@MSGPackSerializer.ext_serializable(0x50)
+class DictionaryDHTValue(TimedStorage[Subkey, BinaryDHTValue]):
+    """ a dictionary-like DHT value type that maps sub-keys to values with individual expirations """
+    latest_expiration_time = float('-inf')
+
+    def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
+        self.latest_expiration_time = max(self.latest_expiration_time, expiration_time)
+        return super().store(key, value, expiration_time)
+
+    def packb(self) -> bytes:
+        """ custom behavior for MSGPackSerializer.dumps """
+        return MSGPackSerializer.dumps([self.maxsize, self.latest_expiration_time, list(map(list, self.items()))])
+
+    @classmethod
+    def unpackb(cls, raw: bytes) -> DictionaryDHTValue:
+        maxsize, latest_expiration_time, items = MSGPackSerializer.loads(raw)
+        with DictionaryDHTValue(maxsize).freeze() as new_dict:
+            for key, value, expiration_time in items:
+                new_dict.store(key, value, expiration_time)
+            new_dict.latest_expiration_time = latest_expiration_time
+            return new_dict
+
+
+class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTValue]]):
+    """ A dictionary-like storage that can store binary values and/or nested dictionaries until expiration """
+    def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration,
+              subkey: Optional[Subkey] = None) -> bool:
+        """
+        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
+        If subkey is not None, adds a subkey-value pair to a dictionary associated with :key: (see store_subkey below)
+        :returns: True if new value was stored, False it was rejected (current value is newer)
+        """
+        if subkey is not None:  # add one sub-key
+            return self.store_subkey(key, subkey, value, expiration_time)
+        else:  # store regular key
+            return super().store(key, value, expiration_time)
+
+    def store_subkey(self, key: DHTID, subkey: Subkey, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
+        """
+        Save a (sub-key, value) into a dictionary associated with a given key.
+         1) if self[key] is empty, create a new dictionary and add sub-key there
+         2) if self[key] is a dictionary (DictionaryDHTValue), store {sub-key: value, expiration} to that storage
+         3) if self[key] is a normal value with smaller expiration time, overwrite it with a dictionary and add sub-key
+        :returns: True if new entry was stored, False it was rejected (current value is newer)
+        """
+        previous_value, previous_expiration_time = self.get(key)
+        if isinstance(previous_value, DictionaryDHTValue):  # already a dictionary, just add new subkey
+            if expiration_time > previous_value.latest_expiration_time:
+                super().store(key, previous_value, expiration_time)  # refresh expiration time
+            return previous_value.store(subkey, value, expiration_time)
+        elif expiration_time > (previous_expiration_time or float('-inf')):  # create new dictionary, add subkey
+            new_storage = DictionaryDHTValue()
+            new_storage.store(subkey, value, expiration_time)
+            return super().store(key, new_storage, new_storage.latest_expiration_time)
+        else:
+            return False
+
+
+class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
+    """ a queue of keys scheduled for refresh in future, used in DHTNode """
+    frozen = True

+ 18 - 12
hivemind/proto/dht.proto

@@ -24,10 +24,11 @@ message NodeInfo {
 message StoreRequest {
   // three lists of the same length representing dht keys, dht values and expiration
   repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
-  repeated bytes values = 2;           // binary-encoded value for i-th key
-  repeated double expiration_time = 3; // expirations for i-th key (type = DHTExpiration)
-  repeated bool in_cache = 4;          // if in_cache[i], store i-th key in cache, else store normally
-  NodeInfo peer = 5;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  repeated string subkeys = 2;         // [optional] subkeys for DictionaryDHTValue type, empty string means no subkey
+  repeated bytes values = 3;           // binary-encoded value for i-th key
+  repeated double expiration_time = 4; // expirations for i-th key (type = DHTExpiration)
+  repeated bool in_cache = 5;          // if in_cache[i], store i-th key in cache, else store normally
+  NodeInfo peer = 6;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
 }
 
 message StoreResponse {
@@ -40,16 +41,21 @@ message FindRequest {
   NodeInfo peer = 2;                   // optional, same behavior as in DHT.ping
 }
 
-message Peers {
-  // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
-  repeated bytes node_ids = 1;         // DHTID serialized with node_id.to_bytes()
-  repeated string endpoints = 2;       // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+enum ResultType {NOT_FOUND = 0; FOUND_REGULAR = 1; FOUND_DICTIONARY = 2;}
+
+message FindResult {
+  ResultType type = 1;                 // NONE |      REGULAR     | DICTIONARY
+  bytes value = 2;                     // n/a  | serialized value | serialized DictionaryDHTValue with serialized fields
+  double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
+
+  // two aligned arrays: DHTIDs and Endpoints for nearest peers (sorted by XOR distance)
+  repeated bytes nearest_node_ids = 4;      // DHTIDs serialized with node_id.to_bytes()
+  repeated string nearest_endpoints = 5;    // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
 }
 
+
 message FindResponse {
-  repeated bytes values = 1;           // value for i-th key, b'' means not found locally
-  repeated double expiration_time = 2; // expiration time for i-th value, only valid value is found
-  repeated Peers nearest = 3;          // peers ordered from nearest to farthest based on distance to i-th key
-  NodeInfo peer = 4;                   // respondent's node id, for you to update routing table
+  repeated FindResult results = 1;       // for each item, return value/expiration (if found) and nearest peers
+  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
 }
 

+ 0 - 1
hivemind/server/checkpoint_saver.py

@@ -1,5 +1,4 @@
 import threading
-import time
 from datetime import datetime
 from pathlib import Path
 from shutil import copytree

+ 43 - 7
hivemind/utils/serializer.py

@@ -1,9 +1,14 @@
 """ A unified interface for several common serialization methods """
 import pickle
 from io import BytesIO
+from typing import Dict, Any
 
 import torch
-import umsgpack
+import msgpack
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
 
 class SerializerBase:
     @staticmethod
@@ -38,10 +43,41 @@ class PytorchSerializer(SerializerBase):
 
 
 class MSGPackSerializer(SerializerBase):
-    @staticmethod
-    def dumps(obj: object) -> bytes:
-        return umsgpack.dumps(obj, use_bin_type=False, strict_types=True)
+    _ExtTypes: Dict[Any, int] = {}
+    _ExtTypeCodes: Dict[int, Any] = {}
+
+    @classmethod
+    def ext_serializable(cls, type_code: int):
+        assert isinstance(type_code, int), "Please specify a (unique) int type code"
+
+        def wrap(wrapped_type: type):
+            assert callable(getattr(wrapped_type, 'packb', None)) and callable(getattr(wrapped_type, 'unpackb', None)),\
+                f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
+            if type_code in cls._ExtTypeCodes:
+                logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting.")
+            cls._ExtTypeCodes[type_code], cls._ExtTypes[wrapped_type] = wrapped_type, type_code
+            return wrapped_type
+        return wrap
+
+    @classmethod
+    def _encode_ext_types(cls, obj):
+        type_code = cls._ExtTypes.get(type(obj))
+        if type_code is not None:
+            return msgpack.ExtType(type_code, obj.packb())
+        return obj
+
+    @classmethod
+    def _decode_ext_types(cls, type_code: int, data: bytes):
+        if type_code in cls._ExtTypeCodes:
+            return cls._ExtTypeCodes[type_code].unpackb(data)
+        logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
+        return data
+
+    @classmethod
+    def dumps(cls, obj: object) -> bytes:
+        return msgpack.dumps(obj, use_bin_type=True, default=cls._encode_ext_types, strict_types=True)
+
+    @classmethod
+    def loads(cls, buf: bytes) -> object:
+        return msgpack.loads(buf, ext_hook=cls._decode_ext_types, raw=False)
 
-    @staticmethod
-    def loads(buf: bytes) -> object:
-        return umsgpack.loads(buf, raw=False)

+ 1 - 1
requirements.txt

@@ -2,7 +2,7 @@ PyYAML
 torch>=1.3.0
 numpy>=1.17
 prefetch_generator>=1.0.1
-umsgpack
+msgpack>=0.5.6
 sortedcontainers
 uvloop>=0.14.0
 grpcio>=1.31

+ 1 - 1
tests/test_dht_experts.py

@@ -6,7 +6,7 @@ import hivemind
 from hivemind import LOCALHOST
 
 
-def test_hivemind_dht():
+def test_store_get_experts():
     peers = [hivemind.DHT(start=True)]
     for i in range(10):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]

+ 91 - 8
tests/test_dht_node.py

@@ -11,6 +11,7 @@ from typing import List, Dict
 from hivemind import get_dht_time
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.storage import DictionaryDHTValue
 
 
 def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
@@ -85,6 +86,24 @@ def test_dht_protocol():
             dummy_port = hivemind.find_open_port()
             assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
 
+            # store/get a dictionary with sub-keys
+            nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
+            value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
+            assert loop.run_until_complete(protocol.call_store(
+                f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
+                expiration_time=[expiration], subkeys=[subkey1])
+            )
+            assert loop.run_until_complete(protocol.call_store(
+                f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
+                expiration_time=[expiration + 5], subkeys=[subkey2])
+            )
+            recv_dict, recv_expiration, nodes_found = loop.run_until_complete(
+                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
+            assert isinstance(recv_dict, DictionaryDHTValue)
+            assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
+            assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
+            assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
+
             if listen:
                 loop.run_until_complete(protocol.shutdown())
             print("DHTProtocol test finished successfully!")
@@ -172,7 +191,8 @@ def test_dht_node():
         # note: we run everything in a separate process to re-initialize all global states from scratch
         # this helps us avoid undesirable side-effects when running multiple tests in sequence
         loop = asyncio.get_event_loop()
-        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10))
+        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
+                                                    cache_refresh_before_expiry=False))
 
         # test 1: find self
         nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
@@ -229,19 +249,24 @@ def test_dht_node():
         assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
 
         # test 5: node without peers
-        other_node = loop.run_until_complete(DHTNode.create())
-        nearest = loop.run_until_complete(other_node.find_nearest_nodes([dummy]))[dummy]
-        assert len(nearest) == 1 and nearest[other_node.node_id] == f"{LOCALHOST}:{other_node.port}"
-        nearest = loop.run_until_complete(other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
+        detached_node = loop.run_until_complete(DHTNode.create())
+        nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
+        assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
+        nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
         assert len(nearest) == 0
 
         # test 6 store and get value
         true_time = get_dht_time() + 1200
         assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
-        for node in [me, other_node]:
-            val, expiration_time = loop.run_until_complete(me.get("mykey"))
-            assert expiration_time == true_time, "Wrong time"
+        that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
+                                                          cache_refresh_before_expiry=False, cache_locally=False))
+
+        for node in [me, that_guy]:
+            val, expiration_time = loop.run_until_complete(node.get("mykey"))
             assert val == ["Value", 10], "Wrong value"
+            assert expiration_time == true_time, f"Wrong time"
+
+        assert loop.run_until_complete(detached_node.get("mykey")) == (None, None)
 
         # test 7: bulk store and bulk get
         keys = 'foo', 'bar', 'baz', 'zzz'
@@ -252,6 +277,31 @@ def test_dht_node():
         for key, value in zip(keys, values):
             assert key in response and response[key][0] == value
 
+        # test 8: store dictionaries as values (with sub-keys)
+        upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
+        now = get_dht_time()
+        assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
+        assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
+        for node in [that_guy, me]:
+            value, time = loop.run_until_complete(node.get(upper_key))
+            assert isinstance(value, dict) and time == now + 20
+            assert value[subkey1] == (123, now + 10)
+            assert value[subkey2] == (456, now + 20)
+            assert len(value) == 2
+
+        assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
+        assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
+        assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
+        loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
+
+        for node in [that_guy, me]:
+            value, time = loop.run_until_complete(node.get(upper_key))
+            assert isinstance(value, dict) and time == now + 50, (value, time)
+            assert value[subkey1] == (123, now + 10)
+            assert value[subkey2] == (567, now + 30)
+            assert value[subkey3] == (890, now + 50)
+            assert len(value) == 3
+
         test_success.set()
 
     tester = mp.Process(target=_tester, daemon=True)
@@ -262,6 +312,39 @@ def test_dht_node():
         proc.terminate()
 
 
+def test_dhtnode_replicas():
+    dht_size = 20
+    initial_peers = 3
+    num_replicas = random.randint(1, 20)
+    test_success = mp.Event()
+
+    async def _tester():
+        peers = []
+        for i in range(dht_size):
+            neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
+            peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
+
+        you = random.choice(peers)
+        assert await you.store('key1', 'foo', get_dht_time() + 999)
+
+        actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
+        assert num_replicas == actual_key1_replicas
+
+        assert await you.store('key2', 'bar', get_dht_time() + 999)
+        total_size = sum(len(peer.protocol.storage) for peer in peers)
+        actual_key2_replicas = total_size - actual_key1_replicas
+        assert num_replicas == actual_key2_replicas
+
+        assert await you.store('key2', 'baz', get_dht_time() + 1000)
+        assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
+        test_success.set()
+
+    proc = mp.Process(target=lambda: asyncio.run(_tester()))
+    proc.start()
+    proc.join()
+    assert test_success.is_set()
+
+
 def test_dhtnode_caching(T=0.05):
     test_success = mp.Event()
 

+ 60 - 10
tests/test_dht_storage.py

@@ -1,18 +1,19 @@
 import time
 
-from hivemind import DHTID, get_dht_time
-from hivemind.dht.protocol import LocalStorage
+from hivemind.dht.routing import get_dht_time
+from hivemind.dht.storage import DHTLocalStorage, DHTID, DictionaryDHTValue
+from hivemind.utils.serializer import MSGPackSerializer
 
 
 def test_store():
-    d = LocalStorage()
+    d = DHTLocalStorage()
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
     assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
     print("Test store passed")
 
 
 def test_get_expired():
-    d = LocalStorage()
+    d = DHTLocalStorage()
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
     time.sleep(0.5)
     assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
@@ -20,13 +21,13 @@ def test_get_expired():
 
 
 def test_get_empty():
-    d = LocalStorage()
-    assert d.get(DHTID.generate(source="key")) == (None, None), "LocalStorage returned non-existent value"
+    d = DHTLocalStorage()
+    assert d.get(DHTID.generate(source="key")) == (None, None), "DHTLocalStorage returned non-existent value"
     print("Test get expired passed")
 
 
 def test_change_expiration_time():
-    d = LocalStorage()
+    d = DHTLocalStorage()
     d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
     assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
     d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
@@ -36,7 +37,7 @@ def test_change_expiration_time():
 
 
 def test_maxsize_cache():
-    d = LocalStorage(maxsize=1)
+    d = DHTLocalStorage(maxsize=1)
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
     assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
@@ -44,7 +45,7 @@ def test_maxsize_cache():
 
 
 def test_localstorage_top():
-    d = LocalStorage(maxsize=3)
+    d = DHTLocalStorage(maxsize=3)
     d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
     d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
     d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
@@ -61,8 +62,40 @@ def test_localstorage_top():
     assert d.top()[:2] == (DHTID.generate("key3"), b"val3")
 
 
+def test_localstorage_nested():
+    time = get_dht_time()
+    d1 = DHTLocalStorage()
+    d2 = DictionaryDHTValue()
+    d2.store('subkey1', b'value1', time + 2)
+    d2.store('subkey2', b'value2', time + 3)
+    d2.store('subkey3', b'value3', time + 1)
+
+    assert d2.latest_expiration_time == time + 3
+    for subkey, subvalue, subexpiration in d2.items():
+        assert d1.store_subkey(DHTID.generate('foo'), subkey, subvalue, subexpiration)
+    assert d1.store(DHTID.generate('bar'), b'456', time + 2)
+    assert d1.get(DHTID.generate('foo'))[0].data == d2.data
+    assert d1.get(DHTID.generate('foo'))[1] == d2.latest_expiration_time
+    assert d1.get(DHTID.generate('foo'))[0].get('subkey1') == (b'value1', time + 2)
+    assert len(d1.get(DHTID.generate('foo'))[0]) == 3
+    assert d1.store_subkey(DHTID.generate('foo'), 'subkey4', b'value4', time + 4)
+    assert len(d1.get(DHTID.generate('foo'))[0]) == 4
+
+    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA', time + 1) is False  # prev has better expiration
+    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA', time + 3)  # new value has better expiration
+    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyB', b'valueB', time + 4)  # new value has better expiration
+    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA+', time + 5)  # overwrite subkeyA under key bar
+    assert all(subkey in d1.get(DHTID.generate('bar'))[0] for subkey in ('subkeyA', 'subkeyB'))
+    assert len(d1.get(DHTID.generate('bar'))[0]) == 2 and d1.get(DHTID.generate('bar'))[1] == time + 5
+
+    assert d1.store(DHTID.generate('foo'), b'nothing', time + 3.5) is False  # previous value has better expiration
+    assert d1.get(DHTID.generate('foo'))[0].get('subkey2') == (b'value2', time + 3)
+    assert d1.store(DHTID.generate('foo'), b'nothing', time + 5) is True  # new value has better expiraiton
+    assert d1.get(DHTID.generate('foo')) == (b'nothing', time + 5)  # value should be replaced
+
+
 def test_localstorage_freeze():
-    d = LocalStorage(maxsize=2)
+    d = DHTLocalStorage(maxsize=2)
 
     with d.freeze():
         d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01)
@@ -77,3 +110,20 @@ def test_localstorage_freeze():
         d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 3)  # key3 will push key1 out due to maxsize
         assert DHTID.generate("key1") in d
     assert DHTID.generate("key1") not in d
+
+
+def test_localstorage_serialize():
+    d1 = DictionaryDHTValue()
+    d2 = DictionaryDHTValue()
+
+    now = get_dht_time()
+    d1.store('key1', b'ololo', now + 1)
+    d2.store('key2', b'pysh', now + 1)
+    d2.store('key3', b'pyshpysh', now + 2)
+
+    data = MSGPackSerializer.dumps([d1, d2, 123321])
+    assert isinstance(data, bytes)
+    new_d1, new_d2, new_value = MSGPackSerializer.loads(data)
+    assert isinstance(new_d1, DictionaryDHTValue) and isinstance(new_d2, DictionaryDHTValue) and new_value == 123321
+    assert 'key1' in new_d1 and len(new_d1) == 1
+    assert 'key1' not in new_d2 and len(new_d2) == 2 and new_d2.get('key3') == (b'pyshpysh', now + 2)

+ 0 - 2
tests/test_moe.py

@@ -15,8 +15,6 @@ def test_moe():
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
                            num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
-        # declare expert uids. Server *should* declare them by itself, but it takes time.
-        assert all(dht.declare_experts(all_expert_uids, endpoint=server_endpoint))
 
         dmoe = hivemind.RemoteMixtureOfExperts(
             in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn')