瀏覽代碼

Extract expert-specific methods from DHT (#192)

Implement DHT.run_coroutine
 * allow run_coroutine to be cancelled from host process
 * add basic tests for DHT.run_coroutine
 * add basic tests for DHT.store/get

Extract expert-specific functions away from DHT
* move beam search to beam_search.py
* move declare/get experts to dht_ops.py
* mark old api as deprecated
* update hivemind.DHT docstr
* update DHT scheme (DHT is no longer expert-specific)

Misc
* remove receiver_threads (always 1, never a bottleneck)
* remove timeout in declare_experts (unused for 6+ months)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 年之前
父節點
當前提交
f132294edb

二進制
docs/_static/dht.odp


二進制
docs/_static/dht.png


+ 3 - 4
hivemind/client/averaging/__init__.py

@@ -64,7 +64,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
-    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
     :param kwargs: extra parameters forwarded to grpc.aio.server
@@ -91,7 +90,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                  allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
-                 listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
+                 listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
@@ -102,7 +101,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         super().__init__()
         self.dht = dht
-        self.listen, self.listen_on, self.receiver_threads, self.kwargs = listen, listen_on, receiver_threads, kwargs
+        self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
         self.channel_options = channel_options
         self.daemon = daemon
 
@@ -155,7 +154,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
-        pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
+        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
 
         async def _run():
             grpc.aio.init_grpc_aio()

+ 273 - 0
hivemind/client/beam_search.py

@@ -0,0 +1,273 @@
+import asyncio
+import heapq
+from collections import deque
+from functools import partial
+from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
+
+from hivemind.dht import DHT, DHTNode
+from hivemind.client.expert import RemoteExpert
+from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, UidEndpoint, Score, Coordinate,
+                                        PREFIX_PATTERN, UID_DELIMITER, is_valid_prefix)
+from hivemind.utils import get_logger, get_dht_time, MPFuture
+
+logger = get_logger(__name__)
+
+
+class MoEBeamSearcher:
+    """
+    Utility class that uses DHT to find most suitable experts for RemoteMixtureOfExperts.
+    Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
+    An expert identifier consists of:
+
+        * optional prefix that determines expert role, experiment name, etc.
+        * one or more integers that determine that expert's position in an N-dimensional grid
+
+    A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
+    When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
+    For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
+    ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
+
+    In order to enable fast beam search, DHT maintains dictionaries of all active suffixes for every prefix
+    (e.g. "ffn_expert.98": {76: ffn_expert.98.76...., 123: ffn_expert.98.123..., 225: ffn_expert.98.225....}))
+
+    RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
+    For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
+    This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
+    However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
+    In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
+    It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
+
+    After selecting k best indices along first dimension, MoE moves to the second dimension.
+    It can find top-k index pairs (e.g. "expert.98.76") that use one of k best indices from the previous step.
+    This beam search explores one additional dimension per step and finds k best experts from across the DHT
+    in O(k * num_dimensions * dimension_size) time depending on the chosen grid dimensions.
+
+    :param dht: a running DHT daemon that is used for beam search AND local caching
+    :param uid_prefix: search for experts whose uids start with this prefix
+    :param grid_size: dimensions that form expert uid (see above)
+    :param num_workers: number of concurrent DHT coroutines per beam search
+    :param negative_caching: if True, whenever DHT is unable to find an expert or prefix, it will cache the "no key"
+      result inside the DHT for :expiration: seconds. Caching only affects beam search and has three main effects:
+
+      1. Faster beam search under node failures: if there are inconsistencies in DHT keys, such as a prefix pointing to
+         a now-defunct expert, these inconsistencies will be overwritten by the first peer that stumbles upon them. As a
+         result, beam search will not have to wait for non-existent experts until the expiration of their DHT entries;
+      2. Delayed expert availability: Without negative cache, new experts are always immediately available for beam
+         search after they are published to the DHT. With negative cache, there are rare cases (e.g. when adding new
+         experts in place of recently defunct ones) when new experts will be initially invisible, but gradually become
+         visible to more peers as those peers refresh their cache. This process takes at most :expiration: seconds;
+      3. Faster beam search in very sparse grids: there is one edge case where negative cache will improve beam search
+         performance; If an expert grid is very sparse, there can be empty indices in the first grid dimension (i.e.
+         indices {i} such that _no_ experts that start with "{prefix}.{i}.*"). If so, the default beam search will
+         be very slow due to the way it forms initial beam. Beam search with negative cache enabled will run normally.
+         Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
+    """
+
+    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Optional[Tuple[int, ...]] = None,
+                 num_workers: Optional[int] = None, negative_caching: bool = True, **kwargs):
+        if not uid_prefix.endswith(UID_DELIMITER):
+            uid_prefix += UID_DELIMITER
+            logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
+        assert is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
+        self.dht = dht
+        self.uid_prefix, self.grid_size = uid_prefix, grid_size
+        self.negative_caching, self.num_workers, self.dht_kwargs = negative_caching, num_workers, kwargs
+
+    def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
+                         ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+        """
+        :param scores: prefer suffix coordinates that have highest scores
+        :param beam_size: select this many active suffixes with highest scores
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
+        """
+        return self.dht.run_coroutine(partial(self._get_initial_beam, prefix=self.uid_prefix, beam_size=beam_size,
+                                              scores=tuple(scores), negative_caching=self.negative_caching,
+                                              num_workers=self.num_workers), return_future)
+
+    @staticmethod
+    async def _get_initial_beam(dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int,
+                                scores: Tuple[float, ...], negative_caching: bool, num_workers: Optional[int] = None
+                                ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+        num_workers = num_workers or dht.max_workers or beam_size
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
+        unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
+        pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
+
+        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
+            # dispatch additional tasks
+            while unattempted_indices and len(pending_tasks) < num_workers:
+                next_index = unattempted_indices.pop()  # note: this is best unattempted index because of sort order
+                next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
+                pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
+
+            # await the next best prefix to be fetched
+            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
+            try:
+                maybe_prefix_data = await pending_task
+                if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
+                    successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items()
+                                  if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list)
+                                  and len(match.value) == 2}
+                    if successors:
+                        beam.append((scores[pending_best_index], pending_best_prefix, successors))
+                elif maybe_prefix_data is None and negative_caching:
+                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
+                    asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
+                                                   expiration_time=get_dht_time() + dht.default_expiration))
+
+            except asyncio.CancelledError:
+                for _, pending_task in pending_tasks:
+                    pending_task.cancel()
+                raise
+        return beam
+
+    def get_active_successors(self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
+                              return_future: bool = False) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+        """
+        :param prefixes: a list of prefix for which to find active successor uids
+        :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
+        :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
+        :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
+        :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
+        """
+        assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
+        for prefix in prefixes:
+            assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
+        return self.dht.run_coroutine(partial(
+            self._get_active_successors, prefixes=list(prefixes), grid_size=grid_size,
+            negative_caching=self.negative_caching, num_workers=self.num_workers), return_future=return_future)
+
+    @staticmethod
+    async def _get_active_successors(dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int],
+                                     negative_caching: bool, num_workers: Optional[int] = None
+                                     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+        grid_size = grid_size or float('inf')
+        num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
+        dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
+        successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
+        for prefix, found in dht_responses.items():
+            if found and isinstance(found.value, dict):
+                successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items()
+                                      if isinstance(coord, Coordinate) and 0 <= coord < grid_size
+                                      and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
+            else:
+                successors[prefix] = {}
+                if found is None and negative_caching:
+                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
+                    asyncio.create_task(node.store(prefix, subkey=-1, value=None,
+                                                   expiration_time=get_dht_time() + dht.default_expiration))
+        return successors
+
+    def find_best_experts(self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
+                          ) -> Union[List[RemoteExpert], MPFuture[RemoteExpert]]:
+        """
+        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
+
+        :param grid_scores: scores predicted for each dimension in the grid
+        :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
+        :param beam_size: how many best experts should beam search return
+         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
+         Please note that any queries that fall outside the budget will still be performed in background and cached
+         for subsequent iterations as long as DHTNode.cache_locally is True
+        :param num_workers: use up to this many concurrent workers to search DHT
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :returns: a list that contains *up to* k_best RemoteExpert instances
+        """
+        assert (not self.grid_size or len(grid_scores) == len(self.grid_size)) and beam_size > 0
+        return self.dht.run_coroutine(partial(self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size,
+                                              grid_scores=list(grid_scores), negative_caching=self.negative_caching,
+                                              num_workers=self.num_workers), return_future)
+
+    @classmethod
+    async def _find_best_experts(
+            cls, dht: DHT, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
+            negative_caching: bool, num_workers: Optional[int] = None) -> List[RemoteExpert]:
+        num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
+
+        # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
+            dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers))
+
+        best_experts_heap: List[Tuple[Score, UidEndpoint]] = []  # max-heap of expert uids/endpoints ordered by scores
+        unique_experts: Set[ExpertUID] = set()
+
+        for dim_index in range(1, len(grid_scores) - 1):
+            for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
+                if uid_endpoint.uid not in unique_experts:
+                    push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
+                    push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
+                    unique_experts.add(uid_endpoint.uid)
+
+            # form new beam using successors from the current beam
+            dim_scores = grid_scores[dim_index]
+            best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(beam_size, (
+                (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
+                for prefix_score, prefix, suffixes in beam for next_coord in suffixes.keys()
+                if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)))
+            _, best_uid_prefixes = zip(*best_active_pairs)
+
+            # search DHT for next step suffixes
+            successors = await cls._get_active_successors(dht, node, best_uid_prefixes, grid_size=None,
+                                                          negative_caching=negative_caching, num_workers=num_workers)
+            beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
+            if not beam:
+                logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
+                break
+
+        # add best experts from the final beam
+        for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
+            if uid_endpoint.uid not in unique_experts:
+                push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
+                push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
+                unique_experts.add(uid_endpoint.uid)
+
+        best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
+        return best_experts
+
+    @staticmethod
+    def _iterate_matching_experts(beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]],
+                                  grid_scores: Sequence[Sequence[float]]) -> Iterator[Tuple[Score, UidEndpoint]]:
+        """ iterate over all exemplar experts attached to current beam """
+        for score, prefix, suffixes in beam:
+            for next_coord, match in suffixes.items():
+                if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
+                    yield score, match
+                elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
+                    expert_coords = match.uid.split(UID_DELIMITER)[1:]
+                    if all(coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
+                           for i, coord in enumerate(expert_coords)):
+                        expert_score = sum(scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords)))
+                        yield expert_score, match
+                    else:
+                        logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
+                else:
+                    logger.warning(f"Found incompatible expert UID: {match.uid}")
+
+    def batch_find_best_experts(self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int,
+                                return_future: bool = False) -> Union[List[List[RemoteExpert]], MPFuture]:
+        """
+        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
+
+        :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
+        :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
+        :param beam_size: how many best experts should beam search return
+         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
+         Please note that any queries that fall outside the budget will still be performed in background and cached
+         for subsequent iterations as long as DHTNode.cache_locally is True
+        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
+        :returns: a list that contains *up to* k_best RemoteExpert instances
+        """
+        return self.dht.run_coroutine(partial(
+            self._batch_find_best_experts, prefix=self.uid_prefix, batch_grid_scores=batch_grid_scores,
+            beam_size=beam_size, negative_caching=self.negative_caching, num_workers=self.num_workers), return_future)
+
+    @classmethod
+    async def _batch_find_best_experts(
+            cls, dht: DHT, node: DHTNode, prefix: str, batch_grid_scores: Sequence[Sequence[Tuple[float]]],
+            beam_size: int, negative_caching: bool, num_workers: Optional[int]) -> Sequence[Sequence[RemoteExpert]]:
+        batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores]
+                             for i in range(len(batch_grid_scores[0]))]
+        coros = [cls._find_best_experts(dht, node, prefix, grid_scores, beam_size, negative_caching, num_workers)
+                 for grid_scores in batch_grid_scores]
+        return await asyncio.gather(*coros)

