فهرست منبع

Switch to new beam search in client (#111)

* remove test for first_k_active

* remove explicit test

* remove first_k_active

* switch test for beam search correctness

* cast grid_scores to numpy before serializing

* get rid of event loop in moe.py

* remove aio stub

* check for None result

* less verbose logging on error

* call many in separate thread

* grpc as_completed

* merge as_completed

* silence warning for mean/std

* update dht scheme

* update DHTNode description

* allow_broadcasting is dead

* avoid warning

* use valid expert names

* fix duplicates

* update benchmark_dht.py

* prefixes now end with '.'

* fix timeout

* staticmethod (@mryab)

* get rid of lambdas (@mryab)

* more realistic test for determinism

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 سال پیش
والد
کامیت
621ac4c0e9

BIN
docs/_static/dht.odp


BIN
docs/_static/dht.png


+ 0 - 4
docs/modules/dht.rst

@@ -11,10 +11,6 @@ Here's a high level scheme of how these components interact with one another:
    :align: center
 
 
-**Note:** hivemind.DHT is currently being updated to improve beam search latency
-(see `issue 92 <https://github.com/learning-at-home/hivemind/issues>`__). New functionality will be documented
-here by 2020.10.15 23:59:59 AOE (ping justheuristic for details).
-
 DHT and DHTNode
 ###############
 

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.7'
+__version__ = '0.8.8'

+ 3 - 8
hivemind/client/expert.py

@@ -9,7 +9,6 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
 from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
 
@@ -17,16 +16,12 @@ DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autogra
 
 
 @lru_cache(maxsize=None)
-def _get_expert_stub(endpoint: Endpoint, aio: bool, *extra_options: Tuple[str, Any]):
+def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
     """ Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache """
     channel_options = [
         ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)
     ] + list(extra_options)
-    if aio:
-        channel = grpc.experimental.aio.insecure_channel(endpoint, options=channel_options)
-    else:
-        channel = grpc.insecure_channel(endpoint, options=channel_options)
-    return runtime_grpc.ConnectionHandlerStub(channel)
+    return runtime_grpc.ConnectionHandlerStub(grpc.insecure_channel(endpoint, options=channel_options))
 
 
 class RemoteExpert(nn.Module):
@@ -48,7 +43,7 @@ class RemoteExpert(nn.Module):
 
     @property
     def stub(self):
-        return _get_expert_stub(self.endpoint, aio=False)
+        return _get_expert_stub(self.endpoint)
 
     def forward(self, *args, **kwargs):
         """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """

+ 129 - 185
hivemind/client/moe.py

@@ -1,10 +1,11 @@
 from __future__ import annotations
 
-import asyncio
 import time
-from typing import Tuple, List, Optional, Awaitable, Set, Dict, Any
+from queue import Queue, Empty
+from typing import Tuple, List, Optional, Dict, Any
+
+import grpc
 
-import grpc.experimental.aio
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
@@ -28,42 +29,37 @@ class RemoteMixtureOfExperts(nn.Module):
      the missing experts
 
     :param in_features: common input size for experts and gating function
