瀏覽代碼

RemoteMixtureOfExperts update part II (#80)

* strict msgpack types

* strict msgpack types

* wip: RemoteCallMany without utils.autograd (still need to update and test RemoteMixtureOfExperts!)

* wip: RemoteCallMany without utils.autograd (still need to update and test RemoteMixtureOfExperts!)

* wip: RemoteCallMany without utils.autograd (still need to update and test RemoteMixtureOfExperts!)

* implement and test RemoteMixtureOfExperts

* numpy -> torch

* order preference

* preferred test order

* split tests into test_remote_expert.py and test_moe.py

* reduce test_moe size

* move test_moe to the top

* magical test order

* Update tests/test_dht.py

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

* review {1337} => 1337

* .data.numpy => .detach.numpy

* review: full_like => full((1,))

* review: less hacky multiply

* unused import

* review: specify key order in docstring of first_k_active

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 5 年之前
父節點
當前提交
6d2b8094c9

+ 17 - 14
hivemind/client/expert.py

@@ -1,5 +1,6 @@
 import pickle
 import pickle
-from typing import Tuple, Optional
+from functools import lru_cache
+from typing import Tuple, Optional, Any
 
 
 import grpc
 import grpc
 import grpc.experimental.aio
 import grpc.experimental.aio
@@ -13,6 +14,19 @@ from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
 
 
+@lru_cache(maxsize=None)
+def _get_expert_stub(endpoint: Endpoint, aio: bool, *extra_options: Tuple[str, Any]):
+    """ Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache """
+    channel_options = [
+        ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)
+    ] + list(extra_options)
+    if aio:
+        channel = grpc.experimental.aio.insecure_channel(endpoint, options=channel_options)
+    else:
+        channel = grpc.insecure_channel(endpoint, options=channel_options)
+    return runtime_grpc.ConnectionHandlerStub(channel)
+
+
 class RemoteExpert(nn.Module):
 class RemoteExpert(nn.Module):
     """
     """
     A simple module that runs forward/backward of an expert hosted on a remote machine.
     A simple module that runs forward/backward of an expert hosted on a remote machine.
@@ -28,22 +42,11 @@ class RemoteExpert(nn.Module):
     def __init__(self, uid, endpoint: Endpoint):
     def __init__(self, uid, endpoint: Endpoint):
         super().__init__()
         super().__init__()
         self.uid, self.endpoint = uid, endpoint
         self.uid, self.endpoint = uid, endpoint
-        self._channel, self._stub, self._info = None, None, None
+        self._info = None
 
 
     @property
     @property
     def stub(self):
     def stub(self):
-        if self._channel is None:
-            self._channel = grpc.insecure_channel(self.endpoint, options=[
-                ('grpc.max_send_message_length', -1),
-                ('grpc.max_receive_message_length', -1)
-            ])
-        if self._stub is None:
-            self._stub = runtime_grpc.ConnectionHandlerStub(self._channel)
-        return self._stub
-
-    def __del__(self):
-        if self._channel is not None:
-            self._channel.close()
+        return _get_expert_stub(self.endpoint, aio=False)
 
 
     def forward(self, *args, **kwargs):
     def forward(self, *args, **kwargs):
         """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """
         """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """

+ 206 - 135
hivemind/client/moe.py

@@ -1,14 +1,20 @@
-from functools import partial
-from typing import Tuple, List, Optional
+from __future__ import annotations
+import time
+import asyncio
+from typing import Tuple, List, Optional, Awaitable, Set, Dict
 
 
-import numpy as np
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
+import grpc.experimental.aio
 
 
-from hivemind.client.expert import RemoteExpert, _RemoteModuleCall, DUMMY
-from hivemind.utils import nested_map, run_and_await_k, nested_pack, nested_flatten, run_in_background, \
-    run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
+import hivemind
+from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
+from hivemind.utils import nested_map, nested_pack, nested_flatten, runtime_grpc, runtime_pb2, \
+    serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
 
 
 
 
 class RemoteMixtureOfExperts(nn.Module):
 class RemoteMixtureOfExperts(nn.Module):
@@ -25,30 +31,31 @@ class RemoteMixtureOfExperts(nn.Module):
     :param uid_prefix: common prefix for all expert uids
     :param uid_prefix: common prefix for all expert uids
      expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
      expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
     :param dht: DHT where the experts reside
     :param dht: DHT where the experts reside
-    :param num_workers: number of threads for parallel dht operation
     :param k_best: queries this many experts with highest scores
     :param k_best: queries this many experts with highest scores
     :param k_min: makes sure at least this many experts returned output
     :param k_min: makes sure at least this many experts returned output
     :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
     :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
      Any expert that didn't manage to return output after that delay is considered unavailable
      Any expert that didn't manage to return output after that delay is considered unavailable
-    :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
     :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
     :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
      allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
      allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
      allow_broadcasting=False will raise an error
      allow_broadcasting=False will raise an error
     """
     """
 
 
-    def __init__(self, *, in_features, grid_size: Tuple[int], dht, k_best, k_min=1,
-                 forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
-                 uid_prefix='', expert_padding=None, allow_broadcasting=True):
+    def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, k_best: int, k_min: int = 1,
+                 forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
+                 backward_k_min: int = 1, backward_timeout: Optional[float] = None, uid_prefix='',
+                 allow_broadcasting=True, loop: asyncio.BaseEventLoop = None):
         super().__init__()
         super().__init__()
-        self.dht, self.grid_size = dht, grid_size
-        self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
+        self.dht, self.grid_size, self.uid_prefix = dht, grid_size, uid_prefix
+        self.loop = loop or asyncio.new_event_loop()
+        assert not self.loop.is_running(), "Event loop is already running. If in jupyter, please apply nest_asyncio " \
+            "(pip install nest_asyncio , https://pypi.org/project/nest-asyncio ) and send loop=asyncio.new_event_loop()"
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
         self.timeout_after_k_min = timeout_after_k_min
         self.timeout_after_k_min = timeout_after_k_min
         self.allow_broadcasting = allow_broadcasting
         self.allow_broadcasting = allow_broadcasting
 
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
-        self._outputs_schema = None
+        self._outputs_schema = None  # expert['info'][outputs_schema] from one of experts in the grid
 
 
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
     def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
         """
         """
@@ -69,25 +76,31 @@ class RemoteMixtureOfExperts(nn.Module):
 
 
         # 1. compute scores and find most appropriate experts with beam search
         # 1. compute scores and find most appropriate experts with beam search
         grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
         grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