+ 11 - 12
hivemind/client/moe.py

@@ -12,6 +12,8 @@ from torch.autograd.function import once_differentiable
 
 import hivemind
 from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
+from hivemind.server.expert_uid import UID_DELIMITER
+from hivemind.client.beam_search import MoEBeamSearcher
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_pack, nested_flatten, serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.logging import get_logger
@@ -45,11 +47,8 @@ class RemoteMixtureOfExperts(nn.Module):
                  backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
                  **dht_kwargs):
         super().__init__()
-        if not uid_prefix.endswith(hivemind.dht.UID_DELIMITER):
-            uid_prefix += hivemind.dht.UID_DELIMITER
-            logger.info(f"Prefix must end with '{hivemind.dht.UID_DELIMITER}'. New prefix: '{uid_prefix}' .")
-        assert hivemind.dht.is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
-        self.dht, self.grid_size, self.uid_prefix, self.dht_kwargs = dht, grid_size, uid_prefix, dht_kwargs
+        self.dht = dht
+        self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
@@ -75,10 +74,10 @@ class RemoteMixtureOfExperts(nn.Module):
             input_for_gating = input
 
         # 1. compute scores and find most appropriate experts with beam search
-        grid_scores = self.proj(input_for_gating).split_with_sizes(self.grid_size, dim=-1)
+        grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
 