-    :param grid_size: hivemind dimensions that form expert uid (see below)
-    :param uid_prefix: common prefix for all expert uids
-     expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
-    :param dht: DHT where the experts reside
-    :param k_best: queries this many experts with highest scores
-    :param k_min: makes sure at least this many experts returned output
-    :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
+    :param grid_size: dimensions that form expert uid (see below)
+    :param uid_prefix: common prefix for all expert uids (must end with '.')
+    :note: expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
+    :param dht: a DHT instance used to search for best experts
+    :param k_best: average this many highest-scoring experts to compute activations
+    :param k_min: make sure at least this many experts returned output (i.e. didn't fail)
+    :param timeout_after_k_min: wait for this many seconds after k_min experts returned results.
      Any expert that didn't manage to return output after that delay is considered unavailable
-    :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
-     allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
-     allow_broadcasting=False will raise an error
     """
 
-    def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, k_best: int, k_min: int = 1,
-                 forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
-                 backward_k_min: int = 1, backward_timeout: Optional[float] = None, uid_prefix='',
-                 allow_broadcasting=True, loop: asyncio.BaseEventLoop = None):
+    def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
+                 k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
+                 backward_k_min: int = 1, backward_timeout: Optional[float] = None, **dht_kwargs):
         super().__init__()
-        self.dht, self.grid_size, self.uid_prefix = dht, grid_size, uid_prefix
-        self.loop = loop or asyncio.new_event_loop()
-        # fmt:off
-        assert not self.loop.is_running(), "Event loop is already running. If in jupyter, please apply nest_asyncio " \
-            "(pip install nest_asyncio , https://pypi.org/project/nest-asyncio ) and send loop=asyncio.new_event_loop()"
-        # fmt:on
+        if not uid_prefix.endswith(hivemind.dht.UID_DELIMITER):
+            uid_prefix += hivemind.dht.UID_DELIMITER
+            logger.info(f"Prefix must end with '{hivemind.dht.UID_DELIMITER}'. New prefix: '{uid_prefix}' .")
+        assert hivemind.dht.is_valid_prefix(uid_prefix), f"Prefix '{uid_prefix}' is invalid."
+        self.dht, self.grid_size, self.uid_prefix, self.dht_kwargs = dht, grid_size, uid_prefix, dht_kwargs
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
-        self.allow_broadcasting = allow_broadcasting
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
         self._expert_info = None  # expert['info'] from one of experts in the grid
 
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
-        Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all
-        dimensions except first and last (we assume that extra dimensions represent sequence length or image dimensions)
+        Choose k best experts with beam search, then call chosen experts and average their outputs.
+        Input tensor is averaged over all dimensions except for first and last
+        (we assume that extra dimensions represent sequence length or image height/width)
 
         :param input: a tensor of values that are used to estimate gating function, batch-first.
         :param args: extra positional parameters that will be passed to each expert after input, batch-first
@@ -78,18 +74,18 @@ class RemoteMixtureOfExperts(nn.Module):
         # 1. compute scores and find most appropriate experts with beam search
         grid_scores = self.proj(input_for_gating).split_with_sizes(self.grid_size, dim=-1)
 
-        async def _search():
-            coroutines = [asyncio.create_task(self.beam_search(
-                [dim_scores[i] for dim_scores in grid_scores], self.k_best))
-                for i in range(len(input))]
-            return list(await asyncio.gather(*coroutines))
+        chosen_experts: List[List[RemoteExpert]] = self.dht.batch_find_best_experts(
+            self.uid_prefix, [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best, **self.dht_kwargs)
 
-        chosen_experts: List[List[RemoteExpert]] = self.loop.run_until_complete(_search())
-        # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
+        if self._expert_info is None:
+            try:
+                self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
+            except grpc.RpcError as e:
+                logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
-            self.backward_timeout, self.loop, self.info, *nested_flatten(((input, *args), kwargs)))
+            self.backward_timeout, self.info, *nested_flatten(((input, *args), kwargs)))
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
@@ -101,52 +97,6 @@ class RemoteMixtureOfExperts(nn.Module):
             for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
         return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
 
-    async def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[RemoteExpert]:
-        """
-        Find and return k best experts in the grid using (exact) beam search of the product space
-
-        :param grid_scores: scores predicted for each dimension in the grid,
-        :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
-        :param k_best: how many of the top experts participate in the computation
-        :param kwargs: extra keyword parameters passed to self.dht.first_k_active
-        :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains \
-         RemoteExpert instances for *up to* k_best experts
-        """
-        assert len(grid_scores) == len(self.grid_size)
-        assert all(dim_scores.shape == (self.grid_size[dim_index],) for dim_index, dim_scores in enumerate(grid_scores))
-        grid_scores = [dim_scores.cpu().detach() for dim_scores in grid_scores]
-
-        beam_experts: List[RemoteExpert] = []
-        beam: List[str] = [self.uid_prefix]
-        beam_scores = torch.zeros(1)
-
-        for dim_index, dim_scores in enumerate(grid_scores):
-            # create all possible successors from current beam and sort them by total score
-            expanded_scores = beam_scores[:, None] + dim_scores[None, :]
-            sorted_indices = [(flat_i // len(dim_scores), flat_i % len(dim_scores))
-                              for flat_i in (-expanded_scores).flatten().argsort().numpy()]
-
-            sorted_candidates = [f"{beam[row]}{self.dht.UID_DELIMITER}{col}" for row, col in sorted_indices]
-            candidate_to_indices = dict(zip(sorted_candidates, sorted_indices))
-
-            # select k best candidates according to scores but only those that are still active
-            best_alive_prefixes: Dict[str, RemoteExpert] = await self.dht.first_k_active(
-                uid_prefixes=sorted_candidates, k=k_best, return_future=True, **kwargs)
-            if not best_alive_prefixes:
-                logger.warning(f"Grid is empty: found neither of {sorted_candidates}")
-                break
-            beam = list(best_alive_prefixes.keys())
-            beam_scores = expanded_scores[tuple(zip(*map(candidate_to_indices.get, beam)))]
-            beam_experts = list(best_alive_prefixes.values())
-
-        if self._expert_info is None:
-            try:
-                self._expert_info = beam_experts[0].info
-            except grpc.RpcError as e:
-                logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
-
-        return beam_experts
-
     def compute_expert_scores(
             self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
         """
@@ -168,8 +118,8 @@ class RemoteMixtureOfExperts(nn.Module):
 
         grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
         for i, expert in enumerate(flat_experts):
-            expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMITER):]
-            expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMITER)))
+            expert_indices = expert.uid[len(self.uid_prefix):]
+            expert_indices = list(map(int, expert_indices.split(hivemind.dht.UID_DELIMITER)))
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
         scores_per_dim = [
@@ -206,117 +156,98 @@ class _RemoteCallMany(torch.autograd.Function):
     @classmethod
     def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                 timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
-                loop: asyncio.base_events.BaseEventLoop, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+                info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
         assert not torch.is_grad_enabled()
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
         flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
 
-        async def _forward():
-            # dispatch tasks to all remote experts, await responses
-            pending_tasks = {
-                asyncio.create_task(cls._forward_one_expert((i, j), expert, info, flat_inputs_per_sample[i]))
-                for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
-            }
-            alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
-                pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min)
-
-            # assemble responses
-            alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
-            mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
-            mask[alive_ii, alive_jj] = True
-
-            alive_flat_outputs_stacked = list(map(torch.cat, zip(*alive_flat_outputs)))
-            # list of torch tensors, where i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
-
-            outputs = []
-            for response_stacked in alive_flat_outputs_stacked:
-                output = torch.zeros(
-                    [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
-                    dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
-                output[alive_ii, alive_jj] = response_stacked
-                outputs.append(output)
-
-            # save individual outputs for backward pass
-            ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
-            ctx._saved_non_tensors = loop, info, backward_k_min, backward_timeout,\
-                                     timeout_after_k_min, experts_per_sample
-            return (mask,) + tuple(outputs)
-
-        return loop.run_until_complete(_forward())
+        # dispatch tasks to all remote experts collect responses
+        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
+        for i in range(num_samples):
+            for j, expert in enumerate(experts_per_sample[i]):
+                input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(
+                                 flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
+                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
+                new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
+                pending_tasks[new_task] = (i, j)
+
+        alive_grid_indices, alive_flat_outputs = cls._collect_responses(
+            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min)
+        if len(alive_grid_indices) == 0:
+            raise TimeoutError("Forward pass: no alive experts responded within timeout.")
+
+        # assemble responses
+        alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
+        mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
+        mask[alive_ii, alive_jj] = True
+
+        alive_flat_outputs_stacked = list(map(torch.cat, zip(*alive_flat_outputs)))
+        # list of torch tensors, where i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
+
+        outputs = []
+        for response_stacked in alive_flat_outputs_stacked:
+            output = torch.zeros(
+                [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
+                dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
+            output[alive_ii, alive_jj] = response_stacked
+            outputs.append(output)
+
+        # save individual outputs for backward pass
+        ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
+        ctx._saved_non_tensors = info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
+        return (mask,) + tuple(outputs)
 
     @classmethod
     @once_differentiable
     def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
         assert not torch.is_grad_enabled()
-        loop, info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
+        info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
         alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
         dummy_grad_mask, *flat_grad_outputs = raw_grads
         num_samples, max_experts = dummy_grad_mask.shape
 
         inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs))
         grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs))
-
-        async def _backward():
-            # dispatch tasks to all remote experts, await responses
-            pending_tasks = set()
-            for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
-                                                        inputs_per_expert, grad_outputs_per_expert):
-                pending_tasks.add(asyncio.create_task(cls._backward_one_expert(
-                    (i, j), expert_per_sample[i.item()][j.item()], info, inputs_ij, grad_outputs_ij)))
-
-            backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
-                pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
-
-            # assemble responses
-            backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices))
-            survivor_grad_inputs_stacked = list(map(torch.cat, zip(*survivor_grad_inputs)))
-            # list of torch tensors, where i-th tensor is of shape [num_backward_survivors, *flat_inputs[i].shape]
-
-            grad_inputs = []
-            for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
-                grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
-                    (num_samples, max_experts, *flat_inputs[i].shape[1:]),
-                    device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
-                grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
-
-                grad_inputs.append(grad_input_per_expert.sum(dim=1))  # add up gradients from each expert
-
-            return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
-
-        return loop.run_until_complete(_backward())
-
-    @staticmethod
-    async def _forward_one_expert(
-            grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any], inputs: Tuple[torch.Tensor]):
-        stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
-        try:
-            outputs = await stub.forward(runtime_pb2.ExpertRequest(
-                uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression) for tensor, proto in 
-                                         zip(inputs, nested_flatten(info['forward_schema']))]))
-            return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
-        except grpc.experimental.aio.AioRpcError as error:
-            logger.warning(f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})")
-
-    @staticmethod
-    async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, info: Dict[str, Any],
-                                   inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]):
-        stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
-        inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs)))
         backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