-        chosen_experts = self.beam_search(grid_scores, self.k_best)
-        # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
 
 
-        expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
+        async def _search():
+            coroutines = [asyncio.create_task(self.beam_search(
+                [dim_scores[i] for dim_scores in grid_scores], self.k_best))
+                for i in range(len(input))]
+            return list(await asyncio.gather(*coroutines))
 
 
-        expert_inputs = ((input, *args), kwargs)
-        input_schema = nested_map(lambda x: None, expert_inputs)
-        flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
+        chosen_experts: List[List[RemoteExpert]] = self.loop.run_until_complete(_search())
+        # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
 
 
-        batch_jobs_args = tuple(
-            (expert_logits[i, :len(chosen_experts[i])], chosen_experts[i], self.k_min, self.timeout_after_k_min,
-             self.backward_k_min, self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
-            for i in range(len(input))
-        )
+        expert_mask, *expert_outputs = _RemoteCallMany.apply(
+            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min,
+            self.forward_timeout, self.backward_timeout, self.loop, *nested_flatten(((input, *args), kwargs)))
+        # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
 
-        averaged_outputs_flat = map(torch.cat, zip(*map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)))
+        expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
+        masked_logits = torch.full((1,), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype)
+        expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
+        expert_weights = torch.softmax(expert_logits, dim=1)
+        averaged_outputs_flat = [
+            (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
+            for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
         return nested_pack(averaged_outputs_flat, self.outputs_schema)
         return nested_pack(averaged_outputs_flat, self.outputs_schema)
 
 
-    def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
+    async def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[RemoteExpert]:
         """
         """
         Find and return k best experts in the grid using (exact) beam search of the product space
         Find and return k best experts in the grid using (exact) beam search of the product space
 
 
@@ -99,51 +112,39 @@ class RemoteMixtureOfExperts(nn.Module):
          RemoteExpert instances for *up to* k_best experts
          RemoteExpert instances for *up to* k_best experts
         """
         """
         assert len(grid_scores) == len(self.grid_size)
         assert len(grid_scores) == len(self.grid_size)
-        assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
-        batch_size = len(grid_scores[0])
-        beam = np.array([[self.uid_prefix]] * batch_size, dtype=object)  # [batch_size, up_to_beam_size]
-        scores = np.zeros([batch_size, 1], dtype=np.float64)
+        assert all(dim_scores.shape == (self.grid_size[dim_index],) for dim_index, dim_scores in enumerate(grid_scores))
+        grid_scores = [dim_scores.cpu().detach() for dim_scores in grid_scores]
 
 
-        delimiters = np.array(self.dht.UID_DELIMITER)[None, None, None]  # pre-compute numpy array for fast concat
+        beam_experts: List[RemoteExpert] = []
+        beam: List[str] = [self.uid_prefix]
+        beam_scores = torch.zeros(1)
 
 
         for dim_index, dim_scores in enumerate(grid_scores):
         for dim_index, dim_scores in enumerate(grid_scores):
-            dim_scores = dim_scores.detach().cpu().numpy()
-            assert dim_scores.shape[-1] == self.grid_size[dim_index]
-
-            # create all possible successsors from current beam
-            dim_indices = np.arange(dim_scores.shape[1]).astype(str)
-            new_candidates = beam[:, :, None] + delimiters + dim_indices[None, None, :]
-            new_candidates = new_candidates.reshape([batch_size, -1])
+            # create all possible successors from current beam and sort them by total score
+            expanded_scores = beam_scores[:, None] + dim_scores[None, :]
+            sorted_indices = [(flat_i // len(dim_scores), flat_i % len(dim_scores))
+                              for flat_i in (-expanded_scores).flatten().argsort().numpy()]
 
 
-            new_scores = scores[:, :, None] + dim_scores[:, None, :]
-            new_scores = new_scores.reshape([batch_size, -1])
+            sorted_candidates = [f"{beam[row]}{self.dht.UID_DELIMITER}{col}" for row, col in sorted_indices]
+            candidate_to_indices = dict(zip(sorted_candidates, sorted_indices))
 
 
             # select k best candidates according to scores but only those that are still active
             # select k best candidates according to scores but only those that are still active
-            new_order = np.argsort(- new_scores, axis=-1)
-            top_alive_lookups = [
-                run_in_background(self.dht.first_k_active, cands[order], k_best, **kwargs)
-                for cands, order in zip(new_candidates, new_order)]
-
-            batch_cand_to_score = [
-                dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
-
-            top_alive_prefixes = [result.result() for result in top_alive_lookups]
-            top_alive_scores = [list(map(cand_to_score.get, top_cands))
-                                for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
-
-            # pad up to beam size
-            beam = np.array([row + [self.expert_padding] * (k_best - len(row))
-                             for row in top_alive_prefixes], dtype='object')
-            scores = np.array([row + [-float('inf')] * (k_best - len(row))
-                               for row in top_alive_scores], dtype='float32')
-
-        unique_experts = self.dht.get_experts(list(set(
-            uid for row in beam for uid in row if uid != self.expert_padding)))
+            best_alive_prefixes: Dict[str, RemoteExpert] = await self.dht.first_k_active(
+                uid_prefixes=sorted_candidates, k=k_best, return_future=True, **kwargs)
+            if not best_alive_prefixes:
+                logger.warning(f"Grid is empty: found neither of {sorted_candidates}")
+                break
+            beam = list(best_alive_prefixes.keys())
+            beam_scores = expanded_scores[tuple(zip(*map(candidate_to_indices.get, beam)))]
+            beam_experts = list(best_alive_prefixes.values())
+
         if self._outputs_schema is None:
         if self._outputs_schema is None:
-            self._outputs_schema = next(iter(unique_experts)).info['outputs_schema']
-        unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
+            try:
+                self._outputs_schema = beam_experts[0].info['outputs_schema']
+            except grpc.RpcError as e:
+                logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
 
-        return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
+        return beam_experts
 
 
     def compute_expert_scores(
     def compute_expert_scores(
             self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
             self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
@@ -164,11 +165,11 @@ class RemoteMixtureOfExperts(nn.Module):
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
         flat_experts = [expert for row in batch_experts for expert in row]
         flat_experts = [expert for row in batch_experts for expert in row]
 
 
-        grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
+        grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
         for i, expert in enumerate(flat_experts):
         for i, expert in enumerate(flat_experts):
             expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMITER):]
             expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMITER):]
             expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMITER)))
             expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMITER)))
