浏览代码

Extract expert-specific methods from DHT (#192)

Implement DHT.run_coroutine
 * allow run_coroutine to be cancelled from host process
 * add basic tests for DHT.run_coroutine
 * add basic tests for DHT.store/get

Extract expert-specific functions away from DHT
* move beam search to beam_search.py
* move declare/get experts to dht_ops.py
* mark old api as deprecated
* update hivemind.DHT docstr
* update DHT scheme (DHT is no longer expert-specific)

Misc
* remove receiver_threads (always 1, never a bottleneck)
* remove timeout in declare_experts (unused for 6+ months)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 年之前
父节点
当前提交
f132294edb

二进制
docs/_static/dht.odp


二进制
docs/_static/dht.png


+ 3 - 4
hivemind/client/averaging/__init__.py

@@ -64,7 +64,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
-    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
     :param kwargs: extra parameters forwarded to grpc.aio.server
     :param kwargs: extra parameters forwarded to grpc.aio.server
@@ -91,7 +90,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                  allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
-                 listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
+                 listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
@@ -102,7 +101,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
-        self.listen, self.listen_on, self.receiver_threads, self.kwargs = listen, listen_on, receiver_threads, kwargs
+        self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
         self.channel_options = channel_options
         self.channel_options = channel_options
         self.daemon = daemon
         self.daemon = daemon
 
 
@@ -155,7 +154,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
         # initialize asyncio synchronization primitives in this event loop
-        pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
+        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
 
 
         async def _run():
         async def _run():
             grpc.aio.init_grpc_aio()
             grpc.aio.init_grpc_aio()

+ 273 - 0
hivemind/client/beam_search.py

@@ -0,0 +1,273 @@
+import asyncio
+import heapq
+from collections import deque
+from functools import partial
+from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
+
+from hivemind.dht import DHT, DHTNode
+from hivemind.client.expert import RemoteExpert
+from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, UidEndpoint, Score, Coordinate,
+                                        PREFIX_PATTERN, UID_DELIMITER, is_valid_prefix)
+from hivemind.utils import get_logger, get_dht_time, MPFuture
+
+logger = get_logger(__name__)
+
+
+class MoEBeamSearcher:
+    """
+    Utility class that uses DHT to find most suitable experts for RemoteMixtureOfExperts.
+    Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
+    An expert identifier consists of:
+
+        * optional prefix that determines expert role, experiment name, etc.
+        * one or more integers that determine that expert's position in an N-dimensional grid
+
+    A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
+    When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
+    For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
+    ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
+
+    In order to enable fast beam search, DHT maintains dictionaries of all active suffixes for every prefix
+    (e.g. "ffn_expert.98": {76: ffn_expert.98.76...., 123: ffn_expert.98.123..., 225: ffn_expert.98.225....}))
+
+    RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
+    For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
+    This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
+    However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
+    In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
+    It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
+
+    After selecting k best indices along first dimension, MoE moves to the second dimension.
+    It can find top-k index pairs (e.g. "expert.98.76") that use one of k best indices from the previous step.
+    This beam search explores one additional dimension per step and finds k best experts from across the DHT
+    in O(k * num_dimensions * dimension_size) time depending on the chosen grid dimensions.
+
+    :param dht: a running DHT daemon that is used for beam search AND local caching
+    :param uid_prefix: search for experts whose uids start with this prefix
+    :param grid_size: dimensions that form expert uid (see above)
+    :param num_workers: number of concurrent DHT coroutines per beam search
+    :param negative_caching: if True, whenever DHT is unable to find an expert or prefix, it will cache the "no key"
+      result inside the DHT for :expiration: seconds. Caching only affects beam search and has three main effects:
+
+      1. Faster beam search under node failures: if there are inconsistencies in DHT keys, such as a prefix pointing to
+         a now-defunct expert, these inconsistencies will be overwritten by the first peer that stumbles upon them. As a
+         result, beam search will not have to wait for non-existent experts until the expiration of their DHT entries;
+      2. Delayed expert availability: Without negative cache, new experts are always immediately available for beam
+         search after they are published to the DHT. With negative cache, there are rare cases (e.g. when adding new
+         experts in place of recently defunct ones) when new experts will be initially invisible, but gradually become
+         visible to more peers as those peers refresh their cache. This process takes at most :expiration: seconds;
+      3. Faster beam search in very sparse grids: there is one edge case where negative cache will improve beam search
+         performance; If an expert grid is very sparse, there can be empty indices in the first grid dimension (i.e.
+         indices {i} such that _no_ experts that start with "{prefix}.{i}.*"). If so, the default beam search will
+         be very slow due to the way it forms initial beam. Beam search with negative cache enabled will run normally.
+         Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
+    """
+
+    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Optional[Tuple[int, ...]] = None,
+                 num_workers: Optional[int] = None, negative_caching: bool = True, **kwargs):
+        if not uid_prefix.endswith(UID_DELIMITER):
+            uid_prefix += UID_DELIMITER
+            logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
+        assert is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
+        self.dht = dht
+        self.uid_prefix, self.grid_size = uid_prefix, grid_size
+        self.negative_caching, self.num_workers, self.dht_kwargs = negative_caching, num_workers, kwargs
+
+    def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
+                         ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+        """
+        :param scores: prefer suffix coordinates that have highest scores
+        :param beam_size: select this many active suffixes with highest scores
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
+        """
+        return self.dht.run_coroutine(partial(self._get_initial_beam, prefix=self.uid_prefix, beam_size=beam_size,
+                                              scores=tuple(scores), negative_caching=self.negative_caching,
+                                              num_workers=self.num_workers), return_future)
+
+    @staticmethod
+    async def _get_initial_beam(dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int,
+                                scores: Tuple[float, ...], negative_caching: bool, num_workers: Optional[int] = None
+                                ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+        num_workers = num_workers or dht.max_workers or beam_size
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
+        unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
+        pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
+
+        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
+            # dispatch additional tasks
+            while unattempted_indices and len(pending_tasks) < num_workers:
+                next_index = unattempted_indices.pop()  # note: this is best unattempted index because of sort order
+                next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
+                pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
+
+            # await the next best prefix to be fetched
+            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
+            try:
+                maybe_prefix_data = await pending_task
+                if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
+                    successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items()
+                                  if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list)
+                                  and len(match.value) == 2}
+                    if successors:
+                        beam.append((scores[pending_best_index], pending_best_prefix, successors))
+                elif maybe_prefix_data is None and negative_caching:
+                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
+                    asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
+                                                   expiration_time=get_dht_time() + dht.default_expiration))
+
+            except asyncio.CancelledError:
+                for _, pending_task in pending_tasks:
+                    pending_task.cancel()
+                raise
+        return beam
+
+    def get_active_successors(self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
+                              return_future: bool = False) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+        """
+        :param prefixes: a list of prefix for which to find active successor uids
+        :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
+        :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
+        :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
+        :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
+        """
+        assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
+        for prefix in prefixes:
+            assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
+        return self.dht.run_coroutine(partial(
+            self._get_active_successors, prefixes=list(prefixes), grid_size=grid_size,
+            negative_caching=self.negative_caching, num_workers=self.num_workers), return_future=return_future)
+
+    @staticmethod
+    async def _get_active_successors(dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int],
+                                     negative_caching: bool, num_workers: Optional[int] = None
+                                     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+        grid_size = grid_size or float('inf')
+        num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
+        dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
+        successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
+        for prefix, found in dht_responses.items():
+            if found and isinstance(found.value, dict):
+                successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items()
+                                      if isinstance(coord, Coordinate) and 0 <= coord < grid_size
+                                      and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
+            else:
+                successors[prefix] = {}
+                if found is None and negative_caching:
+                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
+                    asyncio.create_task(node.store(prefix, subkey=-1, value=None,
+                                                   expiration_time=get_dht_time() + dht.default_expiration))
+        return successors
+
+    def find_best_experts(self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
+                          ) -> Union[List[RemoteExpert], MPFuture[RemoteExpert]]:
+        """
+        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
+
+        :param grid_scores: scores predicted for each dimension in the grid
+        :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
+        :param beam_size: how many best experts should beam search return
+         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
+         Please note that any queries that fall outside the budget will still be performed in background and cached
+         for subsequent iterations as long as DHTNode.cache_locally is True
+        :param num_workers: use up to this many concurrent workers to search DHT
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :returns: a list that contains *up to* k_best RemoteExpert instances
+        """
+        assert (not self.grid_size or len(grid_scores) == len(self.grid_size)) and beam_size > 0
+        return self.dht.run_coroutine(partial(self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size,
+                                              grid_scores=list(grid_scores), negative_caching=self.negative_caching,
+                                              num_workers=self.num_workers), return_future)
+
+    @classmethod
+    async def _find_best_experts(
+            cls, dht: DHT, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
+            negative_caching: bool, num_workers: Optional[int] = None) -> List[RemoteExpert]:
+        num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
+
+        # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
+            dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers))
+
+        best_experts_heap: List[Tuple[Score, UidEndpoint]] = []  # max-heap of expert uids/endpoints ordered by scores
+        unique_experts: Set[ExpertUID] = set()
+
+        for dim_index in range(1, len(grid_scores) - 1):
+            for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
+                if uid_endpoint.uid not in unique_experts:
+                    push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
+                    push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
+                    unique_experts.add(uid_endpoint.uid)
+
+            # form new beam using successors from the current beam
+            dim_scores = grid_scores[dim_index]
+            best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(beam_size, (
+                (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
+                for prefix_score, prefix, suffixes in beam for next_coord in suffixes.keys()
+                if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)))
+            _, best_uid_prefixes = zip(*best_active_pairs)
+
+            # search DHT for next step suffixes
+            successors = await cls._get_active_successors(dht, node, best_uid_prefixes, grid_size=None,
+                                                          negative_caching=negative_caching, num_workers=num_workers)
+            beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
+            if not beam:
+                logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
+                break
+
+        # add best experts from the final beam
+        for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
+            if uid_endpoint.uid not in unique_experts:
+                push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
+                push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
+                unique_experts.add(uid_endpoint.uid)
+
+        best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
+        return best_experts
+
+    @staticmethod
+    def _iterate_matching_experts(beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]],
+                                  grid_scores: Sequence[Sequence[float]]) -> Iterator[Tuple[Score, UidEndpoint]]:
+        """ iterate over all exemplar experts attached to current beam """
+        for score, prefix, suffixes in beam:
+            for next_coord, match in suffixes.items():
+                if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
+                    yield score, match
+                elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
+                    expert_coords = match.uid.split(UID_DELIMITER)[1:]
+                    if all(coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
+                           for i, coord in enumerate(expert_coords)):
+                        expert_score = sum(scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords)))
+                        yield expert_score, match
+                    else:
+                        logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
+                else:
+                    logger.warning(f"Found incompatible expert UID: {match.uid}")
+
+    def batch_find_best_experts(self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int,
+                                return_future: bool = False) -> Union[List[List[RemoteExpert]], MPFuture]:
+        """
+        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
+
+        :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
+        :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
+        :param beam_size: how many best experts should beam search return
+         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
+         Please note that any queries that fall outside the budget will still be performed in background and cached
+         for subsequent iterations as long as DHTNode.cache_locally is True
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :returns: a list that contains *up to* k_best RemoteExpert instances
+        """
+        return self.dht.run_coroutine(partial(
+            self._batch_find_best_experts, prefix=self.uid_prefix, batch_grid_scores=batch_grid_scores,
+            beam_size=beam_size, negative_caching=self.negative_caching, num_workers=self.num_workers), return_future)
+
+    @classmethod
+    async def _batch_find_best_experts(
+            cls, dht: DHT, node: DHTNode, prefix: str, batch_grid_scores: Sequence[Sequence[Tuple[float]]],
+            beam_size: int, negative_caching: bool, num_workers: Optional[int]) -> Sequence[Sequence[RemoteExpert]]:
+        batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores]
+                             for i in range(len(batch_grid_scores[0]))]
+        coros = [cls._find_best_experts(dht, node, prefix, grid_scores, beam_size, negative_caching, num_workers)
+                 for grid_scores in batch_grid_scores]
+        return await asyncio.gather(*coros)

