Преглед изворни кода

RemoteMixtureOfExperts update part II (#80)

* strict msgpack types

* strict msgpack types

* wip: RemoteCallMany without utils.autograd (still need to update and test RemoteMixtureOfExperts!)

* wip: RemoteCallMany without utils.autograd (still need to update and test RemoteMixtureOfExperts!)

* wip: RemoteCallMany without utils.autograd (still need to update and test RemoteMixtureOfExperts!)

* implement and test RemoteMixtureOfExperts

* numpy -> torch

* order preference

* preferred test order

* split tests into test_remote_expert.py and test_moe.py

* reduce test_moe size

* move test_moe to the top

* magical test order

* Update tests/test_dht.py

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

* review {1337} => 1337

* .data.numpy => .detach.numpy

* review: full_like => full((1,))

* review: less hacky multiply

* unused import

* review: specify key order in docstring of first_k_active

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic пре 5 година
родитељ
комит
6d2b8094c9

+ 17 - 14
hivemind/client/expert.py

@@ -1,5 +1,6 @@
 import pickle
-from typing import Tuple, Optional
+from functools import lru_cache
+from typing import Tuple, Optional, Any
 
 import grpc
 import grpc.experimental.aio
@@ -13,6 +14,19 @@ from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
+@lru_cache(maxsize=None)
+def _get_expert_stub(endpoint: Endpoint, aio: bool, *extra_options: Tuple[str, Any]):
+    """ Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache """
+    channel_options = [
+        ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)
+    ] + list(extra_options)
+    if aio:
+        channel = grpc.experimental.aio.insecure_channel(endpoint, options=channel_options)
+    else:
+        channel = grpc.insecure_channel(endpoint, options=channel_options)
+    return runtime_grpc.ConnectionHandlerStub(channel)
+
+
 class RemoteExpert(nn.Module):
     """
     A simple module that runs forward/backward of an expert hosted on a remote machine.
@@ -28,22 +42,11 @@ class RemoteExpert(nn.Module):
     def __init__(self, uid, endpoint: Endpoint):
         super().__init__()
         self.uid, self.endpoint = uid, endpoint
-        self._channel, self._stub, self._info = None, None, None
+        self._info = None
 
     @property
     def stub(self):
-        if self._channel is None:
-            self._channel = grpc.insecure_channel(self.endpoint, options=[
-                ('grpc.max_send_message_length', -1),
-                ('grpc.max_receive_message_length', -1)
-            ])
-        if self._stub is None:
-            self._stub = runtime_grpc.ConnectionHandlerStub(self._channel)
-        return self._stub
-
-    def __del__(self):
-        if self._channel is not None:
-            self._channel.close()
+        return _get_expert_stub(self.endpoint, aio=False)
 
     def forward(self, *args, **kwargs):
         """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """

+ 206 - 135
hivemind/client/moe.py

@@ -1,14 +1,20 @@
-from functools import partial
-from typing import Tuple, List, Optional
+from __future__ import annotations
+import time
+import asyncio
+from typing import Tuple, List, Optional, Awaitable, Set, Dict
 
-import numpy as np
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
+import grpc.experimental.aio
 
-from hivemind.client.expert import RemoteExpert, _RemoteModuleCall, DUMMY
-from hivemind.utils import nested_map, run_and_await_k, nested_pack, nested_flatten, run_in_background, \
-    run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
+import hivemind
+from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
+from hivemind.utils import nested_map, nested_pack, nested_flatten, runtime_grpc, runtime_pb2, \
+    serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
 
 
 class RemoteMixtureOfExperts(nn.Module):
@@ -25,30 +31,31 @@ class RemoteMixtureOfExperts(nn.Module):
     :param uid_prefix: common prefix for all expert uids
      expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
     :param dht: DHT where the experts reside
-    :param num_workers: number of threads for parallel dht operation
     :param k_best: queries this many experts with highest scores
     :param k_min: makes sure at least this many experts returned output
     :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
      Any expert that didn't manage to return output after that delay is considered unavailable
-    :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
     :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
      allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
      allow_broadcasting=False will raise an error
     """
 
-    def __init__(self, *, in_features, grid_size: Tuple[int], dht, k_best, k_min=1,
-                 forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
-                 uid_prefix='', expert_padding=None, allow_broadcasting=True):
+    def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, k_best: int, k_min: int = 1,
+                 forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
+                 backward_k_min: int = 1, backward_timeout: Optional[float] = None, uid_prefix='',
+                 allow_broadcasting=True, loop: asyncio.BaseEventLoop = None):
         super().__init__()
-        self.dht, self.grid_size = dht, grid_size
-        self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
+        self.dht, self.grid_size, self.uid_prefix = dht, grid_size, uid_prefix
+        self.loop = loop or asyncio.new_event_loop()
+        assert not self.loop.is_running(), "Event loop is already running. If in jupyter, please apply nest_asyncio " \
+            "(pip install nest_asyncio , https://pypi.org/project/nest-asyncio ) and send loop=asyncio.new_event_loop()"
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
         self.allow_broadcasting = allow_broadcasting
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
-        self._outputs_schema = None
+        self._outputs_schema = None  # expert['info'][outputs_schema] from one of experts in the grid
 
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
@@ -69,25 +76,31 @@ class RemoteMixtureOfExperts(nn.Module):
 
         # 1. compute scores and find most appropriate experts with beam search
         grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
-        chosen_experts = self.beam_search(grid_scores, self.k_best)
-        # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
 