-            grid_indices[i] = expert_indices
+            grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
 
         scores_per_dim = [
         scores_per_dim = [
             dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
             dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
@@ -183,86 +184,156 @@ class RemoteMixtureOfExperts(nn.Module):
     def outputs_schema(self):
     def outputs_schema(self):
         if self._outputs_schema is None:
         if self._outputs_schema is None:
             # grab some expert to set ensemble output shape
             # grab some expert to set ensemble output shape
-            dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(self.grid_size, dim=-1)
-            self._outputs_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['outputs_schema']
+            dummy_scores = self.proj(torch.randn(self.proj.in_features)).cpu().split_with_sizes(self.grid_size, dim=-1)
+            dummy_experts = self.loop.run_until_complete(self.beam_search(dummy_scores, k_best=1))
+            self._outputs_schema = dummy_experts[0].info['outputs_schema']
         return self._outputs_schema
         return self._outputs_schema
 
 
 
 
-class _RemoteMoECall(torch.autograd.Function):
+class _RemoteCallMany(torch.autograd.Function):
     """
     """
-    Internal autograd-friendly function that calls multiple experts on the same input and averages their outputs.
-    This function that can recover from individual failures during forward and/or backward passes.
-    For user-friendly version of this function, use RemoteMixtureOfExperts module.
+    Internal autograd-friendly function that calls multiple experts on a batch of inputs and awaits responses
+    This function that can recover from individual failures during forward and/or backward pass as long as at least
+    one expert succeeds for each input. For user-friendly version of this function, use RemoteMixtureOfExperts module.
+
+    Note: experts that failed during forward will be assigned zero outputs and marked as mask[i, j] = 0,
+          experts that failed during backward will be treated as constants (i.e. gradients of through them are zeros)
     """
     """
 
 
     @classmethod
     @classmethod
-    def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
-                k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
-                backward_timeout: Optional[float], input_schema, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
-        expert_args, expert_kwargs = nested_pack(flat_inputs, structure=input_schema)
-        assert expert_logits.ndim == 1 and len(expert_logits) == len(experts)
-
-        # 1. call experts and await results
-        jobs = [partial(cls._run_expert_forward, expert, *expert_args, **expert_kwargs) for expert in experts]
-        results = run_and_await_k(jobs, k=k_min, timeout_after_k=timeout_after_k_min, timeout_total=timeout_total)
-
-        alive_contexts, alive_outputs, alive_ix = zip(*[(result[0], result[1], ix) for ix, result in enumerate(results)
-                                                        if not isinstance(result, BaseException)])
-        #     ^               ^            ^-- a list of indices of experts that returned outputs in time
-        #      \               \-- list of outputs of every expert that didn't die on us
-        #       \-- a list of autograd contexts, used for parallel backward
-
-        # 2. compute softmax weights for alive experts and average outputs
-        alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
-        alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
-
-        stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
-
-        flat_average_outputs = tuple((alive_expert_probs @ stacked_out.flatten(1)).view(*stacked_out.shape[1:])
-                                     for stacked_out in stacked_alive_outputs)
-
-        # 3. save individual outputs for backward pass
-        ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
-        ctx._saved_non_tensors = alive_contexts, backward_k_min, backward_timeout
-        return tuple(map(torch.Tensor.detach, flat_average_outputs))
+    def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
+                timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
+                loop: asyncio.base_events.BaseEventLoop, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+        assert not torch.is_grad_enabled()
+        num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
+        flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
+        assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
+
+        async def _forward():
+            # dispatch tasks to all remote experts, await responses
+            pending_tasks = {
+                asyncio.create_task(cls._forward_one_expert((i, j), expert, flat_inputs_per_sample[i]))
+                for i in range(num_samples) for j, expert in enumerate(experts_per_sample[i])
+            }
+            alive_grid_indices, alive_flat_outputs = await cls._wait_for_responses(
+                pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min)
+
+            # assemble responses
+            alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
+            mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
+            mask[alive_ii, alive_jj] = True
+
+            alive_flat_outputs_stacked = list(map(torch.cat, zip(*alive_flat_outputs)))
+            # list of torch tensors, where i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
+
+            outputs = []
+            for response_stacked in alive_flat_outputs_stacked:
+                output = torch.zeros(
+                    [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
+                    dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
+                output[alive_ii, alive_jj] = response_stacked
+                outputs.append(output)
+
+            # save individual outputs for backward pass
+            ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
+            ctx._saved_non_tensors = loop, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
+            return (mask,) + tuple(outputs)
+
+        return loop.run_until_complete(_forward())
 
 
     @classmethod
     @classmethod
     @once_differentiable
     @once_differentiable
-    def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
-        """ Like normal backward, but we ignore any experts that failed during backward pass """
-        expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
-        alive_contexts, backward_k_min, backward_timeout = ctx._saved_non_tensors
-
-        jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
-                for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
-        results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=backward_timeout, timeout_total=None)
-        backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
-        backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
-        backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
-        survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
-        weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
-        flat_grad_inputs = tuple((weight_ratios @ stacked_grad_inp.flatten(1)).view(stacked_grad_inp.shape[1:])
-                                 for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
-
-        # compute grad w.r.t. logits
-        grad_wrt_probs = sum(tuple(
-            torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
-                      dim=tuple(range(1, stacked_avive_out.ndim)))
-            for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
-        ))
-        softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
-        grad_wrt_survived_logits = grad_wrt_probs @ softmax_jacobian
-        grad_wrt_logits = torch.zeros_like(expert_logits).scatter(0, backward_survivors_ix, grad_wrt_survived_logits)
-
-        return (grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs)
+    def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
+        assert not torch.is_grad_enabled()
+        loop, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
+        alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
+        dummy_grad_mask, *flat_grad_outputs = raw_grads
+        num_samples, max_experts = dummy_grad_mask.shape
+
+        inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs))
+        grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs))
+
+        async def _backward():
+            # dispatch tasks to all remote experts, await responses
+            pending_tasks = set()
+            for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
+                                                        inputs_per_expert, grad_outputs_per_expert):
+                pending_tasks.add(asyncio.create_task(
+                    cls._backward_one_expert((i, j), expert_per_sample[i.item()][j.item()], inputs_ij, grad_outputs_ij)
+                ))
+
+            backward_survivor_indices, survivor_grad_inputs = await cls._wait_for_responses(
+                pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min)
+
+            # assemble responses
+            backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices))
+            survivor_grad_inputs_stacked = list(map(torch.cat, zip(*survivor_grad_inputs)))
+            # list of torch tensors, where i-th tensor is of shape [num_backward_survivors, *flat_inputs[i].shape]
+
+            grad_inputs = []
+            for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
+                grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
+                    (num_samples, max_experts, *flat_inputs[i].shape[1:]),
+                    device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
+                grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
+
+                grad_inputs.append(grad_input_per_expert.sum(dim=1))  # add up gradients from each expert
+
+            return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
+        return loop.run_until_complete(_backward())
+
+    @staticmethod
+    async def _forward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor]):
+        stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
+        try:
+            outputs = await stub.forward(runtime_pb2.ExpertRequest(
+                uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
+            return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
+        except grpc.experimental.aio.AioRpcError as error:
+            logger.warning(f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})")
 
 
     @staticmethod
     @staticmethod
-    def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
-        """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
-        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.stub, *nested_flatten((args, kwargs)))
+    async def _backward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert,
+                                   inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]):
+        stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint, aio=True)
+        payload = tuple(nested_flatten((inputs, grad_outputs)))
+        try:
+            grad_inputs = await stub.backward(runtime_pb2.ExpertRequest(
+                uid=expert.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
+            return grid_indices, tuple(deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors)
+        except grpc.experimental.aio.AioRpcError as error:
+            logger.warning(f"RemoteExpert {expert} failed backward: {error.code()} ({inputs}, {grad_outputs})")
 
 
     @staticmethod
     @staticmethod
-    def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
-        backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
-        grad_dummy, no_grad_uid, no_grad_stub, *grad_inputs = backward_result
-        return grad_inputs
+    async def _wait_for_responses(
+            pending_tasks: Set[Awaitable[Tuple[Tuple[int, int], Tuple[torch.Tensor, ...]]]],
+            num_samples: int, k_min: int, timeout_total: Optional[float], timeout_after_k_min: Optional[float]
+            ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
+        """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
+        timeout_total = float('inf') if timeout_total is None else timeout_total
+        timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min
+        num_successful_tasks = [0 for _ in range(num_samples)]
+        pending_samples = num_samples  # samples for which we have less than k_min results
+        finished_indices, finished_outputs = [], []
+        t_finish = time.perf_counter() + timeout_total
+
+        while pending_tasks and time.perf_counter() <= t_finish:
+            finished_tasks, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_COMPLETED,
+                                                               timeout=t_finish - time.perf_counter())
+            for task in finished_tasks:
+                if not task.result():
+                    continue
+                task_indices, task_flat_outputs = await task
+                finished_indices.append(task_indices)
+                finished_outputs.append(task_flat_outputs)
+
+                sample_index = task_indices[0]
+                num_successful_tasks[sample_index] += 1
+                if num_successful_tasks[sample_index] == k_min:
+                    pending_samples -= 1
+                    if pending_samples <= 0:  # all tasks finished, await stragglers for at most timeout_after_k_min
+                        t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
+
+        for task in pending_tasks:
+            task.cancel()
+        return finished_indices, finished_outputs

