Просмотр исходного кода

Initial commit

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Anton Gusev <agus179e@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Max Ryabinin 5 лет назад
Сommit
fb4ef759ff

+ 37 - 0
README.md

@@ -0,0 +1,37 @@
+## Tesseract
+Distributed training of large neural networks across volunteer computers.
+
+![img](./scheme.png)
+
+__[WIP]__ - this branch is in progress of updating. If you're interested in supplementary code for [Learning@home paper](https://arxiv.org/abs/2002.04013), you can find it at https://github.com/mryab/learning-at-home .
+
+
+## What do I need to run it?
+* One or several computers, each equipped with at least one GPU
+* Each computer should have at least two open ports (if not, consider ssh port forwarding)
+* Some popular Linux x64 distribution
+  * Tested on Ubuntu16.04, should work fine on any popular linux64 and even MacOS;
+  * Running on Windows natively is not supported, please use vm or docker;
+
+## How do I run it?
+1. Clone or download this repo. `cd` to its root directory.
+2. Grab or build a working python enviromnent. [Anaconda](https://www.anaconda.com/) works fine.
+3. Install packages from `requirements.txt`
+4. Go to [./experiments](./experiments) and follow the README.md from there
+
+
+## tesseract quick tour
+
+__Trainer process:__
+  * __`RemoteExpert`__(`lib/client/remote_expert.py`) behaves like a pytorch module with autograd support but actually sends request to a remote runtime.
+  * __`GatingFunction`__(`lib/client/gating_function.py`) finds best experts for a given input and either returns them as `RemoteExpert` or applies them right away.
+
+__Runtime process:__
+  * __`TesseractRuntime`__ (`lib/runtime/__init__.py`) aggregates batches and performs inference/training of experts according to their priority. 
+  * __`TesseractServer`__ (`lib/server/__init__.py`) wraps runtime and periodically uploads experts into `TesseractNetwork`.
+
+__DHT:__
+   * __`TesseractNetwork`__(`lib/network/__init__.py`) is a node of Kademlia-based DHT that stores metadata used by trainer and runtime.
+
+## Limitations
+WIP

+ 7 - 0
requirements.txt

@@ -0,0 +1,7 @@
+torch>=1.3.0
+joblib>=0.13
+numpy>=1.17
+requests>=2.22.0
+tqdm
+kademlia>=2.2
+prefetch_generator>=1.0.1


+ 33 - 0
setup.py

@@ -0,0 +1,33 @@
+from pkg_resources import parse_requirements
+from setuptools import setup
+
+with open('requirements.txt') as requirements_file:
+    install_requires = [str(requirement) for requirement in parse_requirements(requirements_file)]
+
+setup(
+    name='tesseract',
+    version='0.7',
+    description='',
+    long_description='',
+    author='Learning@home authors',
+    author_email='mryabinin@hse.ru',
+    packages=['tesseract'],
+    license='MIT',
+    install_requires=install_requires,
+    classifiers=[
+        'Development Status :: 4 - Beta',
+        'Intended Audience :: Developers',
+        'Intended Audience :: Science/Research',
+        'License :: OSI Approved :: MIT License',
+        'Programming Language :: Python :: 3',
+        'Programming Language :: Python :: 3.8',
+        'Topic :: Scientific/Engineering',
+        'Topic :: Scientific/Engineering :: Mathematics',
+        'Topic :: Scientific/Engineering :: Artificial Intelligence',
+        'Topic :: Software Development',
+        'Topic :: Software Development :: Libraries',
+        'Topic :: Software Development :: Libraries :: Python Modules',
+    ],
+    # What does your project relate to?
+    keywords='pytorch, deep learning, machine learning, gpu, distributed computing',
+)

+ 4 - 0
tesseract/__init__.py

@@ -0,0 +1,4 @@
+from .client import *
+from .network import *
+from .server import *
+from .utils import *

+ 2 - 0
tesseract/client/__init__.py

@@ -0,0 +1,2 @@
+from .gating_function import GatingFunction
+from .remote_expert import RemoteExpert

+ 153 - 0
tesseract/client/gating_function.py

@@ -0,0 +1,153 @@
+import multiprocessing as mp
+import multiprocessing.pool
+from functools import partial
+from typing import Tuple, List, Dict, Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .remote_expert import RemoteExpert
+from ..utils import nested_map, check_numpy, run_and_await_k
+
+
+class GatingFunction(nn.Module):
+    def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
+                 k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None):
+        super().__init__()
+        self.network, self.grid_size = network, grid_size
+        self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
+        self.k_best, self.k_min, self.timeout_after_k_min = k_best, k_min, timeout_after_k_min
+
+        self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
+        self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
+
+    def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]:
+        """
+        Choose k best experts with beam search, then call chosen experts and average their outputs.
+        :param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first
+        :return: averaged predictions of all experts that delivered on time
+        """
+        assert len(input.shape) == 2
+
+        # 1. compute scores and find most appropriate experts with beam search
+        grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1)
+        batch_experts = self.beam_search(grid_scores, self.k_best)
+        # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch
+
+        # 2.1 call chosen experts (run them in background to save time)
+        batch_outputs_async = [
+            self.thread_pool.apply_async(self._run_experts,
+                                         args=[chosen_experts, input[i: i + 1], *(tensor[i: i + 1] for tensor in args)],
+                                         kwds={key: tensor[i: i + 1] for key, tensor in kwargs.items()})
+            for i, chosen_experts in enumerate(batch_experts)
+        ]
+
+        # 2.2 compute *differentiable* logits for each expert
+        batch_expert_logits = self._score_experts(grid_scores, batch_experts)
+        # ^-- List[batch_size] of Dict[RemoteExpert, logit] before softmax for each active expert
+
+        batch_outputs = []
+        for output_async, expert_logits in zip(batch_outputs_async, batch_expert_logits):
+            expert_outputs: Dict[RemoteExpert, Any] = output_async.get()
+            flat_experts, flat_outputs = zip(*expert_outputs.items())
+
+            # 3.1. normalize logits over only those experts that DID return output
+            flat_logits = torch.stack([expert_logits[expert] for expert in flat_experts])
+            flat_weights = torch.softmax(flat_logits, dim=-1)
+
+            # 3.2. average each output across experts
+            average_outputs = nested_map(
+                lambda *tensors: sum(x * weight for x, weight in zip(tensors, flat_weights)), *flat_outputs)
+
+            batch_outputs.append(average_outputs)
+
+        # 4. concatenate mixture outputs from individual experts
+        return nested_map(lambda *tensors: torch.cat(tensors, dim=0), *batch_outputs)
+
+    def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]:
+        """
+        Find and return k best experts in the grid using (exact) beam search of the product space
+        :param grid_scores: scores predicted for each dimension in the grid,
+        :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
+        :param k_best: how many of the top experts participate in the computation
+        :param kwargs: extra keyword parameters passed to self.network.first_k_active
+        :returns: a list of *batch_size* lists that contain chosen experts for one sample
+            each inner list contains RemoteExpert instances for *up to* k_best experts
+        """
+        assert len(grid_scores) == len(self.grid_size)
+        assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores)
+        batch_size = len(grid_scores[0])
+        beam = np.array([[self.uid_prefix]] * batch_size, dtype=object)  # [batch_size, up_to_beam_size]
+        scores = np.zeros([batch_size, 1], dtype=np.float64)
+
+        delimeters = np.array(self.network.UID_DELIMETER)[None, None, None]  # pre-compute numpy array for fast concat
+
+        for dim_index, dim_scores in enumerate(grid_scores):
+            dim_scores = check_numpy(dim_scores)
+            assert dim_scores.shape[-1] == self.grid_size[dim_index]
+
+            # create all possible successsors from current beam
+            dim_indices = np.arange(dim_scores.shape[1]).astype(str)
+            new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :]
+            new_candidates = new_candidates.reshape([batch_size, -1])
+
+            new_scores = scores[:, :, None] + dim_scores[:, None, :]
+            new_scores = new_scores.reshape([batch_size, -1])
+
+            # select k best candidates according to scores but only those that are still active
+            new_order = np.argsort(- new_scores, axis=-1)
+            top_alive_lookups = [
+                self.thread_pool.apply_async(self.network.first_k_active, args=(cands[order], k_best), kwds=kwargs)
+                for cands, order in zip(new_candidates, new_order)]
+
+            batch_cand_to_score = [
+                dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
+
+            top_alive_prefixes = [result.get() for result in top_alive_lookups]
+            top_alive_scores = [list(map(cand_to_score.get, top_cands))
+                                for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
+
+            # pad up to beam size
+            beam = np.array([row + [self.expert_padding] * (k_best - len(row))
+                             for row in top_alive_prefixes], dtype='object')
+            scores = np.array([row + [-float('inf')] * (k_best - len(row))
+                               for row in top_alive_scores], dtype='float32')
+
+        unique_experts = self.network.get_experts(list(set(
+            uid for row in beam for uid in row if uid != self.expert_padding)))
+        unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding}
+
+        return [
+            [unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid]
+            for row in beam]
+
+    def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]:
+        outputs = run_and_await_k([partial(expert, *args, **kwargs) for expert in experts],
+                                  k=self.k_min, timeout_after_k=self.timeout_after_k_min)
+        return {expert: output for expert, output in zip(experts, outputs)
+                if not isinstance(output, BaseException)}
+
+    def _score_experts(self, grid_scores: List[torch.Tensor],
+                       experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]:
+        flat_experts = [expert for row in experts for expert in row]
+        flat_batch_indices = torch.tensor([i for i, row in enumerate(experts)
+                                           for uid in range(len(row))])
+
+        grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
+        for i, expert in enumerate(flat_experts):
+            expert_indices = expert.uid[len(self.uid_prefix) + len(self.network.UID_DELIMETER):]
+            expert_indices = list(map(int, expert_indices.split(self.network.UID_DELIMETER)))
+            grid_indices[i] = expert_indices
+
+        scores_per_dim = [
+            dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
+            for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
+        flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
+
+        output_dicts = [dict() for _ in range(len(experts))]
+        for batch_i, expert, score in zip(check_numpy(flat_batch_indices),
+                                          flat_experts, flat_scores):
+            output_dicts[batch_i][expert] = score
+
+        return output_dicts

+ 78 - 0
tesseract/client/remote_expert.py

@@ -0,0 +1,78 @@
+from typing import Tuple, Optional
+
+import torch
+import torch.nn as nn
+
+from ..utils import nested_flatten, DUMMY, PytorchSerializer, nested_pack, nested_compare, Connection
+
+
+class RemoteExpert(nn.Module):
+    """
+    A simple module that runs forward/backward of an expert hosted on a remote machine.
+    Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
+
+    Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
+    Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
+
+    :param uid: unique expert identifier
+    :param host: hostname where TesseractServer operates
+    :param port: port to which TesseractServer listens
+    """
+
+    def __init__(self, uid, host='127.0.0.1', port=8080):
+        super().__init__()
+        self.uid, self.host, self.port = uid, host, port
+        self._info = None
+
+    def forward(self, *args, **kwargs):
+        assert len(kwargs) == len(self.info['keyword_names']), f"Keyword args should be {self.info['keyword_names']}"
+        kwargs = {key: kwargs[key] for key in self.info['keyword_names']}
+        # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+
+        forward_inputs = (args, kwargs)
+
+        if not nested_compare(forward_inputs, self.info['forward_schema']):
+            raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, *nested_flatten(forward_inputs))
+        # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
+        return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
+
+    @property
+    def info(self):
+        if self._info is None:
+            connection = Connection.create(self.host, self.port)
+            connection.send_raw('info', PytorchSerializer.dumps(self.uid))
+            self._info = PytorchSerializer.loads(connection.recv_message()[1])
+        return self._info
+
+    def extra_repr(self):
+        return f"uid={self.uid}, host={self.host}, port={self.port}"
+
+
+class _RemoteModuleCall(torch.autograd.Function):
+    """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
+
+    @staticmethod
+    def forward(ctx, dummy: torch.Tensor,
+                uid: str, host: str, port: int, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+        inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
+        ctx.uid, ctx.host, ctx.port = uid, host, port
+        ctx.save_for_backward(*inputs)
+
+        connection = Connection.create(ctx.host, ctx.port)
+        connection.send_raw('fwd_', PytorchSerializer.dumps((ctx.uid, inputs)))
+        rtype, msg = connection.recv_message()
+        assert len(msg) != 0, "ExpertBackend.forward did not respond"
+        return tuple(PytorchSerializer.loads(msg))  # flattened expert outputs
+
+    @staticmethod
+    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
+        connection = Connection.create(ctx.host, ctx.port)
+        payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
+        connection.send_raw('bwd_', PytorchSerializer.dumps((ctx.uid, payload)))
+        rtype, msg = connection.recv_message()
+        assert len(msg) != 0, "ExpertBackend.backward did not respond"
+        grad_inputs = PytorchSerializer.loads(msg)
+        return (DUMMY, None, None, None, *grad_inputs)

+ 130 - 0
tesseract/network/__init__.py

@@ -0,0 +1,130 @@
+import asyncio
+import datetime
+import multiprocessing as mp
+from typing import Tuple, List, Optional
+
+from kademlia.network import Server
+
+from tesseract.client import RemoteExpert
+from tesseract.utils import run_in_background, repeated, SharedFuture, PickleSerializer
+
+
+class TesseractNetwork(mp.Process):
+    UID_DELIMETER = '.'  # splits expert uids over this delimeter
+    HEARTBEAT_EXPIRATION = 120  # expert is inactive iff it fails to post timestamp for *this many seconds*
+    make_key = "{}::{}".format
+
+    def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False):
+        super().__init__()
+        self.port, self.initial_peers = port, initial_peers
+        self._pipe, self.pipe = mp.Pipe(duplex=False)
+        self.server = Server()
+        if start:
+            self.start()
+
+    def run(self) -> None:
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+        loop.run_until_complete(self.server.listen(self.port))
+        loop.run_until_complete(self.server.bootstrap(self.initial_peers))
+        run_in_background(repeated(loop.run_forever))
+
+        while True:
+            method, args, kwargs = self._pipe.recv()
+            getattr(self, method)(*args, **kwargs)
+
+    def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
+        """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
+        future, _future = SharedFuture.make_pair()
+        self.pipe.send(('_get_experts', [], dict(uids=uids, heartbeat_expiration=heartbeat_expiration, future=_future)))
+        return future.result()
+
+    def _get_experts(self, uids: List[str], heartbeat_expiration: float, future: SharedFuture):
+        loop = asyncio.get_event_loop()
+        lookup_futures = [asyncio.run_coroutine_threadsafe(
+            self.server.get(self.make_key('expert', uid)), loop) for uid in uids]
+        current_time = datetime.datetime.now()
+
+        experts = [None] * len(uids)
+        for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
+            if lookup.result() is not None:
+                (host, port), timestamp = PickleSerializer.loads(lookup.result())
+                if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
+                    experts[i] = RemoteExpert(uid=uid, host=host, port=port)
+
+        future.set_result(experts)
+
+    def declare_experts(self, uids: List[str], addr, port, wait_timeout=0):
+        """
+        Make experts available to DHT; update timestamps if already available
+        :param uids: a list of expert ids to update
+        :param addr: hostname that can be used to call this expert
+        :param port: port that can be used to call this expert
+        :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
+        """
+        done_event = mp.Event() if wait_timeout else None
+        self.pipe.send(('_declare_experts', [], dict(uids=uids, addr=addr, port=port, done_event=done_event)))
+        if done_event is not None:
+            done_event.wait(wait_timeout)
+
+    def _declare_experts(self, uids: List[str], addr: str, port: int, done_event: Optional[mp.Event]):
+        loop = asyncio.get_event_loop()
+        timestamp = datetime.datetime.now()
+        expert_metadata = PickleSerializer.dumps(((addr, port), timestamp))
+        prefix_metadata = PickleSerializer.dumps(timestamp)
+
+        unique_prefixes = set()
+
+        for uid in uids:
+            asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('expert', uid), expert_metadata), loop)
+            uid_parts = uid.split(self.UID_DELIMETER)
+            unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
+
+        for prefix in unique_prefixes:
+            asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('prefix', prefix), prefix_metadata), loop)
+
+        if done_event is not None:
+            done_event.set()
+
+    def first_k_active(self, prefixes: List[str], k: int, heartbeat_expiration=HEARTBEAT_EXPIRATION, max_prefetch=None):
+        """
+        Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
+        :param prefixes: a list of uid prefixes ordered from highest to lowest priority
+        :param k: return at most *this many* active prefixes
+        :param heartbeat_expiration: consider expert active if his last heartbeat was sent at most this many seconds ago
+        :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
+        :returns: a list of at most :k: prefixes that have at least one active expert each;
+        """
+        future, _future = SharedFuture.make_pair()
+        self.pipe.send(('_first_k_active', [], dict(prefixes=prefixes, k=k, heartbeat_expiration=heartbeat_expiration,
+                                                    max_prefetch=max_prefetch or k, future=_future)))
+        return future.result()
+
+    def _first_k_active(self, prefixes: List[str], k, heartbeat_expiration, max_prefetch, future: SharedFuture):
+        loop = asyncio.get_event_loop()
+        lookup_prefetch = [asyncio.run_coroutine_threadsafe(
+            self.server.get(self.make_key('prefix', prefix)), loop) for prefix in prefixes[:max_prefetch]]
+        current_time = datetime.datetime.now()
+
+        active_prefixes = []
+
+        for i, prefix in enumerate(prefixes):
+            lookup = lookup_prefetch[i]
+
+            if lookup.result() is not None:
+                timestamp = PickleSerializer.loads(lookup.result())
+                if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
+                    active_prefixes.append(prefix)
+                    if len(active_prefixes) >= k:
+                        future.set_result(active_prefixes)
+                        return
+
+            # pre-dispatch the next request in line
+            if len(lookup_prefetch) < len(prefixes):
+                lookup_prefetch.append(
+                    asyncio.run_coroutine_threadsafe(self.server.get(
+                        self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop)
+                )
+
+        # could not find enough active prefixes; return what we can
+        future.set_result(active_prefixes)

+ 81 - 0
tesseract/runtime/__init__.py

@@ -0,0 +1,81 @@
+import multiprocessing as mp
+import threading
+from itertools import chain
+from selectors import DefaultSelector, EVENT_READ
+from typing import Dict
+
+import torch
+import tqdm
+from prefetch_generator import BackgroundGenerator
+
+from .expert_backend import ExpertBackend
+from .task_pool import TaskPool, TaskPoolBase
+
+
+class TesseractRuntime(threading.Thread):
+    def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
+                 device: torch.device = None):
+        """
+        A group of processes that process tasks for multiple experts on a shared device
+        :param expert_backends: a dict [expert uid -> ExpertBackend]
+        :param prefetch_batches: generate up to this many batches in advance
+        :param start: start runtime immediately (at the end of __init__)
+        """
+        super().__init__()
+        self.expert_backends = expert_backends
+        self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
+        self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
+        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
+        self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
+
+    def run(self):
+        progress = tqdm.tqdm(bar_format='{desc}, {rate_fmt}')
+        for pool in self.pools:
+            if not pool.is_alive():
+                pool.start()
+        if self.device is not None:
+            for expert_backend in self.expert_backends.values():
+                expert_backend.to(self.device)
+
+        with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
+            try:
+                self.ready.set()
+                for pool, batch_index, batch in BackgroundGenerator(
+                        self.iterate_minibatches_from_pools(), self.prefetch_batches):
+                    outputs = pool.process_func(*batch)
+                    output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
+                    progress.update(len(outputs[0]))
+                    progress.desc = f'{pool.uid=} {len(outputs[0])=}'
+            finally:
+                self.shutdown()
+
+    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
+
+    def shutdown(self):
+        """ Trigger runtime to terminate, process-save """
+        self.ready.clear()
+        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)  # trigger background thread to shutdown
+        for pool in self.pools:
+            if pool.is_alive():
+                pool.terminate()
+
+    def iterate_minibatches_from_pools(self, timeout=None):
+        """
+        Chooses pool according to priority, then copies exposed batch and frees the buffer
+        """
+        with DefaultSelector() as selector:
+            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
+            for pool in self.pools:
+                selector.register(pool.batch_receiver, EVENT_READ, pool)
+
+            while True:
+                # wait until at least one batch_receiver becomes available
+                ready_fds = selector.select()
+                ready_objects = {key.data for (key, events) in ready_fds}
+                if self.SHUTDOWN_TRIGGER in ready_objects:
+                    break  # someone asked us to shutdown, break from the loop
+
+                pool = max(ready_objects, key=lambda pool: pool.priority)
+
+                batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
+                yield pool, batch_index, batch_tensors