+ 11 - 12
hivemind/client/moe.py

@@ -12,6 +12,8 @@ from torch.autograd.function import once_differentiable
 
 
 import hivemind
 import hivemind
 from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
 from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
+from hivemind.server.expert_uid import UID_DELIMITER
+from hivemind.client.beam_search import MoEBeamSearcher
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_pack, nested_flatten, serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils import nested_pack, nested_flatten, serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
@@ -45,11 +47,8 @@ class RemoteMixtureOfExperts(nn.Module):
                  backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
                  backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
                  **dht_kwargs):
                  **dht_kwargs):
         super().__init__()
         super().__init__()
-        if not uid_prefix.endswith(hivemind.dht.UID_DELIMITER):
-            uid_prefix += hivemind.dht.UID_DELIMITER
-            logger.info(f"Prefix must end with '{hivemind.dht.UID_DELIMITER}'. New prefix: '{uid_prefix}' .")
-        assert hivemind.dht.is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
-        self.dht, self.grid_size, self.uid_prefix, self.dht_kwargs = dht, grid_size, uid_prefix, dht_kwargs
+        self.dht = dht
+        self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
         self.timeout_after_k_min = timeout_after_k_min
@@ -75,10 +74,10 @@ class RemoteMixtureOfExperts(nn.Module):
             input_for_gating = input
             input_for_gating = input
 
 
         # 1. compute scores and find most appropriate experts with beam search
         # 1. compute scores and find most appropriate experts with beam search