+ 25 - 19
hivemind/dht/__init__.py

@@ -16,9 +16,9 @@ import asyncio
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import warnings
 import warnings
-from collections import deque
+from collections import deque, OrderedDict
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from typing import List, Optional, Sequence
+from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable
 
 
 import uvloop
 import uvloop
 
 
@@ -126,17 +126,17 @@ class DHT(mp.Process):
         return self._port.value if self._port.value != 0 else None
         return self._port.value if self._port.value != 0 else None
 
 
     def get_experts(self, uids: List[str], expiration_time: Optional[DHTExpiration] = None,
     def get_experts(self, uids: List[str], expiration_time: Optional[DHTExpiration] = None,
-                    wait=True) -> List[Optional[RemoteExpert]]:
+                    return_future=False) -> List[Optional[RemoteExpert]]:
         """
         """
         :param uids: find experts with these ids from across the DHT
         :param uids: find experts with these ids from across the DHT
         :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
         :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
-        :param wait: if True (default), return when experts are returned. Otherwise return a Future.
+        :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
         :returns: a list of [RemoteExpert if found else None]
         :returns: a list of [RemoteExpert if found else None]
         """
         """
         assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
         assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_get_experts', [], dict(uids=uids, expiration_time=expiration_time, future=_future)))
         self.pipe.send(('_get_experts', [], dict(uids=uids, expiration_time=expiration_time, future=_future)))
-        return future.result() if wait else future
+        return future if return_future else future.result()
 
 
     async def _get_experts(
     async def _get_experts(
             self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
             self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
@@ -144,8 +144,8 @@ class DHT(mp.Process):
             expiration_time = get_dht_time()
             expiration_time = get_dht_time()
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         response = await node.get_many(uids, expiration_time, num_workers=num_workers)
         response = await node.get_many(uids, expiration_time, num_workers=num_workers)
-        future.set_result([RemoteExpert(uid, maybe_endpoint) if maybe_expiration_time else None
-                           for uid, (maybe_endpoint, maybe_expiration_time) in response.items()])
+        future.set_result([RemoteExpert(**expert_data) if maybe_expiration_time else None
+                           for uid, (expert_data, maybe_expiration_time) in response.items()])
 
 
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
         """
         """
@@ -172,14 +172,16 @@ class DHT(mp.Process):
             uid_parts = uid.split(self.UID_DELIMITER)
             uid_parts = uid.split(self.UID_DELIMITER)
             for i in range(len(uid_parts)):
             for i in range(len(uid_parts)):
                 uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
                 uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
-                data_to_store[uid_prefix_i] = endpoint
+                data_to_store[uid_prefix_i] = {'uid': uid, 'endpoint': endpoint}
 
 
         store_keys, store_values = zip(*data_to_store.items())
         store_keys, store_values = zip(*data_to_store.items())
         store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers)
         store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers)
         if future is not None:
         if future is not None:
             future.set_result([store_ok[key] for key in data_to_store.keys()])
             future.set_result([store_ok[key] for key in data_to_store.keys()])
 
 
-    def first_k_active(self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None):
+    def first_k_active(
+            self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None,
+            return_future=False) -> Union[TOrderedDict[str, RemoteExpert], Awaitable[TOrderedDict[str, RemoteExpert]]]:
         """
         """
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
 
 
@@ -187,20 +189,22 @@ class DHT(mp.Process):
         :param k: return at most *this many* active prefixes
         :param k: return at most *this many* active prefixes
         :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
         :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
         :param chunk_size: dispatch this many requests in one task
         :param chunk_size: dispatch this many requests in one task