-        try:
-            grad_inputs = await stub.backward(runtime_pb2.ExpertRequest(
-                uid=expert.uid, tensors=[serialize_torch_tensor(tensor, proto.compression)
-                                         for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]))
-            return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors)
-        except grpc.experimental.aio.AioRpcError as error:
-            logger.warning(f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})")
+
+        # dispatch tasks to all remote experts, collect responses
+        pending_tasks = {}
+        for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
+                                                    inputs_per_expert, grad_outputs_per_expert):
+            expert = expert_per_sample[i.item()][j.item()]
+            stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
+            inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
+            tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
+                                  for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
+            new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
+            pending_tasks[new_task] = (i, j)
+
+        backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
+            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
+        if len(backward_survivor_indices) == 0:
+            raise TimeoutError("Backward pass: no alive experts responded within timeout.")
+
+        # assemble responses
+        backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))
+        survivor_grad_inputs_stacked = list(map(torch.cat, zip(*survivor_grad_inputs)))
+        # list of torch tensors, where i-th tensor is of shape [num_backward_survivors, *flat_inputs[i].shape]
+
+        grad_inputs = []
+        for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
+            grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
+                (num_samples, max_experts, *flat_inputs[i].shape[1:]),
+                device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
+            grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
+
+            grad_inputs.append(grad_input_per_expert.sum(dim=1))  # add up gradients from each expert
+
+        return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
 
     @staticmethod