+ 98 - 0
tesseract/runtime/expert_backend.py

@@ -0,0 +1,98 @@
+from typing import Dict, Sequence, Any, Tuple, Union
+
+import torch
+from torch import nn
+
+from .task_pool import TaskPool
+from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorProto, DUMMY_BATCH_SIZE, nested_map
+
+
+class ExpertBackend(nn.Module):
+    def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
+                 args_schema: Tuple[BatchTensorProto, ...] = None,
+                 kwargs_schema: Dict[str, BatchTensorProto] = None,
+                 outputs_schema: Union[BatchTensorProto, Tuple[BatchTensorProto, ...]] = None,
+                 **kwargs):
+        """
+        ExpertBackend implements how a given expert processes tasks.
+        By default, there are two tasks:
+         * forward receives inputs and produces outputs
+         * backward receives gradients w.r.t. outputs, computes gradients w.r.t. inputs and trains the expert
+
+        All incoming tasks are grouped by type (forward/backward) and sent into the corresponding pool,
+        where tasks are grouped into minibatches and prepared for processing on device;
+        The results are dispatched to task authors with SharedFuture.set_result.
+
+        :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
+            * Experts must always receive the same set of *args and **kwargs and produce output tensors of same type
+            * All *args, **kwargs and outputs must be *tensors* where 0-th dimension represents to batch size
+            * We recommend using experts that are ~invariant to the order in which they process batches
+
+        :param opt: torch optimizer to be applied on every backward call
+        :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
+        :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
+        :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
+        :param kwargs: extra parameters to be forwarded into TaskPool.__init__
+        """
+        super().__init__()
+        self.expert, self.opt, self.name = expert, opt, name
+
+        self.args_schema = args_schema = tuple(args_schema or ())
+        self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
+        assert args_schema or kwargs_schema, "expert must receive at least one positional or keyword input." \
+                                             " Did you forget to provide args_schema/kwargs_schema?"
+
+        if outputs_schema is None:
+            # run expert once to get outputs schema
+            dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
+            dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
+            dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
+            outputs_schema = nested_map(BatchTensorProto.from_tensor, dummy_outputs)
+
+        self.outputs_schema = outputs_schema
+        self.forward_schema = (self.args_schema, self.kwargs_schema)
+        self.backward_schema = (self.forward_schema, self.outputs_schema)  # original inputs and grad w.r.t. outputs
+        self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
+        self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
+
+    def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        args, kwargs = nested_pack(inputs, structure=self.forward_schema)
+
+        with torch.no_grad():
+            outputs = self.expert(*args, **kwargs)
+
+        # Note: TaskPool requires function to accept and return a **list** of values, we pack/unpack it on client side
+        return tuple(nested_flatten(outputs))
+
+    def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
+
+        with torch.enable_grad():
+            args = [tensor.detach().requires_grad_(True) for tensor in args]
+            kwargs = {input_key: tensor.detach().requires_grad_(True) for input_key, tensor in kwargs.items()}
+
+            outputs = self.expert(*args, **kwargs)
+            assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
+
+            outputs_flat = tuple(nested_flatten(outputs))
+
+            grad_outputs_flat = tuple(map(
+                lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True),
+                nested_flatten(grad_outputs), outputs_flat))
+            torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat,
+                                    create_graph=False, retain_graph=False)
+            self.apply_gradients()
+
+        return tuple(x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x)
+                     for x in nested_flatten((args, kwargs)))
+
+    def apply_gradients(self) -> None:
+        self.opt.step()
+        self.opt.zero_grad()
+
+    def get_pools(self) -> Sequence[TaskPool]:
+        return self.forward_pool, self.backward_pool
+
+    def get_info(self) -> Dict[str, Any]:
+        return dict(forward_schema=self.forward_schema, outputs_schema=self.outputs_schema,
+                    keyword_names=tuple(self.kwargs_schema.keys()))