-        :returns: a list of at most :k: prefixes that have at least one active expert each;
+        :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
+        :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
+            The keys in the returned dict are ordered same as in uid_prefixes.
         """
         """
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
         self.pipe.send(('_first_k_active', [],
                         dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch,
                         dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch,
                              chunk_size=chunk_size or k, future=_future)))
                              chunk_size=chunk_size or k, future=_future)))
-        return future.result()
+        return future if return_future else future.result()
 
 
     async def _first_k_active(
     async def _first_k_active(
             self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture):
             self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture):
         num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
         num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
         total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
         total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
-        active_prefixes = []
+        found: List[Tuple[str, RemoteExpert]] = []
 
 
         pending_tasks = deque(
         pending_tasks = deque(
             asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size],
             asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size],
@@ -212,14 +216,13 @@ class DHT(mp.Process):
             # parse task results in chronological order, launch additional tasks on demand
             # parse task results in chronological order, launch additional tasks on demand
             response = await pending_tasks.popleft()
             response = await pending_tasks.popleft()
             for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
             for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
-                if response[uid_prefix][1] is not None:  # found active peer
-                    active_prefixes.append(uid_prefix)
+                maybe_expert_data, maybe_expiration_time = response[uid_prefix]
+                if maybe_expiration_time is not None:  # found active peer
+                    found.append((uid_prefix, RemoteExpert(**maybe_expert_data)))
                     # if we found enough active experts, finish immediately
                     # if we found enough active experts, finish immediately
-                    if len(active_prefixes) >= k:
+                    if len(found) >= k:
                         break
                         break
-            if len(active_prefixes) >= k:
-                for task in pending_tasks:
-                    task.cancel()
+            if len(found) >= k:
                 break
                 break
 
 
             pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
             pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
@@ -228,5 +231,8 @@ class DHT(mp.Process):
                     uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size],
                     uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size],
                     num_workers=num_workers_per_chunk)))
                     num_workers=num_workers_per_chunk)))
 
 
+        for task in pending_tasks:
+            task.cancel()
+
         # return k active prefixes or as many as we could find
         # return k active prefixes or as many as we could find
-        future.set_result(active_prefixes)
+        future.set_result(OrderedDict(found))

+ 1 - 1
hivemind/dht/node.py

@@ -67,7 +67,7 @@ class DHTNode:
         :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
         :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
         :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
         :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
           Reduce this value if your RPC requests register no response despite the peer sending the response.
           Reduce this value if your RPC requests register no response despite the peer sending the response.
-        :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
+        :param wait_timeout: a kademlia rpc request is deemed lost if we did not receive a reply in this many seconds
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds

+ 2 - 2
hivemind/server/__init__.py

@@ -41,8 +41,8 @@ class Server(threading.Thread):
         super().__init__()
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         if get_port(listen_on) is None:
         if get_port(listen_on) is None:
-            self.listen_on = listen_on = replace_port(listen_on, new_port=find_open_port())
-        self.port = get_port(listen_on)
+            listen_on = replace_port(listen_on, new_port=find_open_port())
+        self.listen_on, self.port = listen_on, get_port(listen_on)
 
 
         self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
         self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:
         if checkpoint_dir is not None:

+ 0 - 1
hivemind/utils/__init__.py

@@ -4,6 +4,5 @@ from hivemind.utils.tensor_descr import *
 from hivemind.utils.serializer import *
 from hivemind.utils.serializer import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.threading import *
 from hivemind.utils.threading import *
-from hivemind.utils.autograd import *
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger

+ 0 - 100
hivemind/utils/autograd.py

@@ -1,100 +0,0 @@
-"""
-Temporary autograd extensions to enable inter-op parallelism during backward pass
-Note: we should get rid of this module if https://github.com/pytorch/pytorch/pull/33157 reaches a pytorch release
-"""
-from itertools import chain
-from typing import Tuple, Any
-from concurrent.futures import Future
-
-import numpy as np
-import torch
-import torch.autograd.function
-
-from hivemind.utils.threading import run_in_background
-
-
-class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
-    """
-    A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd,
-    such as running several parallel backwards or transferring backward to a separate device.
-    This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
-    """
-
-    @property
-    def saved_tensors(self):
-        return tuple(self.to_save)
-
-
-def run_isolated_forward(func: torch.autograd.Function, *args) -> Tuple[EmulatedAutogradContext, Any]:
-    """
-    run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
-    can be used to run backward through the same graph (performed manually by the user).
-    """
-    ctx = EmulatedAutogradContext()
-    # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
-    args = tuple(x.detach().requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x for x in args)
-    with torch.no_grad():
-        return ctx, func.forward(ctx, *args)
-
-
-def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):
-    """
-    run backward pass for :func: in an isolated graph that was previously created through run_isolated_forward
-    """
-    with torch.no_grad():
-        return func.backward(ctx, *grad_outputs)
-
-
-def map_with_parallel_backward(
-        func: torch.autograd.Function, *args_per_call: Tuple[torch.Tensor, ...]) -> Tuple[Tuple[torch.Tensor, ...]]:
-    """
-    Apply an autograd function to several sets of inputs with two extra guarantees:
-    (1) both forward and backward pass happens concurrently for each set of inputs
-    (2) any operation dependent on any individual function will wait for all functions to finish
-    :param func: torch autograd function to be called several times in parallel
-    :param args_per_call: a sequence of tuples of arguments, each tuple corresponds to one function call
-    :returns: a tuple of outputs from each func call
-
-    Note: this function currently requires that all :func: calls succeed (i.e. do not raise an exception).
-    """
-    arg_counts = list(map(len, args_per_call))
-    assert len(set(arg_counts)) == 1, "All input sets must have the same number of arguments"
-    output_strides_ph = Future()
-    flat_outputs: Tuple[torch.Tensor, ...] = _ParallelApplyFunction.apply(
-        func, len(args_per_call), arg_counts[0], output_strides_ph, *chain(*args_per_call))
-    output_strides = output_strides_ph.result()
-    return tuple(flat_outputs[output_strides[i]: output_strides[i + 1]] for i in range(len(output_strides) - 1))
-
-
-class _ParallelApplyFunction(torch.autograd.Function):
-    """
-    A special torch autograd function that runs another function several times in parallel.
-    Please do not call this function directly. Use apply_with_parallel_backward instead.
-    Unlike default pytorch behavior, the backward pass for each function will also happen in parallel.
-    """
-
-    @staticmethod
-    def forward(ctx, func: torch.autograd.Function, num_calls: int, num_args_per_call: int,
-                output_strides_ph: Future, *args_flat) -> Tuple[torch.Tensor, ...]:
-        assert num_calls * num_args_per_call == len(args_flat)
-        args_per_call = [args_flat[i * num_args_per_call: (i + 1) * num_args_per_call] for i in range(num_calls)]
-
-        futures = [run_in_background(run_isolated_forward, func, *args) for args in args_per_call]
-
-        contexts, outputs = zip(*[future.result() for future in futures])
-        output_strides = np.cumsum([0] + list(map(len, outputs)))
-        ctx._inner_func = func
-        ctx._call_contexts = contexts
-        ctx._output_strides = output_strides
-        output_strides_ph.set_result(output_strides)
-        return tuple(chain(*outputs))
-
-    @staticmethod
-    def backward(ctx, *grad_outputs_flat: torch.Tensor):
-        func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
-        grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]]
-                                 for i in range(len(contexts))]
-        futures = [run_in_background(run_isolated_backward, func, context, *grads)
-                   for context, grads in zip(contexts, grad_outputs_per_call)]
-        flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())
-        return (None, None, None, None, *flat_grads_wrt_input)