-        chosen_experts: List[List[RemoteExpert]] = self.dht.batch_find_best_experts(
-            self.uid_prefix, [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best, **self.dht_kwargs)
+        chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
+            [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best)
 
         if self._expert_info is None:
             try:
@@ -121,8 +120,8 @@ class RemoteMixtureOfExperts(nn.Module):
 
         grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
         for i, expert in enumerate(flat_experts):
-            expert_indices = expert.uid[len(self.uid_prefix):]
-            expert_indices = list(map(int, expert_indices.split(hivemind.dht.UID_DELIMITER)))
+            expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
+            expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
         scores_per_dim = [
@@ -140,8 +139,8 @@ class RemoteMixtureOfExperts(nn.Module):
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
-            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.grid_size, dim=-1)
-            dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
+            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
+            dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             self._expert_info = dummy_experts[0].info
         return self._expert_info
 

+ 57 - 367
hivemind/dht/__init__.py

@@ -15,53 +15,27 @@ The code is organized as follows:
 from __future__ import annotations
 import asyncio
 import ctypes
-import heapq
 import multiprocessing as mp
-import re
-from collections import deque
 from concurrent.futures import ThreadPoolExecutor
-from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
-
+from typing import List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
 
+import hivemind
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
-from hivemind.dht.routing import get_dht_time, DHTValue, DHTKey, Subkey
-from hivemind.utils import MPFuture, Endpoint, Hostname, get_logger, switch_to_uvloop, strip_port, ValueWithExpiration
+from hivemind.dht.routing import DHTValue, DHTKey, Subkey
+from hivemind.utils.networking import Hostname, Endpoint, strip_port
+from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, get_dht_time
 
 logger = get_logger(__name__)
 
-ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
-UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
-FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
-UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
-PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
-#  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
-
-
-def is_valid_uid(maybe_uid: str) -> bool:
-    """ An uid must contain a string expert type, followed by one or more .-separated numeric indices """
-    return bool(UID_PATTERN.fullmatch(maybe_uid))
-
-
-def is_valid_prefix(maybe_prefix: str) -> bool:
-    """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
-    return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
-
-
-def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
-    """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
-    uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
-    pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
-    return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
+ReturnType = TypeVar('ReturnType')
 
 
 class DHT(mp.Process):
     """
-    High-level interface to hivemind.dht that is designed to allow RemoteMixtureOfExperts to select best experts.
-
-    * hivemind servers periodically announce their experts via DHT.declare_experts
-    * trainers find most suitable experts via DHT.find_best_experts
+    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
+    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
+    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
 
     :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
@@ -70,60 +44,17 @@ class DHT(mp.Process):
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
         (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
-    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
-    :param negative_caching: if True, whenever DHT is unable to find an expert or prefix, it will cache the "no key"
-      result inside the DHT for :expiration: seconds. Caching only affects beam search and has three main effects:
-
-      1. Faster beam search under node failures: if there are inconsistencies in DHT keys, such as a prefix pointing to
-         a now-defunct expert, these inconsistencies will be overwritten by the first peer that stumbles upon them. As a
-         result, beam search will not have to wait for non-existent experts until the expiration of their DHT entries;
-      2. Delayed expert availability: Without negative cache, new experts are always immediately available for beam
-         search after they are published to the DHT. With negative cache, there are rare cases (e.g. when adding new
-         experts in place of recently defunct ones) when new experts will be initially invisible, but gradually become
-         visible to more peers as those peers refresh their cache. This process takes at most :expiration: seconds;
-      3. Faster beam search in very sparse grids: there is one edge case where negative cache will improve beam search
-         performance; If an expert grid is very sparse, there can be empty indices in the first grid dimension (i.e.
-         indices {i} such that _no_ experts that start with "{prefix}.{i}.*"). If so, the default beam search will
-         be very slow due to the way it forms initial beam. Beam search with negative cache enabled will run normally.
-         Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
-
     :param kwargs: any other params will be forwarded to DHTNode upon creation
-
-    Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
-    An expert identifier consists of:
-
-        * optional prefix that determines expert role, experiment name, etc.
-        * one or more integers that determine that expert's position in an N-dimensional grid
-
-    A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
-    When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
-    For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
-    ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
-
-    In order to enable fast beam search, DHT maintains dictionaries of all active suffixes for every prefix
-    (e.g. "ffn_expert.98": {76: ffn_expert.98.76...., 123: ffn_expert.98.123..., 225: ffn_expert.98.225....}))
-
-    RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
-    For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
-    This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
-    However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
-    In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
-    It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
-
-    After selecting k best indices along first dimension, MoE moves to the second dimension.
-    It can find top-k index pairs (e.g. "expert.98.76") that use one of k best indices from the previous step.
-    This beam search explores one additional dimension per step and finds k best experts from across the DHT
-    in O(k * num_dimensions * dimension_size) time depending on the chosen grid dimensions.
     """
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 receiver_threads: int = 1, negative_caching: bool = True, expiration: float = 300, **kwargs):
+                 expiration: float = 300, **kwargs):
         super().__init__()
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
-        self.receiver_threads, self.max_workers, self.parallel_rpc = receiver_threads, max_workers, parallel_rpc
-        self.expiration, self.negative_caching = expiration, negative_caching
+        self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
+        self.default_expiration = expiration
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
@@ -134,7 +65,7 @@ class DHT(mp.Process):
     def run(self) -> None:
         """ Serve DHT forever. This function will not return until DHT node is shut down """
         loop = switch_to_uvloop()
-        pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
+        pipe_awaiter = ThreadPoolExecutor(max_workers=1)
 
         async def _run():
             node = await DHTNode.create(
@@ -201,7 +132,11 @@ class DHT(mp.Process):
               subkey: Optional[Subkey] = None, return_future: bool = False, **kwargs) -> Union[bool, MPFuture]:
         """
         Find num_replicas best nodes to store (key, value) and store it there until expiration time.
-        :note: store is a simplified interface to store_many, all kwargs are be forwarded there
+
+        :param key: msgpack-serializable key to be associated with value until expiration.
+        :param value: msgpack-serializable value to be stored under a given key until expiration.
+        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
+        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
@@ -221,6 +156,39 @@ class DHT(mp.Process):
                 future.set_exception(e)
             raise
 
+    def run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
+                      return_future: bool = False) -> Union[ReturnType, MPFuture[ReturnType]]:
+        """
+        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
+         for running custom functions DHT for special cases (e.g. declare experts, beam search)
+
+        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: coroutine outputs or MPFuture for these outputs
+        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
+          DHT fields made by this coroutine will not be accessible from the host process.
+        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
+          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
+        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
+        """
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
+        return future if return_future else future.result()
+
+    async def _run_coroutine(self, node: DHTNode, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
+                             future: MPFuture[ReturnType]):
+        main_task = asyncio.create_task(coro(self, node))
+        cancel_task = asyncio.create_task(await_cancelled(future))
+        try:
+            await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
+            if future.cancelled():
+                main_task.cancel()
+            else:
+                future.set_result(await main_task)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+
     def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
         """
         Get this machine's visible address by requesting other peers or using pre-specified network addresses.
@@ -274,289 +242,11 @@ class DHT(mp.Process):
             future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
                                             f" Please ensure the node is connected or specify peers=... manually."))
 