-        grid_scores = self.proj(input_for_gating).split_with_sizes(self.grid_size, dim=-1)
+        grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
 
 
-        chosen_experts: List[List[RemoteExpert]] = self.dht.batch_find_best_experts(
-            self.uid_prefix, [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best, **self.dht_kwargs)
+        chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
+            [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best)
 
 
         if self._expert_info is None:
         if self._expert_info is None:
             try:
             try:
@@ -121,8 +120,8 @@ class RemoteMixtureOfExperts(nn.Module):
 
 
         grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
         grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
         for i, expert in enumerate(flat_experts):
         for i, expert in enumerate(flat_experts):
-            expert_indices = expert.uid[len(self.uid_prefix):]
-            expert_indices = list(map(int, expert_indices.split(hivemind.dht.UID_DELIMITER)))
+            expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
+            expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
 
         scores_per_dim = [
         scores_per_dim = [
@@ -140,8 +139,8 @@ class RemoteMixtureOfExperts(nn.Module):
             # grab some expert to set ensemble output shape
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
             proj_device = self.proj.weight.device
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
-            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.grid_size, dim=-1)
-            dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
+            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
+            dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             self._expert_info = dummy_experts[0].info
             self._expert_info = dummy_experts[0].info
         return self._expert_info
         return self._expert_info
 
 

+ 57 - 367
hivemind/dht/__init__.py

@@ -15,53 +15,27 @@ The code is organized as follows:
 from __future__ import annotations
 from __future__ import annotations
 import asyncio
 import asyncio
 import ctypes
 import ctypes
-import heapq
 import multiprocessing as mp
 import multiprocessing as mp
-import re
-from collections import deque
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
-
+from typing import List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
 
 
+import hivemind
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
-from hivemind.dht.routing import get_dht_time, DHTValue, DHTKey, Subkey
-from hivemind.utils import MPFuture, Endpoint, Hostname, get_logger, switch_to_uvloop, strip_port, ValueWithExpiration
+from hivemind.dht.routing import DHTValue, DHTKey, Subkey
+from hivemind.utils.networking import Hostname, Endpoint, strip_port
+from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, get_dht_time
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
-ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
-UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
-FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
-UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
-PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
-#  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
-
-
-def is_valid_uid(maybe_uid: str) -> bool:
-    """ An uid must contain a string expert type, followed by one or more .-separated numeric indices """
-    return bool(UID_PATTERN.fullmatch(maybe_uid))
-
-
-def is_valid_prefix(maybe_prefix: str) -> bool:
-    """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
-    return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
-
-
-def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
-    """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
-    uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
-    pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
-    return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
+ReturnType = TypeVar('ReturnType')
 
 
 
 
 class DHT(mp.Process):
 class DHT(mp.Process):
     """
     """
-    High-level interface to hivemind.dht that is designed to allow RemoteMixtureOfExperts to select best experts.
-
-    * hivemind servers periodically announce their experts via DHT.declare_experts
-    * trainers find most suitable experts via DHT.find_best_experts
+    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
+    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
+    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
 
 
     :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
     :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
@@ -70,60 +44,17 @@ class DHT(mp.Process):
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
         (but no more than one per key)
         (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
-    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
-    :param negative_caching: if True, whenever DHT is unable to find an expert or prefix, it will cache the "no key"
-      result inside the DHT for :expiration: seconds. Caching only affects beam search and has three main effects:
-
-      1. Faster beam search under node failures: if there are inconsistencies in DHT keys, such as a prefix pointing to
-         a now-defunct expert, these inconsistencies will be overwritten by the first peer that stumbles upon them. As a
-         result, beam search will not have to wait for non-existent experts until the expiration of their DHT entries;
-      2. Delayed expert availability: Without negative cache, new experts are always immediately available for beam
-         search after they are published to the DHT. With negative cache, there are rare cases (e.g. when adding new
-         experts in place of recently defunct ones) when new experts will be initially invisible, but gradually become
-         visible to more peers as those peers refresh their cache. This process takes at most :expiration: seconds;
-      3. Faster beam search in very sparse grids: there is one edge case where negative cache will improve beam search
-         performance; If an expert grid is very sparse, there can be empty indices in the first grid dimension (i.e.
-         indices {i} such that _no_ experts that start with "{prefix}.{i}.*"). If so, the default beam search will
-         be very slow due to the way it forms initial beam. Beam search with negative cache enabled will run normally.
-         Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
-
     :param kwargs: any other params will be forwarded to DHTNode upon creation
     :param kwargs: any other params will be forwarded to DHTNode upon creation
-
-    Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
-    An expert identifier consists of:
-
-        * optional prefix that determines expert role, experiment name, etc.
-        * one or more integers that determine that expert's position in an N-dimensional grid
-
-    A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
-    When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
-    For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
-    ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
-
-    In order to enable fast beam search, DHT maintains dictionaries of all active suffixes for every prefix
-    (e.g. "ffn_expert.98": {76: ffn_expert.98.76...., 123: ffn_expert.98.123..., 225: ffn_expert.98.225....}))
-
-    RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
-    For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
-    This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
-    However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
-    In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
-    It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
-
-    After selecting k best indices along first dimension, MoE moves to the second dimension.
-    It can find top-k index pairs (e.g. "expert.98.76") that use one of k best indices from the previous step.
-    This beam search explores one additional dimension per step and finds k best experts from across the DHT
-    in O(k * num_dimensions * dimension_size) time depending on the chosen grid dimensions.
     """
     """
 
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 receiver_threads: int = 1, negative_caching: bool = True, expiration: float = 300, **kwargs):
+                 expiration: float = 300, **kwargs):
         super().__init__()
         super().__init__()
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
-        self.receiver_threads, self.max_workers, self.parallel_rpc = receiver_threads, max_workers, parallel_rpc
-        self.expiration, self.negative_caching = expiration, negative_caching
+        self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
+        self.default_expiration = expiration
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
         self.ready = mp.Event()
@@ -134,7 +65,7 @@ class DHT(mp.Process):
     def run(self) -> None:
     def run(self) -> None:
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
-        pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
+        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
 
 
         async def _run():
         async def _run():
             node = await DHTNode.create(
             node = await DHTNode.create(
@@ -201,7 +132,11 @@ class DHT(mp.Process):
               subkey: Optional[Subkey] = None, return_future: bool = False, **kwargs) -> Union[bool, MPFuture]:
               subkey: Optional[Subkey] = None, return_future: bool = False, **kwargs) -> Union[bool, MPFuture]:
         """
         """
         Find num_replicas best nodes to store (key, value) and store it there until expiration time.
         Find num_replicas best nodes to store (key, value) and store it there until expiration time.
-        :note: store is a simplified interface to store_many, all kwargs are be forwarded there
+
+        :param key: msgpack-serializable key to be associated with value until expiration.
+        :param value: msgpack-serializable value to be stored under a given key until expiration.
+        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
+        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         """
@@ -221,6 +156,39 @@ class DHT(mp.Process):
                 future.set_exception(e)
                 future.set_exception(e)
             raise
             raise
 
 
+    def run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
+                      return_future: bool = False) -> Union[ReturnType, MPFuture[ReturnType]]:
+        """
+        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
+         for running custom functions DHT for special cases (e.g. declare experts, beam search)
+
+        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: coroutine outputs or MPFuture for these outputs
+        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
+          DHT fields made by this coroutine will not be accessible from the host process.
+        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
+          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
+        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
+        return future if return_future else future.result()
+
+    async def _run_coroutine(self, node: DHTNode, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
+                             future: MPFuture[ReturnType]):
+        main_task = asyncio.create_task(coro(self, node))
+        cancel_task = asyncio.create_task(await_cancelled(future))
+        try:
+            await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
+            if future.cancelled():
+                main_task.cancel()
+            else:
+                future.set_result(await main_task)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+
     def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
     def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
         """
         """
         Get this machine's visible address by requesting other peers or using pre-specified network addresses.
         Get this machine's visible address by requesting other peers or using pre-specified network addresses.
@@ -274,289 +242,11 @@ class DHT(mp.Process):
             future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
             future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
                                             f" Please ensure the node is connected or specify peers=... manually."))
                                             f" Please ensure the node is connected or specify peers=... manually."))
 
 