+ 204 - 0
tesseract/runtime/task_pool.py

@@ -0,0 +1,204 @@
+"""
+Task pool is responsible for receiving tasks and grouping them together for processing (but not processing itself)
+"""
+import ctypes
+import multiprocessing as mp
+import os
+import threading
+import time
+import uuid
+from collections import namedtuple
+from concurrent.futures import Future
+from queue import Empty
+from typing import List, Tuple, Dict, Any
+
+import torch
+
+from ..utils import SharedFuture
+
+Task = namedtuple("Task", ("future", "args"))
+
+
+class TaskPoolBase(mp.Process):
+    """ A pool that accepts tasks and forms batches for parallel processing, interacts with TesseractRuntime """
+
+    def __init__(self, process_func: callable):
+        super().__init__()
+        self.process_func = process_func
+        self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
+
+    def run(self):
+        raise NotImplementedError()
+
+    def submit_task(self, *args: torch.Tensor) -> Future:
+        raise NotImplementedError()
+
+    def form_batch(self, *args, **kwargs) -> List[Task]:
+        raise NotImplementedError()
+
+    def iterate_minibatches(self, *args, **kwargs):
+        while True:
+            yield self.form_batch(*args, **kwargs)
+
+    @property
+    def priority(self):
+        return self._priority.value
+
+    @priority.setter
+    def priority(self, value):
+        self._priority.value = float(value)
+
+    @property
+    def empty(self):
+        raise NotImplementedError()
+
+
+class TaskPool(TaskPoolBase):
+
+    def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
+                 timeout=None, pool_size=None, prefetch_batches=1, uid=None, start=False):
+        """
+        Naive implementation of task pool that forms batch from earliest submitted tasks
+        :param process_func: function to be applied to every formed batch; called by TesseractRuntime
+            Note: process_func should accept only *args Tensors and return a list of output Tensors
+        :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
+        :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
+        :param timeout: wait for a subsequent task for at most this many seconds
+        :param pool_size: store at most this many unprocessed tasks in a queue
+        :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
+        :param uid: pool identifier used for shared array allocation
+        :param start: if True, start automatically at the end of __init__
+        """
+
+        super().__init__(process_func)
+        self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
+        self.uid = uid or uuid.uuid4()
+        self.prefetch_batches = prefetch_batches
+
+        # interaction with ConnectionHandlers
+        self.tasks = mp.Queue(maxsize=pool_size or 0)
+        self.undispatched_task_timestamps = mp.SimpleQueue()
+
+        # interaction with TesseractRuntime
+        self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)  # send/recv arrays that contain batch inputs
+        self.batch_received = mp.Event()  # runtime can notify pool that it can send next batch
+        self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False)  # send/recv arrays that contain outputs
+
+        if start:
+            self.start()
+
+    def submit_task(self, *args: torch.Tensor) -> Future:
+        future1, future2 = SharedFuture.make_pair()
+        self.tasks.put(Task(future1, args))
+        self.undispatched_task_timestamps.put(time.time())
+        return future2
+
+    def form_batch(self) -> List[Task]:
+        batch_tasks = []
+        total_size = 0
+
+        while total_size < self.max_batch_size:
+            if total_size >= self.min_batch_size and self.tasks.empty():
+                break  # timeout reached, returning incomplete batch
+
+            try:
+                task = self.tasks.get(timeout=self.timeout)
+            except Empty:
+                exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
+                for task in batch_tasks:
+                    task.future.set_exception(exc)
+                raise exc
+
+            if task.future.set_running_or_notify_cancel():
+                batch_tasks.append(task)
+                total_size += self.get_task_size(task)
+
+        return batch_tasks
+
+    def run(self, *args, **kwargs):
+        print(f'Starting pool, {os.getpid()=}')
+        pending_batches = {}  # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
+        output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
+                                         name=f'{self.uid}-pool_output_loop')
+        try:
+            output_thread.start()
+            self._pool_input_loop(pending_batches, *args, **kwargs)
+        except BaseException as e:
+            # terminate output loop
+            self.outputs_sender.send(e)
+            output_thread.join()
+            raise e
+
+    def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
+        """ Infinite loop: aggregate tasks into batches and send them to runtime """
+        prev_num_tasks = 0  # number of tasks currently in shared buffer
+        batch_index = max(pending_batches.keys(), default=0)
+        batch_iterator = self.iterate_minibatches(*args, **kwargs)
+        self.batch_received.set()  # initial state: no batches/outputs pending
+
+        while True:
+            self.batch_received.wait()  # wait for runtime to receive (copy) previous batch
+
+            # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
+            # assumes that tasks are processed in the same order as they are created
+            for skip_i in range(prev_num_tasks):
+                finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
+                if skip_i == prev_num_tasks - 1:
+                    self.priority = finished_task_timestamp
+
+            batch_tasks = next(batch_iterator)
+            # save batch futures, _output_loop will deliver on them later
+            pending_batches[batch_index] = batch_tasks
+
+            # find or create shared arrays for current batch size
+            batch_inputs = [
+                torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
+                for i in range(len(batch_tasks[0].args))
+            ]
+
+            self.batch_received.clear()  # sending next batch...
+            self.batch_sender.send((batch_index, batch_inputs))
+            prev_num_tasks = len(batch_tasks)
+            batch_index += 1
+
+    def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
+        """ Infinite loop: receive results from runtime and dispatch them to task Futures """
+
+        while True:
+            payload = self.outputs_receiver.recv()
+            if isinstance(payload, BaseException):
+                raise payload
+            else:
+                batch_index, batch_outputs = payload
+
+            # split batch into partitions for individual tasks
+            batch_tasks = pending_batches.pop(batch_index)
+            task_sizes = [self.get_task_size(task) for task in batch_tasks]
+            outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
+
+            # dispatch results to futures
+            for task, task_outputs in zip(batch_tasks, outputs_per_task):
+                task.future.set_result(tuple(task_outputs))
+
+    @property
+    def empty(self):
+        return not self.batch_receiver.poll()
+
+    def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
+        """ receive next batch of numpy arrays """
+        if not self.batch_receiver.poll(timeout):
+            raise TimeoutError()
+
+        batch_index, batch_inputs = self.batch_receiver.recv()
+        self.batch_received.set()  # pool can now prepare next batch
+        batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
+        return batch_index, batch_inputs
+
+    def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
+        """ send results for a processed batch, previously loaded through load_batch_to_runtime """
+        batch_outputs = [tensor.to(device='cpu').share_memory_() for tensor in batch_outputs]
+        self.outputs_sender.send((batch_index, batch_outputs))
+
+    def get_task_size(self, task: Task) -> int:
+        """ compute task processing complexity (used for batching); defaults to batch size """
+        return len(task.args[0]) if task.args else 1