-    def declare_experts(self, uids: Sequence[ExpertUID], endpoint: Endpoint, wait: bool = True,
-                        timeout: Optional[float] = None) -> Dict[ExpertUID, bool]:
-        """
-        Make experts visible to all DHT peers; update timestamps if declared previously.
+    def declare_experts(self, uids, endpoint, wait: bool = True):
+        logger.warning("dht.declare_experts is scheduled for removal in 0.9.8, please use hivemind.declare_experts.",)
+        return hivemind.declare_experts(self, uids, endpoint, wait=wait)
 
-        :param uids: a list of expert ids to update
-        :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
-        :param wait: if True, awaits for declaration to finish, otherwise runs in background
-        :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
-        :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
-        """
-        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-        for uid in uids:
-            assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
-        future, _future = MPFuture.make_pair() if wait else (None, None)
-        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), endpoint=endpoint, future=_future)))
-        if wait:
-            return future.result(timeout)
-
-    async def _declare_experts(self, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint,
-                               future: Optional[MPFuture]) -> Dict[ExpertUID, bool]:
-        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
-        expiration_time = get_dht_time() + self.expiration
-        data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
-        for uid in uids:
-            data_to_store[uid, None] = endpoint
-            prefix = uid if uid.count(UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}'
-            for i in range(prefix.count(UID_DELIMITER) - 1):
-                prefix, last_coord = split_uid(prefix)
-                data_to_store[prefix, last_coord] = [uid, endpoint]
-
-        keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
-        store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
-        if future is not None:
-            future.set_result(store_ok)
-        return store_ok
-
-    def get_experts(self, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None,
+    def get_experts(self, uids, expiration_time: Optional[DHTExpiration] = None,
                     return_future: bool = False) -> List[Optional[RemoteExpert]]:
-        """
-        :param uids: find experts with these ids from across the DHT
-        :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: a list of [RemoteExpert if found else None]
-        """
-        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_experts', [], dict(uids=list(uids), expiration_time=expiration_time, future=_future)))
-        return future if return_future else future.result()
-
-    async def _get_experts(self, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration],
-                           future: Optional[MPFuture] = None) -> List[Optional[RemoteExpert]]:
-        if expiration_time is None:
-            expiration_time = get_dht_time()
-        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
-        found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
-
-        experts: List[Optional[RemoteExpert]] = [None] * len(uids)
-        for i, uid in enumerate(uids):
-            if found[uid] is not None and isinstance(found[uid].value, Endpoint):
-                experts[i] = RemoteExpert(uid, found[uid].value)
-        if future:
-            future.set_result(experts)
-        return experts
-
-    def get_initial_beam(self, prefix: ExpertPrefix, scores: Sequence[float], beam_size: int,
-                         num_workers: Optional[int] = None, return_future: bool = False
-                         ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
-        """
-        :param prefix: search for experts whose uids start with this prefix
-        :param scores: prefer suffix coordinates that have highest scores
-        :param beam_size: select this many active suffixes with highest scores
-        :param num_workers: maintain up to this many concurrent DHT searches
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
-        """
-        assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_initial_beam', [], dict(prefix=prefix, scores=tuple(scores), beam_size=beam_size,
-                                                      num_workers=num_workers, future=_future)))
-        return future if return_future else future.result()
-
-    async def _get_initial_beam(self, node, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...],
-                                num_workers: Optional[int] = None, future: Optional[MPFuture] = None
-                                ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
-        num_workers = num_workers or self.max_workers or beam_size
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
-        unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
-        pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
-
-        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
-            # dispatch additional tasks
-            while unattempted_indices and len(pending_tasks) < num_workers:
-                next_index = unattempted_indices.pop()  # note: this is best unattempted index because of sort order
-                next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
-                pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
-
-            # await the next best prefix to be fetched
-            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
-            try:
-                maybe_prefix_data = await pending_task
-                if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
-                    successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items()
-                                  if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list)
-                                  and len(match.value) == 2}
-                    if successors:
-                        beam.append((scores[pending_best_index], pending_best_prefix, successors))
-                elif maybe_prefix_data is None and self.negative_caching:
-                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
-                    asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + self.expiration))
-
-            except asyncio.CancelledError:
-                for _, pending_task in pending_tasks:
-                    pending_task.cancel()
-                raise
-        if future:
-            future.set_result(beam)
-        return beam
-
-    def get_active_successors(self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
-                              num_workers: Optional[int] = None, return_future: bool = False
-                              ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
-        """
-        :param prefixes: a list of prefix for which to find active successor uids
-        :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
-        :param num_workers: how many parallel workers to use for DHTNode.get_many
-        :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
-        :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
-        :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
-        """
-        assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
-        for prefix in prefixes:
-            assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_active_successors', [], dict(
-            prefixes=list(prefixes), grid_size=grid_size, num_workers=num_workers, future=_future)))
-        return future if return_future else future.result()
-
-    async def _get_active_successors(self, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
-                                     num_workers: Optional[int] = None, future: Optional[MPFuture] = None
-                                     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
-        grid_size = grid_size or float('inf')
-        num_workers = num_workers or min(len(prefixes), self.max_workers or len(prefixes))
-        dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
-        successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
-        for prefix, found in dht_responses.items():
-            if found and isinstance(found.value, dict):
-                successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items()
-                                      if isinstance(coord, Coordinate) and 0 <= coord < grid_size
-                                      and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
-            else:
-                successors[prefix] = {}
-                if found is None and self.negative_caching:
-                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
-                    asyncio.create_task(node.store(prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + self.expiration))
-        if future:
-            future.set_result(successors)
-        return successors
-
-    def find_best_experts(self, prefix: ExpertPrefix, grid_scores: Sequence[Sequence[float]], beam_size: int,
-                          num_workers: Optional[int] = None, return_future: bool = False
-                          ) -> Union[List[RemoteExpert], MPFuture]:
-        """
-        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
-
-        :param prefix: common prefix for all expert uids in grid
-        :param grid_scores: scores predicted for each dimension in the grid,
-        :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
-        :param beam_size: how many best experts should beam search return
-         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
-         Please note that any queries that fall outside the budget will still be performed in background and cached
-         for subsequent iterations as long as DHTNode.cache_locally is True
-        :param num_workers: use up to this many concurrent workers to search DHT
-        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :returns: a list that contains *up to* k_best RemoteExpert instances
-        """
-        assert len(grid_scores) > 0 and beam_size > 0
-        assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_find_best_experts', [], dict(prefix=prefix, grid_scores=list(map(tuple, grid_scores)),
-                                                       beam_size=beam_size, num_workers=num_workers, future=_future)))
-        return future if return_future else future.result()
-
-    async def _find_best_experts(
-            self, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
-            num_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]:
-        num_workers = num_workers or min(beam_size, self.max_workers or beam_size)
-
-        # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await self._get_initial_beam(
-            node, prefix, beam_size, grid_scores[0], min(beam_size, num_workers))
-
-        best_experts_heap: List[Tuple[Score, UidEndpoint]] = []  # max-heap of expert uids/endpoints ordered by scores
-        unique_experts: Set[ExpertUID] = set()
-
-        for dim_index in range(1, len(grid_scores) - 1):
-            for score, uid_endpoint in self._iterate_matching_experts(beam, grid_scores):
-                if uid_endpoint.uid not in unique_experts:
-                    push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                    push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                    unique_experts.add(uid_endpoint.uid)
-
-            # form new beam using successors from the current beam
-            dim_scores = grid_scores[dim_index]
-            best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(beam_size, (
-                (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
-                for prefix_score, prefix, suffixes in beam for next_coord in suffixes.keys()
-                if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)))
-            _, best_uid_prefixes = zip(*best_active_pairs)
-
-            # search DHT for next step suffixes
-            successors = await self._get_active_successors(node, best_uid_prefixes, num_workers=num_workers)
-            beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
-            if not beam:
-                logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
-                break
-
-        # add best experts from the final beam
-        for score, uid_endpoint in self._iterate_matching_experts(beam, grid_scores):
-            if uid_endpoint.uid not in unique_experts:
-                push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                unique_experts.add(uid_endpoint.uid)
-
-        best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
-        if future is not None:
-            future.set_result(best_experts)
-        return best_experts
-
-    @staticmethod
-    def _iterate_matching_experts(beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]],
-                                  grid_scores: Sequence[Sequence[float]]) -> Iterator[Tuple[Score, UidEndpoint]]:
-        """ iterate over all exemplar experts attached to current beam """
-        for score, prefix, suffixes in beam:
-            for next_coord, match in suffixes.items():
-                if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
-                    yield score, match
-                elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
-                    expert_coords = match.uid.split(UID_DELIMITER)[1:]
-                    if all(coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
-                           for i, coord in enumerate(expert_coords)):
-                        expert_score = sum(scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords)))
-                        yield expert_score, match
-                    else:
-                        logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
-                else:
-                    logger.warning(f"Found incompatible expert UID: {match.uid}")
-
-    def batch_find_best_experts(
-            self, prefix: str, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, *,
-            workers_per_sample: Optional[int] = None, return_future=False) -> Union[List[List[RemoteExpert]], MPFuture]:
-        """
-        Find and return :beam_size: active experts with highest scores, use both local cache and DHT
-
-        :param prefix: common prefix for all expert uids in grid
-        :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
-        :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
-        :param beam_size: how many best experts should beam search return
-         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
-         Please note that any queries that fall outside the budget will still be performed in background and cached
-         for subsequent iterations as long as DHTNode.cache_locally is True
-        :param workers_per_sample: use up to this many concurrent workers for every sample in batch
-        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :returns: a list that contains *up to* k_best RemoteExpert instances
-        """
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_batch_find_best_experts', [], dict(prefix=prefix, batch_grid_scores=batch_grid_scores,
-                                                             beam_size=beam_size, workers_per_sample=workers_per_sample,
-                                                             future=_future)))
-        return future if return_future else future.result()
-
-    async def _batch_find_best_experts(
-            self, node: DHTNode, prefix: str, batch_grid_scores: Sequence[Sequence[Tuple[float]]], beam_size: int,
-            workers_per_sample: Optional[int] = None, future: Optional[MPFuture] = None) -> List[List[RemoteExpert]]:
-
-        batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores]
-                             for i in range(len(batch_grid_scores[0]))]
-        coros = [self._find_best_experts(node, prefix, grid_scores, beam_size, workers_per_sample)
-                 for grid_scores in batch_grid_scores]
-
-        best_experts_batch = await asyncio.gather(*coros)
-        if future is not None:
-            future.set_result(best_experts_batch)
-        return best_experts_batch
+        logger.warning("dht.get_experts is scheduled for removal in 0.9.8, please use hivemind.get_experts.")
+        return hivemind.get_experts(self, uids, expiration_time, return_future)