-    async def _wait_for_responses(
-            pending_tasks: Set[Awaitable[Tuple[Tuple[int, int], Tuple[torch.Tensor, ...]]]],
-            num_samples: int, k_min: int, timeout_total: Optional[float], timeout_after_k_min: Optional[float]
-    ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
+    def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
+                           timeout_total: Optional[float], timeout_after_k_min: Optional[float]
+                           ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
         """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
         timeout_total = float('inf') if timeout_total is None else timeout_total
         timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min
@@ -324,24 +255,37 @@ class _RemoteCallMany(torch.autograd.Function):
         pending_samples = num_samples  # samples for which we have less than k_min results
         finished_indices, finished_outputs = [], []
         t_finish = time.perf_counter() + timeout_total
+        pending_tasks = set(task_to_indices.keys())
+        finished_tasks = Queue()
 
-        while pending_tasks and time.perf_counter() <= t_finish:
-            finished_tasks, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_COMPLETED,
-                                                               timeout=t_finish - time.perf_counter())
-            for task in finished_tasks:
-                if not task.result():
+        try:
+            # the algorithm below is essentially futures.as_completed, but for grpc.Future
+            for task in pending_tasks:
+                task.add_done_callback(finished_tasks.put)
+
+            for _ in range(len(task_to_indices)):
+                timeout = max(0.0, t_finish - time.perf_counter()) if t_finish != float('inf') else None
+                task = finished_tasks.get(timeout=timeout)
+                pending_tasks.discard(task)
+
+                if task.exception() or task.cancelled():
+                    logger.warning(f"Task {task} failed: {type(task.exception())}")
                     continue
-                task_indices, task_flat_outputs = await task
-                finished_indices.append(task_indices)
-                finished_outputs.append(task_flat_outputs)
 
-                sample_index = task_indices[0]
+                finished_indices.append(task_to_indices[task])
+                finished_outputs.append(tuple(deserialize_torch_tensor(tensor) for tensor in task.result().tensors))
+
+                # count how many successes we have for each input sample
+                sample_index = task_to_indices[task][0]
                 num_successful_tasks[sample_index] += 1
                 if num_successful_tasks[sample_index] == k_min:
                     pending_samples -= 1
                     if pending_samples <= 0:  # all tasks finished, await stragglers for at most timeout_after_k_min
                         t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
 
-        for task in pending_tasks:
-            task.cancel()
+        except Empty:
+            pass  # we reached t_finish, this is normal behavior
+        finally:
+            for task in pending_tasks:
+                task.cancel()
         return finished_indices, finished_outputs

+ 242 - 183
hivemind/dht/__init__.py

@@ -16,26 +16,53 @@ import asyncio
 import ctypes
 import heapq
 import multiprocessing as mp
+import re
 import warnings
-from collections import deque, OrderedDict
+from collections import deque
 from concurrent.futures import ThreadPoolExecutor
-from itertools import chain
-from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque, Set
+from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 
 import uvloop
 
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
-from hivemind.dht.routing import get_dht_time
+from hivemind.dht.routing import get_dht_time, DHTValue
+from hivemind.dht.storage import ValueWithExpiration
 from hivemind.utils import MPFuture, Endpoint, get_logger
 
 logger = get_logger(__name__)
 
+ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
+UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
+UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
+FLAT_EXPERT = -1     # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
+UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
+PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
+#  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
+
+
+def is_valid_uid(maybe_uid: str) -> bool:
+    return bool(UID_PATTERN.fullmatch(maybe_uid))
+
+
+def is_valid_prefix(maybe_prefix: str) -> bool:
+    return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
+
+
+def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
+    """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
+    uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
+    pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
+    return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
+
 
 class DHT(mp.Process):
     """
     High-level interface to hivemind.dht that is designed to allow RemoteMixtureOfExperts to select best experts.
 
+    * hivemind servers periodically announce their experts via DHT.declare_experts
+    * trainers find most suitable experts via DHT.find_best_experts
+
     :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
@@ -55,7 +82,10 @@ class DHT(mp.Process):
     A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
     When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
     For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
-    ``"ffn_expert", "ffn_expert.98", "ffn_expert.98.76", ..., "ffn_expert.98.76.54.32.10"``
+    ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``
+
+    In order to enable fast beam search, DHT maintains dictionaries of all active suffixes for every prefix
+    (e.g. "ffn_expert.98": {76: ffn_expert.98.76...., 123: ffn_expert.98.123..., 225: ffn_expert.98.225....}))
 
     RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
     For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
@@ -63,14 +93,12 @@ class DHT(mp.Process):
     However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
     In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
     It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
-    This is done using DHT.first_k_active function. After selecting k best indices along first dimension, MoE moves
-    to the second dimension. It can find top-k pairs of indices (e.g. "expert.98.76") that start with one of k first
-    indices from the previous step. Finally, MoE will use DHT.get_experts(uids: List[str]) search for specific experts.
+
+    After selecting k best indices along first dimension, MoE moves to the second dimension.
+    It can find top-k index pairs (e.g. "expert.98.76") that use one of k best indices from the previous step.
     This beam search explores one additional dimension per step and finds k best experts from across the DHT
-    in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts.
+    in O(k * num_dimensions * dimension_size) time depending on the chosen grid dimensions.
     """
-    UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
-    #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
@@ -129,31 +157,8 @@ class DHT(mp.Process):
     def port(self) -> Optional[int]:
         return self._port.value if self._port.value != 0 else None
 
-    def get_experts(self, uids: List[str], expiration_time: Optional[DHTExpiration] = None,
-                    return_future=False) -> List[Optional[RemoteExpert]]:
-        """
-        :param uids: find experts with these ids from across the DHT
-        :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
-        :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
-        :returns: a list of [RemoteExpert if found else None]
-        """
-        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_experts', [], dict(uids=uids, expiration_time=expiration_time, future=_future)))
-        return future if return_future else future.result()
-
-    async def _get_experts(
-            self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
-        if expiration_time is None:
-            expiration_time = get_dht_time()
-        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
-        response = await node.get_many(uids, expiration_time, num_workers=num_workers)
-        # TODO expert_data['expert'] -> namedtuple with meaningful field names
-        future.set_result([RemoteExpert(*expert_data.value['expert'].value)
-                           if expert_data is not None and 'expert' in expert_data.value else None
-                           for uid, expert_data in response.items()])
-
-    def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
+    def declare_experts(self, uids: Sequence[ExpertUID], endpoint: Endpoint, wait: bool = True,
+                        timeout: Optional[float] = None) -> Dict[ExpertUID, bool]:
         """
         Make experts visible to all DHT peers; update timestamps if declared previously.
 
@@ -161,38 +166,151 @@ class DHT(mp.Process):
         :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
         :param wait: if True, awaits for declaration to finish, otherwise runs in background
         :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
-        :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
+        :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
         """
         assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+        for uid in uids:
+            assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
         future, _future = MPFuture.make_pair() if wait else (None, None)
         self.pipe.send(('_declare_experts', [], dict(uids=list(uids), endpoint=endpoint, future=_future)))
         if wait:
             return future.result(timeout)
 
-    async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
+    async def _declare_experts(self, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint,
+                               future: Optional[MPFuture]) -> Dict[ExpertUID, bool]:
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         expiration_time = get_dht_time() + self.expiration
-        unique_entries: Set[Tuple[str, str]] = set()
-        #                 prefix---v next_dim     uid  endpoint
-        data_to_store: List[Tuple[str, str, List[str, Endpoint]]] = []
-        for uid in uids:  # first k entries are expert uids themselves
-            data_to_store.append((uid, "expert", [uid, endpoint]))
-        for uid in uids:  # and then, add all prefixes
-            uid_parts = uid.split(self.UID_DELIMITER)
-            for i in range(len(uid_parts) - 1):
-                uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
-                if (uid_prefix_i, uid_parts[i + 1]) in unique_entries:
-                    continue
-                unique_entries.add((uid_prefix_i, uid_parts[i + 1]))
-                data_to_store.append((uid_prefix_i, uid_parts[i + 1], [uid, endpoint]))
-
-        keys, subkeys, values = map(list, zip(*data_to_store))
-        store_ok = await node.store_many(keys, values, expiration_time, subkeys=subkeys, num_workers=num_workers)
+        data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
+        for uid in uids:
+            data_to_store[uid, None] = endpoint
+            prefix = uid if uid.count(UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}'
+            for i in range(prefix.count(UID_DELIMITER) - 1):
+                prefix, last_coord = split_uid(prefix)
+                data_to_store[prefix, last_coord] = [uid, endpoint]
+
+        keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
+        store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
         if future is not None:
-            future.set_result([store_ok[key, subkey] for key, subkey in zip(keys, subkeys)])
+            future.set_result(store_ok)
+        return store_ok
+
+    def get_experts(self, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None,
+                    return_future: bool = False) -> List[Optional[RemoteExpert]]:
+        """
+        :param uids: find experts with these ids from across the DHT
+        :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: a list of [RemoteExpert if found else None]
+        """
+        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get_experts', [], dict(uids=list(uids), expiration_time=expiration_time, future=_future)))
+        return future if return_future else future.result()
 
-    def find_best_experts(self, prefix: str, grid_scores: Sequence[Sequence[float]], beam_size: int, *,
-                          return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]:
+    async def _get_experts(self, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration],
+                           future: Optional[MPFuture] = None) -> List[Optional[RemoteExpert]]:
+        if expiration_time is None:
+            expiration_time = get_dht_time()
+        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
+        found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
+
+        experts: List[Optional[RemoteExpert]] = [None] * len(uids)
+        for i, uid in enumerate(uids):
+            if found[uid] is not None and isinstance(found[uid].value, Endpoint):
+                experts[i] = RemoteExpert(uid, found[uid].value)
+        if future:
+            future.set_result(experts)
+        return experts
+
+    def get_initial_beam(self, prefix: ExpertPrefix, scores: Sequence[float], beam_size: int,
+                         num_workers: Optional[int] = None, return_future: bool = False
+                         ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+        """
+        :param prefix: search for experts whose uids start with this prefix
+        :param scores: prefer suffix coordinates that have highest scores
+        :param beam_size: select this many active suffixes with highest scores
+        :param num_workers: maintain up to this many concurrent DHT searches
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
+        """
+        assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get_initial_beam', [], dict(prefix=prefix, scores=tuple(scores), beam_size=beam_size,
+                                                      num_workers=num_workers, future=_future)))
+        return future if return_future else future.result()
+
+    async def _get_initial_beam(self, node, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...],
+                                num_workers: Optional[int] = None, future: Optional[MPFuture] = None
+                                ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+        num_workers = num_workers or self.max_workers or beam_size
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
+        unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
+        pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
+
+        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
+            # dispatch additional tasks
+            while unattempted_indices and len(pending_tasks) < num_workers:
+                next_index = unattempted_indices.pop()  # note: this is best unattempted index because of sort order
+                next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
+                pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
+
+            # await the next best prefix to be fetched
+            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
+            try:
+                maybe_prefix_data = await pending_task
+                if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
+                    successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items()
+                                  if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list)
+                                  and len(match.value) == 2}
+                    beam.append((scores[pending_best_index], pending_best_prefix, successors))
+            except asyncio.CancelledError:
+                for _, pending_task in pending_tasks:
+                    pending_task.cancel()
+                raise
+        if future:
+            future.set_result(beam)
+        return beam
+
+    def get_active_successors(self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
+                              num_workers: Optional[int] = None, return_future: bool = False
+                              ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+        """
+        :param prefixes: a list of prefix for which to find active successor uids
+        :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
+        :param num_workers: how many parallel workers to use for DHTNode.get_many
+        :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
+        :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
+        :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
+        """
+        assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
+        for prefix in prefixes:
+            assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get_active_successors', [], dict(
+            prefixes=list(prefixes), grid_size=grid_size, num_workers=num_workers, future=_future)))
+        return future if return_future else future.result()
+
+    async def _get_active_successors(self, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
+                                     num_workers: Optional[int] = None, future: Optional[MPFuture] = None
+                                     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+        grid_size = grid_size or float('inf')
+        num_workers = num_workers or min(len(prefixes), self.max_workers or len(prefixes))
+        dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
+        successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
+        for prefix, found in dht_responses.items():
+            if found and isinstance(found.value, dict):
+                successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items()
+                                      if isinstance(coord, Coordinate) and 0 <= coord < grid_size
+                                      and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
+            else:
+                successors[prefix] = {}
+        if future:
+            future.set_result(successors)
+        return successors
+
+    def find_best_experts(self, prefix: ExpertPrefix, grid_scores: Sequence[Sequence[float]], beam_size: int,
+                          num_workers: Optional[int] = None, return_future: bool = False
+                          ) -> Union[List[RemoteExpert], MPFuture]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -203,174 +321,115 @@ class DHT(mp.Process):
          After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
          Please note that any queries that fall outside the budget will still be performed in background and cached
          for subsequent iterations as long as DHTNode.cache_locally is True
+        :param num_workers: use up to this many concurrent workers to search DHT
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :param kwargs: extra keyword parameters passed to DHTNode.get_many
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
+        assert len(grid_scores) > 0 and beam_size > 0
+        assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_find_best_experts', [], dict(prefix=prefix, grid_scores=list(map(tuple, grid_scores)),
-                                                       beam_size=beam_size, future=_future, **kwargs)))
+                                                       beam_size=beam_size, num_workers=num_workers, future=_future)))
         return future if return_future else future.result()
 
     async def _find_best_experts(
             self, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
-            max_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]:
-        max_workers: Optional[int] = max_workers or self.max_workers or beam_size
+            num_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]:
+        num_workers = num_workers or min(beam_size, self.max_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
-        beam: List[Tuple[float, str, Dict[str, List[str, Endpoint]]]] = await self._get_initial_beam(
-            node, prefix, beam_size, grid_scores[0], num_workers=min(beam_size, max_workers))
-        if not beam:
-            logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
-            return []
-        # TODO warn user if indices are out of range on the _last_ level! (rationale: beam search may return <k results)
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await self._get_initial_beam(
+            node, prefix, beam_size, grid_scores[0], min(beam_size, num_workers))
+
+        best_experts_heap: List[Tuple[Score, UidEndpoint]] = []  # max-heap of expert uids/endpoints ordered by scores
+        unique_experts: Set[ExpertUID] = set()
 
         for dim_index in range(1, len(grid_scores) - 1):
-            # select beam_size best suffixes from current beam
+            for score, uid_endpoint in self._iterate_matching_experts(beam, grid_scores):
+                if uid_endpoint.uid not in unique_experts:
+                    push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
+                    push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
+                    unique_experts.add(uid_endpoint.uid)
+
+            # form new beam using successors from the current beam
             dim_scores = grid_scores[dim_index]
-            best_active_pairs: List[Tuple[float, str]] = heapq.nlargest(beam_size, (
-                (prefix_score + dim_scores[int(suffix_i)], f"{prefix}{self.UID_DELIMITER}{suffix_i}")
-                for prefix_score, prefix, suffixes in beam for suffix_i in suffixes.keys()
-                # TODO get rid of str.isdecimal
-                if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)))
+            best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(beam_size, (
+                (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
+                for prefix_score, prefix, suffixes in beam for next_coord in suffixes.keys()
+                if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)))
+            _, best_uid_prefixes = zip(*best_active_pairs)
 
             # search DHT for next step suffixes