+ 80 - 0
tesseract/server/__init__.py

@@ -0,0 +1,80 @@
+import multiprocessing as mp
+import os
+import threading
+from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
+from typing import Dict
+from warnings import warn
+
+from .connection_handler import handle_connection
+from .network_handler import NetworkHandlerThread
+from ..network import TesseractNetwork
+from ..runtime import TesseractRuntime, ExpertBackend
+
+
+class TesseractServer(threading.Thread):
+    def __init__(self, network: TesseractNetwork, expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
+                 port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
+                 **kwargs):
+        super().__init__()
+        self.network, self.experts, self.update_period = network, expert_backends, update_period
+        self.addr, self.port = addr, port
+        self.conn_handlers = self._create_connection_handlers(conn_handler_processes)
+        self.runtime = TesseractRuntime(self.experts, **kwargs)
+
+        if start:
+            self.start()
+
+    def run(self):
+        if self.network:
+            if not self.network.is_alive():
+                self.network.start()
+
+            network_thread = NetworkHandlerThread(experts=self.experts, network=self.network,
+                                                  addr=self.addr, port=self.port, update_period=self.update_period)
+            network_thread.start()
+
+        for process in self.conn_handlers:
+            if not process.is_alive():
+                process.start()
+
+        self.runtime.run()
+
+        for process in self.conn_handlers:
+            process.join()
+        if self.network:
+            network_thread.join()
+
+    @property
+    def ready(self):
+        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
+
+    def _create_connection_handlers(self, num_handlers):
+        sock = socket(AF_INET, SOCK_STREAM)
+        sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+        sock.bind(('', self.port))
+        sock.listen()
+        sock.settimeout(self.update_period)
+
+        processes = [mp.Process(target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts))
+                     for i in range(num_handlers)]
+        return processes
+
+    def shutdown(self):
+        """ Gracefully terminate a tesseract server, process-safe """
+        self.runtime.shutdown()
+        for process in self.conn_handlers:
+            process.terminate()
+        warn("TODO shutdown network")
+
+
+def socket_loop(sock, experts):
+    """ catch connections, send tasks to processing, respond with results """
+    print(f'Spawned connection handler pid={os.getpid()}')
+    while True:
+        try:
+            handle_connection(sock.accept(), experts)
+        except KeyboardInterrupt as e:
+            print(f'Socket loop has caught {type(e)}, exiting')
+            break
+        except (timeout, BrokenPipeError, ConnectionResetError, NotImplementedError):
+            continue