-    def declare_experts(self, uids: Sequence[ExpertUID], endpoint: Endpoint, wait: bool = True,
-                        timeout: Optional[float] = None) -> Dict[ExpertUID, bool]:
-        """
-        Make experts visible to all DHT peers; update timestamps if declared previously.
+    def declare_experts(self, uids, endpoint, wait: bool = True):
+        logger.warning("dht.declare_experts is scheduled for removal in 0.9.8, please use hivemind.declare_experts.",)
+        return hivemind.declare_experts(self, uids, endpoint, wait=wait)
 
 
-        :param uids: a list of expert ids to update
-        :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
-        :param wait: if True, awaits for declaration to finish, otherwise runs in background
-        :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
-        :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
-        """
-        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-        for uid in uids:
-            assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
-        future, _future = MPFuture.make_pair() if wait else (None, None)
-        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), endpoint=endpoint, future=_future)))
-        if wait:
-            return future.result(timeout)
-
-    async def _declare_experts(self, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint,
-                               future: Optional[MPFuture]) -> Dict[ExpertUID, bool]:
-        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
-        expiration_time = get_dht_time() + self.expiration
-        data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
-        for uid in uids:
-            data_to_store[uid, None] = endpoint
-            prefix = uid if uid.count(UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}'
-            for i in range(prefix.count(UID_DELIMITER) - 1):
-                prefix, last_coord = split_uid(prefix)
-                data_to_store[prefix, last_coord] = [uid, endpoint]
-
-        keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
-        store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
-        if future is not None:
-            future.set_result(store_ok)
-        return store_ok
-
-    def get_experts(self, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None,
+    def get_experts(self, uids, expiration_time: Optional[DHTExpiration] = None,
                     return_future: bool = False) -> List[Optional[RemoteExpert]]:
                     return_future: bool = False) -> List[Optional[RemoteExpert]]:
-        """
-        :param uids: find experts with these ids from across the DHT
-        :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: a list of [RemoteExpert if found else None]
-        """
-        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_experts', [], dict(uids=list(uids), expiration_time=expiration_time, future=_future)))
-        return future if return_future else future.result()
-
-    async def _get_experts(self, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration],
-                           future: Optional[MPFuture] = None) -> List[Optional[RemoteExpert]]:
-        if expiration_time is None:
-            expiration_time = get_dht_time()
-        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
-        found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
-
-        experts: List[Optional[RemoteExpert]] = [None] * len(uids)
-        for i, uid in enumerate(uids):
-            if found[uid] is not None and isinstance(found[uid].value, Endpoint):
-                experts[i] = RemoteExpert(uid, found[uid].value)
-        if future:
-            future.set_result(experts)
-        return experts
-
-    def get_initial_beam(self, prefix: ExpertPrefix, scores: Sequence[float], beam_size: int,
-                         num_workers: Optional[int] = None, return_future: bool = False
-                         ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
-        """
-        :param prefix: search for experts whose uids start with this prefix
-        :param scores: prefer suffix coordinates that have highest scores
-        :param beam_size: select this many active suffixes with highest scores
-        :param num_workers: maintain up to this many concurrent DHT searches
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
-        """
-        assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_initial_beam', [], dict(prefix=prefix, scores=tuple(scores), beam_size=beam_size,
-                                                      num_workers=num_workers, future=_future)))
-        return future if return_future else future.result()
-
-    async def _get_initial_beam(self, node, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...],
-                                num_workers: Optional[int] = None, future: Optional[MPFuture] = None
-                                ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
-        num_workers = num_workers or self.max_workers or beam_size
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
-        unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
-        pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
-
-        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
-            # dispatch additional tasks
-            while unattempted_indices and len(pending_tasks) < num_workers:
-                next_index = unattempted_indices.pop()  # note: this is best unattempted index because of sort order
-                next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
-                pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
-
-            # await the next best prefix to be fetched
-            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
-            try:
-                maybe_prefix_data = await pending_task
-                if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
-                    successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items()
-                                  if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list)
-                                  and len(match.value) == 2}
-                    if successors:
-                        beam.append((scores[pending_best_index], pending_best_prefix, successors))
-                elif maybe_prefix_data is None and self.negative_caching:
-                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
-                    asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + self.expiration))
-
-            except asyncio.CancelledError:
-                for _, pending_task in pending_tasks:
-                    pending_task.cancel()
-                raise
-        if future:
-            future.set_result(beam)
-        return beam
-
-    def get_active_successors(self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
-                              num_workers: Optional[int] = None, return_future: bool = False
-                              ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
-        """
-        :param prefixes: a list of prefix for which to find active successor uids
-        :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
-        :param num_workers: how many parallel workers to use for DHTNode.get_many
-        :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
-        :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
-        :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
-        """
-        assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
-        for prefix in prefixes:
-            assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_active_successors', [], dict(
-            prefixes=list(prefixes), grid_size=grid_size, num_workers=num_workers, future=_future)))
-        return future if return_future else future.result()
-
-    async def _get_active_successors(self, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
-                                     num_workers: Optional[int] = None, future: Optional[MPFuture] = None
-                                     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
-        grid_size = grid_size or float('inf')
-        num_workers = num_workers or min(len(prefixes), self.max_workers or len(prefixes))
-        dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
-        successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
-        for prefix, found in dht_responses.items():
-            if found and isinstance(found.value, dict):
-                successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items()
-                                      if isinstance(coord, Coordinate) and 0 <= coord < grid_size
-                                      and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
-            else:
-                successors[prefix] = {}
-                if found is None and self.negative_caching:
-                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
-                    asyncio.create_task(node.store(prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + self.expiration))
-        if future:
-            future.set_result(successors)
-        return successors
-
-    def find_best_experts(self, prefix: ExpertPrefix, grid_scores: Sequence[Sequence[float]], beam_size: int,
-                          num_workers: Optional[int] = None, return_future: bool = False
-                          ) -> Union[List[RemoteExpert], MPFuture]:
-        """
-        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
-
-        :param prefix: common prefix for all expert uids in grid
-        :param grid_scores: scores predicted for each dimension in the grid,
-        :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
-        :param beam_size: how many best experts should beam search return
-         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
-         Please note that any queries that fall outside the budget will still be performed in background and cached
-         for subsequent iterations as long as DHTNode.cache_locally is True
-        :param num_workers: use up to this many concurrent workers to search DHT
-        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :returns: a list that contains *up to* k_best RemoteExpert instances
-        """
-        assert len(grid_scores) > 0 and beam_size > 0
-        assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_find_best_experts', [], dict(prefix=prefix, grid_scores=list(map(tuple, grid_scores)),
-                                                       beam_size=beam_size, num_workers=num_workers, future=_future)))
-        return future if return_future else future.result()
-
-    async def _find_best_experts(
-            self, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
-            num_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]:
-        num_workers = num_workers or min(beam_size, self.max_workers or beam_size)
-
-        # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await self._get_initial_beam(
-            node, prefix, beam_size, grid_scores[0], min(beam_size, num_workers))
-
-        best_experts_heap: List[Tuple[Score, UidEndpoint]] = []  # max-heap of expert uids/endpoints ordered by scores
-        unique_experts: Set[ExpertUID] = set()
-
-        for dim_index in range(1, len(grid_scores) - 1):
-            for score, uid_endpoint in self._iterate_matching_experts(beam, grid_scores):
-                if uid_endpoint.uid not in unique_experts:
-                    push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                    push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                    unique_experts.add(uid_endpoint.uid)
-
-            # form new beam using successors from the current beam
-            dim_scores = grid_scores[dim_index]
-            best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(beam_size, (
-                (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
-                for prefix_score, prefix, suffixes in beam for next_coord in suffixes.keys()
-                if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)))
-            _, best_uid_prefixes = zip(*best_active_pairs)
-
-            # search DHT for next step suffixes
-            successors = await self._get_active_successors(node, best_uid_prefixes, num_workers=num_workers)
-            beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
-            if not beam:
-                logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
-                break
-
-        # add best experts from the final beam
-        for score, uid_endpoint in self._iterate_matching_experts(beam, grid_scores):
-            if uid_endpoint.uid not in unique_experts:
-                push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                unique_experts.add(uid_endpoint.uid)
-
-        best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
-        if future is not None:
-            future.set_result(best_experts)
-        return best_experts
-
-    @staticmethod
-    def _iterate_matching_experts(beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]],
-                                  grid_scores: Sequence[Sequence[float]]) -> Iterator[Tuple[Score, UidEndpoint]]:
-        """ iterate over all exemplar experts attached to current beam """
-        for score, prefix, suffixes in beam:
-            for next_coord, match in suffixes.items():
-                if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
-                    yield score, match
-                elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
-                    expert_coords = match.uid.split(UID_DELIMITER)[1:]
-                    if all(coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
-                           for i, coord in enumerate(expert_coords)):
-                        expert_score = sum(scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords)))
-                        yield expert_score, match
-                    else:
-                        logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
-                else:
-                    logger.warning(f"Found incompatible expert UID: {match.uid}")
-
-    def batch_find_best_experts(
-            self, prefix: str, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, *,
-            workers_per_sample: Optional[int] = None, return_future=False) -> Union[List[List[RemoteExpert]], MPFuture]:
-        """
-        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
-
-        :param prefix: common prefix for all expert uids in grid
-        :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
-        :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
-        :param beam_size: how many best experts should beam search return
-         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
-         Please note that any queries that fall outside the budget will still be performed in background and cached
-         for subsequent iterations as long as DHTNode.cache_locally is True
-        :param workers_per_sample: use up to this many concurrent workers for every sample in batch
-        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :returns: a list that contains *up to* k_best RemoteExpert instances
-        """
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_batch_find_best_experts', [], dict(prefix=prefix, batch_grid_scores=batch_grid_scores,
-                                                             beam_size=beam_size, workers_per_sample=workers_per_sample,
-                                                             future=_future)))
-        return future if return_future else future.result()
-
-    async def _batch_find_best_experts(
-            self, node: DHTNode, prefix: str, batch_grid_scores: Sequence[Sequence[Tuple[float]]], beam_size: int,
-            workers_per_sample: Optional[int] = None, future: Optional[MPFuture] = None) -> List[List[RemoteExpert]]:
-
-        batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores]
-                             for i in range(len(batch_grid_scores[0]))]
-        coros = [self._find_best_experts(node, prefix, grid_scores, beam_size, workers_per_sample)
-                 for grid_scores in batch_grid_scores]
-
-        best_experts_batch = await asyncio.gather(*coros)
-        if future is not None:
-            future.set_result(best_experts_batch)
-        return best_experts_batch
+        logger.warning("dht.get_experts is scheduled for removal in 0.9.8, please use hivemind.get_experts.")
+        return hivemind.get_experts(self, uids, expiration_time, return_future)