+ 5 - 4
hivemind/server/__init__.py

@@ -13,9 +13,10 @@ import torch
 
 import hivemind
 from hivemind.dht import DHT
+from hivemind.server.expert_uid import UID_DELIMITER
 from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
 from hivemind.server.connection_handler import ConnectionHandler
-from hivemind.server.dht_handler import DHTHandlerThread
+from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.layers import name_to_block, name_to_input
 from hivemind.server.runtime import Runtime
@@ -287,10 +288,10 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
 
     def _generate_uid():
         if expert_pattern is None:
-            return f"expert{hivemind.dht.UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
+            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
 
         uid = []
-        for block in expert_pattern.split(hivemind.dht.UID_DELIMITER):
+        for block in expert_pattern.split(UID_DELIMITER):
             try:
                 if '[' not in block and ']' not in block:
                     uid.append(block)
@@ -303,7 +304,7 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
                 raise e
             except Exception as e:
                 raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
-        return hivemind.dht.UID_DELIMITER.join(uid)
+        return UID_DELIMITER.join(uid)
 
     while remaining_attempts > 0 and len(found_uids) < num_experts:
 

+ 68 - 5
hivemind/server/dht_handler.py

@@ -1,8 +1,12 @@
 import threading
-import time
+from functools import partial
+from typing import Sequence, Dict, List, Tuple, Optional
 
-from hivemind.dht import DHT
-from hivemind.utils import Endpoint, get_port
+from hivemind.dht import DHT, DHTNode, DHTExpiration, DHTValue
+from hivemind.client.expert import RemoteExpert
+from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, Coordinate,
+                                        UID_DELIMITER, UID_PATTERN, is_valid_uid, split_uid)
+from hivemind.utils import Endpoint, get_dht_time, get_port
 
 
 class DHTHandlerThread(threading.Thread):
@@ -16,6 +20,65 @@ class DHTHandlerThread(threading.Thread):
         self.stop = threading.Event()
 
     def run(self) -> None:
-        self.dht.declare_experts(self.experts.keys(), self.endpoint)
+        declare_experts(self.dht, self.experts.keys(), self.endpoint)
         while not self.stop.wait(self.update_period):
-            self.dht.declare_experts(self.experts.keys(), self.endpoint)
+            declare_experts(self.dht, self.experts.keys(), self.endpoint)
+
+
+def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint,
+                    wait: bool = True) -> Dict[ExpertUID, bool]:
+    """
+    Make experts visible to all DHT peers; update timestamps if declared previously.
+
+    :param uids: a list of expert ids to update
+    :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
+    :param wait: if True, awaits for declaration to finish, otherwise runs in background
+    :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
+    :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
+    """
+    assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    for uid in uids:
+        assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
+    return dht.run_coroutine(partial(_declare_experts, uids=list(uids), endpoint=endpoint), return_future=not wait)
+
+
+async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint) -> Dict[ExpertUID, bool]:
+    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    expiration_time = get_dht_time() + dht.default_expiration  # TODO use local expiration
+    data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
+    for uid in uids:
+        data_to_store[uid, None] = endpoint
+        prefix = uid if uid.count(UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}'
+        for i in range(prefix.count(UID_DELIMITER) - 1):
+            prefix, last_coord = split_uid(prefix)
+            data_to_store[prefix, last_coord] = [uid, endpoint]
+
+    keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
+    store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
+    return store_ok
+
+
+def get_experts(dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None,
+                return_future: bool = False) -> List[Optional[RemoteExpert]]:
+    """
+    :param uids: find experts with these ids from across the DHT
+    :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
+    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+    :returns: a list of [RemoteExpert if found else None]
+    """
+    assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    return dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+
+
+async def _get_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
+                       ) -> List[Optional[RemoteExpert]]:
+    if expiration_time is None:
+        expiration_time = get_dht_time()
+    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
+
+    experts: List[Optional[RemoteExpert]] = [None] * len(uids)
+    for i, uid in enumerate(uids):
+        if found[uid] is not None and isinstance(found[uid].value, Endpoint):
+            experts[i] = RemoteExpert(uid, found[uid].value)
+    return experts