+ 29 - 0
tesseract/server/connection_handler.py

@@ -0,0 +1,29 @@
+from socket import socket
+from typing import Tuple, Dict
+
+from tesseract.runtime.expert_backend import ExpertBackend
+from tesseract.utils import PytorchSerializer, Connection
+
+
+def handle_connection(connection_tuple: Tuple[socket, str], experts: Dict[str, ExpertBackend]):
+    with Connection(*connection_tuple) as connection:
+        try:
+            header = connection.recv_header()
+            payload = PytorchSerializer.loads(connection.recv_raw())
+
+            if header == 'fwd_':
+                uid, inputs = payload
+                response = experts[uid].forward_pool.submit_task(*inputs).result()
+            elif header == 'bwd_':
+                uid, inputs_and_grad_outputs = payload
+                response = experts[uid].backward_pool.submit_task(*inputs_and_grad_outputs).result()
+            elif header == 'info':
+                uid = payload
+                response = experts[uid].get_info()
+            else:
+                raise NotImplementedError(f"Unknown header: {header}")
+
+            connection.send_raw('rest', PytorchSerializer.dumps(response))
+        except RuntimeError:
+            # socket connection broken
+            pass

+ 20 - 0
tesseract/server/network_handler.py

@@ -0,0 +1,20 @@
+import threading
+import time
+
+from ..network import TesseractNetwork
+
+
+class NetworkHandlerThread(threading.Thread):
+    def __init__(self, experts, network: TesseractNetwork,
+                 update_period: int = 5, addr: str = '127.0.0.1', port: int = 8080):
+        super(NetworkHandlerThread, self).__init__()
+        self.port = port
+        self.addr = addr
+        self.experts = experts
+        self.network = network
+        self.update_period = update_period
+
+    def run(self) -> None:
+        while True:
+            self.network.declare_experts(self.experts.keys(), self.addr, self.port)
+            time.sleep(self.update_period)

+ 7 - 0
tesseract/utils/__init__.py

@@ -0,0 +1,7 @@
+from .connection import *
+from .data import *
+from .nested import *
+from .proto import *
+from .serializer import *
+from .shared_future import *
+from .threading import *

+ 54 - 0
tesseract/utils/connection.py