+ 5 - 4
hivemind/server/__init__.py

@@ -13,9 +13,10 @@ import torch
 
 
 import hivemind
 import hivemind
 from hivemind.dht import DHT
 from hivemind.dht import DHT
+from hivemind.server.expert_uid import UID_DELIMITER
 from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
 from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.connection_handler import ConnectionHandler
-from hivemind.server.dht_handler import DHTHandlerThread
+from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.layers import name_to_block, name_to_input
 from hivemind.server.layers import name_to_block, name_to_input
 from hivemind.server.runtime import Runtime
 from hivemind.server.runtime import Runtime
@@ -287,10 +288,10 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
 
 
     def _generate_uid():
     def _generate_uid():
         if expert_pattern is None:
         if expert_pattern is None:
-            return f"expert{hivemind.dht.UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
+            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
 
 
         uid = []
         uid = []
-        for block in expert_pattern.split(hivemind.dht.UID_DELIMITER):
+        for block in expert_pattern.split(UID_DELIMITER):
             try:
             try:
                 if '[' not in block and ']' not in block:
                 if '[' not in block and ']' not in block:
                     uid.append(block)
                     uid.append(block)
@@ -303,7 +304,7 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
                 raise e
                 raise e
             except Exception as e:
             except Exception as e:
                 raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
                 raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
-        return hivemind.dht.UID_DELIMITER.join(uid)
+        return UID_DELIMITER.join(uid)
 
 
     while remaining_attempts > 0 and len(found_uids) < num_experts:
     while remaining_attempts > 0 and len(found_uids) < num_experts:
 
 

+ 68 - 5
hivemind/server/dht_handler.py

@@ -1,8 +1,12 @@
 import threading
 import threading
-import time
+from functools import partial
+from typing import Sequence, Dict, List, Tuple, Optional
 
 
-from hivemind.dht import DHT
-from hivemind.utils import Endpoint, get_port
+from hivemind.dht import DHT, DHTNode, DHTExpiration, DHTValue
+from hivemind.client.expert import RemoteExpert
+from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, Coordinate,
+                                        UID_DELIMITER, UID_PATTERN, is_valid_uid, split_uid)
+from hivemind.utils import Endpoint, get_dht_time, get_port
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):
@@ -16,6 +20,65 @@ class DHTHandlerThread(threading.Thread):
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def run(self) -> None:
     def run(self) -> None:
-        self.dht.declare_experts(self.experts.keys(), self.endpoint)
+        declare_experts(self.dht, self.experts.keys(), self.endpoint)
         while not self.stop.wait(self.update_period):
         while not self.stop.wait(self.update_period):
-            self.dht.declare_experts(self.experts.keys(), self.endpoint)
+            declare_experts(self.dht, self.experts.keys(), self.endpoint)
+
+
+def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint,
+                    wait: bool = True) -> Dict[ExpertUID, bool]:
+    """
+    Make experts visible to all DHT peers; update timestamps if declared previously.
+
+    :param uids: a list of expert ids to update
+    :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
+    :param wait: if True, awaits for declaration to finish, otherwise runs in background
+    :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
+    :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
+    """
+    assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    for uid in uids:
+        assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
+    return dht.run_coroutine(partial(_declare_experts, uids=list(uids), endpoint=endpoint), return_future=not wait)
+
+
+async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint) -> Dict[ExpertUID, bool]:
+    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    expiration_time = get_dht_time() + dht.default_expiration  # TODO use local expiration
+    data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
+    for uid in uids:
+        data_to_store[uid, None] = endpoint
+        prefix = uid if uid.count(UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}'
+        for i in range(prefix.count(UID_DELIMITER) - 1):
+            prefix, last_coord = split_uid(prefix)
+            data_to_store[prefix, last_coord] = [uid, endpoint]
+
+    keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
+    store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
+    return store_ok
+
+
+def get_experts(dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None,
+                return_future: bool = False) -> List[Optional[RemoteExpert]]:
+    """
+    :param uids: find experts with these ids from across the DHT
+    :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
+    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+    :returns: a list of [RemoteExpert if found else None]
+    """
+    assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    return dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+
+
+async def _get_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
+                       ) -> List[Optional[RemoteExpert]]:
+    if expiration_time is None:
+        expiration_time = get_dht_time()
+    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
+
+    experts: List[Optional[RemoteExpert]] = [None] * len(uids)
+    for i, uid in enumerate(uids):
+        if found[uid] is not None and isinstance(found[uid].value, Endpoint):
+            experts[i] = RemoteExpert(uid, found[uid].value)
+    return experts

+ 32 - 0
hivemind/server/expert_uid.py