+ 32 - 0
hivemind/server/expert_uid.py

@@ -0,0 +1,32 @@
+import re
+from typing import NamedTuple, Union, Tuple
+
+from hivemind.utils.networking import Endpoint
+
+
+ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
+UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
+UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
+FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
+UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
+PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
+#  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
+
+
+def is_valid_uid(maybe_uid: str) -> bool:
+    """ An uid must contain a string expert type, followed by one or more .-separated numeric indices """
+    return bool(UID_PATTERN.fullmatch(maybe_uid))
+
+
+def is_valid_prefix(maybe_prefix: str) -> bool:
+    """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
+    return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
+
+
+def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
+    """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
+    uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
+    pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
+    return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
+
+

+ 11 - 1
hivemind/utils/asyncio.py

@@ -1,4 +1,4 @@
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable
+from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable
 import asyncio
 import uvloop
 T = TypeVar('T')
@@ -32,3 +32,13 @@ async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
     for aiter in async_iters:
         async for elem in aiter:
             yield elem
+
+
+async def await_cancelled(awaitable: Awaitable) -> bool:
+    try:
+        await awaitable
+        return False
+    except asyncio.CancelledError:
+        return True
+    except BaseException:
+        return False

+ 7 - 5
hivemind/utils/mpfuture.py

@@ -6,12 +6,14 @@ import concurrent.futures._base as base
 
 import asyncio
 from functools import lru_cache
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Generic, TypeVar
 
 from hivemind.utils.threading import run_in_background
 
+ResultType = TypeVar('ResultType')
 