-        expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
+        async def _search():
+            coroutines = [asyncio.create_task(self.beam_search(
+                [dim_scores[i] for dim_scores in grid_scores], self.k_best))
+                for i in range(len(input))]
+            return list(await asyncio.gather(*coroutines))
 
-        expert_inputs = ((input, *args), kwargs)
-        input_schema = nested_map(lambda x: None, expert_inputs)
-        flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
+        chosen_experts: List[List[RemoteExpert]] = self.loop.run_until_complete(_search())
+        # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
 
-        batch_jobs_args = tuple(
-            (expert_logits[i, :len(chosen_experts[i])], chosen_experts[i], self.k_min, self.timeout_after_k_min,
-             self.backward_k_min, self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
-            for i in range(len(input))
-        )
+        expert_mask, *expert_outputs = _RemoteCallMany.apply(
+            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min,
+            self.forward_timeout, self.backward_timeout, self.loop, *nested_flatten(((input, *args), kwargs)))
+        # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
-        averaged_outputs_flat = map(torch.cat, zip(*map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)))
+        expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
+        masked_logits = torch.full((1,), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype)
+        expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
+        expert_weights = torch.softmax(expert_logits, dim=1)
+        averaged_outputs_flat = [
+            (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
+            for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
         return nested_pack(averaged_outputs_flat, self.outputs_schema)
 
-    def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
+    async def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[RemoteExpert]:
         """
         Find and return k best experts in the grid using (exact) beam search of the product space
 
@@ -99,51 +112,39 @@ class RemoteMixtureOfExperts(nn.Module):
          RemoteExpert instances for *up to* k_best experts
         """
         assert len(grid_scores) == len(self.grid_size)
-        assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
-        batch_size = len(grid_scores[0])
-        beam = np.array([[self.uid_prefix]] * batch_size, dtype=object)  # [batch_size, up_to_beam_size]
-        scores = np.zeros([batch_size, 1], dtype=np.float64)
+        assert all(dim_scores.shape == (self.grid_size[dim_index],) for dim_index, dim_scores in enumerate(grid_scores))
+        grid_scores = [dim_scores.cpu().detach() for dim_scores in grid_scores]
 
-        delimiters = np.array(self.dht.UID_DELIMITER)[None, None, None]  # pre-compute numpy array for fast concat
+        beam_experts: List[RemoteExpert] = []
+        beam: List[str] = [self.uid_prefix]
+        beam_scores = torch.zeros(1)
 
         for dim_index, dim_scores in enumerate(grid_scores):
-            dim_scores = dim_scores.detach().cpu().numpy()
-            assert dim_scores.shape[-1] == self.grid_size[dim_index]
-
-            # create all possible successsors from current beam
-            dim_indices = np.arange(dim_scores.shape[1]).astype(str)
-            new_candidates = beam[:, :, None] + delimiters + dim_indices[None, None, :]
-            new_candidates = new_candidates.reshape([batch_size, -1])
+            # create all possible successors from current beam and sort them by total score
+            expanded_scores = beam_scores[:, None] + dim_scores[None, :]
+            sorted_indices = [(flat_i // len(dim_scores), flat_i % len(dim_scores))
+                              for flat_i in (-expanded_scores).flatten().argsort().numpy()]
 
-            new_scores = scores[:, :, None] + dim_scores[:, None, :]
-            new_scores = new_scores.reshape([batch_size, -1])
+            sorted_candidates = [f"{beam[row]}{self.dht.UID_DELIMITER}{col}" for row, col in sorted_indices]
+            candidate_to_indices = dict(zip(sorted_candidates, sorted_indices))
 
             # select k best candidates according to scores but only those that are still active
-            new_order = np.argsort(- new_scores, axis=-1)
-            top_alive_lookups = [
-                run_in_background(self.dht.first_k_active, cands[order], k_best, **kwargs)
-                for cands, order in zip(new_candidates, new_order)]
-
-            batch_cand_to_score = [
-                dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
-
-            top_alive_prefixes = [result.result() for result in top_alive_lookups]
-            top_alive_scores = [list(map(cand_to_score.get, top_cands))
-                                for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
-
-            # pad up to beam size
-            beam = np.array([row + [self.expert_padding] * (k_best - len(row))
-                             for row in top_alive_prefixes], dtype='object')
-            scores = np.array([row + [-float('inf')] * (k_best - len(row))
-                               for row in top_alive_scores], dtype='float32')
-
-        unique_experts = self.dht.get_experts(list(set(
-            uid for row in beam for uid in row if uid != self.expert_padding)))
+            best_alive_prefixes: Dict[str, RemoteExpert] = await self.dht.first_k_active(
+                uid_prefixes=sorted_candidates, k=k_best, return_future=True, **kwargs)
+            if not best_alive_prefixes:
+                logger.warning(f"Grid is empty: found neither of {sorted_candidates}")
+                break
+            beam = list(best_alive_prefixes.keys())
+            beam_scores = expanded_scores[tuple(zip(*map(candidate_to_indices.get, beam)))]
+            beam_experts = list(best_alive_prefixes.values())
+
         if self._outputs_schema is None:
-            self._outputs_schema = next(iter(unique_experts)).info['outputs_schema']
-        unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
+            try:
+                self._outputs_schema = beam_experts[0].info['outputs_schema']
+            except grpc.RpcError as e:
+                logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
-        return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
+        return beam_experts
 
     def compute_expert_scores(
             self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
@@ -164,11 +165,11 @@ class RemoteMixtureOfExperts(nn.Module):
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_experts = [expert for row in batch_experts for expert in row]
 
-        grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
+        grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
         for i, expert in enumerate(flat_experts):
             expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMITER):]
             expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMITER)))
-            grid_indices[i] = expert_indices
+            grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
         scores_per_dim = [
             dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
@@ -183,86 +184,156 @@ class RemoteMixtureOfExperts(nn.Module):
     def outputs_schema(self):
         if self._outputs_schema is None:
             # grab some expert to set ensemble output shape
-            dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(self.grid_size, dim=-1)
-            self._outputs_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['outputs_schema']
+            dummy_scores = self.proj(torch.randn(self.proj.in_features)).cpu().split_with_sizes(self.grid_size, dim=-1)
+            dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
+            self._outputs_schema = dummy_experts[0].info['outputs_schema']
         return self._outputs_schema
 
 
-class _RemoteMoECall(torch.autograd.Function):
+class _RemoteCallMany(torch.autograd.Function):
     """
-    Internal autograd-friendly function that calls multiple experts on the same input and averages their outputs.
-    This function that can recover from individual failures during forward and/or backward passes.
-    For user-friendly version of this function, use RemoteMixtureOfExperts module.
+    Internal autograd-friendly function that calls multiple experts on a batch of inputs and awaits responses
+    This function that can recover from individual failures during forward and/or backward pass as long as at least
+    one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
+
+    Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
+          experts that failed during backward will be treated as constants (i.e. gradients of through them are zeros)
     """
 
     @classmethod
-    def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
-                k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
-                backward_timeout: Optional[float], input_schema, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
-        expert_args, expert_kwargs = nested_pack(flat_inputs, structure=input_schema)
-        assert expert_logits.ndim == 1 and len(expert_logits) == len(experts)
-
-        # 1. call experts and await results
-        jobs = [partial(cls._run_expert_forward, expert, *expert_args, **expert_kwargs) for expert in experts]
-        results = run_and_await_k(jobs, k=k_min, timeout_after_k=timeout_after_k_min, timeout_total=timeout_total)
-
-        alive_contexts, alive_outputs, alive_ix = zip(*[(result[0], result[1], ix) for ix, result in enumerate(results)
-                                                        if not isinstance(result, BaseException)])
-        #     ^               ^            ^-- a list of indices of experts that returned outputs in time
-        #      \               \-- list of outputs of every expert that didn't die on us
-        #       \-- a list of autograd contexts, used for parallel backward
-
-        # 2. compute softmax weights for alive experts and average outputs
-        alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
-        alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
-
-        stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
-
-        flat_average_outputs = tuple((alive_expert_probs @ stacked_out.flatten(1)).view(*stacked_out.shape[1:])
-                                     for stacked_out in stacked_alive_outputs)
-
-        # 3. save individual outputs for backward pass
-        ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
-        ctx._saved_non_tensors = alive_contexts, backward_k_min, backward_timeout
-        return tuple(map(torch.Tensor.detach, flat_average_outputs))
+    def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
+                timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
+                loop: asyncio.base_events.BaseEventLoop, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+        assert not torch.is_grad_enabled()
+        num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
+        flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
+        assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
+
+        async def _forward():
+            # dispatch tasks to all remote experts, await responses
+            pending_tasks = {
+                asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]))
+                for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
+            }
+            alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
+                pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min)
+
+            # assemble responses
+            alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
+            mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
+            mask[alive_ii, alive_jj] = True
+
+            alive_flat_outputs_stacked = list(map(torch.cat, zip(*alive_flat_outputs)))
+            # list of torch tensors, where i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
+
+            outputs = []
+            for response_stacked in alive_flat_outputs_stacked:
+                output = torch.zeros(
+                    [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
+                    dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
+                output[alive_ii, alive_jj] = response_stacked
+                outputs.append(output)
+
+            # save individual outputs for backward pass
+            ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
+            ctx._saved_non_tensors = loop, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
+            return (mask,) + tuple(outputs)
+
+        return loop.run_until_complete(_forward())
 
     @classmethod
     @once_differentiable
-    def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
-        """ Like normal backward, but we ignore any experts that failed during backward pass """
-        expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
-        alive_contexts, backward_k_min, backward_timeout = ctx._saved_non_tensors
-
-        jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
-                for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
-        results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=backward_timeout, timeout_total=None)
-        backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
-        backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
-        backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
-        survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
-        weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
-        flat_grad_inputs = tuple((weight_ratios @ stacked_grad_inp.flatten(1)).view(stacked_grad_inp.shape[1:])
-                                 for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
-
-        # compute grad w.r.t. logits
-        grad_wrt_probs = sum(tuple(
-            torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
-                      dim=tuple(range(1, stacked_avive_out.ndim)))
-            for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
-        ))
-        softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
-        grad_wrt_survived_logits = grad_wrt_probs @ softmax_jacobian
-        grad_wrt_logits = torch.zeros_like(expert_logits).scatter(0, backward_survivors_ix, grad_wrt_survived_logits)
-
-        return (grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs)
+    def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
+        assert not torch.is_grad_enabled()
+        loop, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
+        alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
+        dummy_grad_mask, *flat_grad_outputs = raw_grads
+        num_samples, max_experts = dummy_grad_mask.shape
+
+        inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs))
+        grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs))
+
+        async def _backward():
+            # dispatch tasks to all remote experts, await responses
+            pending_tasks = set()
+            for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
+                                                        inputs_per_expert, grad_outputs_per_expert):
+                pending_tasks.add(asyncio.create_task(
+                    cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij)
+                ))
+
+            backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
+                pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
+
+            # assemble responses
+            backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices))
+            survivor_grad_inputs_stacked = list(map(torch.cat, zip(*survivor_grad_inputs)))
+            # list of torch tensors, where i-th tensor is of shape [num_backward_survivors, *flat_inputs[i].shape]
+
+            grad_inputs = []
+            for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
+                grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
+                    (num_samples, max_experts, *flat_inputs[i].shape[1:]),
+                    device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
+                grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
+
+                grad_inputs.append(grad_input_per_expert.sum(dim=1))  # add up gradients from each expert
+
+            return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
+        return loop.run_until_complete(_backward())
+
+    @staticmethod
+    async def _forward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor]):
+        stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
+        try:
+            outputs = await stub.forward(runtime_pb2.ExpertRequest(
+                uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
+            return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
+        except grpc.experimental.aio.AioRpcError as error:
+            logger.warning(f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})")
 
     @staticmethod
-    def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
-        """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
-        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.stub, *nested_flatten((args, kwargs)))
+    async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert,
+                                   inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]):
+        stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
+        payload = tuple(nested_flatten((inputs, grad_outputs)))
+        try:
+            grad_inputs = await stub.backward(runtime_pb2.ExpertRequest(
+                uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
+            return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors)
+        except grpc.experimental.aio.AioRpcError as error:
+            logger.warning(f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})")
 
     @staticmethod
-    def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
-        backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
-        grad_dummy, no_grad_uid, no_grad_stub, *grad_inputs = backward_result
-        return grad_inputs
+    async def _wait_for_responses(
+            pending_tasks: Set[Awaitable[Tuple[Tuple[int, int], Tuple[torch.Tensor, ...]]]],
+            num_samples: int, k_min: int, timeout_total: Optional[float], timeout_after_k_min: Optional[float]
+            ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
+        """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
+        timeout_total = float('inf') if timeout_total is None else timeout_total
+        timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min
+        num_successful_tasks = [0 for _ in range(num_samples)]
+        pending_samples = num_samples  # samples for which we have less than k_min results
+        finished_indices, finished_outputs = [], []
+        t_finish = time.perf_counter() + timeout_total
+
+        while pending_tasks and time.perf_counter() <= t_finish:
+            finished_tasks, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_COMPLETED,
+                                                               timeout=t_finish - time.perf_counter())
+            for task in finished_tasks:
+                if not task.result():
+                    continue
+                task_indices, task_flat_outputs = await task
+                finished_indices.append(task_indices)
+                finished_outputs.append(task_flat_outputs)
+
+                sample_index = task_indices[0]
+                num_successful_tasks[sample_index] += 1
+                if num_successful_tasks[sample_index] == k_min:
+                    pending_samples -= 1
+                    if pending_samples <= 0:  # all tasks finished, await stragglers for at most timeout_after_k_min
+                        t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
+
+        for task in pending_tasks:
+            task.cancel()
+        return finished_indices, finished_outputs