-            _, best_uid_prefixes = zip(*best_active_pairs)
-            # TODO Tuple[Dict[str, List[str, Endpoint]], DHTExpiration] -> namedtuple
-            dht_responses: Dict[str, Tuple[Dict[str, List[str, Endpoint]], DHTExpiration]] = await node.get_many(
-                keys=best_uid_prefixes, num_workers=min(len(best_uid_prefixes), max_workers), **kwargs)
-            if all(expiration is None for key, (_, expiration) in dht_responses.items()):
-                logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim {dim_index})")
+            successors = await self._get_active_successors(node, best_uid_prefixes, num_workers=num_workers)
+            beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
+            if not beam:
+                logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
                 break
-            beam = [(prefix_score, prefix, dht_responses[prefix][0])  # add suffix dict if it is found
-                    for prefix_score, prefix in best_active_pairs if dht_responses[prefix][1] is not None]
-
-        # select best experts from the final beam
-        dim_scores = grid_scores[-1]
-        # TODO use heap to harness all results, get rid of five-line expression
-        final_best_pairs: List[Tuple[float, str, Endpoint]] = heapq.nlargest(beam_size, chain((
-            (prefix_score + dim_scores[int(suffix_i)], uid, endpoint)
-            for prefix_score, prefix, suffixes in beam for suffix_i, ((uid, endpoint), _) in suffixes.items()
-            if str.isdecimal(suffix_i) and 0 <= int(suffix_i) < len(dim_scores)
-        ), ((score, *suffixes['expert']) for score, _, suffixes in beam if 'expert' in suffixes)))
-        best_experts = [RemoteExpert(uid, endpoint) for score, uid, endpoint in final_best_pairs]
+
+        # add best experts from the final beam
+        for score, uid_endpoint in self._iterate_matching_experts(beam, grid_scores):
+            if uid_endpoint.uid not in unique_experts:
+                push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
+                push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
+                unique_experts.add(uid_endpoint.uid)
+
+        best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
         if future is not None:
             future.set_result(best_experts)
         return best_experts
 
-    def batch_find_best_experts(self, prefix: str, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, *,
-                                return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]:
+    @staticmethod
+    def _iterate_matching_experts(beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]],
+                                  grid_scores: Sequence[Sequence[float]]) -> Iterator[Tuple[Score, UidEndpoint]]:
+        """ iterate over all exemplar experts attached to current beam """
+        for score, prefix, suffixes in beam:
+            for next_coord, match in suffixes.items():
+                if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
+                    yield score, match
+                elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
+                    expert_coords = match.uid.split(UID_DELIMITER)[1:]
+                    if all(coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
+                           for i, coord in enumerate(expert_coords)):
+                        expert_score = sum(scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords)))
+                        yield expert_score, match
+                    else:
+                        logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
+                else:
+                    logger.warning(f"Found incompatible expert UID: {match.uid}")
+
+    def batch_find_best_experts(
+            self, prefix: str, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, *,
+            workers_per_sample: Optional[int] = None, return_future=False) -> Union[List[List[RemoteExpert]], MPFuture]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
         :param prefix: common prefix for all expert uids in grid
         :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
-        :type batch_grid_scores: model scores for each example and each grid dimension,  list of arrays of shape (batch_size, grid_size[i])
+        :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
         :param beam_size: how many best experts should beam search return
          After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
          Please note that any queries that fall outside the budget will still be performed in background and cached
          for subsequent iterations as long as DHTNode.cache_locally is True
+        :param workers_per_sample: use up to this many concurrent workers for every sample in batch
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
-        :param kwargs: extra keyword parameters passed to DHTNode.get_many
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_batch_find_best_experts', [], dict(prefix=prefix, batch_grid_scores=batch_grid_scores,
-                                                             beam_size=beam_size, future=_future, **kwargs)))
+                                                             beam_size=beam_size, workers_per_sample=workers_per_sample,
+                                                             future=_future)))
         return future if return_future else future.result()
 
     async def _batch_find_best_experts(
             self, node: DHTNode, prefix: str, batch_grid_scores: Sequence[Sequence[Tuple[float]]], beam_size: int,
-            max_workers: Optional[int] = None, future: Optional[MPFuture] = None, **kwargs) -> List[List[RemoteExpert]]:
+            workers_per_sample: Optional[int] = None, future: Optional[MPFuture] = None) -> List[List[RemoteExpert]]:
 