+ 0 - 1
hivemind/utils/networking.py

@@ -1,5 +1,4 @@
 import socket
 import socket
-import urllib.parse
 from contextlib import closing
 from contextlib import closing
 from typing import Optional
 from typing import Optional
 
 

+ 1 - 1
hivemind/utils/serializer.py

@@ -40,7 +40,7 @@ class PytorchSerializer(SerializerBase):
 class MSGPackSerializer(SerializerBase):
 class MSGPackSerializer(SerializerBase):
     @staticmethod
     @staticmethod
     def dumps(obj: object) -> bytes:
     def dumps(obj: object) -> bytes:
-        return umsgpack.dumps(obj, use_bin_type=False)  # TODO strict https://github.com/msgpack/msgpack-python/pull/158
+        return umsgpack.dumps(obj, use_bin_type=False, strict_types=True)
 
 
     @staticmethod
     @staticmethod
     def loads(buf: bytes) -> object:
     def loads(buf: bytes) -> object:

+ 1 - 50
hivemind/utils/threading.py

@@ -1,7 +1,5 @@
 import os
 import os
-from concurrent.futures import Future, as_completed, TimeoutError, ThreadPoolExecutor
-import time
-from typing import Optional, List
+from concurrent.futures import Future, ThreadPoolExecutor
 
 
 EXECUTOR_PID, GLOBAL_EXECUTOR = None, None
 EXECUTOR_PID, GLOBAL_EXECUTOR = None, None
 
 
@@ -13,50 +11,3 @@ def run_in_background(func: callable, *args, **kwargs) -> Future:
         GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
         GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
         EXECUTOR_PID = os.getpid()
         EXECUTOR_PID = os.getpid()
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
-
-
-def run_and_await_k(jobs: List[callable], k: int,
-                    timeout_after_k: Optional[float] = 0, timeout_total: Optional[float] = None):
-    """
-    Runs all :jobs: asynchronously, awaits for at least k of them to finish
-    :param jobs: functions to call asynchronously
-    :param k: how many functions should finish for call to be successful
-    :param timeout_after_k: after reaching k finished jobs, wait for this long before cancelling
-    :param timeout_total: if specified, terminate cancel jobs after this many seconds
-    :returns: a list of either results or exceptions for each job
-    """
-    jobs = list(jobs)
-    assert k <= len(jobs), f"Can't await {k} out of {len(jobs)} jobs."
-    start_time = time.time()
-    future_to_ix = {run_in_background(job): i for i, job in enumerate(jobs)}
-    outputs = [None] * len(jobs)
-    success_count = 0
-
-    try:
-        # await first k futures for as long as it takes
-        for future in as_completed(list(future_to_ix.keys()), timeout=timeout_total):
-            success_count += int(not future.exception())
-            outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
-            if success_count >= k:
-                break  # we have enough futures to succeed
-            if len(outputs) + len(future_to_ix) < k:
-                failed = len(jobs) - len(outputs) - len(future_to_ix)
-                raise ValueError(f"Couldn't get enough results: too many jobs failed ({failed} / {len(outputs)})")
-
-        # await stragglers for at most self.timeout_after_k_min or whatever time is left
-        if timeout_after_k is not None and timeout_total is not None:
-            time_left = min(timeout_after_k, timeout_total - time.time() + start_time)
-        else:
-            time_left = timeout_after_k if timeout_after_k is not None else timeout_total
-        for future in as_completed(list(future_to_ix.keys()), timeout=time_left):
-            success_count += int(not future.exception())
-            outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
-
-    except TimeoutError:
-        if len(outputs) < k:
-            raise TimeoutError(f"Couldn't get enough results: time limit exceeded (got {len(outputs)} of {k})")
-    finally:
-        for future, index in future_to_ix.items():
-            future.cancel()
-            outputs[index] = future.result() if not future.exception() else future.exception()
-    return outputs

+ 0 - 1
tests/benchmark_dht.py

@@ -1,7 +1,6 @@
 import time
 import time
 import argparse
 import argparse
 import random
 import random
-from typing import Tuple
 from warnings import warn
 from warnings import warn
 import hivemind
 import hivemind
 from tqdm import trange
 from tqdm import trange

+ 32 - 7
tests/test_dht.py

@@ -290,14 +290,14 @@ def test_hivemind_dht():
     assert you_found.endpoint == f'that_host:{that_guys_port}'
     assert you_found.endpoint == f'that_host:{that_guys_port}'
 
 
     # test first_k_active
     # test first_k_active
-    assert theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10) == expert_uids[:10]
+    assert list(theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10)) == expert_uids[:10]
 
 
     some_permuted_experts = random.sample(expert_uids, k=32)
     some_permuted_experts = random.sample(expert_uids, k=32)