@@ -0,0 +1,54 @@
+from contextlib import AbstractContextManager
+from socket import socket
+from typing import Tuple
+
+
+class Connection(AbstractContextManager):
+    header_size = 4  # number of characters in all headers
+    payload_length_size = 8  # number of bytes used to encode payload length
+
+    __slots__ = ('conn', 'addr')
+
+    def __init__(self, conn: socket, addr: Tuple[str, int]):
+        self.conn, self.addr = conn, addr
+
+    @staticmethod
+    def create(host: str, port: int):
+        sock = socket()
+        addr = (host, port)
+        sock.connect(addr)
+        return Connection(sock, addr)
+
+    def send_raw(self, header: str, content: bytes):
+        self.conn.send(header.encode())
+        self.conn.send(len(content).to_bytes(self.payload_length_size, byteorder='big'))
+
+        total_sent = 0
+        while total_sent < len(content):
+            sent = self.conn.send(content[total_sent:])
+            if sent == 0:
+                raise RuntimeError("socket connection broken")
+            total_sent = total_sent + sent
+
+    def recv_header(self) -> str:
+        return self.conn.recv(self.header_size).decode()
+
+    def recv_raw(self, max_package: int = 2048) -> bytes:
+        length = int.from_bytes(self.conn.recv(self.payload_length_size), byteorder='big')
+        chunks = []
+        bytes_recd = 0
+        while bytes_recd < length:
+            chunk = self.conn.recv(min(length - bytes_recd, max_package))
+            if chunk == b'':
+                raise RuntimeError("socket connection broken")
+            chunks.append(chunk)
+            bytes_recd = bytes_recd + len(chunk)
+        ret = b''.join(chunks)
+        assert len(ret) == length
+        return ret
+
+    def recv_message(self) -> Tuple[str, bytes]:
+        return self.recv_header(), self.recv_raw()
+
+    def __exit__(self, *exc_info):
+        self.conn.close()

+ 13 - 0
tesseract/utils/data.py

@@ -0,0 +1,13 @@
+import numpy as np
+import torch
+
+
+def check_numpy(x):
+    """ Makes sure x is a numpy array """
+    if isinstance(x, torch.Tensor):
+        return x.detach().cpu().numpy()
+    else:
+        return np.asarray(x)
+
+
+DUMMY = torch.empty(0, requires_grad=True)

+ 97 - 0
tesseract/utils/nested.py

@@ -0,0 +1,97 @@
+""" utility functions that help you process nested dicts, tuples, lists and namedtuples """
+
+
+def nested_compare(t, u):
+    """
+    Return whether nested structure of t1 and t2 matches.
+    """
+    if isinstance(t, (list, tuple)):
+        if not isinstance(u, type(t)):
+            return False
+        if len(t) != len(u):
+            return False
+        for a, b in zip(t, u):
+            if not nested_compare(a, b):
+                return False
+        return True
+
+    if isinstance(t, dict):
+        if not isinstance(u, dict):
+            return False
+        if set(t.keys()) != set(u.keys()):
+            return False
+        for k in t:
+            if not nested_compare(t[k], u[k]):
+                return False
+        return True
+
+    else:
+        return True
+
+
+def nested_flatten(t):
+    """
+    Turn nested list/tuple/dict into a flat iterator.
+    """
+    if isinstance(t, (list, tuple)):
+        for x in t:
+            yield from nested_flatten(x)
+    elif isinstance(t, dict):
+        for k, v in sorted(t.items()):
+            yield from nested_flatten(v)
+    else:
+        yield t
+
+
+def nested_pack(flat, structure):
+    """
+    Restore nested structure from flattened state
+    :param flat: result of nested_flatten
+    :param structure: used as example when recovering structure
+    :returns: nested structure like :structure: filled with elements of :flat:
+    """
+    return _nested_pack(iter(flat), structure)
+
+
+def _nested_pack(flat_iter, structure):
+    if is_namedtuple(structure):
+        return type(structure)(*[
+            _nested_pack(flat_iter, x)
+            for x in structure]
+                               )
+    elif isinstance(structure, (list, tuple)):
+        return type(structure)(
+            _nested_pack(flat_iter, x)
+            for x in structure
+        )
+    elif isinstance(structure, dict):
+        return {
+            k: _nested_pack(flat_iter, v)
+            for k, v in sorted(structure.items())
+        }
+    else:
+        return next(flat_iter)
+
+
+def is_namedtuple(x):
+    """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 ."""
+    t = type(x)
+    b = t.__bases__
+    if len(b) != 1 or b[0] != tuple: return False
+    f = getattr(t, '_fields', None)
+    if not isinstance(f, tuple): return False
+    return all(type(n) == str for n in f)
+
+
+def nested_map(fn, *t):
+    # Check arguments.
+    if not t:
+        raise ValueError('Expected 2+ arguments, got 1')
+    for i in range(1, len(t)):
+        if not nested_compare(t[0], t[i]):
+            msg = 'Nested structure of %r and %r differs'
+            raise ValueError(msg % (t[0], t[i]))
+
+    # Map.
+    flat = map(nested_flatten, t)
+    return nested_pack(map(fn, *flat), t[0])

+ 52 - 0
tesseract/utils/proto.py

@@ -0,0 +1,52 @@
+from dataclasses import dataclass, asdict
+
+import torch
+
+DUMMY_BATCH_SIZE = 3  # used for dummy runs only
+
+
+@dataclass(init=True, repr=True, frozen=True)
+class ProtoBase:
+    pass
+
+
+@dataclass(init=True, repr=True, frozen=True)
+class TensorProto(ProtoBase):
+    size: tuple
+    dtype: torch.dtype = None
+    layout: torch.layout = torch.strided
+    device: torch.device = None
+    requires_grad: bool = False
+    pin_memory: bool = False
+
+    @property
+    def shape(self):
+        return self.size
+
+    @classmethod
+    def from_tensor(cls, tensor: torch.Tensor):
+        return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, tensor.is_pinned())
+
+    def make_empty(self, **kwargs):
+        properties = asdict(self)
+        properties.update(kwargs)
+        return torch.empty(**properties)
+
+
+@dataclass(repr=True, frozen=True)
+class BatchTensorProto(TensorProto):
+    """ torch Tensor with a variable 0-th dimension, used to describe batched data """
+
+    def __init__(self, *instance_size, **kwargs):  # compatibility: allow initializing with *size
+        if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)):
+            instance_size = instance_size[0]  # we were given size as the only parameter instead of *parameters
+        super().__init__((None, *instance_size), **kwargs)
+
+    @classmethod
+    def from_tensor(cls, tensor: torch.Tensor):
+        return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
+                   device=tensor.device, requires_grad=tensor.requires_grad, pin_memory=tensor.is_pinned())
+
+    def make_empty(self, batch_size, **kwargs):
+        assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
+        return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)

+ 41 - 0
tesseract/utils/serializer.py

@@ -0,0 +1,41 @@
+import pickle
+from io import BytesIO
+
+import joblib
+import torch
+
+
+class JoblibSerializer:
+
+    @staticmethod
+    def dumps(obj) -> bytes:
+        s = BytesIO()
+        joblib.dump(obj, s)
+        return s.getvalue()
+
+    @staticmethod
+    def loads(buf: bytes):
+        return joblib.load(BytesIO(buf))
+
+
+class PickleSerializer:
+    @staticmethod
+    def dumps(obj) -> bytes:
+        return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
+
+    @staticmethod
+    def loads(buf: bytes):
+        return pickle.loads(buf)
+
+
+class PytorchSerializer:
+
+    @staticmethod
+    def dumps(obj) -> bytes:
+        s = BytesIO()
+        torch.save(obj, s, pickle_protocol=pickle.HIGHEST_PROTOCOL)
+        return s.getvalue()
+
+    @staticmethod
+    def loads(buf: bytes):
+        return torch.load(BytesIO(buf))

+ 105 - 0
tesseract/utils/shared_future.py