-        batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))]
-        coros = [self._find_best_experts(node, prefix, grid_scores, beam_size, max_workers, **kwargs) for grid_scores in batch_grid_scores]
+        batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores]
+                             for i in range(len(batch_grid_scores[0]))]
+        coros = [self._find_best_experts(node, prefix, grid_scores, beam_size, workers_per_sample)
+                 for grid_scores in batch_grid_scores]
 
         best_experts_batch = await asyncio.gather(*coros)
         if future is not None:
             future.set_result(best_experts_batch)
         return best_experts_batch
-
-    async def _get_initial_beam(self, node, prefix: str, beam_size: int, scores: Tuple[float, ...], num_workers: int
-                                ) -> List[Tuple[float, str, Dict[str, List[str]]]]:
-        """ Fetch a list of all active level-one prefixes of a given prefix. Used for beam search """
-        beam: List[Tuple[float, str, Dict[str, List[str, Endpoint]]]] = []  # results will be stored here
-        unattempted_indices: List[int] = sorted(range(len(scores)), key=scores.__getitem__)  # order: worst to best
-        pending_tasks: Deque[Tuple[int, str, asyncio.Task]] = deque()  # up to num_workers concurrent get tasks
-
-        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
-            # dispatch additional tasks
-            while unattempted_indices and len(pending_tasks) < num_workers:
-                next_index = unattempted_indices.pop()  # note: this is best unattempted index because of sort order
-                next_best_prefix = f"{prefix}{self.UID_DELIMITER}{next_index}"
-                pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))
-
-            # await the next best prefix to be fetched
-            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
-            try:
-                maybe_prefix_data = await pending_task
-                if maybe_prefix_data is not None:
-                    beam.append((scores[pending_best_index], pending_best_prefix, maybe_prefix_data.value))
-            except asyncio.CancelledError:
-                for _, pending_task in pending_tasks:
-                    pending_task.cancel()
-                raise
-        return beam
-
-    def first_k_active(
-            self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None,
-            return_future=False) -> Union[TOrderedDict[str, RemoteExpert], Awaitable[TOrderedDict[str, RemoteExpert]]]:
-        """
-        Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
-
-        :param uid_prefixes: a list of uid prefixes ordered from highest to lowest priority
-        :param k: return at most *this many* active prefixes
-        :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
-        :param chunk_size: dispatch this many requests in one task
-        :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
-        :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
-            The keys in the returned dict are ordered same as in uid_prefixes.
-        """
-        logger.warning("first_k_active is deprecated and will be removed in 0.8.8")
-        assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
-        future, _future = MPFuture.make_pair()
-        self.pipe.send(('_first_k_active', [],
-                        dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch,
-                             chunk_size=chunk_size or k, future=_future)))
-        return future if return_future else future.result()
-
-    async def _first_k_active(
-            self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture):
-        num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
-        total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
-        found: List[Tuple[str, RemoteExpert]] = []
-
-        pending_tasks = deque(
-            asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size],
-                                              num_workers=num_workers_per_chunk))
-            for chunk_i in range(min(max_prefetch + 1, total_chunks))
-        )  # pre-dispatch first task and up to max_prefetch additional tasks
-
-        for chunk_i in range(total_chunks):
-            # parse task results in chronological order, launch additional tasks on demand
-            response = await pending_tasks.popleft()
-            for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
-                if response[uid_prefix] is not None and len(response[uid_prefix].value) > 0:  # found active peer
-                    found.append((uid_prefix, RemoteExpert(*next(iter(response[uid_prefix].value.values()))[0])))
-                    # if we found enough active experts, finish immediately
-                    if len(found) >= k:
-                        break
-            if len(found) >= k:
-                break
-
-            pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
-            if pre_dispatch_chunk_i < total_chunks:
-                pending_tasks.append(asyncio.create_task(node.get_many(
-                    uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size],
-                    num_workers=num_workers_per_chunk)))
-
-        for task in pending_tasks:
-            task.cancel()
-
-        # return k active prefixes or as many as we could find
-        future.set_result(OrderedDict(found))

+ 20 - 8
hivemind/dht/node.py

@@ -21,18 +21,18 @@ logger = get_logger(__name__)
 
 class DHTNode:
     """
-    A low-level class that represents a DHT participant. Please see DHTNode.create for parameters
+    Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
     Each DHTNode has an identifier, a local storage and access too other nodes via DHTProtocol.
 
     :note: Hivemind DHT is optimized to store a lot of temporary metadata that is regularly updated.
-     For example, an expert alive timestamp that emitted by the Server responsible for that expert.
-     Such metadata does not require regular maintenance by peers, persistence on shutdown.
+     For example, expert heartbeat emitted by a hivemind.Server responsible for that expert.
+     Such metadata does not require regular maintenance by peers or persistence on shutdown.
      Instead, DHTNode is designed to rapidly send bulk data and resolve conflicts.
 
-    Every (key, value) pair in this DHT has an expiration time - float computed as get_dht_time(), UnixTime by default
+    Every (key, value) pair in this DHT has an expiration time - float computed as get_dht_time() (UnixTime by default)
     DHT nodes always prefer values with higher expiration time and may delete any value past its expiration.
 
-    Compared to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
+    Similar to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
 
     * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
     * store - send several (key, value, expiration_time) pairs to the same peer (like Kademlia STORE, but in bulk)
@@ -46,9 +46,21 @@ class DHTNode:
       IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
     - when requested to store(key: value, expiration_time), a node must store (key => value) at until expiration time
       or until DHTNode gets the same key with greater expiration time. If a node is asked to store a key but it already
-      has the same key with newer expiration, the older key will not be stored. Return True if stored, False if refused;
-    - when requested to store(key: value, expiration_time, in_cache=True), stores (key => value) in a separate "cache".
-      Cache operates same as regular storage, but it has a limited size and evicts least recently used nodes when full;
+      has the same key with newer expiration, store will be rejected. Store returns True if accepted, False if rejected;
+    - when requested to store(key: value, expiration_time, subkey=subkey), adds a sub-key to a dictionary value type.
+      Dictionary values can have multiple sub-keys stored by different peers with individual expiration times. A subkey
+      will be accepted to a dictionary either if there is no such sub-key or if new subkey's expiration is later than
+      previous expiration under that subkey. See DHTProtocol.call_store for details.
+
+    DHTNode also features several (optional) caching policies:
+
+    - cache_locally: after GET, store the result in node's own local cache
+    - cache_nearest: after GET, send the result to this many nearest nodes that don't have that value yet (see Kademlia)
+    - cache_on_store: after STORE, either save or remove that key from node's own cache depending on store status
+    - cache_refresh_before_expiry: if a value in cache was used and is about to expire, try to GET it this many seconds
+      before expiration. The motivation here is that some frequent keys should be always kept in cache to avoid latency.
+    - reuse_get_requests: if there are several concurrent GET requests, when one request finishes, DHTNode will attempt
+      to reuse the result of this GET request for other requests with the same key. Useful for batch-parallel requests.
 
     """
     # fmt:off