@@ -0,0 +1,32 @@
+import re
+from typing import NamedTuple, Union, Tuple
+
+from hivemind.utils.networking import Endpoint
+
+
+ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
+UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
+UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
+FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
+UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
+PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
+#  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
+
+
+def is_valid_uid(maybe_uid: str) -> bool:
+    """ An uid must contain a string expert type, followed by one or more .-separated numeric indices """
+    return bool(UID_PATTERN.fullmatch(maybe_uid))
+
+
+def is_valid_prefix(maybe_prefix: str) -> bool:
+    """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
+    return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
+
+
+def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
+    """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
+    uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
+    pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
+    return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
+
+

+ 11 - 1
hivemind/utils/asyncio.py

@@ -1,4 +1,4 @@
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable
+from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable
 import asyncio
 import asyncio
 import uvloop
 import uvloop
 T = TypeVar('T')
 T = TypeVar('T')
@@ -32,3 +32,13 @@ async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
     for aiter in async_iters:
     for aiter in async_iters:
         async for elem in aiter:
         async for elem in aiter:
             yield elem
             yield elem
+
+
+async def await_cancelled(awaitable: Awaitable) -> bool:
+    try:
+        await awaitable
+        return False
+    except asyncio.CancelledError:
+        return True
+    except BaseException:
+        return False

+ 7 - 5
hivemind/utils/mpfuture.py

@@ -6,12 +6,14 @@ import concurrent.futures._base as base
 
 
 import asyncio
 import asyncio
 from functools import lru_cache
 from functools import lru_cache
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Generic, TypeVar
 
 
 from hivemind.utils.threading import run_in_background
 from hivemind.utils.threading import run_in_background
 
 
+ResultType = TypeVar('ResultType')
 
 
-class MPFuture(base.Future):
+
+class MPFuture(base.Future, Generic[ResultType]):
     """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
     """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
 
 
     TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
     TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
@@ -74,7 +76,7 @@ class MPFuture(base.Future):
         except TimeoutError:
         except TimeoutError:
             pass
             pass
 
 
-    def set_result(self, result):
+    def set_result(self, result: ResultType):
         self._sync_updates()
         self._sync_updates()
         if self._state in self.TERMINAL_STATES:
         if self._state in self.TERMINAL_STATES:
             raise RuntimeError(f"Can't set_result to a future that is in {self._state}")
             raise RuntimeError(f"Can't set_result to a future that is in {self._state}")
@@ -105,13 +107,13 @@ class MPFuture(base.Future):
         self._state, self._exception = base.CANCELLED, base.CancelledError()
         self._state, self._exception = base.CANCELLED, base.CancelledError()
         return self._send_updates()
         return self._send_updates()
 
 
-    def result(self, timeout: Optional[float] = None):
+    def result(self, timeout: Optional[float] = None) -> ResultType:
         self._await_terminal_state(timeout)
         self._await_terminal_state(timeout)
         if self._exception is not None:
         if self._exception is not None:
             raise self._exception
             raise self._exception
         return self._result
         return self._result
 
 
-    def exception(self, timeout=None):
+    def exception(self, timeout=None) -> BaseException:
         self._await_terminal_state(timeout)
         self._await_terminal_state(timeout)
         if self._state == base.CANCELLED:
         if self._state == base.CANCELLED:
             raise base.CancelledError()
             raise base.CancelledError()

+ 4 - 2
tests/benchmark_dht.py

@@ -5,6 +5,7 @@ import time
 from tqdm import trange
 from tqdm import trange
 
 
 import hivemind
 import hivemind
+import hivemind.server.expert_uid
 from hivemind.utils.threading import increase_file_limit
 from hivemind.utils.threading import increase_file_limit
 
 
 logger = hivemind.get_logger(__name__)
 logger = hivemind.get_logger(__name__)
@@ -42,7 +43,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for start in trange(0, num_experts, expert_batch_size):
     for start in trange(0, num_experts, expert_batch_size):
         store_start = time.perf_counter()
         store_start = time.perf_counter()
         endpoints.append(random_endpoint())
         endpoints.append(random_endpoint())
-        successes = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], endpoints[-1]).values()
+        store_ok = hivemind.declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1])
+        successes = store_ok.values()
         total_store_time += time.perf_counter() - store_start
         total_store_time += time.perf_counter() - store_start
 
 
         total_stores += len(successes)
         total_stores += len(successes)
@@ -60,7 +62,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
 
     for start in trange(0, len(expert_uids), expert_batch_size):
     for start in trange(0, len(expert_uids), expert_batch_size):
         get_start = time.perf_counter()
         get_start = time.perf_counter()
-        get_result = get_peer.get_experts(expert_uids[start: start + expert_batch_size])
+        get_result = hivemind.get_experts(get_peer, expert_uids[start: start + expert_batch_size])
         total_get_time += time.perf_counter() - get_start
         total_get_time += time.perf_counter() - get_start
 
 
         for i, expert in enumerate(get_result):
         for i, expert in enumerate(get_result):

+ 70 - 142
tests/test_dht.py

@@ -1,175 +1,103 @@
+import asyncio
 import random
 import random
-import numpy as np
+import time
+
 import pytest
 import pytest
-import asyncio
 
 
 import hivemind
 import hivemind
-from hivemind import LOCALHOST, UidEndpoint, strip_port
+from hivemind import LOCALHOST, strip_port
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_store_get_experts():
-    peers = [hivemind.DHT(start=True)]
+def test_get_store():
+    peers = []
     for i in range(10):
     for i in range(10):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
         peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
         peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
 
 
-    you: hivemind.dht.DHT = random.choice(peers)
-    theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
+    node1, node2 = random.sample(peers, 2)
+    assert node1.store('key1', 'value1', expiration_time=hivemind.get_dht_time() + 30)
+    assert node1.get('key1').value == 'value1'
+    assert node2.get('key1').value == 'value1'
+    assert node2.get('key2') is None
+
+    future = node1.get('foo', return_future=True)
+    assert future.result() is None
 
 
-    expert_uids = [f"my_expert.{i}" for i in range(110)]
-    batch_size = 10
-    for batch_start in range(0, len(expert_uids), batch_size):
-        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
+    future = node1.get('foo', return_future=True)
+    future.cancel()
 
 
-    found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
-    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
-    assert all(res is None for res in found[-2:]), "Found non-existing experts"
+    assert node2.store('key1', 123, expiration_time=hivemind.get_dht_time() + 31)
+    assert node2.store('key2', 456, expiration_time=hivemind.get_dht_time() + 32)
+    assert node1.get('key1', latest=True).value == 123
+    assert node1.get('key2').value == 456
 
 
-    that_guys_expert, that_guys_port = "my_other_expert.1337", random.randint(1000, 9999)
-    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
-    you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
-    assert isinstance(you_found, hivemind.RemoteExpert)
-    assert you_found.endpoint == f'that_host:{that_guys_port}'
+    assert node1.store('key2', subkey='subkey1', value=789, expiration_time=hivemind.get_dht_time() + 32)
+    assert node2.store('key2', subkey='subkey2', value='pew', expiration_time=hivemind.get_dht_time() + 32)
+    found_dict = node1.get('key2', latest=True).value
+    assert isinstance(found_dict, dict) and len(found_dict) == 2
+    assert found_dict['subkey1'].value == 789 and found_dict['subkey2'].value == 'pew'
 
 
     for peer in peers:
     for peer in peers:
         peer.shutdown()
         peer.shutdown()
 
 
 
 
