瀏覽代碼

Merge pull request #23 from learning-at-home/manage_technical_debt

Manage technical debt+refactor gating function (support parallel backward)
Max Ryabinin 5 年之前
父節點
當前提交
19d7299cbc

+ 3 - 0
.circleci/config.yml

@@ -15,6 +15,9 @@ jobs:
       - run:
           command: sudo python setup.py develop
           name: setup
+      - run:
+          command: nosetests tests/*
+          name: tests
       - run:
           command: python tests/benchmark_throughput.py --preset minimalistic
           name: benchmark

+ 1 - 53
README.md

@@ -9,56 +9,4 @@ Distributed training of large neural networks across volunteer computers.
 
 **[WIP]** - this branch is a work in progress. If you're interested in
 supplementary code for [Learning@home paper](https://arxiv.org/abs/2002.04013),
-you can find it at https://github.com/mryab/learning-at-home.
-
-## What do I need to run it?
-
-- One or several computers, each equipped with at least one GPU
-- Each computer should have at least two open ports (if not, consider ssh port
-  forwarding)
-- Some popular Linux x64 distribution
-  - Tested on Ubuntu16.04, should work fine on any popular linux64 and even
-    MacOS;
-  - Running on Windows natively is not supported, please use vm or docker;
-
-## How do I run it?
-
-Currently, there is no way to do it easily. There are some tests (you can check [`./tests/benchmark_throughput.py`](./tests/benchmark_throughput.py)
- or look into CI logs) and we want to expand them. If you want to
-do something complex with it, please contact us by opening an issue (less preferred: [telegram](https://t.me/justheuristic)).
-
-## `tesseract` quick tour
-
-**Trainer process:**
-
-- **`RemoteExpert`**(`tesseract/client/remote_expert.py`) behaves like a pytorch
-  module with autograd support but actually sends request to a remote runtime.
-- **`GatingFunction`**(`tesseract/client/gating_function.py`) finds best experts
-  for a given input and either returns them as `RemoteExpert` or applies them
-  right away.
-
-**Runtime process:**
-
-- **`TesseractRuntime`** (`tesseract/runtime/__init__.py`) aggregates batches
-  and performs inference/training of experts according to their priority.
-- **`TesseractServer`** (`tesseract/server/__init__.py`) wraps runtime and
-  periodically uploads experts into `TesseractNetwork`.
-
-**DHT:**
-
-- **`TesseractNetwork`**(`tesseract/network/__init__.py`) is a node of
-  Kademlia-based DHT that stores metadata used by trainer and runtime.
-
-## Limitations
-
-**DHT**:
-
-- DHT functionality is severely limited by its inability to traverse NAT.
-- Because of this all the features that require DHT are in deep pre-alpha state
-  and cannot be used without special setup.
-
-**Runtime**:
-* You can achieve 4x less network load by passing quantized uint8 activations across experts.
-    Implement your own quantization or wait for tesseract v0.8.
-* Currently runtime can form batches that exceed maximal batch_size by task_size - 1. 
-    We will fix that in the nearest patch.
+you can find it at https://github.com/mryab/learning-at-home.

+ 1 - 1
docs/modules/client.rst

@@ -14,6 +14,6 @@
 .. autoclass:: RemoteExpert
    :members: forward
 
-.. autoclass:: GatingFunction
+.. autoclass:: RemoteMixtureOfExperts
    :members:
    :member-order: bysource

+ 51 - 0
docs/user/quickstart.md

@@ -4,3 +4,54 @@ This will eventually become a tutorial on how to host a tesseract node or connec
 
 ![img](https://media.giphy.com/media/3oz8xtBx06mcZWoNJm/giphy.gif)
 
+## What do I need to run it?
+
+- One or several computers, each equipped with at least one GPU
+- Each computer should have at least two open ports (if not, consider ssh port
+  forwarding)
+- Some popular Linux x64 distribution
+  - Tested on Ubuntu16.04, should work fine on any popular linux64 and even
+    MacOS;
+  - Running on Windows natively is not supported, please use vm or docker;
+
+## How do I run it?
+
+Currently, there is no way to do it easily. There are some tests (you can check [`./tests/benchmark_throughput.py`](https://github.com/learning-at-home/tesseract/blob/master/tests/benchmark_throughput.py)
+ or look into CI logs) and we want to expand them. If you want to
+do something complex with it, please contact us by opening an issue (less preferred: [telegram](https://t.me/justheuristic)).
+
+## `tesseract` quick tour
+
+**Trainer process:**
+
+- **`RemoteExpert`**(`tesseract/client/remote_expert.py`) behaves like a pytorch
+  module with autograd support but actually sends request to a remote runtime.
+- **`RemoteMixtureOfExperts`**(`tesseract/client/remote_moe.py`) finds best experts
+  for a given input and either returns them as `RemoteExpert` or applies them
+  right away.
+
+**Runtime process:**
+
+- **`TesseractRuntime`** (`tesseract/runtime/__init__.py`) aggregates batches
+  and performs inference/training of experts according to their priority.
+- **`TesseractServer`** (`tesseract/server/__init__.py`) wraps runtime and
+  periodically uploads experts into `TesseractNetwork`.
+
+**DHT:**
+
+- **`TesseractNetwork`**(`tesseract/network/__init__.py`) is a node of
+  Kademlia-based DHT that stores metadata used by trainer and runtime.
+
+## Limitations
+
+**DHT**:
+
+- DHT functionality is severely limited by its inability to traverse NAT.
+- Because of this all the features that require DHT are in deep pre-alpha state
+  and cannot be used without special setup.
+
+**Runtime**:
+* You can achieve 4x less network load by passing quantized uint8 activations across experts.
+    Implement your own quantization or wait for tesseract v0.8.
+* Currently runtime can form batches that exceed maximal batch_size by task_size - 1. 
+    We will fix that in the nearest patch.

+ 2 - 1
requirements.txt

@@ -4,4 +4,5 @@ numpy>=1.17
 requests>=2.22.0
 tqdm
 kademlia>=2.2
-prefetch_generator>=1.0.1
+prefetch_generator>=1.0.1
+nose>=1.3.0

+ 2 - 2
tesseract/client/__init__.py

@@ -1,2 +1,2 @@
-from .gating_function import GatingFunction
-from .remote_expert import RemoteExpert
+from .moe import RemoteMixtureOfExperts
+from .expert import RemoteExpert

+ 2 - 0
tesseract/client/remote_expert.py → tesseract/client/expert.py

@@ -2,6 +2,7 @@ from typing import Tuple, Optional
 
 import torch
 import torch.nn as nn
+from torch.autograd.function import once_differentiable
 
 from ..utils import nested_flatten, DUMMY, PytorchSerializer, nested_pack, nested_compare, Connection
 
@@ -69,6 +70,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         return tuple(PytorchSerializer.loads(msg))  # flattened expert outputs
 
     @staticmethod
+    @once_differentiable
     def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
         connection = Connection.create(ctx.host, ctx.port)
         payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))

+ 0 - 170
tesseract/client/gating_function.py

@@ -1,170 +0,0 @@
-import multiprocessing as mp
-import multiprocessing.pool
-from functools import partial
-from typing import Tuple, List, Dict, Any
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-from .remote_expert import RemoteExpert
-from ..utils import nested_map, check_numpy, run_and_await_k
-
-
-class GatingFunction(nn.Module):
-    """
-    A torch module that selects experts across the network and averages their predictions
-
-    :param in_features: common input size for experts and gating function
-    :param grid_size: tesseract dimensions that form expert uid (see below)
-    :param uid_prefix: common prefix for all expert uids
-     expert uid follows the pattern {uid_prefix}{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
-    :param network: TesseractNetwork where the experts reside
-    :param num_workers: number of threads for parallel network operation
-    :param k_best: queries this many experts with highest scores
-    :param k_min: makes sure at least this many experts returned output
-    :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
-     Any expert that didn't manage to return output after that delay is considered unavailable
-    :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
-    """
-    def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
-                 k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
-        super().__init__()
-        self.network, self.grid_size = network, grid_size
-        self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
-        self.k_best, self.k_min, self.timeout_after_k_min = k_best, k_min, timeout_after_k_min
-
-        self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
-        self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
-
-    def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]:
-        """
-        Choose k best experts with beam search, then call chosen experts and average their outputs.
-
-        :param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
-        :returns: averaged predictions of all experts that delivered on time
-        """
-        assert len(input.shape) == 2
-
-        # 1. compute scores and find most appropriate experts with beam search
-        grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
-        batch_experts = self.beam_search(grid_scores, self.k_best)
-        # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
-
-        # 2.1 call chosen experts (run them in background to save time)
-        batch_outputs_async = [
-            self.thread_pool.apply_async(self._run_experts,
-                                         args=[chosen_experts, input[i: i + 1], *(tensor[i: i + 1] for tensor in args)],
-                                         kwds={key: tensor[i: i + 1] for key, tensor in kwargs.items()})
-            for i, chosen_experts in enumerate(batch_experts)
-        ]
-
-        # 2.2 compute *differentiable* logits for each expert
-        batch_expert_logits = self._score_experts(grid_scores, batch_experts)
-        # ^-- List[batch_size] of Dict[RemoteExpert, logit] before softmax for each active expert
-
-        batch_outputs = []
-        for output_async, expert_logits in zip(batch_outputs_async, batch_expert_logits):
-            expert_outputs: Dict[RemoteExpert, Any] = output_async.get()
-            flat_experts, flat_outputs = zip(*expert_outputs.items())
-
-            # 3.1. normalize logits over only those experts that DID return output
-            flat_logits = torch.stack([expert_logits[expert] for expert in flat_experts])
-            flat_weights = torch.softmax(flat_logits, dim=-1)
-
-            # 3.2. average each output across experts
-            average_outputs = nested_map(
-                lambda *tensors: sum(x * weight for x, weight in zip(tensors, flat_weights)), *flat_outputs)
-
-            batch_outputs.append(average_outputs)
-
-        # 4. concatenate mixture outputs from individual experts
-        return nested_map(lambda *tensors: torch.cat(tensors, dim=0), *batch_outputs)
-
-    def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
-        """
-        Find and return k best experts in the grid using (exact) beam search of the product space
-
-        :param grid_scores: scores predicted for each dimension in the grid,
-        :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
-        :param k_best: how many of the top experts participate in the computation
-        :param kwargs: extra keyword parameters passed to self.network.first_k_active
-        :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains \
-         RemoteExpert instances for *up to* k_best experts
-        """
-        assert len(grid_scores) == len(self.grid_size)
-        assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
-        batch_size = len(grid_scores[0])
-        beam = np.array([[self.uid_prefix]] * batch_size, dtype=object)  # [batch_size, up_to_beam_size]
-        scores = np.zeros([batch_size, 1], dtype=np.float64)
-
-        delimeters = np.array(self.network.UID_DELIMETER)[None, None, None]  # pre-compute numpy array for fast concat
-
-        for dim_index, dim_scores in enumerate(grid_scores):
-            dim_scores = check_numpy(dim_scores)
-            assert dim_scores.shape[-1] == self.grid_size[dim_index]
-
-            # create all possible successsors from current beam
-            dim_indices = np.arange(dim_scores.shape[1]).astype(str)
-            new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :]
-            new_candidates = new_candidates.reshape([batch_size, -1])
-
-            new_scores = scores[:, :, None] + dim_scores[:, None, :]
-            new_scores = new_scores.reshape([batch_size, -1])
-
-            # select k best candidates according to scores but only those that are still active
-            new_order = np.argsort(- new_scores, axis=-1)
-            top_alive_lookups = [
-                self.thread_pool.apply_async(self.network.first_k_active, args=(cands[order], k_best), kwds=kwargs)
-                for cands, order in zip(new_candidates, new_order)]
-
-            batch_cand_to_score = [
-                dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
-
-            top_alive_prefixes = [result.get() for result in top_alive_lookups]
-            top_alive_scores = [list(map(cand_to_score.get, top_cands))
-                                for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
-
-            # pad up to beam size
-            beam = np.array([row + [self.expert_padding] * (k_best - len(row))
-                             for row in top_alive_prefixes], dtype='object')
-            scores = np.array([row + [-float('inf')] * (k_best - len(row))
-                               for row in top_alive_scores], dtype='float32')
-
-        unique_experts = self.network.get_experts(list(set(
-            uid for row in beam for uid in row if uid != self.expert_padding)))
-        unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
-
-        return [
-            [unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid]
-            for row in beam]
-
-    def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]:
-        outputs = run_and_await_k([partial(expert, *args, **kwargs) for expert in experts],
-                                  k=self.k_min, timeout_after_k=self.timeout_after_k_min)
-        return {expert: output for expert, output in zip(experts, outputs)
-                if not isinstance(output, BaseException)}
-
-    def _score_experts(self, grid_scores: List[torch.Tensor],
-                       experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
-        flat_experts = [expert for row in experts for expert in row]
-        flat_batch_indices = torch.tensor([i for i, row in enumerate(experts)
-                                           for uid in range(len(row))])
-
-        grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
-        for i, expert in enumerate(flat_experts):
-            expert_indices = expert.uid[len(self.uid_prefix) + len(self.network.UID_DELIMETER):]
-            expert_indices = list(map(int, expert_indices.split(self.network.UID_DELIMETER)))
-            grid_indices[i] = expert_indices
-
-        scores_per_dim = [
-            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
-            for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
-        flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
-
-        output_dicts = [dict() for _ in range(len(experts))]
-        for batch_i, expert, score in zip(check_numpy(flat_batch_indices),
-                                          flat_experts, flat_scores):
-            output_dicts[batch_i][expert] = score
-
-        return output_dicts

+ 261 - 0
tesseract/client/moe.py

@@ -0,0 +1,261 @@
+import multiprocessing as mp
+import multiprocessing.pool
+from functools import partial
+from typing import Tuple, List, Dict, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.autograd.function import once_differentiable
+
+from .expert import RemoteExpert, _RemoteModuleCall
+from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background
+from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
+
+
+class RemoteMixtureOfExperts(nn.Module):
+    """
+    A torch module that performs mixture of experts inference with a local gating function and multiple remote experts.
+    Natively supports pytorch autograd.
+
+    :note: By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran
+     forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without
+     the missing experts
+
+    :param in_features: common input size for experts and gating function
+    :param grid_size: tesseract dimensions that form expert uid (see below)
+    :param uid_prefix: common prefix for all expert uids
+     expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
+    :param network: TesseractNetwork where the experts reside
+    :param num_workers: number of threads for parallel network operation
+    :param k_best: queries this many experts with highest scores
+    :param k_min: makes sure at least this many experts returned output
+    :param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
+     Any expert that didn't manage to return output after that delay is considered unavailable
+    :param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
+    :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
+     allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
+     allow_broadcasting=False will raise an error
+    """
+    def __init__(self, *, in_features, grid_size: Tuple[int], network, k_best, k_min=1,
+                 forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
+                 uid_prefix='', expert_padding=None, allow_broadcasting=True):
+        super().__init__()
+        self.network, self.grid_size = network, grid_size
+        self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
+        self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
+        self.forward_timeout, self.timeout_after_k_min, self.backward_timeout = forward_timeout, timeout_after_k_min, backward_timeout
+        self.allow_broadcasting = allow_broadcasting
+
+        self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
+        self._outputs_schema = None
+
+    def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tensor):
+        """
+        Choose k best experts with beam search, then call chosen experts and average their outputs.
+        :param input: a tensor of values that are used to estimate gating function, batch-first
+        :param args: extra positional parameters that will be passed to each expert after input, batch-first
+        :param kwargs: extra keyword parameters that will be passed to each expert, batch-first
+        :returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
+        """
+        if self.allow_broadcasting and input.ndim != 2:
+            # flatten extra dimensions, apply the function and then un-flatten them back to normal like nn.Linear does
+            flattened_dims = input.shape[:-1]
+            input_flat = input.view(-1, input.shape[-1])
+            args_flat = [tensor.view(-1, tensor.shape[len(flattened_dims):]) for tensor in args]
+            kwargs_flat = {key: tensor.view(-1, tensor.shape[len(flattened_dims):]) for key, tensor in kwargs.items()}
+            out_flat = self.forward(input_flat, *args_flat, **kwargs_flat)
+            return nested_map(lambda tensor: tensor.view(flattened_dims, tensor.shape[len(flattened_dims):]), out_flat)
+
+        # 1. compute scores and find most appropriate experts with beam search
+        grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
+        chosen_experts = self.beam_search(grid_scores, self.k_best)
+        # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
+
+        expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
+
+        expert_inputs = ((input, *args), kwargs)
+        input_schema = nested_map(lambda x: None, expert_inputs)
+        flat_inputs_per_expert = tuple(zip(*[tensor.split(1, dim=0) for tensor in nested_flatten(expert_inputs)]))
+
+        batch_jobs_args = tuple(
+            (expert_logits[i, :len(chosen_experts[i])], chosen_experts[i], self.k_min, self.timeout_after_k_min,
+             self.backward_k_min, self.forward_timeout, self.backward_timeout, input_schema, *flat_inputs_per_expert[i])
+            for i in range(len(input))
+        )
+
+        averaged_outputs_flat = map(torch.cat, zip(*map_with_parallel_backward(_RemoteMoECall, *batch_jobs_args)))
+        return nested_pack(averaged_outputs_flat, self.outputs_schema)
+
+    def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
+        """
+        Find and return k best experts in the grid using (exact) beam search of the product space
+
+        :param grid_scores: scores predicted for each dimension in the grid,
+        :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
+        :param k_best: how many of the top experts participate in the computation
+        :param kwargs: extra keyword parameters passed to self.network.first_k_active
+        :returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains \
+         RemoteExpert instances for *up to* k_best experts
+        """
+        assert len(grid_scores) == len(self.grid_size)
+        assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
+        batch_size = len(grid_scores[0])
+        beam = np.array([[self.uid_prefix]] * batch_size, dtype=object)  # [batch_size, up_to_beam_size]
+        scores = np.zeros([batch_size, 1], dtype=np.float64)
+
+        delimeters = np.array(self.network.UID_DELIMETER)[None, None, None]  # pre-compute numpy array for fast concat
+
+        for dim_index, dim_scores in enumerate(grid_scores):
+            dim_scores = check_numpy(dim_scores)
+            assert dim_scores.shape[-1] == self.grid_size[dim_index]
+
+            # create all possible successsors from current beam
+            dim_indices = np.arange(dim_scores.shape[1]).astype(str)
+            new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :]
+            new_candidates = new_candidates.reshape([batch_size, -1])
+
+            new_scores = scores[:, :, None] + dim_scores[:, None, :]
+            new_scores = new_scores.reshape([batch_size, -1])
+
+            # select k best candidates according to scores but only those that are still active
+            new_order = np.argsort(- new_scores, axis=-1)
+            top_alive_lookups = [
+                run_in_background(self.network.first_k_active, cands[order], k_best, **kwargs)
+                for cands, order in zip(new_candidates, new_order)]
+
+            batch_cand_to_score = [
+                dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
+
+            top_alive_prefixes = [result.result() for result in top_alive_lookups]
+            top_alive_scores = [list(map(cand_to_score.get, top_cands))
+                                for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
+
+            # pad up to beam size
+            beam = np.array([row + [self.expert_padding] * (k_best - len(row))
+                             for row in top_alive_prefixes], dtype='object')
+            scores = np.array([row + [-float('inf')] * (k_best - len(row))
+                               for row in top_alive_scores], dtype='float32')
+
+        unique_experts = self.network.get_experts(list(set(
+            uid for row in beam for uid in row if uid != self.expert_padding)))
+        if self._outputs_schema is None:
+            self._outputs_schema = next(iter(unique_experts)).info['outputs_schema']
+        unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
+
+        return [[unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] for row in beam]
+
+    def compute_expert_scores(
+            self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
+        """ TODO(jheuristic) docstring here """
+        expert_counts = list(map(len, batch_experts))
+        batch_size = len(batch_experts)
+        max_num_experts = max(expert_counts)
+        total_num_experts = sum(expert_counts)
+        expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
+        expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]
+        flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
+        flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
+        flat_experts = [expert for row in batch_experts for expert in row]
+
+        grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
+        for i, expert in enumerate(flat_experts):
+            expert_indices = expert.uid[len(self.uid_prefix) + len(self.network.UID_DELIMETER):]
+            expert_indices = list(map(int, expert_indices.split(self.network.UID_DELIMETER)))
+            grid_indices[i] = expert_indices
+
+        scores_per_dim = [
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
+        flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
+
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
+        scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
+        return scores
+
+    @property
+    def outputs_schema(self):
+        if self._outputs_schema is None:
+            # grab some expert to set ensemble output shape
+            dummy_scores = self.proj(torch.randn(1, self.proj.in_features)).split_with_sizes(self.grid_size, dim=-1)
+            self._outputs_schema = self.beam_search(dummy_scores, k_best=1)[0][0].info['outputs_schema']
+        return self._outputs_schema
+
+
+class _RemoteMoECall(torch.autograd.Function):
+    """
+    Internal autograd-friendly function that calls multiple experts on the same input and averages their outputs.
+    This function that can recover from individual failures during forward and/or backward passes.
+    For user-friendly version of this function, use RemoteMixtureOfExperts module.
+    """
+    @classmethod
+    def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
+                k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
+                backward_timeout: Optional[float], input_schema, *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+        expert_args, expert_kwargs = nested_pack(flat_inputs, structure=input_schema)
+        assert expert_logits.ndim == 1 and len(expert_logits) == len(experts)
+
+        # 1. call experts and await results
+        jobs = [partial(cls._run_expert_forward, expert, *expert_args, **expert_kwargs) for expert in experts]
+        results = run_and_await_k(jobs, k=k_min, timeout_after_k=timeout_after_k_min, timeout_total=timeout_total)
+
+        alive_contexts, alive_outputs, alive_ix = zip(*[(result[0], result[1], ix) for ix, result in enumerate(results)
+                                                        if not isinstance(result, BaseException)])
+        #     ^               ^            ^-- a list of indices of experts that returned outputs in time
+        #      \               \-- list of outputs of every expert that didn't die on us
+        #       \-- a list of autograd contexts, used for parallel backward
+
+        # 2. compute softmax weights for alive experts and average outputs
+        alive_ix = torch.as_tensor(alive_ix, device=expert_logits.device)
+        alive_expert_probs = torch.softmax(expert_logits[alive_ix], dim=0)
+
+        stacked_alive_outputs = tuple(map(torch.stack, zip(*alive_outputs)))
+
+        flat_average_outputs = tuple((alive_expert_probs @ stacked_out.flatten(1)).view(*stacked_out.shape[1:])
+                                     for stacked_out in stacked_alive_outputs)
+
+        # 3. save individual outputs for backward pass
+        ctx.save_for_backward(expert_logits, alive_ix, alive_expert_probs, *stacked_alive_outputs)
+        ctx._saved_non_tensors = alive_contexts, backward_k_min, backward_timeout
+        return tuple(map(torch.Tensor.detach, flat_average_outputs))
+
+    @classmethod
+    @once_differentiable
+    def backward(cls, ctx, *grad_outputs_flat: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
+        """ Like normal backward, but we ignore any experts that failed during backward pass """
+        expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
+        alive_contexts, backward_k_min, backward_timeout = ctx._saved_non_tensors
+
+        jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
+                for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
+        results = run_and_await_k(jobs, k=backward_k_min, timeout_after_k=backward_timeout, timeout_total=None)
+        backward_survivors_in_alive_ix, survived_grad_inputs = zip(*((i, grads) for i, grads in enumerate(results)))
+        backward_survivors_in_alive_ix = torch.as_tensor(backward_survivors_in_alive_ix, device=expert_logits.device)
+        backward_survivors_ix = alive_ix[backward_survivors_in_alive_ix]
+        survived_probas = torch.softmax(expert_logits[backward_survivors_ix], dim=0)
+        weight_ratios = survived_probas / alive_expert_probas[backward_survivors_in_alive_ix]
+        flat_grad_inputs = tuple((weight_ratios @ stacked_grad_inp.flatten(1)).view(stacked_grad_inp.shape[1:])
+                                 for stacked_grad_inp in map(torch.stack, zip(*survived_grad_inputs)))
+
+        # compute grad w.r.t. logits
+        grad_wrt_probs = sum(tuple(
+            torch.sum(grad_out[None, ...] * stacked_avive_out[backward_survivors_in_alive_ix],
+                      dim=tuple(range(1, stacked_avive_out.ndim)))
+            for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
+        ))
+        softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
+        grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
+
+        return grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs
+
+    @staticmethod
+    def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
+        """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
+        flat_inputs = nested_flatten((args, kwargs))
+        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, *flat_inputs)
+
+    @staticmethod
+    def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
+        backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
+        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
+        return grad_inputs

+ 21 - 7
tesseract/network/__init__.py

@@ -7,7 +7,7 @@ from typing import Tuple, List, Optional
 from kademlia.network import Server
 
 from tesseract.client import RemoteExpert
-from tesseract.utils import run_in_background, repeated, SharedFuture, PickleSerializer
+from tesseract.utils import run_forever, SharedFuture, PickleSerializer
 
 
 class TesseractNetwork(mp.Process):
@@ -15,29 +15,43 @@ class TesseractNetwork(mp.Process):
     HEARTBEAT_EXPIRATION = 120  # expert is inactive iff it fails to post timestamp for *this many seconds*
     make_key = "{}::{}".format
 
-    def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False):
+    def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False, daemon=True):
         super().__init__()
         self.port, self.initial_peers = port, initial_peers
         self._pipe, self.pipe = mp.Pipe(duplex=False)
+        self.ready = mp.Event()
         self.server = Server()
+        self.daemon = daemon
         if start:
-            self.start()
+            self.run_in_background(await_ready=True)
 
     def run(self) -> None:
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
         loop.run_until_complete(self.server.listen(self.port))
         loop.run_until_complete(self.server.bootstrap(self.initial_peers))
-        run_in_background(repeated(loop.run_forever))
+        run_forever(loop.run_forever)
+        self.ready.set()
 
         while True:
             method, args, kwargs = self._pipe.recv()
             getattr(self, method)(*args, **kwargs)
 
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts TesseractNetwork in a background process. if await_ready, this method will wait until background network
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError("TesseractServer didn't notify .ready in {timeout} seconds")
+
     def shutdown(self) -> None:
         """ Shuts down the network process """
-        warnings.warn("TODO shutdown network gracefully")
-        self.terminate()
+        if self.is_alive():
+            self.kill()
+        else:
+            warnings.warn("Network shutdown has no effect: network process is already not alive")
 
     def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
         """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
@@ -69,7 +83,7 @@ class TesseractNetwork(mp.Process):
         :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
         """
         done_event = mp.Event() if wait_timeout else None
-        self.pipe.send(('_declare_experts', [], dict(uids=uids, addr=addr, port=port, done_event=done_event)))
+        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, done_event=done_event)))
         if done_event is not None:
             done_event.wait(wait_timeout)
 

+ 4 - 2
tesseract/server/__init__.py

@@ -52,7 +52,7 @@ class TesseractServer(threading.Thread):
         """
         if self.network:
             if not self.network.is_alive():
-                self.network.start()
+                self.network.run_in_background(await_ready=True)
 
             network_thread = NetworkHandlerThread(experts=self.experts, network=self.network,
                                                   addr=self.addr, port=self.port, update_period=self.update_period)
@@ -111,10 +111,12 @@ class TesseractServer(threading.Thread):
         self.ready.clear()
         for process in self.conn_handlers:
             process.terminate()
-        self.runtime.shutdown()
+
         if self.network is not None:
             self.network.shutdown()
 
+        self.runtime.shutdown()
+
 
 def socket_loop(sock, experts):
     """ catch connections, send tasks to processing, respond with results """

+ 1 - 0
tesseract/utils/__init__.py

@@ -5,3 +5,4 @@ from .proto import *
 from .serializer import *
 from .shared_future import *
 from .threading import *
+from .autograd import *

+ 97 - 0
tesseract/utils/autograd.py

@@ -0,0 +1,97 @@
+"""
+Temporary autograd extensions to enable inter-op parallelism during backward pass
+Note: we should get rid of this module if https://github.com/pytorch/pytorch/pull/33157 reaches a pytorch release
+"""
+from itertools import chain
+from typing import Tuple, Any
+from concurrent.futures import Future
+
+import numpy as np
+import torch
+import torch.autograd.function
+
+from .threading import run_in_background
+
+
+class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
+    """
+    A special class that pretends to be pytorch autograd context. Used to circumvent limitatons of pytorch autograd,
+    such as running several parallel backwards or transferring backward to a separate device.
+    This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
+    """
+    @property
+    def saved_tensors(self):
+        return tuple(self.to_save)
+
+
+def run_isolated_forward(func: torch.autograd.Function, *args) -> Tuple[EmulatedAutogradContext, Any]:
+    """
+    run :func: in a detached pytorch graph, return *detached* function outputs and an EmulatedAutogradContext that
+    can be used to run backward through the same graph (performed manually by the user).
+    """
+    ctx = EmulatedAutogradContext()
+    # create detached copies of every input so that we can differentiate w.r.t. them without modifying actual variables
+    args = tuple(x.detach().requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x for x in args)
+    with torch.no_grad():
+        return ctx, func.forward(ctx, *args)
+
+
+def run_isolated_backward(func: torch.autograd.Function, ctx: EmulatedAutogradContext, *grad_outputs):
+    """
+    run backward pass for :func: in an isolated graph that was previously created through run_isolated_forward
+    """
+    with torch.no_grad():
+        return func.backward(ctx, *grad_outputs)
+
+
+def map_with_parallel_backward(
+        func: torch.autograd.Function, *args_per_call: Tuple[torch.Tensor, ...]) -> Tuple[Tuple[torch.Tensor, ...]]:
+    """
+    Apply an autograd function to several sets of inputs with two extra guarantees:
+    (1) both forward and backward pass happens concurrently for each set of inputs
+    (2) any operation dependent on any individual function will wait for all functions to finish
+    :param func: torch autograd function to be called several times in parallel
+    :param args_per_call: a sequence of tuples of arguments, each tuple corresponds to one function call
+    :returns: a tuple of outputs from each func call
+
+    Note: this function currently requires that all :func: calls succeed (i.e. do not raise an exception).
+    """
+    arg_counts = list(map(len, args_per_call))
+    assert len(set(arg_counts)) == 1, "All input sets must have the same number of arguments"
+    output_strides_ph = Future()
+    flat_outputs: Tuple[torch.Tensor, ...] = _ParallelApplyFunction.apply(
+        func, len(args_per_call), arg_counts[0], output_strides_ph, *chain(*args_per_call))
+    output_strides = output_strides_ph.result()
+    return tuple(flat_outputs[output_strides[i]: output_strides[i + 1]] for i in range(len(output_strides) - 1))
+
+
+class _ParallelApplyFunction(torch.autograd.Function):
+    """
+    A special torch autograd function that runs another function several times in parallel.
+    Please do not call this function directly. Use apply_with_parallel_backward instead.
+    Unlike default pytorch behavior, the backward pass for each function will also happen in parallel.
+    """
+    @staticmethod
+    def forward(ctx, func: torch.autograd.Function, num_calls: int, num_args_per_call: int,
+                output_strides_ph: Future, *args_flat) -> Tuple[torch.Tensor, ...]:
+        assert num_calls * num_args_per_call == len(args_flat)
+        args_per_call = [args_flat[i * num_args_per_call: (i + 1) * num_args_per_call] for i in range(num_calls)]
+
+        futures = [run_in_background(run_isolated_forward, func, *args) for args in args_per_call]
+
+        contexts, outputs = zip(*[future.result() for future in futures])
+        output_strides = np.cumsum([0] + list(map(len, outputs)))
+        ctx._inner_func = func
+        ctx._call_contexts = contexts
+        ctx._output_strides = output_strides
+        output_strides_ph.set_result(output_strides)
+        return tuple(chain(*outputs))
+
+    @staticmethod
+    def backward(ctx, *grad_outputs_flat: torch.Tensor):
+        func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
+        grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]] for i in range(len(contexts))]
+        futures = [run_in_background(run_isolated_backward, func, context, *grads)
+                   for context, grads in zip(contexts, grad_outputs_per_call)]
+        flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())
+        return None, None, None, None, *flat_grads_wrt_input

+ 9 - 0
tesseract/utils/connection.py

@@ -52,3 +52,12 @@ class Connection(AbstractContextManager):
 
     def __exit__(self, *exc_info):
         self.conn.close()
+
+
+def find_open_port():
+    try:
+        sock = socket()
+        sock.bind(('', 0))
+        return sock.getsockname()[1]
+    except:
+        raise ValueError("Could not find open port")

+ 49 - 107
tesseract/utils/threading.py

@@ -1,125 +1,67 @@
+import os
+from concurrent.futures import Future, ThreadPoolExecutor, as_completed, TimeoutError
 import time
-from concurrent.futures import Future, TimeoutError
-from itertools import count
-from threading import Thread, Event, Lock
+from typing import Optional, List
 
+GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("TESSERACT_THREADS", float('inf')))
 
-def run_in_background(func: callable, *args, **kwargs):
-    """ run f(*args, **kwargs) in background and return Future for its outputs """
-    future = Future()
 
-    def _run():
-        try:
-            future.set_result(func(*args, **kwargs))
-        except Exception as e:
-            future.set_exception(e)
+def run_in_background(func: callable, *args, **kwargs) -> Future:
+    """ run func(*args, **kwargs) in background and return Future for its outputs """
 
-    Thread(target=_run).start()
-    return future
+    return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
 
 
-def repeated(func: callable, n_times=None):
-    """ A function that runs a :func: forever or for a specified number of times; use with run_run_in_background """
-
+def run_forever(func: callable, *args, **kwargs):
+    """ A function that runs a :func: in background forever. Returns a future that catches exceptions """
     def repeat():
-        for i in count():
-            if n_times is not None and i > n_times:
-                break
-            func()
-
-    return repeat
-
-
-def add_event_callback(event: Event, callback, timeout=None):
-    """ Add callback that will be executed asynchronously when event is set """
-    return Thread(target=lambda: (event.wait(timeout), callback())).start()
-
+        while True:
+            func(*args, **kwargs)
+    return run_in_background(repeat)
 
-class CountdownEvent(Event):
-    def __init__(self, count_to: int, initial=0):
-        """ An event that must be incremented :count_to: times before it is considered set """
-        super().__init__()
-        self.value = initial
-        self.count_to = count_to
-        self.lock = Lock()
-        self.increment(by=0)  # trigger set/unset depending on initial value
 
-    def increment(self, by=1):
-        with self.lock:
-            self.value += by
-            if self.value >= self.count_to:
-                super().set()
-            else:
-                super().clear()
-            return self.value
-
-    def clear(self):
-        return self.increment(by=-self.value)
-
-
-def await_first(*events: Event, k=1, timeout=None):
-    """
-    wait until first k (default=1) events are set, return True if event was set fast
-    # Note: after k successes we manually *set* all events to avoid memory leak.
-    """
-    events_done = CountdownEvent(count_to=k)
-    for event in events:
-        add_event_callback(event, callback=events_done.increment, timeout=timeout)
-
-    if events_done.wait(timeout=timeout):
-        [event.set() for event in events]
-        return True
-    else:
-        raise TimeoutError()
-
-
-def run_and_await_k(jobs: callable, k, timeout_after_k=0, timeout_total=None):
+def run_and_await_k(jobs: List[callable], k: int,
+                    timeout_after_k: Optional[float] = 0, timeout_total: Optional[float] = None):
     """
     Runs all :jobs: asynchronously, awaits for at least k of them to finish
-    :param jobs: functions to call
-    :param k: how many functions should finish
+    :param jobs: functions to call asynchronously
+    :param k: how many functions should finish for call to be successful
     :param timeout_after_k: after reaching k finished jobs, wait for this long before cancelling
     :param timeout_total: if specified, terminate cancel jobs after this many seconds
     :returns: a list of either results or exceptions for each job
     """
-    assert k <= len(jobs)
+    jobs = list(jobs)
+    assert k <= len(jobs), f"Can't await {k} out of {len(jobs)} jobs."
     start_time = time.time()
-    min_successful_jobs = CountdownEvent(count_to=k)
-    max_failed_jobs = CountdownEvent(count_to=len(jobs) - k + 1)
-
-    def _run_and_increment(run_job: callable):
-        try:
-            result = run_job()
-            min_successful_jobs.increment()
-            return result
-        except Exception as e:
-            max_failed_jobs.increment()
-            return e
-
-    def _run_and_await(run_job: callable):
-        # call function asynchronously. Increment counter after finished
-        future = run_in_background(_run_and_increment, run_job)
-
-        try:  # await for success counter to reach k OR for fail counter to reach n - k + 1
-            await_first(min_successful_jobs, max_failed_jobs,
-                        timeout=None if timeout_total is None else timeout_total - time.time() + start_time)
-        except TimeoutError as e:  # counter didn't reach k jobs in timeout_total
-            return future.result() if future.done() else e
-
-        try:  # await for subsequent jobs if asked to
-            return future.result(timeout=timeout_after_k)
-        except TimeoutError as e:
+    future_to_ix = {run_in_background(job): i for i, job in enumerate(jobs)}
+    outputs = [None] * len(jobs)
+    success_count = 0
+
+    try:
+        # await first k futures for as long as it takes
+        for future in as_completed(list(future_to_ix.keys()), timeout=timeout_total):
+            success_count += int(not future.exception())
+            outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
+            if success_count >= k:
+                break  # we have enough futures to succeed
+            if len(outputs) + len(future_to_ix) < k:
+                failed = len(jobs) - len(outputs) - len(future_to_ix)
+                raise ValueError(f"Couldn't get enough results: too many jobs failed ({failed} / {len(outputs)})")
+
+        # await stragglers for at most self.timeout_after_k_min or whatever time is left
+        if timeout_after_k is not None and timeout_total is not None:
+            time_left = min(timeout_after_k, timeout_total - time.time() + start_time)
+        else:
+            time_left = timeout_after_k if timeout_after_k is not None else timeout_total
+        for future in as_completed(list(future_to_ix.keys()), timeout=time_left):
+            success_count += int(not future.exception())
+            outputs[future_to_ix.pop(future)] = future.result() if not future.exception() else future.exception()
+
+    except TimeoutError:
+        if len(outputs) < k:
+            raise TimeoutError(f"Couldn't get enough results: time limit exceeded (got {len(outputs)} of {k})")
+    finally:
+        for future, index in future_to_ix.items():
             future.cancel()
-            return e
-
-        except Exception as e:  # job failed with exception. Ignore it.
-            return e
-
-    results = [run_in_background(_run_and_await, f) for f in jobs]
-    results = [result.result() for result in results]
-    if min_successful_jobs.is_set():
-        return results
-    elif max_failed_jobs.is_set():
-        raise ValueError("Could not get enough results: too many jobs failed.")
-    else:
-        raise TimeoutError("Could not get enough results: reached timeout_total.")
+            outputs[index] = future.result() if not future.exception() else future.exception()
+    return outputs

+ 2 - 1
tests/benchmark_throughput.py

@@ -6,7 +6,8 @@ import sys
 import time
 
 import torch
-from test_utils import layers, print_device_info, find_open_port
+from test_utils import layers, print_device_info
+from tesseract import find_open_port
 
 import tesseract
 

+ 72 - 0
tests/test_moe.py

@@ -0,0 +1,72 @@
+import torch
+import tesseract
+from test_utils.run_server import background_server
+
+
+def test_remote_module_call():
+    """ Check that remote_module_call returns correct outputs and gradients if called directly """
+    num_experts = 8
+    k_min = 1
+    timeout_after_k_min = None
+    backward_k_min = 1
+    timeout_total = None
+    backward_timeout = None
+    rtol = 1e-3
+    atol = 1e-6
+
+    xx = torch.randn(32, 1024, requires_grad=True)
+    logits = torch.randn(3, requires_grad=True)
+    random_proj = torch.randn_like(xx)
+
+    with background_server(num_experts=num_experts,  device='cpu',
+                           no_optimizer=True, no_network=True) as (localhost, server_port, network_port):
+        experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
+        moe_output, = tesseract.client.moe._RemoteMoECall.apply(
+            logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
+            [(None,), {}], xx)
+
+        grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
+        grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), logits, retain_graph=True)
+
+        # reference outputs: call all experts manually and average their outputs with softmax probabilities
+        probs = torch.softmax(logits, 0)
+        outs = [expert(xx) for expert in experts[:3]]
+        manual_output = sum(p * x for p, x in zip(probs, outs))
+        grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
+        grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
+        grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
+
+    assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun, rtol, atol), "Experts are non-deterministic. The test" \
+                                                                             " is only valid for deterministic experts"
+    assert torch.allclose(moe_output, manual_output, rtol, atol), "_RemoteMoECall returned incorrect output"
+    assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol, atol), "incorrect gradient w.r.t. input"
+    assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
+
+
+def test_compute_expert_scores():
+    try:
+        dht = tesseract.TesseractNetwork(port=tesseract.find_open_port(), start=True)
+        moe = tesseract.client.moe.RemoteMixtureOfExperts(
+            network=dht, in_features=1024, grid_size=[40], k_best=4, k_min=1, timeout_after_k_min=1,
+            uid_prefix='expert')
+        gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
+        ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
+        jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
+        batch_experts = [
+            [tesseract.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}') for e in range(len(ii[b]))]
+            for b in range(len(ii))
+        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
+        logits = moe.compute_expert_scores([gx, gy], batch_experts)
+        torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
+        assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
+
+        for b in range(len(ii)):
+            for e in range(len(ii[b])):
+                assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
+    finally:
+        dht.shutdown()
+
+
+if __name__ == '__main__':
+    test_remote_module_call()
+    test_compute_expert_scores()

+ 0 - 11
tests/test_utils/__init__.py

@@ -1,5 +1,3 @@
-from socket import socket
-
 import torch
 
 
@@ -14,12 +12,3 @@ def print_device_info(device=None):
         print('Memory Usage:')
         print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
         print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
-
-
-def find_open_port():
-    try:
-        sock = socket()
-        sock.bind(('', 0))
-        return sock.getsockname()[1]
-    except:
-        raise ValueError("Could not find open port")

+ 132 - 0
tests/test_utils/run_server.py

@@ -0,0 +1,132 @@
+import resource
+from contextlib import contextmanager
+import multiprocessing as mp
+import argparse
+
+import torch
+import tesseract
+from .layers import name_to_block
+
+
+def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
+                      expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
+                      no_network=False, initial_peers=(), network_port=None, root_port=None, verbose=True, start=False,
+                      UID_DELIMETER=tesseract.TesseractNetwork.UID_DELIMETER, **kwargs) -> tesseract.TesseractServer:
+    """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
+    if verbose and len(kwargs) != 0:
+        print("Ignored kwargs:", kwargs)
+    assert expert_cls in name_to_block
+    num_handlers = num_handlers if num_handlers is not None else num_experts * 8
+    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+
+    # initialize network
+    network = None
+    if not no_network:
+        if not len(initial_peers):
+            print("No initial peers provided. Starting additional network as an initial peer.")
+            dht_root = tesseract.TesseractNetwork(
+                *initial_peers, port=root_port or tesseract.find_open_port(), start=True)
+            print(f"Initializing DHT with port {dht_root.port}")
+            initial_peers = (('localhost', dht_root.port), )
+        else:
+            print("Bootstrapping dht with peers:", initial_peers)
+            if root_port is not None:
+                print(f"Warning: root_port={root_port} will not be used since we already have peers.")
+
+        network = tesseract.TesseractNetwork(
+            *initial_peers, port=network_port or tesseract.find_open_port(), start=True)
+        if verbose:
+            print(f"Running network node on port {network.port}")
+
+    # initialize experts
+    experts = {}
+    for i in range(num_experts):
+        expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
+        opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
+        expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
+        experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
+                                                      args_schema=(tesseract.BatchTensorProto(hidden_dim),),
+                                                      outputs_schema=tesseract.BatchTensorProto(hidden_dim),
+                                                      max_batch_size=max_batch_size,
+                                                      )
+    # actually start server
+    server = tesseract.TesseractServer(
+        network, experts, addr=host, port=port or tesseract.find_open_port(),
+        conn_handler_processes=num_handlers, device=device)
+
+    if start:
+        server.run_in_background(await_ready=True)
+        if verbose:
+            print(f"Server started at {server.addr}:{server.port}")
+            print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
+    return server
+
+
+@contextmanager
+def background_server(*args, verbose=True, **kwargs):
+    """ Runs server in a background process and returns a reference to it. """
+    recv_addr, send_addr = mp.Pipe(duplex=True)
+    trigger_shutdown = mp.Event()
+
+    def server_runner():
+        try:
+            server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
+            network_port = server.network.port if server.network is not None else None
+            send_addr.send((server.addr, server.port, network_port))
+            trigger_shutdown.wait()
+        finally:
+            if verbose:
+                print("Shutting down server...")
+            trigger_shutdown.set()  # if server failed internally, set the shutdown trigger anyway
+            server.shutdown()
+            if verbose:
+                print("Server shut down successfully.")
+
+    try:
+        runner = mp.Process(target=server_runner)
+        runner.start()
+        yield recv_addr.recv()  # yield tuple(hostname, port)
+
+    finally:
+        trigger_shutdown.set()
+        runner.terminate()
+        runner.join()
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', type=str, default='0.0.0.0', required=False)
+    parser.add_argument('--port', type=int, default=None, required=False)
+    parser.add_argument('--num_experts', type=int, default=1, required=False)
+    parser.add_argument('--expert_cls', type=str, default='ffn', required=False)
+    parser.add_argument('--hidden_dim', type=int, default=1024, required=False)
+    parser.add_argument('--num_handlers', type=int, default=None, required=False)
+    parser.add_argument('--expert_prefix', type=str, default='expert', required=False)
+    parser.add_argument('--expert_offset', type=int, default=0, required=False)
+    parser.add_argument('--max_batch_size', type=int, default=16384, required=False)
+    parser.add_argument('--device', type=str, default=None, required=False)
+    parser.add_argument('--no_optimizer', action='store_true')
+    parser.add_argument('--no_network', action='store_true')
+    parser.add_argument('--initial_peers', type=str, default="[]", required=False)
+    parser.add_argument('--network_port', type=int, default=None, required=False)
+    parser.add_argument('--root_port', type=int, default=None, required=False)
+
+    parser.add_argument('--increase_file_limit', action='store_true')
+
+    args = vars(parser.parse_args())
+
+    if args.pop('increase_file_limit'):
+        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+        try:
+            print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
+            resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
+        except:
+            print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
+
+    args['initial_peers'] = eval(args['initial_peers'])
+
+    try:
+        server = make_dummy_server(**args, start=True, verbose=True)
+        server.join()
+    finally:
+        server.shutdown()