+ 3 - 3
hivemind/server/__init__.py

@@ -276,10 +276,10 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
 
     def _generate_uid():
         if expert_pattern is None:
-            return f"expert{hivemind.DHT.UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
+            return f"expert{hivemind.dht.UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
 
         uid = []
-        for block in expert_pattern.split(hivemind.DHT.UID_DELIMITER):
+        for block in expert_pattern.split(hivemind.dht.UID_DELIMITER):
             try:
                 if '[' not in block and ']' not in block:
                     uid.append(block)
@@ -292,7 +292,7 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
                 raise e
             except Exception as e:
                 raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block} , {e}")
-        return hivemind.DHT.UID_DELIMITER.join(uid)
+        return hivemind.dht.UID_DELIMITER.join(uid)
 
     while remaining_attempts > 0 and len(found_uids) < num_experts:
 

+ 3 - 3
hivemind/utils/grpc.py

@@ -68,9 +68,9 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
         stats_size[-1] = 1
         stats_count = np.prod(stats_size)
         means, stds = serialized_tensor.buffer[-8*stats_count:-4*stats_count], serialized_tensor.buffer[-4*stats_count:]
-        means = torch.as_tensor(np.frombuffer(means, dtype=np.float32)).view(*stats_size)
-        stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32)).view(*stats_size)
-        array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16)
+        means = torch.as_tensor(np.frombuffer(means, dtype=np.float32).copy()).view(*stats_size)
+        stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32).copy()).view(*stats_size)
+        array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16).copy()
         tensor = torch.as_tensor(array).to(torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
     elif serialized_tensor.compression == CompressionType.FLOAT16:
         array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy()

+ 4 - 4
hivemind/utils/mpfuture.py

@@ -6,7 +6,7 @@ import concurrent.futures._base as base
 
 import asyncio
 from functools import lru_cache
-from typing import Optional
+from typing import Optional, Tuple
 
 from hivemind.utils.threading import run_in_background
 
@@ -22,7 +22,7 @@ class MPFuture(base.Future):
         self.connection = connection
 
     @classmethod
-    def make_pair(cls):
+    def make_pair(cls) -> Tuple[MPFuture, MPFuture]:
         """ Create a pair of linked futures to be used in two processes """
         connection1, connection2 = mp.Pipe()
         return cls(connection1), cls(connection2)
@@ -105,7 +105,7 @@ class MPFuture(base.Future):
         self._state, self._exception = base.CANCELLED, base.CancelledError()
         return self._send_updates()
 
-    def result(self, timeout=None):
+    def result(self, timeout: Optional[float] = None):
         self._await_terminal_state(timeout)
         if self._exception is not None:
             raise self._exception
@@ -117,7 +117,7 @@ class MPFuture(base.Future):
             raise base.CancelledError()
         return self._exception
 
-    def done(self):
+    def done(self) -> bool:
         self._sync_updates()
         return self._state in self.TERMINAL_STATES
 

+ 3 - 3
tests/benchmark_dht.py

@@ -41,11 +41,11 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for start in trange(0, num_experts, expert_batch_size):
         store_start = time.perf_counter()
         endpoints.append(random_endpoint())
-        success_list = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], endpoints[-1])
+        successes = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], endpoints[-1]).values()
         total_store_time += time.perf_counter() - store_start
 