-@pytest.mark.forked
-def test_dht_get_address(addr=LOCALHOST, dummy_endpoint='123.45.67.89:*'):
-    node1 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
-    node2 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node1.port}"])
-    node3 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node2.port}"])
-    assert addr in node3.get_visible_address(num_peers=2)
+async def dummy_dht_coro(self, node):
+    return 'pew'
 
 
-    node4 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
-    with pytest.raises(ValueError):
-        node4.get_visible_address()
-    assert node4.get_visible_address(peers=[f'{addr}:{node1.port}']).endswith(addr)
 
 
-    node5 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", endpoint=f"{dummy_endpoint}")
-    assert node5.get_visible_address() == strip_port(dummy_endpoint)
+async def dummy_dht_coro_error(self, node):
+    raise ValueError("Oops, i did it again...")
 
 
 
 
-@pytest.mark.forked
-def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=16,
-                     grid_dims=(32, 32, 32)):
-    dht = []
-    for i in range(dht_size):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-        dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
-
-    real_experts = sorted({
-        'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
-        for _ in range(total_experts)
-    })
-    for batch_start in range(0, len(real_experts), batch_size):
-        random.choice(dht).declare_experts(
-            real_experts[batch_start: batch_start + batch_size], wait=True,
-            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
-
-    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-    you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
-
-    for i in range(50):
-        topk_experts = you.find_best_experts('expert.', [np.random.randn(dim) for dim in grid_dims], beam_size=beam_size)
-        assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
-        assert len(topk_experts) == beam_size
+async def dummy_dht_coro_stateful(self, node):
+    self._x_dummy = getattr(self, '_x_dummy', 123) + 1
+    return self._x_dummy
 
 
-    for i in range(10):
-        batch_experts = you.batch_find_best_experts('expert.', [np.random.randn(batch_size, dim) for dim in grid_dims],
-                                                    beam_size=beam_size)
-        assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
-        assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
-        assert all(len(experts) == beam_size for experts in batch_experts)
 
 
+async def dummy_dht_coro_long(self, node):
+    await asyncio.sleep(0.25)
+    return self._x_dummy ** 2
 
 
-@pytest.mark.forked
-def test_dht_single_node():
-    node = hivemind.DHT(start=True, expiration=999)
-
-    assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
-    assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(node.declare_experts(['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
-
-    for expert in node.get_experts(['expert.3', 'expert.2']):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
-
-    assert all(node.declare_experts(['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
-    found_experts = node.find_best_experts('expert.', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
-    assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
-
-    successors = node.get_active_successors(['e.1.2.', 'e.2.', 'e.4.5.'])
-    assert len(successors['e.1.2.']) == 2
-    assert successors['e.1.2.'][3] == UidEndpoint('e.1.2.3', f'{LOCALHOST}:42')
-    assert successors['e.1.2.'][5] == UidEndpoint('e.1.2.5', f'{LOCALHOST}:42')
-    assert len(successors['e.2.']) == 1 and successors['e.2.'][0] == UidEndpoint('e.2.0', f'{LOCALHOST}:42')
-    assert successors['e.4.5.'] == {}
-
-    initial_beam = node.get_initial_beam('expert.', (3, 2, 1, 0, -1, -2, -3), beam_size=3)
-    assert len(initial_beam) == 3
-    assert initial_beam[0][:2] == (2.0, 'expert.1.')
-    assert initial_beam[1][:2] == (1.0, 'expert.2.')
-    assert initial_beam[2][:2] == (0.0, 'expert.3.')
-
-    with pytest.raises(AssertionError):
-        node.find_best_experts('expert', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
-
-    with pytest.raises(AssertionError):
-        node.find_best_experts('expert.1', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
-
-    with pytest.raises(AssertionError):
-        node.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
-
-    with pytest.raises(AssertionError):
-        node.get_initial_beam('expert', (3, 2, 1, 0, -1, -2, -3), beam_size=3)
-
-
-def test_uid_patterns():
-    valid_experts = ["expert.1", "expert.0", "expert.0.0.1", "expert.1337", "ffn.12.34.56.78.90",
-                     "transformer.3.2.1.0", "transformer_encoder.2", "transformer::encoder.2", "T®@nsf0rmE®🤗.321",
-                     "🤗.321", "0.1.2", "00.1.2", "7070.3.2.1.0", "block2.1.23", "LAYER.1.0.1"]
-    valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
-    valid_prefixes.extend([f"{uid}." for uid in valid_experts])
-    valid_prefixes.extend([hivemind.split_uid(uid)[0] for uid in valid_experts])
-    for uid in valid_experts:
-        assert hivemind.is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
-    for pfx in valid_prefixes:
-        assert hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
-
-    invalid = ["", ".", "expert.-1", "xxx.a", "expert.1x", "expert_ffn.1.abc1", "some.123.01", "expert.123.01",
-               "e1", "e..1", "e", "e.1.2.3..4", "ffn.1..1", ".123", ".1.2.3.", ".expert", "transformer.encoder.2",
-               "T®@nsf0rmE®.🤗.321", "layer::123", "expert.0.1.2.suffix", "0.1.2.suffix", "expert.1 something",
-               "expert.1\n", "expert.1\n2", "expert.1 ", "expert.1\nexpert.2", "'expert.1'", '"expert.1"']
-    invalid_experts = invalid + valid_prefixes + ["0", "123456"]
-    invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
-    for uid in invalid_experts:
-        assert not hivemind.is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
-    for pfx in invalid_prefixes:
-        assert not hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
+
+async def dummy_dht_coro_for_cancel(self, node):
+    self._x_dummy = -100
+    await asyncio.sleep(0.5)
+    self._x_dummy = 999
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-@pytest.mark.asyncio
-async def test_negative_caching():
-    peers = []
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(hivemind.DHT(initial_peers=neighbors_i, negative_caching=False, cache_locally=False, start=True))
+def test_run_coroutine():
+    dht = hivemind.DHT(start=True)
+    assert dht.run_coroutine(dummy_dht_coro) == 'pew'
+
+    with pytest.raises(ValueError):
+        res = dht.run_coroutine(dummy_dht_coro_error)
+
+    bg_task = dht.run_coroutine(dummy_dht_coro_long, return_future=True)
+    assert dht.run_coroutine(dummy_dht_coro_stateful) == 124
+    assert dht.run_coroutine(dummy_dht_coro_stateful) == 125
+    assert dht.run_coroutine(dummy_dht_coro_stateful) == 126
+    assert not hasattr(dht, '_x_dummy')
+    assert bg_task.result() == 126 ** 2
+
+    future = dht.run_coroutine(dummy_dht_coro_for_cancel, return_future=True)
+    time.sleep(0.25)
+    future.cancel()
+    assert dht.run_coroutine(dummy_dht_coro_stateful) == -99
 
 
-    normal_peer, writer_peer = random.sample(peers, 2)
 
 
-    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-    neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, negative_caching=True, cache_locally=False, start=True)
+@pytest.mark.forked
+def test_dht_get_address(addr=LOCALHOST, dummy_endpoint='123.45.67.89:*'):
+    node1 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
+    node2 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node1.port}"])
+    node3 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node2.port}"])
+    assert addr in node3.get_visible_address(num_peers=2)
 
 
-    assert all(writer_peer.declare_experts(['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
-    # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
-    assert len(neg_caching_peer.get_initial_beam(prefix='ffn.', scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
+    node4 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
+    with pytest.raises(ValueError):
+        node4.get_visible_address()
+    assert node4.get_visible_address(peers=[f'{addr}:{node1.port}']).endswith(addr)
 
 
-    node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
-    fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
-    for i in range(6):
-        assert fetched[i] is not None, f"node should have cached ffn.{i}."
-    for i in range(6, len(fetched)):
-        assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
+    node5 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", endpoint=f"{dummy_endpoint}")
+    assert node5.get_visible_address() == strip_port(dummy_endpoint)

+ 159 - 0
tests/test_dht_experts.py

@@ -0,0 +1,159 @@
+import asyncio
+import random
+
+import numpy as np
+import pytest
+
+import hivemind
+import hivemind.server.expert_uid
+from hivemind import LOCALHOST
+from hivemind.client.beam_search import MoEBeamSearcher
+from hivemind.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
+
+
+@pytest.mark.forked
+def test_store_get_experts():
+    peers = [hivemind.DHT(start=True)]
+    for i in range(10):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
+
+    first_peer = random.choice(peers)
+    other_peer = random.choice(peers)
+
+    expert_uids = [f"my_expert.{i}" for i in range(110)]
+    batch_size = 10
+    for batch_start in range(0, len(expert_uids), batch_size):
+        hivemind.declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
+
+    found = other_peer.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
+    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
+    assert all(res is None for res in found[-2:]), "Found non-existing experts"
+
+    other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
+    hivemind.declare_experts(other_peer, [other_expert], f'that_host:{other_port}')
+    first_notfound, first_found = hivemind.get_experts(first_peer, ['foobar', other_expert])
+    assert isinstance(first_found, hivemind.RemoteExpert)
+    assert first_found.endpoint == f'that_host:{other_port}'
+
+    for peer in peers:
+        peer.shutdown()
+
+
+@pytest.mark.forked
+def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=16,
+                     grid_dims=(32, 32, 32)):
+    dht = []
+    for i in range(dht_size):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
+        dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
+
+    real_experts = sorted({
+        'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
+        for _ in range(total_experts)
+    })
+    for batch_start in range(0, len(real_experts), batch_size):
+        random.choice(dht).declare_experts(
+            real_experts[batch_start: batch_start + batch_size], wait=True,
+            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
+
+    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
+    you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
+    beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
+
+    for i in range(50):
+        topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
+        assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
+        assert len(topk_experts) == beam_size
+
+    for i in range(10):
+        batch_experts = beam_search.batch_find_best_experts([np.random.randn(batch_size, dim) for dim in grid_dims],
+                                                            beam_size=beam_size)
+        assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
+        assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
+        assert all(len(experts) == beam_size for experts in batch_experts)
+
+
+@pytest.mark.forked
+def test_dht_single_node():
+    node = hivemind.DHT(start=True, expiration=999)
+    beam_search = MoEBeamSearcher(node, 'expert.')
+
+    assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
+    assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
+    assert len(node.declare_experts(['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
+
+    for expert in node.get_experts(['expert.3', 'expert.2']):
+        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+
+    assert all(node.declare_experts(['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
+    found_experts = beam_search.find_best_experts([(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
+    assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
+
+    successors = beam_search.get_active_successors(['e.1.2.', 'e.2.', 'e.4.5.'])
+    assert len(successors['e.1.2.']) == 2
+    assert successors['e.1.2.'][3] == UidEndpoint('e.1.2.3', f'{LOCALHOST}:42')
+    assert successors['e.1.2.'][5] == UidEndpoint('e.1.2.5', f'{LOCALHOST}:42')
+    assert len(successors['e.2.']) == 1 and successors['e.2.'][0] == UidEndpoint('e.2.0', f'{LOCALHOST}:42')
+    assert successors['e.4.5.'] == {}
+
+    initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
+    assert len(initial_beam) == 3
+    assert initial_beam[0][:2] == (2.0, 'expert.1.')
+    assert initial_beam[1][:2] == (1.0, 'expert.2.')
+    assert initial_beam[2][:2] == (0.0, 'expert.3.')
+
+    with pytest.raises(AssertionError):
+        beam_search = MoEBeamSearcher(node, 'expert.1.ffn')
+
+    with pytest.raises(AssertionError):
+        beam_search.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
+
+
+def test_uid_patterns():
+    valid_experts = ["expert.1", "expert.0", "expert.0.0.1", "expert.1337", "ffn.12.34.56.78.90",
+                     "transformer.3.2.1.0", "transformer_encoder.2", "transformer::encoder.2", "T®@nsf0rmE®🤗.321",
+                     "🤗.321", "0.1.2", "00.1.2", "7070.3.2.1.0", "block2.1.23", "LAYER.1.0.1"]
+    valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
+    valid_prefixes.extend([f"{uid}." for uid in valid_experts])
+    valid_prefixes.extend([split_uid(uid)[0] for uid in valid_experts])
+    for uid in valid_experts:
+        assert is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
+    for pfx in valid_prefixes:
+        assert is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
+
+    invalid = ["", ".", "expert.-1", "xxx.a", "expert.1x", "expert_ffn.1.abc1", "some.123.01", "expert.123.01",
+               "e1", "e..1", "e", "e.1.2.3..4", "ffn.1..1", ".123", ".1.2.3.", ".expert", "transformer.encoder.2",
+               "T®@nsf0rmE®.🤗.321", "layer::123", "expert.0.1.2.suffix", "0.1.2.suffix", "expert.1 something",
+               "expert.1\n", "expert.1\n2", "expert.1 ", "expert.1\nexpert.2", "'expert.1'", '"expert.1"']
+    invalid_experts = invalid + valid_prefixes + ["0", "123456"]
+    invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
+    for uid in invalid_experts:
+        assert not is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
+    for pfx in invalid_prefixes:
+        assert not is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_negative_caching():
+    peers = []
+    for i in range(10):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True))
+
+    writer_peer = random.choice(peers)
+    assert all(hivemind.declare_experts(writer_peer, ['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
+
+    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+    neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True)
+    beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix='ffn.', negative_caching=True)
+    # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
+    assert len(beam_search.get_initial_beam(scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
+
+    node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
+    fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
+    for i in range(6):
+        assert fetched[i] is not None, f"node should have cached ffn.{i}."
+    for i in range(6, len(fetched)):
+        assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."

+ 4 - 5
tests/test_dht_node.py

@@ -1,15 +1,14 @@
 import asyncio
 import asyncio
+import heapq
 import multiprocessing as mp
 import multiprocessing as mp
 import random
 import random
-import heapq
-from typing import Optional
+from itertools import product
+from typing import Optional, List, Dict
+
 import numpy as np
 import numpy as np
 import pytest
 import pytest
-from itertools import product
 
 
 import hivemind
 import hivemind
-from typing import List, Dict
-
 from hivemind import get_dht_time, replace_port
 from hivemind import get_dht_time, replace_port
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
 from hivemind.dht.protocol import DHTProtocol, ValidationError
 from hivemind.dht.protocol import DHTProtocol, ValidationError

+ 3 - 3
tests/test_moe.py

@@ -109,10 +109,10 @@ def test_beam_search_correctness():
 
 
     for i in range(25):
     for i in range(25):
         input = torch.randn(32)
         input = torch.randn(32)
-        grid_scores = dmoe.proj(input).split_with_sizes(dmoe.grid_size, dim=-1)
+        grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
 
 
-        chosen_experts = dht.find_best_experts(dmoe.uid_prefix, [tensor.detach().numpy() for tensor in grid_scores],
-                                               beam_size=dmoe.k_best)
+        chosen_experts = dmoe.beam_search.find_best_experts([tensor.detach().numpy() for tensor in grid_scores],
+                                                            beam_size=dmoe.k_best)
         chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
         chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
                                                    [chosen_experts])[0]
                                                    [chosen_experts])[0]
         our_best_scores = list(chosen_scores.cpu().detach().numpy())
         our_best_scores = list(chosen_scores.cpu().detach().numpy())