+ 25 - 19
hivemind/dht/__init__.py

@@ -16,9 +16,9 @@ import asyncio
 import ctypes
 import multiprocessing as mp
 import warnings
-from collections import deque
+from collections import deque, OrderedDict
 from concurrent.futures import ThreadPoolExecutor
-from typing import List, Optional, Sequence
+from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable
 
 import uvloop
 
@@ -126,17 +126,17 @@ class DHT(mp.Process):
         return self._port.value if self._port.value != 0 else None
 
     def get_experts(self, uids: List[str], expiration_time: Optional[DHTExpiration] = None,
-                    wait=True) -> List[Optional[RemoteExpert]]:
+                    return_future=False) -> List[Optional[RemoteExpert]]:
         """
         :param uids: find experts with these ids from across the DHT
         :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
-        :param wait: if True (default), return when experts are returned. Otherwise return a Future.
+        :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
         :returns: a list of [RemoteExpert if found else None]
         """
         assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_get_experts', [], dict(uids=uids, expiration_time=expiration_time, future=_future)))
-        return future.result() if wait else future
+        return future if return_future else future.result()
 
     async def _get_experts(
             self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
@@ -144,8 +144,8 @@ class DHT(mp.Process):
             expiration_time = get_dht_time()
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         response = await node.get_many(uids, expiration_time, num_workers=num_workers)
-        future.set_result([RemoteExpert(uid, maybe_endpoint) if maybe_expiration_time else None
-                           for uid, (maybe_endpoint, maybe_expiration_time) in response.items()])
+        future.set_result([RemoteExpert(**expert_data) if maybe_expiration_time else None
+                           for uid, (expert_data, maybe_expiration_time) in response.items()])
 
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
         """
@@ -172,14 +172,16 @@ class DHT(mp.Process):
             uid_parts = uid.split(self.UID_DELIMITER)
             for i in range(len(uid_parts)):
                 uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
-                data_to_store[uid_prefix_i] = endpoint
+                data_to_store[uid_prefix_i] = {'uid': uid, 'endpoint': endpoint}
 
         store_keys, store_values = zip(*data_to_store.items())
         store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers)
         if future is not None:
             future.set_result([store_ok[key] for key in data_to_store.keys()])
 
-    def first_k_active(self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None):
+    def first_k_active(
+            self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None,
+            return_future=False) -> Union[TOrderedDict[str, RemoteExpert], Awaitable[TOrderedDict[str, RemoteExpert]]]:
         """
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
 
@@ -187,20 +189,22 @@ class DHT(mp.Process):
         :param k: return at most *this many* active prefixes
         :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
         :param chunk_size: dispatch this many requests in one task
-        :returns: a list of at most :k: prefixes that have at least one active expert each;
+        :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
+        :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
+            The keys in the returned dict are ordered same as in uid_prefixes.
         """
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
                         dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch,
                              chunk_size=chunk_size or k, future=_future)))
-        return future.result()
+        return future if return_future else future.result()
 
     async def _first_k_active(
             self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture):
         num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
         total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
-        active_prefixes = []
+        found: List[Tuple[str, RemoteExpert]] = []
 
         pending_tasks = deque(
             asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size],
@@ -212,14 +216,13 @@ class DHT(mp.Process):
             # parse task results in chronological order, launch additional tasks on demand
             response = await pending_tasks.popleft()
             for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
-                if response[uid_prefix][1] is not None:  # found active peer
-                    active_prefixes.append(uid_prefix)
+                maybe_expert_data, maybe_expiration_time = response[uid_prefix]
+                if maybe_expiration_time is not None:  # found active peer
+                    found.append((uid_prefix, RemoteExpert(**maybe_expert_data)))
                     # if we found enough active experts, finish immediately
-                    if len(active_prefixes) >= k:
+                    if len(found) >= k:
                         break
-            if len(active_prefixes) >= k:
-                for task in pending_tasks:
-                    task.cancel()
+            if len(found) >= k:
                 break
 
             pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
@@ -228,5 +231,8 @@ class DHT(mp.Process):
                     uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size],
                     num_workers=num_workers_per_chunk)))
 
+        for task in pending_tasks:
+            task.cancel()
+
         # return k active prefixes or as many as we could find
-        future.set_result(active_prefixes)
+        future.set_result(OrderedDict(found))

+ 1 - 1
hivemind/dht/node.py

@@ -67,7 +67,7 @@ class DHTNode:
         :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
         :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
           Reduce this value if your RPC requests register no response despite the peer sending the response.
-        :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
+        :param wait_timeout: a kademlia rpc request is deemed lost if we did not receive a reply in this many seconds
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds

+ 2 - 2
hivemind/server/__init__.py

@@ -41,8 +41,8 @@ class Server(threading.Thread):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         if get_port(listen_on) is None:
-            self.listen_on = listen_on = replace_port(listen_on, new_port=find_open_port())
-        self.port = get_port(listen_on)
+            listen_on = replace_port(listen_on, new_port=find_open_port())
+        self.listen_on, self.port = listen_on, get_port(listen_on)
 
         self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:

+ 0 - 1
hivemind/utils/__init__.py

@@ -4,6 +4,5 @@ from hivemind.utils.tensor_descr import *
 from hivemind.utils.serializer import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.threading import *
-from hivemind.utils.autograd import *
 from hivemind.utils.grpc import *
 from hivemind.utils.logging import get_logger

+ 0 - 100
hivemind/utils/autograd.py

@@ -1,100 +0,0 @@
-"""
-Temporary autograd extensions to enable inter-op parallelism during backward pass
-Note: we should get rid of this module if https://github.com/pytorch/pytorch/pull/33157 reaches a pytorch release
-"""
-from itertools import chain
-from typing import Tuple, Any
-from concurrent.futures import Future
-
-import numpy as np
-import torch
-import torch.autograd.function
-
-from hivemind.utils.threading import run_in_background
-
-
-class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
-    """
-    A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd,
-    such as running several parallel backwards or transferring backward to a separate device.
-    This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
-    """
-
-    @property
-    def saved_tensors(self):
-        return tuple(self.to_save)
-
-
-def run_isolated_forward(func: torch.autograd.Function, *args) -> Tuple[EmulatedAutogradContext, Any]:
-    """
-    run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
-    can be used to run backward through the same graph (performed manually by the user).
-    """
-    ctx = EmulatedAutogradContext()
-    # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
-    args = tuple(x.detach().requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x for x in args)
-    with torch.no_grad():
-        return ctx, func.forward(ctx, *args)
-
-
-def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):
-    """
-    run backward pass for :func: in an isolated graph that was previously created through run_isolated_forward
-    """
-    with torch.no_grad():
-        return func.backward(ctx, *grad_outputs)
-
-
-def map_with_parallel_backward(
-        func: torch.autograd.Function, *args_per_call: Tuple[torch.Tensor, ...]) -> Tuple[Tuple[torch.Tensor, ...]]:
-    """
-    Apply an autograd function to several sets of inputs with two extra guarantees:
-    (1) both forward and backward pass happens concurrently for each set of inputs
-    (2) any operation dependent on any individual function will wait for all functions to finish
-    :param func: torch autograd function to be called several times in parallel
-    :param args_per_call: a sequence of tuples of arguments, each tuple corresponds to one function call
-    :returns: a tuple of outputs from each func call
-
-    Note: this function currently requires that all :func: calls succeed (i.e. do not raise an exception).
-    """
-    arg_counts = list(map(len, args_per_call))
-    assert len(set(arg_counts)) == 1, "All input sets must have the same number of arguments"
-    output_strides_ph = Future()
-    flat_outputs: Tuple[torch.Tensor, ...] = _ParallelApplyFunction.apply(
-        func, len(args_per_call), arg_counts[0], output_strides_ph, *chain(*args_per_call))
-    output_strides = output_strides_ph.result()
-    return tuple(flat_outputs[output_strides[i]: output_strides[i + 1]] for i in range(len(output_strides) - 1))
-
-
-class _ParallelApplyFunction(torch.autograd.Function):
-    """
-    A special torch autograd function that runs another function several times in parallel.
-    Please do not call this function directly. Use apply_with_parallel_backward instead.
-    Unlike default pytorch behavior, the backward pass for each function will also happen in parallel.
-    """
-
-    @staticmethod
-    def forward(ctx, func: torch.autograd.Function, num_calls: int, num_args_per_call: int,
-                output_strides_ph: Future, *args_flat) -> Tuple[torch.Tensor, ...]:
-        assert num_calls * num_args_per_call == len(args_flat)
-        args_per_call = [args_flat[i * num_args_per_call: (i + 1) * num_args_per_call] for i in range(num_calls)]
-
-        futures = [run_in_background(run_isolated_forward, func, *args) for args in args_per_call]
-
-        contexts, outputs = zip(*[future.result() for future in futures])
-        output_strides = np.cumsum([0] + list(map(len, outputs)))
-        ctx._inner_func = func
-        ctx._call_contexts = contexts
-        ctx._output_strides = output_strides
-        output_strides_ph.set_result(output_strides)
-        return tuple(chain(*outputs))
-
-    @staticmethod
-    def backward(ctx, *grad_outputs_flat: torch.Tensor):
-        func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
-        grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]]
-                                 for i in range(len(contexts))]
-        futures = [run_in_background(run_isolated_backward, func, context, *grads)
-                   for context, grads in zip(contexts, grad_outputs_per_call)]
-        flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())
-        return (None, None, None, None, *flat_grads_wrt_input)

+ 0 - 1
hivemind/utils/networking.py

@@ -1,5 +1,4 @@
 import socket
-import urllib.parse
 from contextlib import closing
 from typing import Optional
 

+ 1 - 1
hivemind/utils/serializer.py

@@ -40,7 +40,7 @@ class PytorchSerializer(SerializerBase):
 class MSGPackSerializer(SerializerBase):
     @staticmethod
     def dumps(obj: object) -> bytes:
-        return umsgpack.dumps(obj, use_bin_type=False)  # TODO strict https://github.com/msgpack/msgpack-python/pull/158
+        return umsgpack.dumps(obj, use_bin_type=False, strict_types=True)
 
     @staticmethod
     def loads(buf: bytes) -> object:

+ 1 - 50
hivemind/utils/threading.py

@@ -1,7 +1,5 @@
 import os
-from concurrent.futures import Future, as_completed, TimeoutError, ThreadPoolExecutor
-import time
-from typing import Optional, List
+from concurrent.futures import Future, ThreadPoolExecutor
 
 EXECUTOR_PID, GLOBAL_EXECUTOR = None, None
 
@@ -13,50 +11,3 @@ def run_in_background(func: callable, *args, **kwargs) -> Future:
         GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
         EXECUTOR_PID = os.getpid()
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
-
-
-def run_and_await_k(jobs: List[callable], k: int,
-                    timeout_after_k: Optional[float] = 0, timeout_total: Optional[float] = None):
-    """
-    Runs all :jobs: asynchronously, awaits for at least k of them to finish
-    :param jobs: functions to call asynchronously
-    :param k: how many functions should finish for call to be successful
-    :param timeout_after_k: after reaching k finished jobs, wait for this long before cancelling
-    :param timeout_total: if specified, terminate cancel jobs after this many seconds
-    :returns: a list of either results or exceptions for each job
-    """
-    jobs = list(jobs)
-    assert k <= len(jobs), f"Can't await {k} out of {len(jobs)} jobs."
-    start_time = time.time()
-    future_to_ix = {run_in_background(job): i for i, job in enumerate(jobs)}
-    outputs = [None] * len(jobs)
-    success_count = 0
-
-    try:
-        # await first k futures for as long as it takes
-        for future in as_completed(list(future_to_ix.keys()), timeout=timeout_total):
-            success_count += int(not future.exception())
-            outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
-            if success_count >= k:
-                break  # we have enough futures to succeed
-            if len(outputs) + len(future_to_ix) < k:
-                failed = len(jobs) - len(outputs) - len(future_to_ix)
-                raise ValueError(f"Couldn't get enough results: too many jobs failed ({failed} / {len(outputs)})")
-
-        # await stragglers for at most self.timeout_after_k_min or whatever time is left
-        if timeout_after_k is not None and timeout_total is not None:
-            time_left = min(timeout_after_k, timeout_total - time.time() + start_time)
-        else:
-            time_left = timeout_after_k if timeout_after_k is not None else timeout_total
-        for future in as_completed(list(future_to_ix.keys()), timeout=time_left):
-            success_count += int(not future.exception())
-            outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
-
-    except TimeoutError:
-        if len(outputs) < k:
-            raise TimeoutError(f"Couldn't get enough results: time limit exceeded (got {len(outputs)} of {k})")
-    finally:
-        for future, index in future_to_ix.items():
-            future.cancel()
-            outputs[index] = future.result() if not future.exception() else future.exception()
-    return outputs

+ 0 - 1
tests/benchmark_dht.py

@@ -1,7 +1,6 @@
 import time
 import argparse
 import random
-from typing import Tuple
 from warnings import warn
 import hivemind
 from tqdm import trange

+ 32 - 7
tests/test_dht.py

@@ -290,14 +290,14 @@ def test_hivemind_dht():
     assert you_found.endpoint == f'that_host:{that_guys_port}'
 
     # test first_k_active
-    assert theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10) == expert_uids[:10]
+    assert list(theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10)) == expert_uids[:10]
 
     some_permuted_experts = random.sample(expert_uids, k=32)
-    assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32) == some_permuted_experts
-    assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1) == some_permuted_experts[:1]
+    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32)) == some_permuted_experts
+    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1)) == some_permuted_experts[:1]
     fake_and_real_experts = list(chain(*zip(
         [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
-    assert theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9) == some_permuted_experts[:9]
+    assert list(theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9)) == some_permuted_experts[:9]
 
     for peer in peers:
         peer.shutdown()
@@ -305,10 +305,35 @@ def test_hivemind_dht():
 
 def test_dht_single_node():
     node = hivemind.DHT(start=True)
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:{1337}"))
+    assert node.first_k_active(['e3', 'e2'], k=3) == {}
+    assert node.get_experts(['e3', 'e2']) == [None, None]
+
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
     for expert in node.get_experts(['e3', 'e2']):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:{1337}"
-    assert node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2) == ['e1', 'e3']
+        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+    active_found = node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2)
+    assert list(active_found.keys()) == ['e1', 'e3']
+    assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
+
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
+
+
+def test_first_k_active():
+    node = hivemind.DHT(start=True)
+    assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
+    assert all(node.declare_experts(['e.2.1.1'], endpoint=f"{hivemind.LOCALHOST}:1338"))
+
+    results = node.first_k_active(['e.0', 'e.1', 'e.2', 'e.3'], k=2)
+    assert len(results) == 2 and next(iter(results.keys())) == 'e.1'
+    assert results['e.1'].uid in ('e.1.2.3', 'e.1.2.4') and results['e.1'].endpoint == f"{hivemind.LOCALHOST}:1337"
+    assert results['e.2'].uid == 'e.2.1.1' and results['e.2'].endpoint == f"{hivemind.LOCALHOST}:1338"
+
+    results = node.first_k_active(['e', 'e.1', 'e.1.2', 'e.1.2.3'], k=10)
+    assert len(results) == 4
+    assert 'e' in results
+    for k in ('e.1', 'e.1.2', 'e.1.2.3'):
+        assert results[k].uid in ('e.1.2.3', 'e.1.2.4') and results[k].endpoint == f"{hivemind.LOCALHOST}:1337"
+
 
 
 def test_store():

+ 110 - 29
tests/test_moe.py

@@ -1,46 +1,127 @@
+import asyncio
+
+import grpc
+import numpy as np
+import pytest
 import torch
 import hivemind
+from hivemind.client.expert import DUMMY
 from test_utils.run_server import background_server
 
 
-def test_remote_module_call():
-    """ Check that remote_module_call returns correct outputs and gradients if called directly """
-    num_experts = 8
+def test_moe():
+    all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
+                       for _ in range(20)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
+                           num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
+
+        dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        # declare expert uids. Server *should* declare them by itself, but it takes time.
+        assert all(dht.declare_experts(all_expert_uids, endpoint=server_endpoint))
+
+        dmoe = hivemind.RemoteMixtureOfExperts(
+            in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn')
+
+        for i in range(10):
+            out = dmoe(torch.randn(10, 16))
+            out.sum().backward()
+
+
+def test_call_many():
     k_min = 1
     timeout_after_k_min = None
     backward_k_min = 1
-    timeout_total = None
+    forward_timeout = None
     backward_timeout = None
     rtol = 1e-3
     atol = 1e-6
 
-    xx = torch.randn(32, 1024, requires_grad=True)
-    logits = torch.randn(3, requires_grad=True)
-    random_proj = torch.randn_like(xx)
+    with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
+                           no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
 
-    with background_server(num_experts=num_experts, device='cpu', num_handlers=1,
+        inputs = torch.randn(4, 64, requires_grad=True)
+        inputs_clone = inputs.clone().detach().requires_grad_(True)
+        e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]
+        e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
+
+        mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
+            DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
+            k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout,
+            asyncio.new_event_loop(), inputs
+        )
+        assert mask.shape == (4, 3)
+        assert expert_outputs.shape == (4, 3, 64)
+
+        assert np.all(mask.data.numpy() == np.array([[True, True, True],
+                                                     [True, True, False],
+                                                     [True, False, True],
+                                                     [False, False, False]])), f"Incorrect mask, {mask}"
+
+        reference_outputs = torch.zeros_like(expert_outputs)
+        reference_outputs[0, 0] = e0(inputs_clone[0:1])
+        reference_outputs[0, 1] = e1(inputs_clone[0:1])
+        reference_outputs[0, 2] = e2(inputs_clone[0:1])
+        reference_outputs[1, 0] = e2(inputs_clone[1:2])
+        reference_outputs[1, 1] = e4(inputs_clone[1:2])
+        reference_outputs[2, 0] = e1(inputs_clone[2:3])
+        reference_outputs[2, 2] = e3(inputs_clone[2:3])
+
+        assert torch.allclose(expert_outputs, reference_outputs, rtol, atol)
+        proj = torch.randn(4, 64)
+        loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
+        loss.backward()
+        our_grad = inputs.grad.data.cpu().clone()
+
+        reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
+        reference_loss.backward()
+        reference_grad = inputs_clone.grad.data.cpu().clone()
+        assert torch.allclose(our_grad, reference_grad, rtol, atol)
+
+
+def test_remote_module_call():
+    with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=1024,
                            no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
-        experts = [hivemind.RemoteExpert(uid=f'expert.{i}', endpoint=server_endpoint) for i in range(num_experts)]
-        moe_output, = hivemind.client.moe._RemoteMoECall.apply(
-            logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
-            [(None,), {}], xx)
-
-        grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
-        grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), logits, retain_graph=True)
-
-        # reference outputs: call all experts manually and average their outputs with softmax probabilities
-        probs = torch.softmax(logits, 0)
-        outs = [expert(xx) for expert in experts[:3]]
-        manual_output = sum(p * x for p, x in zip(probs, outs))
-        grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
-        grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
-        grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
-
-    assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun, rtol, atol), "Experts are non-deterministic. The test" \
-                                                                             " is only valid for deterministic experts"
-    assert torch.allclose(moe_output, manual_output, rtol, atol), "_RemoteMoECall returned incorrect output"
-    assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol, atol), "incorrect gradient w.r.t. input"
-    assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
+        real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
+        fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)
+
+        out1 = real_expert(torch.randn(1, 1024))
+        assert out1.shape == (1, 1024)
+        dummy_x = torch.randn(3, 1024, requires_grad=True)
+        out3 = real_expert(dummy_x)
+        assert out3.shape == (3, 1024)
+        out3_again = real_expert(dummy_x[1:])
+        assert torch.allclose(out3_again, out3[1:])
+        out3_again.norm().backward()
+        assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
+
+        with pytest.raises(grpc.RpcError):
+            real_expert(torch.randn(3, 11))
+        with pytest.raises(grpc.RpcError):
+            fake_expert(dummy_x)
+
+
+def test_moe_beam_search():
+    all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
+    dht = hivemind.DHT(start=True, expiration=999)
+    assert all(dht.declare_experts(all_expert_uids, endpoint='fake-endpoint'))
+
+    dmoe = hivemind.RemoteMixtureOfExperts(
+        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn')
+
+    for i in range(25):
+        input = torch.randn(32)
+        grid_scores = dmoe.proj(input).split_with_sizes(dmoe.grid_size, dim=-1)
+
+        chosen_experts = dmoe.loop.run_until_complete(dmoe.beam_search(grid_scores, k_best=dmoe.k_best))
+
+        chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
+                                                   [chosen_experts])[0]
+
+        all_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
+                                                [[hivemind.RemoteExpert(uid, '') for uid in all_expert_uids]])[0]
+        true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[:len(chosen_experts)]
+        our_best_scores = list(chosen_scores.cpu().detach().numpy())
+        assert np.allclose(true_best_scores, our_best_scores)
 
 
 def test_determinism():

+ 14 - 8
tests/test_utils/run_server.py

@@ -11,7 +11,7 @@ import hivemind
 from test_utils.layers import name_to_block, name_to_input
 
 
-def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hidden_dim=1024,
+def make_dummy_server(listen_on='0.0.0.0:*', num_experts=None, expert_uids=None, expert_cls='ffn', hidden_dim=1024,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
                       start=False, **kwargs) -> hivemind.Server:
@@ -19,11 +19,12 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
     Instantiate a server with several identical experts. See argparse comments below for details
     :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
     :param num_experts: run this many identical experts
+    :param expert_prefix: all expert uids will be {expert_prefix}.{index}
+    :param expert_offset: expert uid will use indices in range(expert_offset, expert_offset + num_experts)
+    :param expert_uids: spawn experts with these exact uids, overrides num_experts, expert_prefix and expert_offset
     :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
     :param hidden_dim: main dimension for expert_cls
     :param num_handlers: server will use this many parallel processes to handle incoming requests
-    :param expert_prefix: all expert uids will be {expert_prefix}.{index}
-    :param expert_offset: expert uid will use indices in range(expert_offset, expert_offset + num_experts)
     :param max_batch_size: total num examples in the same batch will not exceed this value
     :param device: all experts will use this device in torch notation; default: cuda if available else cpu
     :param no_optimizer: if specified, all optimizers use learning rate=0
@@ -36,6 +37,8 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
     :param verbose: whether to print server started / finished / terminated events
     :param start: if True, starts server right away and returns when server is ready for requests
     """
+    assert (expert_uids is None) != (num_experts is None and expert_prefix == 'expert' and expert_offset == 0), \
+        "Please provide either expert uids *or* (num_experts, expert_prefix and expert_offset), not both"
     if verbose and len(kwargs) != 0:
         print("Ignored kwargs:", kwargs)
     assert expert_cls in name_to_block
@@ -68,11 +71,15 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
         args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
 
     # initialize experts
+    if expert_uids is None:
+        num_experts = num_experts if num_experts is not None else 1
+        expert_uids = [f'{expert_prefix}{hivemind.DHT.UID_DELIMITER}{i + expert_offset}'
+                       for i in range(num_experts)]
+
     experts = {}
-    for i in range(num_experts):
+    for expert_uid in expert_uids:
         expert = name_to_block[expert_cls](hidden_dim)
         opt = torch.optim.SGD(expert.parameters(), 0.0 if no_optimizer else 0.05)
-        expert_uid = f'{expert_prefix}{hivemind.DHT.UID_DELIMITER}{i + expert_offset}'
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
                                                      args_schema=args_schema,
                                                      outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
@@ -108,7 +115,7 @@ def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tupl
         finally:
             if verbose:
                 print("Server failed to shutdown gracefully, terminating it the hard way...")
-            runner.terminate()
+            runner.kill()
             if verbose:
                 print("Server terminated.")
 
@@ -132,9 +139,8 @@ def _server_runner(pipe, *args, verbose, **kwargs):
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('--interface', type=str, default='0.0.0.0', required=False,
+    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
                         help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
-    parser.add_argument('--port', type=int, default=None, required=False, help="server will listen to this port")
     parser.add_argument('--num_experts', type=int, default=1, required=False, help="run this many identical experts")
     parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")