-        total_stores += len(success_list)
-        successful_stores += sum(success_list)
+        total_stores += len(successes)
+        successful_stores += sum(successes)
         time.sleep(wait_after_request)
 
     print(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")

+ 64 - 44
tests/test_dht_experts.py

@@ -1,10 +1,9 @@
 import random
-import uuid
-from itertools import chain
 import numpy as np
+import pytest
 
 import hivemind
-from hivemind import LOCALHOST
+from hivemind import LOCALHOST, UidEndpoint
 
 
 def test_store_get_experts():
@@ -16,7 +15,7 @@ def test_store_get_experts():
     you: hivemind.dht.DHT = random.choice(peers)
     theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
 
-    expert_uids = [str(uuid.uuid4()) for _ in range(110)]
+    expert_uids = [f"my_expert.{i}" for i in range(110)]
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
         you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
@@ -25,22 +24,12 @@ def test_store_get_experts():
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
-    that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
+    that_guys_expert, that_guys_port = "my_other_expert.1337", random.randint(1000, 9999)
     theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
     you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
     assert isinstance(you_found, hivemind.RemoteExpert)
     assert you_found.endpoint == f'that_host:{that_guys_port}'
 
-    # test first_k_active
-    assert list(theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10)) == expert_uids[:10]
-
-    some_permuted_experts = random.sample(expert_uids, k=32)
-    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32)) == some_permuted_experts
-    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1)) == some_permuted_experts[:1]
-    fake_and_real_experts = list(chain(*zip(
-        [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
-    assert list(theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9)) == some_permuted_experts[:9]
-
     for peer in peers:
         peer.shutdown()
 
@@ -65,46 +54,77 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
     you = hivemind.DHT(start=True, expiration=999999, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
 
     for i in range(50):
-        topk_experts = you.find_best_experts('expert', [np.random.randn(dim) for dim in grid_dims], beam_size=beam_size)
+        topk_experts = you.find_best_experts('expert.', [np.random.randn(dim) for dim in grid_dims], beam_size=beam_size)
         assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
         assert len(topk_experts) == beam_size
 
     for i in range(10):
-        batch_experts = you.batch_find_best_experts('expert', [np.random.randn(batch_size, dim) for dim in grid_dims],
+        batch_experts = you.batch_find_best_experts('expert.', [np.random.randn(batch_size, dim) for dim in grid_dims],
                                                     beam_size=beam_size)
         assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
         assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
         assert all(len(experts) == beam_size for experts in batch_experts)
 
 
-def test_first_k_active():
-    node = hivemind.DHT(start=True)
-    assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
-    assert all(node.declare_experts(['e.2.1.1'], endpoint=f"{hivemind.LOCALHOST}:1338"))
-
-    results = node.first_k_active(['e.0', 'e.1', 'e.2', 'e.3'], k=2)
-    assert len(results) == 2 and next(iter(results.keys())) == 'e.1'
-    assert results['e.1'].uid in ('e.1.2.3', 'e.1.2.4') and results['e.1'].endpoint == f"{hivemind.LOCALHOST}:1337"
-    assert results['e.2'].uid == 'e.2.1.1' and results['e.2'].endpoint == f"{hivemind.LOCALHOST}:1338"
-
-    results = node.first_k_active(['e', 'e.1', 'e.1.2', 'e.1.2.3'], k=10)
-    assert len(results) == 4
-    assert 'e' in results
-    for k in ('e.1', 'e.1.2', 'e.1.2.3'):
-        assert results[k].uid in ('e.1.2.3', 'e.1.2.4') and results[k].endpoint == f"{hivemind.LOCALHOST}:1337"
-
-
 def test_dht_single_node():
-    node = hivemind.DHT(start=True)
-    assert node.first_k_active(['e.3', 'e.2'], k=3) == {}
-    assert node.get_experts(['e.3', 'e.2']) == [None, None]
+    node = hivemind.DHT(start=True, expiration=999)
+
+    assert all(node.declare_experts(['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
+    assert len(node.declare_experts(["ffn.1", "ffn.2"], endpoint="that_place")) == 4
+    assert len(node.declare_experts(['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
 
-    assert all(node.declare_experts(['e.1', 'e.2', 'e.3'], f"{hivemind.LOCALHOST}:1337"))
-    for expert in node.get_experts(['e.3', 'e.2']):
+    for expert in node.get_experts(['expert.3', 'expert.2']):
         assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
-    active_found = node.first_k_active(['e.0', 'e.1', 'e.3', 'e.5', 'e.2'], k=2)
-    assert list(active_found.keys()) == ['e.1', 'e.3']
-    assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
 
-    assert all(node.declare_experts(['e.1', 'e.2', 'e.3'], f"{hivemind.LOCALHOST}:1337"))
-    assert node.find_best_experts('e', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=4)
+    assert all(node.declare_experts(['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
+    found_experts = node.find_best_experts('expert.', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
+    assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
+
+    successors = node.get_active_successors(['e.1.2.', 'e.2.', 'e.4.5.'])
+    assert len(successors['e.1.2.']) == 2
+    assert successors['e.1.2.'][3] == UidEndpoint('e.1.2.3', f'{LOCALHOST}:42')
+    assert successors['e.1.2.'][5] == UidEndpoint('e.1.2.5', f'{LOCALHOST}:42')
+    assert len(successors['e.2.']) == 1 and successors['e.2.'][0] == UidEndpoint('e.2.0', f'{LOCALHOST}:42')
+    assert successors['e.4.5.'] == {}
+
+    initial_beam = node.get_initial_beam('expert.', (3, 2, 1, 0, -1, -2, -3), beam_size=3)
+    assert len(initial_beam) == 3
+    assert initial_beam[0][:2] == (2.0, 'expert.1.')
+    assert initial_beam[1][:2] == (1.0, 'expert.2.')
+    assert initial_beam[2][:2] == (0.0, 'expert.3.')
+
+    with pytest.raises(AssertionError):
+        node.find_best_experts('expert', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
+
+    with pytest.raises(AssertionError):
+        node.find_best_experts('expert.1', [(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
+
+    with pytest.raises(AssertionError):
+        node.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
+
+    with pytest.raises(AssertionError):
+        node.get_initial_beam('expert', (3, 2, 1, 0, -1, -2, -3), beam_size=3)
+
+
+def test_uid_patterns():
+    valid_experts = ["expert.1", "expert.0", "expert.0.0.1", "expert.1337", "ffn.12.34.56.78.90",
+                     "transformer.3.2.1.0", "transformer_encoder.2", "transformer::encoder.2", "T®@nsf0rmE®🤗.321",
+                     "🤗.321", "0.1.2", "00.1.2", "7070.3.2.1.0", "block2.1.23", "LAYER.1.0.1"]
+    valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
+    valid_prefixes.extend([f"{uid}." for uid in valid_experts])
+    valid_prefixes.extend([hivemind.split_uid(uid)[0] for uid in valid_experts])
+    for uid in valid_experts:
+        assert hivemind.is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
+    for pfx in valid_prefixes:
+        assert hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
+
+    invalid = ["", ".", "expert.-1", "xxx.a", "expert.1x", "expert_ffn.1.abc1", "some.123.01", "expert.123.01",
+               "e1", "e..1", "e", "e.1.2.3..4", "ffn.1..1", ".123", ".1.2.3.", ".expert", "transformer.encoder.2",
+               "T®@nsf0rmE®.🤗.321", "layer::123", "expert.0.1.2.suffix", "0.1.2.suffix", "expert.1 something",
+               "expert.1\n", "expert.1\n2", "expert.1 ", "expert.1\nexpert.2", "'expert.1'", '"expert.1"']
+    invalid_experts = invalid + valid_prefixes + ["0", "123456"]
+    invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
+    for uid in invalid_experts:
+        assert not hivemind.is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
+    for pfx in invalid_prefixes:
+        assert not hivemind.is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"

+ 13 - 14
tests/test_moe.py

@@ -1,5 +1,3 @@
-import asyncio
-
 import grpc
 import numpy as np
 import pytest
@@ -17,7 +15,7 @@ def test_moe():
         dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
 
         dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn')
+            in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn.')
 
         for i in range(10):
             out = dmoe(torch.randn(10, 16))
@@ -31,7 +29,7 @@ def test_call_many():
     forward_timeout = None
     backward_timeout = None
     rtol = 1e-3
-    atol = 1e-6
+    atol = 1e-5
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
                            optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
@@ -42,8 +40,7 @@ def test_call_many():
 
         mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
             DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
-            k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout,
-            asyncio.new_event_loop(), e1.info, inputs
+            k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout, e1.info, inputs
         )
         assert mask.shape == (4, 3)
         assert expert_outputs.shape == (4, 3, 64)
@@ -96,33 +93,35 @@ def test_remote_module_call():
             fake_expert(dummy_x)
 
 
-def test_moe_beam_search():
+def test_beam_search_correctness():
     all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
     dht = hivemind.DHT(start=True, expiration=999)
     assert all(dht.declare_experts(all_expert_uids, endpoint='fake-endpoint'))
 
     dmoe = hivemind.RemoteMixtureOfExperts(
-        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn')
+        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn.')
 
     for i in range(25):
         input = torch.randn(32)
         grid_scores = dmoe.proj(input).split_with_sizes(dmoe.grid_size, dim=-1)
 
-        chosen_experts = dmoe.loop.run_until_complete(dmoe.beam_search(grid_scores, k_best=dmoe.k_best))
-
+        chosen_experts = dht.find_best_experts(dmoe.uid_prefix, [tensor.detach().numpy() for tensor in grid_scores],
+                                               beam_size=dmoe.k_best)
         chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
                                                    [chosen_experts])[0]
+        our_best_scores = list(chosen_scores.cpu().detach().numpy())
 
-        all_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
+        # reference: independently find :beam_size: best experts with exhaustive search
+        all_scores = dmoe.compute_expert_scores([dim_scores.unsqueeze(0) for dim_scores in grid_scores],
                                                 [[hivemind.RemoteExpert(uid, '') for uid in all_expert_uids]])[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[:len(chosen_experts)]
-        our_best_scores = list(chosen_scores.cpu().detach().numpy())
+
         assert np.allclose(true_best_scores, our_best_scores)
 
 
 def test_determinism():
     rtol = 0
-    atol = 1e-6
+    atol = 1e-5
 
     xx = torch.randn(32, 1024, requires_grad=True)
     mask = torch.randint(0, 1, (32, 1024))
@@ -146,7 +145,7 @@ def test_compute_expert_scores():
         dht = hivemind.DHT(start=True)
         moe = hivemind.client.moe.RemoteMixtureOfExperts(
             dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
-            uid_prefix='expert')
+            uid_prefix='expert.')
         gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
         ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]