-    assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32) == some_permuted_experts
-    assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1) == some_permuted_experts[:1]
+    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32)) == some_permuted_experts
+    assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1)) == some_permuted_experts[:1]
     fake_and_real_experts = list(chain(*zip(
     fake_and_real_experts = list(chain(*zip(
         [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
         [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
-    assert theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9) == some_permuted_experts[:9]
+    assert list(theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9)) == some_permuted_experts[:9]
 
 
     for peer in peers:
     for peer in peers:
         peer.shutdown()
         peer.shutdown()
@@ -305,10 +305,35 @@ def test_hivemind_dht():
 
 
 def test_dht_single_node():
 def test_dht_single_node():
     node = hivemind.DHT(start=True)
     node = hivemind.DHT(start=True)
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:{1337}"))
+    assert node.first_k_active(['e3', 'e2'], k=3) == {}
+    assert node.get_experts(['e3', 'e2']) == [None, None]
+
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
     for expert in node.get_experts(['e3', 'e2']):
     for expert in node.get_experts(['e3', 'e2']):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:{1337}"
-    assert node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2) == ['e1', 'e3']
+        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+    active_found = node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2)
+    assert list(active_found.keys()) == ['e1', 'e3']
+    assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
+
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
+
+
+def test_first_k_active():
+    node = hivemind.DHT(start=True)
+    assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
+    assert all(node.declare_experts(['e.2.1.1'], endpoint=f"{hivemind.LOCALHOST}:1338"))
+
+    results = node.first_k_active(['e.0', 'e.1', 'e.2', 'e.3'], k=2)
+    assert len(results) == 2 and next(iter(results.keys())) == 'e.1'
+    assert results['e.1'].uid in ('e.1.2.3', 'e.1.2.4') and results['e.1'].endpoint == f"{hivemind.LOCALHOST}:1337"
+    assert results['e.2'].uid == 'e.2.1.1' and results['e.2'].endpoint == f"{hivemind.LOCALHOST}:1338"
+
+    results = node.first_k_active(['e', 'e.1', 'e.1.2', 'e.1.2.3'], k=10)
+    assert len(results) == 4
+    assert 'e' in results
+    for k in ('e.1', 'e.1.2', 'e.1.2.3'):
+        assert results[k].uid in ('e.1.2.3', 'e.1.2.4') and results[k].endpoint == f"{hivemind.LOCALHOST}:1337"
+
 
 
 
 
 def test_store():
 def test_store():

+ 110 - 29
tests/test_moe.py

@@ -1,46 +1,127 @@
+import asyncio
+
+import grpc
+import numpy as np
+import pytest
 import torch
 import torch
 import hivemind
 import hivemind
+from hivemind.client.expert import DUMMY
 from test_utils.run_server import background_server
 from test_utils.run_server import background_server
 
 
 
 
-def test_remote_module_call():
-    """ Check that remote_module_call returns correct outputs and gradients if called directly """
-    num_experts = 8
+def test_moe():
+    all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
+                       for _ in range(20)]
+    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
+                           num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
+
+        dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
+        # declare expert uids. Server *should* declare them by itself, but it takes time.
+        assert all(dht.declare_experts(all_expert_uids, endpoint=server_endpoint))
+
+        dmoe = hivemind.RemoteMixtureOfExperts(
+            in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn')
+
+        for i in range(10):
+            out = dmoe(torch.randn(10, 16))
+            out.sum().backward()
+
+
+def test_call_many():
     k_min = 1
     k_min = 1
     timeout_after_k_min = None
     timeout_after_k_min = None
     backward_k_min = 1
     backward_k_min = 1
-    timeout_total = None
+    forward_timeout = None
     backward_timeout = None
     backward_timeout = None
     rtol = 1e-3
     rtol = 1e-3
     atol = 1e-6
     atol = 1e-6
 
 
-    xx = torch.randn(32, 1024, requires_grad=True)
-    logits = torch.randn(3, requires_grad=True)
-    random_proj = torch.randn_like(xx)
+    with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
+                           no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
 
 
-    with background_server(num_experts=num_experts, device='cpu', num_handlers=1,
+        inputs = torch.randn(4, 64, requires_grad=True)
+        inputs_clone = inputs.clone().detach().requires_grad_(True)
+        e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]
+        e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
+
+        mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
+            DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
+            k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout,
+            asyncio.new_event_loop(), inputs
+        )
+        assert mask.shape == (4, 3)
+        assert expert_outputs.shape == (4, 3, 64)
+
+        assert np.all(mask.data.numpy() == np.array([[True, True, True],
+                                                     [True, True, False],
+                                                     [True, False, True],
+                                                     [False, False, False]])), f"Incorrect mask, {mask}"
+
+        reference_outputs = torch.zeros_like(expert_outputs)
+        reference_outputs[0, 0] = e0(inputs_clone[0:1])
+        reference_outputs[0, 1] = e1(inputs_clone[0:1])
+        reference_outputs[0, 2] = e2(inputs_clone[0:1])
+        reference_outputs[1, 0] = e2(inputs_clone[1:2])
+        reference_outputs[1, 1] = e4(inputs_clone[1:2])
+        reference_outputs[2, 0] = e1(inputs_clone[2:3])
+        reference_outputs[2, 2] = e3(inputs_clone[2:3])
+
+        assert torch.allclose(expert_outputs, reference_outputs, rtol, atol)
+        proj = torch.randn(4, 64)
+        loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
+        loss.backward()
+        our_grad = inputs.grad.data.cpu().clone()
+
+        reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
+        reference_loss.backward()
+        reference_grad = inputs_clone.grad.data.cpu().clone()
+        assert torch.allclose(our_grad, reference_grad, rtol, atol)
+
+
+def test_remote_module_call():
+    with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=1024,
                            no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
                            no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
-        experts = [hivemind.RemoteExpert(uid=f'expert.{i}', endpoint=server_endpoint) for i in range(num_experts)]
-        moe_output, = hivemind.client.moe._RemoteMoECall.apply(
-            logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
-            [(None,), {}], xx)
-
-        grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
-        grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), logits, retain_graph=True)
-
-        # reference outputs: call all experts manually and average their outputs with softmax probabilities
-        probs = torch.softmax(logits, 0)
-        outs = [expert(xx) for expert in experts[:3]]
-        manual_output = sum(p * x for p, x in zip(probs, outs))
-        grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
-        grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
-        grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
-
-    assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun, rtol, atol), "Experts are non-deterministic. The test" \
-                                                                             " is only valid for deterministic experts"
-    assert torch.allclose(moe_output, manual_output, rtol, atol), "_RemoteMoECall returned incorrect output"
-    assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol, atol), "incorrect gradient w.r.t. input"
-    assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
+        real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
+        fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)
+
+        out1 = real_expert(torch.randn(1, 1024))
+        assert out1.shape == (1, 1024)
+        dummy_x = torch.randn(3, 1024, requires_grad=True)
+        out3 = real_expert(dummy_x)
+        assert out3.shape == (3, 1024)
+        out3_again = real_expert(dummy_x[1:])
+        assert torch.allclose(out3_again, out3[1:])
+        out3_again.norm().backward()
+        assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
+
+        with pytest.raises(grpc.RpcError):
+            real_expert(torch.randn(3, 11))
+        with pytest.raises(grpc.RpcError):
+            fake_expert(dummy_x)
+
+
+def test_moe_beam_search():
+    all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
+    dht = hivemind.DHT(start=True, expiration=999)
+    assert all(dht.declare_experts(all_expert_uids, endpoint='fake-endpoint'))
+
+    dmoe = hivemind.RemoteMixtureOfExperts(
+        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn')
+
+    for i in range(25):
+        input = torch.randn(32)
+        grid_scores = dmoe.proj(input).split_with_sizes(dmoe.grid_size, dim=-1)
+
+        chosen_experts = dmoe.loop.run_until_complete(dmoe.beam_search(grid_scores, k_best=dmoe.k_best))
+
+        chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
+                                                   [chosen_experts])[0]
+
+        all_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
+                                                [[hivemind.RemoteExpert(uid, '') for uid in all_expert_uids]])[0]
+        true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[:len(chosen_experts)]
+        our_best_scores = list(chosen_scores.cpu().detach().numpy())
+        assert np.allclose(true_best_scores, our_best_scores)
 
 
 
 
 def test_determinism():
 def test_determinism():