@@ -0,0 +1,105 @@
+import multiprocessing as mp
+import multiprocessing.connection
+from concurrent.futures import Future, CancelledError
+from warnings import warn
+
+
+class SharedFuture(Future):
+    """ Multiprocessing version of concurrent.futures.Future, interacts between two processes via Pipe """
+    STATES = 'pending', 'running', 'cancelled', 'finished', 'exception'
+    STATE_PENDING, STATE_RUNNING, STATE_CANCELLED, STATE_FINISHED, STATE_EXCEPTION = STATES
+
+    def __init__(self, connection: mp.connection.Connection):
+        """ manually create MPFuture. Please use MPFuture.make_pair instead """
+        self.connection = connection
+        self.state = self.STATE_PENDING
+        self._result = None
+        self._exception = None
+
+    @classmethod
+    def make_pair(cls):
+        """ Create a pair of linked futures to be used in two processes """
+        connection1, connection2 = mp.Pipe()
+        return cls(connection1), cls(connection2)
+
+    def _recv(self, timeout):
+        if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
+            if not self.connection.poll(timeout):
+                raise TimeoutError()
+            try:
+                status, payload = self.connection.recv()
+            except BrokenPipeError as e:
+                status, payload = self.STATE_EXCEPTION, e
+
+            assert status in self.STATES
+            self.state = status
+
+            if status == self.STATE_FINISHED:
+                self._result = payload
+            elif status == self.STATE_EXCEPTION:
+                self._exception = payload
+            elif status in (self.STATE_RUNNING, self.STATE_CANCELLED):
+                pass  # only update self.state
+            else:
+                raise ValueError("Result status should not be self.STATE_PENDING")
+
+    def set_result(self, result):
+        try:
+            self.state, self._result = self.STATE_FINISHED, result
+            self.connection.send((self.STATE_FINISHED, result))
+            return True
+        except BrokenPipeError:
+            return False
+
+    def set_exception(self, exception: BaseException):
+        try:
+            self.state, self._exception = self.STATE_EXCEPTION, exception
+            self.connection.send((self.STATE_EXCEPTION, exception))
+            return True
+        except BrokenPipeError:
+            return False
+
+    def set_running_or_notify_cancel(self):
+        return True
+
+    def cancel(self):
+        raise NotImplementedError()
+
+    def result(self, timeout=None):
+        self._recv(timeout)
+        if self.state == self.STATE_FINISHED:
+            return self._result
+        elif self.state == self.STATE_EXCEPTION:
+            raise self._exception
+        else:
+            assert self.state == self.STATE_CANCELLED
+            raise CancelledError()
+
+    def exception(self, timeout=None):
+        self._recv(timeout)
+        return self._exception
+
+    def done(self):
+        return self.state in (self.STATE_FINISHED, self.STATE_EXCEPTION, self.STATE_CANCELLED)
+
+    def running(self):
+        return self.state == self.STATE_RUNNING
+
+    def cancelled(self):
+        warn("cancelled not implemented")
+        return False
+
+    def add_done_callback(self, callback):
+        raise NotImplementedError()
+
+    def __repr__(self):
+        try:
+            self._recv(timeout=0)
+        except TimeoutError:
+            pass
+        if self.state == self.STATE_FINISHED:
+            return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
+        elif self.state == self.STATE_EXCEPTION:
+            return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
+        else:
+            return "<MPFuture at 0x{:x} state={}>".format(id(self), self.state)

+ 125 - 0
tesseract/utils/threading.py

@@ -0,0 +1,125 @@
+import time
+from concurrent.futures import Future, TimeoutError
+from itertools import count
+from threading import Thread, Event, Lock
+
+
+def run_in_background(func: callable, *args, **kwargs):
+    """ run f(*args, **kwargs) in background and return Future for its outputs """
+    future = Future()
+
+    def _run():
+        try:
+            future.set_result(func(*args, **kwargs))
+        except Exception as e:
+            future.set_exception(e)
+
+    Thread(target=_run).start()
+    return future
+
+
+def repeated(func: callable, n_times=None):
+    """ A function that runs a :func: forever or for a specified number of times; use with run_run_in_background """
+
+    def repeat():
+        for i in count():
+            if n_times is not None and i > n_times:
+                break
+            func()
+
+    return repeat
+
+
+def add_event_callback(event: Event, callback, timeout=None):
+    """ Add callback that will be executed asynchronously when event is set """
+    return Thread(target=lambda: (event.wait(timeout), callback())).start()
+
+
+class CountdownEvent(Event):
+    def __init__(self, count_to: int, initial=0):
+        """ An event that must be incremented :count_to: times before it is considered set """
+        super().__init__()
+        self.value = initial
+        self.count_to = count_to
+        self.lock = Lock()
+        self.increment(by=0)  # trigger set/unset depending on initial value
+
+    def increment(self, by=1):
+        with self.lock:
+            self.value += by
+            if self.value >= self.count_to:
+                super().set()
+            else:
+                super().clear()
+            return self.value
+
+    def clear(self):
+        return self.increment(by=-self.value)
+
+
+def await_first(*events: Event, k=1, timeout=None):
+    """
+    wait until first k (default=1) events are set, return True if event was set fast
+    # Note: after k successes we manually *set* all events to avoid memory leak.
+    """
+    events_done = CountdownEvent(count_to=k)
+    for event in events:
+        add_event_callback(event, callback=events_done.increment, timeout=timeout)
+
+    if events_done.wait(timeout=timeout):
+        [event.set() for event in events]
+        return True
+    else:
+        raise TimeoutError()
+
+
+def run_and_await_k(jobs: callable, k, timeout_after_k=0, timeout_total=None):
+    """
+    Runs all :jobs: asynchronously, awaits for at least k of them to finish
+    :param jobs: functions to call
+    :param k: how many functions should finish
+    :param timeout_after_k: after reaching k finished jobs, wait for this long before cancelling
+    :param timeout_total: if specified, terminate cancel jobs after this many seconds
+    :returns: a list of either results or exceptions for each job
+    """
+    assert k <= len(jobs)
+    start_time = time.time()
+    min_successful_jobs = CountdownEvent(count_to=k)
+    max_failed_jobs = CountdownEvent(count_to=len(jobs) - k + 1)
+
+    def _run_and_increment(run_job: callable):
+        try:
+            result = run_job()
+            min_successful_jobs.increment()
+            return result
+        except Exception as e:
+            max_failed_jobs.increment()
+            return e
+
+    def _run_and_await(run_job: callable):
+        # call function asynchronously. Increment counter after finished
+        future = run_in_background(_run_and_increment, run_job)
+
+        try:  # await for success counter to reach k OR for fail counter to reach n - k + 1
+            await_first(min_successful_jobs, max_failed_jobs,
+                        timeout=None if timeout_total is None else timeout_total - time.time() + start_time)
+        except TimeoutError as e:  # counter didn't reach k jobs in timeout_total
+            return future.result() if future.done() else e
+
+        try:  # await for subsequent jobs if asked to
+            return future.result(timeout=timeout_after_k)
+        except TimeoutError as e:
+            future.cancel()
+            return e
+
+        except Exception as e:  # job failed with exception. Ignore it.
+            return e
+
+    results = [run_in_background(_run_and_await, f) for f in jobs]
+    results = [result.result() for result in results]
+    if min_successful_jobs.is_set():
+        return results
+    elif max_failed_jobs.is_set():
+        raise ValueError("Could not get enough results: too many jobs failed.")
+    else:
+        raise TimeoutError("Could not get enough results: reached timeout_total.")

+ 142 - 0
tests/benchmark_throughput.py