-class MPFuture(base.Future):
+
+class MPFuture(base.Future, Generic[ResultType]):
     """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
 
     TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
@@ -74,7 +76,7 @@ class MPFuture(base.Future):
         except TimeoutError:
             pass
 
-    def set_result(self, result):
+    def set_result(self, result: ResultType):
         self._sync_updates()
         if self._state in self.TERMINAL_STATES:
             raise RuntimeError(f"Can't set_result to a future that is in {self._state}")
@@ -105,13 +107,13 @@ class MPFuture(base.Future):
         self._state, self._exception = base.CANCELLED, base.CancelledError()
         return self._send_updates()
 
-    def result(self, timeout: Optional[float] = None):
+    def result(self, timeout: Optional[float] = None) -> ResultType:
         self._await_terminal_state(timeout)
         if self._exception is not None:
             raise self._exception
         return self._result
 
-    def exception(self, timeout=None):
+    def exception(self, timeout=None) -> BaseException:
         self._await_terminal_state(timeout)
         if self._state == base.CANCELLED:
             raise base.CancelledError()

+ 4 - 2
tests/benchmark_dht.py

@@ -5,6 +5,7 @@ import time
 from tqdm import trange
 
 import hivemind
+import hivemind.server.expert_uid
 from hivemind.utils.threading import increase_file_limit
 
 logger = hivemind.get_logger(__name__)
@@ -42,7 +43,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for start in trange(0, num_experts, expert_batch_size):
         store_start = time.perf_counter()
         endpoints.append(random_endpoint())
-        successes = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], endpoints[-1]).values()
+        store_ok = hivemind.declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1])
+        successes = store_ok.values()
         total_store_time += time.perf_counter() - store_start
 
         total_stores += len(successes)
@@ -60,7 +62,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
     for start in trange(0, len(expert_uids), expert_batch_size):
         get_start = time.perf_counter()
-        get_result = get_peer.get_experts(expert_uids[start: start + expert_batch_size])
+        get_result = hivemind.get_experts(get_peer, expert_uids[start: start + expert_batch_size])
         total_get_time += time.perf_counter() - get_start
 
         for i, expert in enumerate(get_result):

+ 70 - 142
tests/test_dht.py

@@ -1,175 +1,103 @@
+import asyncio
 import random
-import numpy as np
+import time
+
 import pytest
-import asyncio
 
 import hivemind
-from hivemind import LOCALHOST, UidEndpoint, strip_port
+from hivemind import LOCALHOST, strip_port
 
 
 @pytest.mark.forked
-def test_store_get_experts():
-    peers = [hivemind.DHT(start=True)]
+def test_get_store():
+    peers = []
     for i in range(10):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
         peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
 
-    you: hivemind.dht.DHT = random.choice(peers)
-    theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
+    node1, node2 = random.sample(peers, 2)
+    assert node1.store('key1', 'value1', expiration_time=hivemind.get_dht_time() + 30)
+    assert node1.get('key1').value == 'value1'
+    assert node2.get('key1').value == 'value1'
+    assert node2.get('key2') is None
+
+    future = node1.get('foo', return_future=True)
+    assert future.result() is None
 
-    expert_uids = [f"my_expert.{i}" for i in range(110)]
-    batch_size = 10
-    for batch_start in range(0, len(expert_uids), batch_size):
-        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
+    future = node1.get('foo', return_future=True)
+    future.cancel()
 
-    found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
-    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
-    assert all(res is None for res in found[-2:]), "Found non-existing experts"
+    assert node2.store('key1', 123, expiration_time=hivemind.get_dht_time() + 31)
+    assert node2.store('key2', 456, expiration_time=hivemind.get_dht_time() + 32)
+    assert node1.get('key1', latest=True).value == 123
+    assert node1.get('key2').value == 456
 
-    that_guys_expert, that_guys_port = "my_other_expert.1337", random.randint(1000, 9999)
-    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
-    you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
-    assert isinstance(you_found, hivemind.RemoteExpert)
-    assert you_found.endpoint == f'that_host:{that_guys_port}'
+    assert node1.store('key2', subkey='subkey1', value=789, expiration_time=hivemind.get_dht_time() + 32)
+    assert node2.store('key2', subkey='subkey2', value='pew', expiration_time=hivemind.get_dht_time() + 32)
+    found_dict = node1.get('key2', latest=True).value
+    assert isinstance(found_dict, dict) and len(found_dict) == 2
+    assert found_dict['subkey1'].value == 789 and found_dict['subkey2'].value == 'pew'
 
     for peer in peers:
         peer.shutdown()
 
 
-@pytest.mark.forked
-def test_dht_get_address(addr=LOCALHOST, dummy_endpoint='123.45.67.89:*'):
-    node1 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
-    node2 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node1.port}"])
-    node3 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node2.port}"])
-    assert addr in node3.get_visible_address(num_peers=2)
+async def dummy_dht_coro(self, node):
+    return 'pew'
 
-    node4 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
-    with pytest.raises(ValueError):
-        node4.get_visible_address()
-    assert node4.get_visible_address(peers=[f'{addr}:{node1.port}']).endswith(addr)
 
-    node5 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", endpoint=f"{dummy_endpoint}")
-    assert node5.get_visible_address() == strip_port(dummy_endpoint)
+async def dummy_dht_coro_error(self, node):
+    raise ValueError("Oops, i did it again...")
 
 
-@pytest.mark.forked
-def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=16,
-                     grid_dims=(32, 32, 32)):
-    dht = []
-    for i in range(dht_size):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-        dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
-
-    real_experts = sorted({
-        'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
-        for _ in range(total_experts)
-    })
-    for batch_start in range(0, len(real_experts), batch_size):
-        random.choice(dht).declare_experts(
-            real_experts[batch_start: batch_start + batch_size], wait=True,
-            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
-
-    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-    you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
-
-    for i in range(50):
-        topk_experts = you.find_best_experts('expert.', [np.random.randn(dim) for dim in grid_dims], beam_size=beam_size)
-        assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
-        assert len(topk_experts) == beam_size
+async def dummy_dht_coro_stateful(self, node):
+    self._x_dummy = getattr(self, '_x_dummy', 123) + 1
+    return self._x_dummy
 
-    for i in range(10):
-        batch_experts = you.batch_find_best_experts('expert.', [np.random.randn(batch_size, dim) for dim in grid_dims],
-                                                    beam_size=beam_size)
-        assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
-        assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
-        assert all(len(experts) == beam_size for experts in batch_experts)
 
+async def dummy_dht_coro_long(self, node):
+    await asyncio.sleep(0.25)
+    return self._x_dummy ** 2
 
-@pytest.mark.forked
-def test_dht_single_node():
-    node = hivemind.DHT(start=True, expiration=999)
-
-    assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
-    assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(node.declare_experts(['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
-
-    for expert in node.get_experts(['expert.3', 'expert.2']):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
-
-    assert all(node.declare_experts(['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
-    found_experts = node.find_best_experts('expert.', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
-    assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
-
-    successors = node.get_active_successors(['e.1.2.', 'e.2.', 'e.4.5.'])
-    assert len(successors['e.1.2.']) == 2
-    assert successors['e.1.2.'][3] == UidEndpoint('e.1.2.3', f'{LOCALHOST}:42')
-    assert successors['e.1.2.'][5] == UidEndpoint('e.1.2.5', f'{LOCALHOST}:42')
-    assert len(successors['e.2.']) == 1 and successors['e.2.'][0] == UidEndpoint('e.2.0', f'{LOCALHOST}:42')
-    assert successors['e.4.5.'] == {}
-
-    initial_beam = node.get_initial_beam('expert.', (3, 2, 1, 0, -1, -2, -3), beam_size=3)
-    assert len(initial_beam) == 3
-    assert initial_beam[0][:2] == (2.0, 'expert.1.')
-    assert initial_beam[1][:2] == (1.0, 'expert.2.')
-    assert initial_beam[2][:2] == (0.0, 'expert.3.')
-
-    with pytest.raises(AssertionError):
-        node.find_best_experts('expert', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
-
-    with pytest.raises(AssertionError):
-        node.find_best_experts('expert.1', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
-
-    with pytest.raises(AssertionError):
-        node.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
-
-    with pytest.raises(AssertionError):
-        node.get_initial_beam('expert', (3, 2, 1, 0, -1, -2, -3), beam_size=3)
-
-
-def test_uid_patterns():
-    valid_experts = ["expert.1", "expert.0", "expert.0.0.1", "expert.1337", "ffn.12.34.56.78.90",
-                     "transformer.3.2.1.0", "transformer_encoder.2", "transformer::encoder.2", "T®@nsf0rmE®🤗.321",
-                     "🤗.321", "0.1.2", "00.1.2", "7070.3.2.1.0", "block2.1.23", "LAYER.1.0.1"]
-    valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
-    valid_prefixes.extend([f"{uid}." for uid in valid_experts])
-    valid_prefixes.extend([hivemind.split_uid(uid)[0] for uid in valid_experts])
-    for uid in valid_experts:
-        assert hivemind.is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
-    for pfx in valid_prefixes:
-        assert hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
-
-    invalid = ["", ".", "expert.-1", "xxx.a", "expert.1x", "expert_ffn.1.abc1", "some.123.01", "expert.123.01",
-               "e1", "e..1", "e", "e.1.2.3..4", "ffn.1..1", ".123", ".1.2.3.", ".expert", "transformer.encoder.2",
-               "T®@nsf0rmE®.🤗.321", "layer::123", "expert.0.1.2.suffix", "0.1.2.suffix", "expert.1 something",
-               "expert.1\n", "expert.1\n2", "expert.1 ", "expert.1\nexpert.2", "'expert.1'", '"expert.1"']
-    invalid_experts = invalid + valid_prefixes + ["0", "123456"]
-    invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
-    for uid in invalid_experts:
-        assert not hivemind.is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
-    for pfx in invalid_prefixes:
-        assert not hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
+
+async def dummy_dht_coro_for_cancel(self, node):
+    self._x_dummy = -100
+    await asyncio.sleep(0.5)
+    self._x_dummy = 999
 
 
 @pytest.mark.forked
-@pytest.mark.asyncio
-async def test_negative_caching():
-    peers = []
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(hivemind.DHT(initial_peers=neighbors_i, negative_caching=False, cache_locally=False, start=True))
+def test_run_coroutine():
+    dht = hivemind.DHT(start=True)
+    assert dht.run_coroutine(dummy_dht_coro) == 'pew'
+
+    with pytest.raises(ValueError):
+        res = dht.run_coroutine(dummy_dht_coro_error)
+
+    bg_task = dht.run_coroutine(dummy_dht_coro_long, return_future=True)
+    assert dht.run_coroutine(dummy_dht_coro_stateful) == 124
+    assert dht.run_coroutine(dummy_dht_coro_stateful) == 125
+    assert dht.run_coroutine(dummy_dht_coro_stateful) == 126
+    assert not hasattr(dht, '_x_dummy')
+    assert bg_task.result() == 126 ** 2
+
+    future = dht.run_coroutine(dummy_dht_coro_for_cancel, return_future=True)
+    time.sleep(0.25)
+    future.cancel()
+    assert dht.run_coroutine(dummy_dht_coro_stateful) == -99
 
-    normal_peer, writer_peer = random.sample(peers, 2)
 
-    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-    neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, negative_caching=True, cache_locally=False, start=True)
+@pytest.mark.forked
+def test_dht_get_address(addr=LOCALHOST, dummy_endpoint='123.45.67.89:*'):
+    node1 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
+    node2 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node1.port}"])
+    node3 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node2.port}"])
+    assert addr in node3.get_visible_address(num_peers=2)
 
-    assert all(writer_peer.declare_experts(['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
-    # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
-    assert len(neg_caching_peer.get_initial_beam(prefix='ffn.', scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
+    node4 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
+    with pytest.raises(ValueError):
+        node4.get_visible_address()
+    assert node4.get_visible_address(peers=[f'{addr}:{node1.port}']).endswith(addr)
 
-    node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
-    fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
-    for i in range(6):
-        assert fetched[i] is not None, f"node should have cached ffn.{i}."
-    for i in range(6, len(fetched)):
-        assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
+    node5 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", endpoint=f"{dummy_endpoint}")
+    assert node5.get_visible_address() == strip_port(dummy_endpoint)

+ 159 - 0
tests/test_dht_experts.py

@@ -0,0 +1,159 @@
+import asyncio
+import random
+
+import numpy as np
+import pytest
+
+import hivemind
+import hivemind.server.expert_uid
+from hivemind import LOCALHOST
+from hivemind.client.beam_search import MoEBeamSearcher
+from hivemind.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
+
+
+@pytest.mark.forked
+def test_store_get_experts():
+    peers = [hivemind.DHT(start=True)]
+    for i in range(10):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
+
+    first_peer = random.choice(peers)
+    other_peer = random.choice(peers)
+
+    expert_uids = [f"my_expert.{i}" for i in range(110)]
+    batch_size = 10
+    for batch_start in range(0, len(expert_uids), batch_size):
+        hivemind.declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
+
+    found = other_peer.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
+    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
+    assert all(res is None for res in found[-2:]), "Found non-existing experts"
+
+    other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
+    hivemind.declare_experts(other_peer, [other_expert], f'that_host:{other_port}')
+    first_notfound, first_found = hivemind.get_experts(first_peer, ['foobar', other_expert])
+    assert isinstance(first_found, hivemind.RemoteExpert)
+    assert first_found.endpoint == f'that_host:{other_port}'
+
+    for peer in peers:
+        peer.shutdown()
+
+
+@pytest.mark.forked
+def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=16,
+                     grid_dims=(32, 32, 32)):
+    dht = []
+    for i in range(dht_size):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
+        dht.append(hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
+
+    real_experts = sorted({
+        'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
+        for _ in range(total_experts)
+    })
+    for batch_start in range(0, len(real_experts), batch_size):
+        random.choice(dht).declare_experts(
+            real_experts[batch_start: batch_start + batch_size], wait=True,
+            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
+
+    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
+    you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
+    beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
+
+    for i in range(50):
+        topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
+        assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
+        assert len(topk_experts) == beam_size
+
+    for i in range(10):
+        batch_experts = beam_search.batch_find_best_experts([np.random.randn(batch_size, dim) for dim in grid_dims],
+                                                            beam_size=beam_size)
+        assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
+        assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
+        assert all(len(experts) == beam_size for experts in batch_experts)
+
+
+@pytest.mark.forked
+def test_dht_single_node():
+    node = hivemind.DHT(start=True, expiration=999)
+    beam_search = MoEBeamSearcher(node, 'expert.')
+
+    assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
+    assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
+    assert len(node.declare_experts(['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
+
+    for expert in node.get_experts(['expert.3', 'expert.2']):
+        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+
+    assert all(node.declare_experts(['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
+    found_experts = beam_search.find_best_experts([(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
+    assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
+
+    successors = beam_search.get_active_successors(['e.1.2.', 'e.2.', 'e.4.5.'])
+    assert len(successors['e.1.2.']) == 2
+    assert successors['e.1.2.'][3] == UidEndpoint('e.1.2.3', f'{LOCALHOST}:42')
+    assert successors['e.1.2.'][5] == UidEndpoint('e.1.2.5', f'{LOCALHOST}:42')
+    assert len(successors['e.2.']) == 1 and successors['e.2.'][0] == UidEndpoint('e.2.0', f'{LOCALHOST}:42')
+    assert successors['e.4.5.'] == {}
+
+    initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
+    assert len(initial_beam) == 3
+    assert initial_beam[0][:2] == (2.0, 'expert.1.')
+    assert initial_beam[1][:2] == (1.0, 'expert.2.')
+    assert initial_beam[2][:2] == (0.0, 'expert.3.')
+
+    with pytest.raises(AssertionError):
+        beam_search = MoEBeamSearcher(node, 'expert.1.ffn')
+
+    with pytest.raises(AssertionError):
+        beam_search.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
+
+
+def test_uid_patterns():
+    valid_experts = ["expert.1", "expert.0", "expert.0.0.1", "expert.1337", "ffn.12.34.56.78.90",
+                     "transformer.3.2.1.0", "transformer_encoder.2", "transformer::encoder.2", "T®@nsf0rmE®🤗.321",
+                     "🤗.321", "0.1.2", "00.1.2", "7070.3.2.1.0", "block2.1.23", "LAYER.1.0.1"]
+    valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
+    valid_prefixes.extend([f"{uid}." for uid in valid_experts])
+    valid_prefixes.extend([split_uid(uid)[0] for uid in valid_experts])
+    for uid in valid_experts:
+        assert is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
+    for pfx in valid_prefixes:
+        assert is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
+
+    invalid = ["", ".", "expert.-1", "xxx.a", "expert.1x", "expert_ffn.1.abc1", "some.123.01", "expert.123.01",
+               "e1", "e..1", "e", "e.1.2.3..4", "ffn.1..1", ".123", ".1.2.3.", ".expert", "transformer.encoder.2",
+               "T®@nsf0rmE®.🤗.321", "layer::123", "expert.0.1.2.suffix", "0.1.2.suffix", "expert.1 something",
+               "expert.1\n", "expert.1\n2", "expert.1 ", "expert.1\nexpert.2", "'expert.1'", '"expert.1"']
+    invalid_experts = invalid + valid_prefixes + ["0", "123456"]
+    invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
+    for uid in invalid_experts:
+        assert not is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
+    for pfx in invalid_prefixes:
+        assert not is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_negative_caching():
+    peers = []
+    for i in range(10):
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True))
+
+    writer_peer = random.choice(peers)
+    assert all(hivemind.declare_experts(writer_peer, ['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
+
+    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+    neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True)
+    beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix='ffn.', negative_caching=True)
+    # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
+    assert len(beam_search.get_initial_beam(scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
+
+    node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
+    fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
+    for i in range(6):
+        assert fetched[i] is not None, f"node should have cached ffn.{i}."
+    for i in range(6, len(fetched)):
+        assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."

+ 4 - 5
tests/test_dht_node.py

@@ -1,15 +1,14 @@
 import asyncio
+import heapq
 import multiprocessing as mp
 import random
-import heapq
-from typing import Optional
+from itertools import product
+from typing import Optional, List, Dict
+
 import numpy as np
 import pytest
-from itertools import product
 
 import hivemind
-from typing import List, Dict
-
 from hivemind import get_dht_time, replace_port
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
 from hivemind.dht.protocol import DHTProtocol, ValidationError

+ 3 - 3
tests/test_moe.py

@@ -109,10 +109,10 @@ def test_beam_search_correctness():
 
     for i in range(25):
         input = torch.randn(32)
-        grid_scores = dmoe.proj(input).split_with_sizes(dmoe.grid_size, dim=-1)
+        grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
 
-        chosen_experts = dht.find_best_experts(dmoe.uid_prefix, [tensor.detach().numpy() for tensor in grid_scores],
-                                               beam_size=dmoe.k_best)
+        chosen_experts = dmoe.beam_search.find_best_experts([tensor.detach().numpy() for tensor in grid_scores],
+                                                            beam_size=dmoe.k_best)
         chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
                                                    [chosen_experts])[0]
         our_best_scores = list(chosen_scores.cpu().detach().numpy())