+ 14 - 8
tests/test_utils/run_server.py

@@ -11,7 +11,7 @@ import hivemind
 from test_utils.layers import name_to_block, name_to_input
 from test_utils.layers import name_to_block, name_to_input
 
 
 
 
-def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hidden_dim=1024,
+def make_dummy_server(listen_on='0.0.0.0:*', num_experts=None, expert_uids=None, expert_cls='ffn', hidden_dim=1024,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
                       start=False, **kwargs) -> hivemind.Server:
                       start=False, **kwargs) -> hivemind.Server:
@@ -19,11 +19,12 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
     Instantiate a server with several identical experts. See argparse comments below for details
     Instantiate a server with several identical experts. See argparse comments below for details
     :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
     :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
     :param num_experts: run this many identical experts
     :param num_experts: run this many identical experts
+    :param expert_prefix: all expert uids will be {expert_prefix}.{index}
+    :param expert_offset: expert uid will use indices in range(expert_offset, expert_offset + num_experts)
+    :param expert_uids: spawn experts with these exact uids, overrides num_experts, expert_prefix and expert_offset
     :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
     :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
     :param hidden_dim: main dimension for expert_cls
     :param hidden_dim: main dimension for expert_cls
     :param num_handlers: server will use this many parallel processes to handle incoming requests
     :param num_handlers: server will use this many parallel processes to handle incoming requests
-    :param expert_prefix: all expert uids will be {expert_prefix}.{index}
-    :param expert_offset: expert uid will use indices in range(expert_offset, expert_offset + num_experts)
     :param max_batch_size: total num examples in the same batch will not exceed this value
     :param max_batch_size: total num examples in the same batch will not exceed this value
     :param device: all experts will use this device in torch notation; default: cuda if available else cpu
     :param device: all experts will use this device in torch notation; default: cuda if available else cpu
     :param no_optimizer: if specified, all optimizers use learning rate=0
     :param no_optimizer: if specified, all optimizers use learning rate=0
@@ -36,6 +37,8 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
     :param verbose: whether to print server started / finished / terminated events
     :param verbose: whether to print server started / finished / terminated events
     :param start: if True, starts server right away and returns when server is ready for requests
     :param start: if True, starts server right away and returns when server is ready for requests
     """
     """
+    assert (expert_uids is None) != (num_experts is None and expert_prefix == 'expert' and expert_offset == 0), \
+        "Please provide either expert uids *or* (num_experts, expert_prefix and expert_offset), not both"
     if verbose and len(kwargs) != 0:
     if verbose and len(kwargs) != 0:
         print("Ignored kwargs:", kwargs)
         print("Ignored kwargs:", kwargs)
     assert expert_cls in name_to_block
     assert expert_cls in name_to_block
@@ -68,11 +71,15 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
         args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
         args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
 
 
     # initialize experts
     # initialize experts
+    if expert_uids is None:
+        num_experts = num_experts if num_experts is not None else 1
+        expert_uids = [f'{expert_prefix}{hivemind.DHT.UID_DELIMITER}{i + expert_offset}'
+                       for i in range(num_experts)]
+
     experts = {}
     experts = {}
-    for i in range(num_experts):
+    for expert_uid in expert_uids:
         expert = name_to_block[expert_cls](hidden_dim)
         expert = name_to_block[expert_cls](hidden_dim)
         opt = torch.optim.SGD(expert.parameters(), 0.0 if no_optimizer else 0.05)
         opt = torch.optim.SGD(expert.parameters(), 0.0 if no_optimizer else 0.05)
-        expert_uid = f'{expert_prefix}{hivemind.DHT.UID_DELIMITER}{i + expert_offset}'
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
                                                      args_schema=args_schema,
                                                      args_schema=args_schema,
                                                      outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
                                                      outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
@@ -108,7 +115,7 @@ def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tupl
         finally:
         finally:
             if verbose:
             if verbose:
                 print("Server failed to shutdown gracefully, terminating it the hard way...")
                 print("Server failed to shutdown gracefully, terminating it the hard way...")
-            runner.terminate()
+            runner.kill()
             if verbose:
             if verbose:
                 print("Server terminated.")
                 print("Server terminated.")
 
 
@@ -132,9 +139,8 @@ def _server_runner(pipe, *args, verbose, **kwargs):
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
-    parser.add_argument('--interface', type=str, default='0.0.0.0', required=False,
+    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
                         help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
                         help="'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6")
-    parser.add_argument('--port', type=int, default=None, required=False, help="server will listen to this port")
     parser.add_argument('--num_experts', type=int, default=1, required=False, help="run this many identical experts")
     parser.add_argument('--num_experts', type=int, default=1, required=False, help="run this many identical experts")
     parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
     parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")