@@ -0,0 +1,142 @@
+import argparse
+import multiprocessing as mp
+import random
+import resource
+import sys
+import time
+
+import torch
+from test_utils import layers, print_device_info, find_open_port
+
+import tesseract
+
+
+def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+    can_start.wait()
+    experts = [tesseract.RemoteExpert(f"expert{i}", port=port) for i in range(num_experts)]
+
+    try:
+        dummy_batch = torch.randn(batch_size, hid_dim)
+        for batch_i in range(num_batches):
+            expert = random.choice(experts)
+            out = expert(dummy_batch)
+            if backprop:
+                out.sum().backward()
+    except BaseException as e:
+        benchmarking_failed.set()
+        raise e
+
+
+def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num_batches_per_client=16,
+                         expert_cls='ffn', hid_dim=1024, batch_size=2048, max_batch_size=None, backprop=True,
+                         device=None, port=None):
+    assert not hasattr(torch.cuda, 'is_initialized') or not torch.cuda.is_initialized() \
+           or torch.device(device) == torch.device('cpu')
+    assert expert_cls in layers.name_to_block
+    port = port or find_open_port()
+    max_batch_size = max_batch_size or batch_size * 4
+    num_handlers = max(1, num_handlers or num_clients // 2)
+    benchmarking_failed = mp.Event()
+    can_start = mp.Event()
+    timestamps = dict(started=time.perf_counter())
+
+    try:
+        # start clients and await server
+        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        clients = [
+            mp.Process(
+                target=client_process, name=f'client_process-{i}',
+                args=(can_start, benchmarking_failed, port, num_experts, batch_size,
+                      hid_dim, num_batches_per_client, backprop))
+            for i in range(num_clients)]
+
+        for client in clients:
+            client.daemon = True
+            client.start()
+
+        timestamps['launched_clients'] = timestamps['began_launching_server'] = time.perf_counter()
+
+        # start server
+        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+        experts = {}
+        for i in range(num_experts):
+            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
+            experts[f'expert{i}'] = tesseract.ExpertBackend(name=f'expert{i}',
+                                                            expert=expert, opt=torch.optim.Adam(expert.parameters()),
+                                                            args_schema=(tesseract.BatchTensorProto(hid_dim),),
+                                                            outputs_schema=tesseract.BatchTensorProto(hid_dim),
+                                                            max_batch_size=max_batch_size,
+                                                            )
+        timestamps['created_experts'] = time.perf_counter()
+        server = tesseract.TesseractServer(None, experts, port=port, conn_handler_processes=num_handlers, device=device)
+        server.start()
+        server.ready.wait()
+        timestamps['server_ready'] = time.perf_counter()
+        can_start.set()
+
+        for client in clients:
+            client.join()
+        timestamps['clients_finished'] = time.perf_counter()
+    except BaseException as e:
+        benchmarking_failed.set()
+        raise e
+    finally:
+        for client in clients:
+            if client.is_alive():
+                client.terminate()
+        server.shutdown()
+        timestamps['server_shutdown_finished'] = time.perf_counter()
+        server.join()
+
+    sys.stdout.flush()
+    sys.stderr.flush()
+    time_between = lambda key1, key2: \
+        abs(timestamps[key2] - timestamps[key1]) if (key1 in timestamps and key2 in timestamps) else float('nan')
+    total_examples = batch_size * num_clients * num_batches_per_client
+
+    print('\n' * 3)
+    print("Benchmark finished, status:", ["Success", "Failure"][benchmarking_failed.is_set()])
+    print(f"Server parameters: {num_experts=} {num_handlers=} {max_batch_size=} {expert_cls=} {hid_dim=} {device=}")
+    print(f"Client parameters: {num_clients=} {num_batches_per_client=} {batch_size=} {backprop=}")
+    print(f"Results: ")
+    print(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
+          f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
+          f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
+    print(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
+    print(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
+          f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
+    print(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
+    if benchmarking_failed.is_set():
+        print("Note: benchmark code failed, timing/memory results only indicate time till failure!")
+    print_device_info(device)
+    print(flush=True)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--preset', type=str, default='default', required=False)
+    parser.add_argument('--num_batches_per_client', type=int, default=16, required=False)
+    args = parser.parse_args()
+
+    if args.preset in ('default', 'ffn_forward_backward'):
+        benchmark_throughput()
+    elif args.preset == 'ffn_forward':
+        benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
+    elif args.preset == 'ffn_small_batch':
+        benchmark_throughput(backprop=False, num_experts=4, batch_size=32, max_batch_size=8192,
+                             num_batches_per_client=args.num_batches_per_client)
+    elif args.preset == 'ffn_massive':
+        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+        try:
+            print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
+            resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
+        except:
+            print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
+        benchmark_throughput(backprop=False, num_clients=512, batch_size=512,
+                             max_batch_size=8192, num_batches_per_client=args.num_batches_per_client)
+    elif args.preset == 'minimalistic':
+        benchmark_throughput(num_experts=1, num_clients=1, num_handlers=1)
+    elif args.preset == 'nop':
+        benchmark_throughput(expert_cls='nop', backprop=False, num_batches_per_client=args.num_batches_per_client)
+    else:
+        raise ValueError(f"No such benchmark preset: {args.preset}")

+ 25 - 0
tests/test_utils/__init__.py

@@ -0,0 +1,25 @@
+from socket import socket
+
+import torch
+
+
+def print_device_info(device=None):
+    # prints device stats. Code from https://stackoverflow.com/a/53374933/12891528
+    device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
+    print('Using device:', device)
+
+    # Additional Info when using cuda
+    if device.type == 'cuda':
+        print(torch.cuda.get_device_name(0))
+        print('Memory Usage:')
+        print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
+        print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
+
+
+def find_open_port():
+    try:
+        sock = socket()
+        sock.bind(('', 0))
+        return sock.getsockname()[1]
+    except:
+        raise ValueError("Could not find open port")

+ 68 - 0
tests/test_utils/layers.py

@@ -0,0 +1,68 @@
+import torch
+from torch import nn as nn
+
+
+class FeedforwardBlock(nn.Module):
+    def __init__(self, hid_dim):
+        super().__init__()
+        self.layers = nn.Sequential(
+            nn.Linear(hid_dim, 4 * hid_dim),
+            nn.LayerNorm(4 * hid_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(4 * hid_dim, 4 * hid_dim),
+            nn.LayerNorm(4 * hid_dim),
+            nn.ReLU(inplace=True),
+            nn.Linear(4 * hid_dim, hid_dim),
+        )
+
+    def forward(self, x):
+        return x + self.layers(x)
+
+
+class TransformerEncoderLayer(nn.Module):
+    """
+    A slight modification of torch.nn.TransformerEncoderLayer which allows for torch.jit scripting
+    """
+
+    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = torch.nn.GELU()
+
+    def forward(self, src):
+        src.transpose_(0, 1)
+        src2 = self.self_attn(src, src, src)[0]
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+        src.transpose_(0, 1)
+        return src
+
+
+class NopExpert(nn.Sequential):
+    def __init__(self, hid_dim):
+        super().__init__()
+        self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
+
+    def forward(self, x):
+        return x.clone()
+
+
+name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim),
+                 'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, nhead=16),
+                 'nop': lambda hid_dim: NopExpert(hid_dim)}
+name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)),
+                 'transformer': lambda batch_size, hid_dim: torch.empty((batch_size, 512, hid_dim)),
+                 'nop': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))}