浏览代码

GRPC connection handlers (#61)

* Add connection handlers with grpc

* Implement grpc for client/server

* Cache channel, get rid of warnings

* Awaitable interactions with TaskPool

* [personal] parallel gRPC handlers (#69)

* spawn multiple connection handlers

* remove reserve_port

* minor: make preset_minimalistic actually care about num_batches_per_client

* RemoteExpert can no longer be pickled

* fix broken moe.py after changes in RemoteModuleCall

* write-through

* fix moe.py (broken by changes in _RemoteModuleCall

* fix a bug where connection handlers set ready flag prematurely

* fix wrong gradient type if not all experts survived

* connection_handler now sets ready only when actually ready

* create stub in a lazy manner

* rollback changes to DHT

* Update TODO, remove message limits

* Connection is gone 🦀🦀🦀

* Cleanup

* Switch to absolute imports (#70)

Co-authored-by: xtinkt <ant.sinitsin@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Max Ryabinin 5 年之前
父节点
当前提交
f1565ef7af

+ 1 - 1
.circleci/config.yml

@@ -27,7 +27,7 @@ jobs:
           command: sudo python setup.py develop
           command: sudo python setup.py develop
           name: setup
           name: setup
       - run:
       - run:
-          command: pytest ./tests --full-trace
+          command: for test_file in tests/test*.py; do pytest $test_file --full-trace; done
           name: tests
           name: tests
       - run:
       - run:
           command: python tests/benchmark_throughput.py --preset minimalistic
           command: python tests/benchmark_throughput.py --preset minimalistic

+ 5 - 5
hivemind/__init__.py

@@ -1,7 +1,7 @@
-from .client import *
-from .dht import *
-from .server import Server
-from .utils import *
-from .runtime import *
+from hivemind.client import *
+from hivemind.dht import *
+from hivemind.server import Server
+from hivemind.utils import *
+from hivemind.runtime import *
 
 
 __version__ = '0.7.1'
 __version__ = '0.7.1'

+ 2 - 2
hivemind/client/__init__.py

@@ -1,2 +1,2 @@
-from .moe import RemoteMixtureOfExperts
-from .expert import RemoteExpert
+from hivemind.client.expert import RemoteExpert
+from hivemind.client.moe import RemoteMixtureOfExperts

+ 41 - 19
hivemind/client/expert.py

@@ -1,10 +1,14 @@
+import pickle
 from typing import Tuple, Optional
 from typing import Tuple, Optional
 
 
+import grpc
+import grpc.experimental.aio
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-from ..utils import nested_flatten, DUMMY, PytorchSerializer, nested_pack, nested_compare, Connection
+from hivemind.utils import nested_flatten, DUMMY, nested_pack, nested_compare
+from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, runtime_pb2, runtime_grpc
 
 
 
 
 class RemoteExpert(nn.Module):
 class RemoteExpert(nn.Module):
@@ -23,12 +27,29 @@ class RemoteExpert(nn.Module):
     def __init__(self, uid, host='127.0.0.1', port=8080):
     def __init__(self, uid, host='127.0.0.1', port=8080):
         super().__init__()
         super().__init__()
         self.uid, self.host, self.port = uid, host, port
         self.uid, self.host, self.port = uid, host, port
+        self._channel, self._stub = None, None
         self._info = None
         self._info = None
 
 
+    @property
+    def stub(self):
+        if self._channel is None:
+            self._channel = grpc.insecure_channel(f'{self.host}:{self.port}', options=[
+                ('grpc.max_send_message_length', -1),
+                ('grpc.max_receive_message_length', -1)
+            ])
+        if self._stub is None:
+            self._stub = runtime_grpc.ConnectionHandlerStub(self._channel)
+        return self._stub
+
+    def __del__(self):
+        if self._channel is not None:
+            self._channel.close()
+
     def forward(self, *args, **kwargs):
     def forward(self, *args, **kwargs):
         """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """
         """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """
         assert len(kwargs) == len(self.info['keyword_names']), f"Keyword args should be {self.info['keyword_names']}"
         assert len(kwargs) == len(self.info['keyword_names']), f"Keyword args should be {self.info['keyword_names']}"
         kwargs = {key: kwargs[key] for key in self.info['keyword_names']}
         kwargs = {key: kwargs[key] for key in self.info['keyword_names']}
+
         # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
         # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
 
 
         forward_inputs = (args, kwargs)
         forward_inputs = (args, kwargs)
@@ -36,16 +57,16 @@ class RemoteExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info['forward_schema']):
         if not nested_compare(forward_inputs, self.info['forward_schema']):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
 
-        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, *nested_flatten(forward_inputs))
+        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, self.stub,
+                                               *nested_flatten(forward_inputs))
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
         return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
 
 
     @property
     @property
     def info(self):
     def info(self):
         if self._info is None:
         if self._info is None:
-            connection = Connection.create(self.host, self.port)
-            connection.send_raw('info', PytorchSerializer.dumps(self.uid))
-            self._info = PytorchSerializer.loads(connection.recv_message()[1])
+            outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
+            self._info = pickle.loads(outputs.serialized_info)
         return self._info
         return self._info
 
 
     def extra_repr(self):
     def extra_repr(self):
@@ -56,26 +77,27 @@ class _RemoteModuleCall(torch.autograd.Function):
     """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
     """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
 
 
     @staticmethod
     @staticmethod
-    def forward(ctx, dummy: torch.Tensor,
-                uid: str, host: str, port: int, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+    def forward(ctx, dummy: torch.Tensor, uid: str, host: str, port: int, stub: runtime_grpc.ConnectionHandlerStub,
+                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
         inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
-        ctx.uid, ctx.host, ctx.port = uid, host, port
+        ctx.uid, ctx.host, ctx.port, ctx.stub = uid, host, port, stub
         ctx.save_for_backward(*inputs)
         ctx.save_for_backward(*inputs)
 
 
-        connection = Connection.create(ctx.host, ctx.port)
-        connection.send_raw('fwd_', PytorchSerializer.dumps((ctx.uid, inputs)))
-        rtype, msg = connection.recv_message()
-        assert len(msg) != 0, "ExpertBackend.forward did not respond"
-        return tuple(PytorchSerializer.loads(msg))  # flattened expert outputs
+        outputs = stub.forward(
+            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
+
+        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+
+        return tuple(deserialized_outputs)
 
 
     @staticmethod
     @staticmethod
     @once_differentiable
     @once_differentiable
     def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
     def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
-        connection = Connection.create(ctx.host, ctx.port)
         payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
         payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
-        connection.send_raw('bwd_', PytorchSerializer.dumps((ctx.uid, payload)))
-        rtype, msg = connection.recv_message()
-        assert len(msg) != 0, "ExpertBackend.backward did not respond"
-        grad_inputs = PytorchSerializer.loads(msg)
-        return (DUMMY, None, None, None, *grad_inputs)
+
+        grad_inputs = ctx.stub.backward(
+            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
+
+        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
+        return (DUMMY, None, None, None, None, *deserialized_grad_inputs)

+ 12 - 11
hivemind/client/moe.py

@@ -1,16 +1,14 @@
-import multiprocessing as mp
-import multiprocessing.pool
 from functools import partial
 from functools import partial
-from typing import Tuple, List, Dict, Optional
+from typing import Tuple, List, Optional
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-from .expert import RemoteExpert, _RemoteModuleCall
-from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background
-from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
+from hivemind.client.expert import RemoteExpert, _RemoteModuleCall
+from hivemind.utils import nested_map, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background, \
+    run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
 
 
 
 
 class RemoteMixtureOfExperts(nn.Module):
 class RemoteMixtureOfExperts(nn.Module):
@@ -37,6 +35,7 @@ class RemoteMixtureOfExperts(nn.Module):
      allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
      allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
      allow_broadcasting=False will raise an error
      allow_broadcasting=False will raise an error
     """
     """
+
     def __init__(self, *, in_features, grid_size: Tuple[int], dht, k_best, k_min=1,
     def __init__(self, *, in_features, grid_size: Tuple[int], dht, k_best, k_min=1,
                  forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
                  forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
                  uid_prefix='', expert_padding=None, allow_broadcasting=True):
                  uid_prefix='', expert_padding=None, allow_broadcasting=True):
@@ -107,7 +106,7 @@ class RemoteMixtureOfExperts(nn.Module):
         delimeters = np.array(self.dht.UID_DELIMETER)[None, None, None]  # pre-compute numpy array for fast concat
         delimeters = np.array(self.dht.UID_DELIMETER)[None, None, None]  # pre-compute numpy array for fast concat
 
 
         for dim_index, dim_scores in enumerate(grid_scores):
         for dim_index, dim_scores in enumerate(grid_scores):
-            dim_scores = check_numpy(dim_scores)
+            dim_scores = dim_scores.detach().cpu().numpy()
             assert dim_scores.shape[-1] == self.grid_size[dim_index]
             assert dim_scores.shape[-1] == self.grid_size[dim_index]
 
 
             # create all possible successsors from current beam
             # create all possible successsors from current beam
@@ -194,6 +193,7 @@ class _RemoteMoECall(torch.autograd.Function):
     This function that can recover from individual failures during forward and/or backward passes.
     This function that can recover from individual failures during forward and/or backward passes.
     For user-friendly version of this function, use RemoteMixtureOfExperts module.
     For user-friendly version of this function, use RemoteMixtureOfExperts module.
     """
     """
+
     @classmethod
     @classmethod
     def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
     def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
                 k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
                 k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
@@ -250,18 +250,19 @@ class _RemoteMoECall(torch.autograd.Function):
             for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
             for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
         ))
         ))
         softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
         softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
-        grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
+        grad_wrt_survived_logits = grad_wrt_probs @ softmax_jacobian
+        grad_wrt_logits = torch.zeros_like(expert_logits).scatter(0, backward_survivors_ix, grad_wrt_survived_logits)
 
 
         return (grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs)
         return (grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs)
 
 
     @staticmethod
     @staticmethod
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
         """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
         """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
-        flat_inputs = nested_flatten((args, kwargs))
-        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, *flat_inputs)
+        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, expert.stub,
+                                    *nested_flatten((args, kwargs)))
 
 
     @staticmethod
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
         backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
         backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
-        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
+        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, no_grad_stub, *grad_inputs = backward_result
         return grad_inputs
         return grad_inputs

+ 5 - 5
hivemind/dht/__init__.py

@@ -17,13 +17,13 @@ import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import warnings
 import warnings
 from typing import List, Optional
 from typing import List, Optional
-import uvloop
 
 
-from .node import DHTNode, DHTID, DHTExpiration
-from .routing import get_dht_time
+import uvloop
 
 
-from ..client import RemoteExpert
-from ..utils import SharedFuture, find_open_port, Endpoint, Port, run_in_background, LOCALHOST
+from hivemind.client import RemoteExpert
+from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
+from hivemind.dht.routing import get_dht_time
+from hivemind.utils import SharedFuture, Endpoint, run_in_background
 
 
 
 
 class DHT(mp.Process):
 class DHT(mp.Process):

+ 27 - 27
hivemind/dht/dht.proto

@@ -4,52 +4,52 @@ syntax = "proto3";
 // For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode)
 // For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode)
 
 
 service DHT {
 service DHT {
-    // find out recipient's DHTID and possibly update its routing table
-    rpc rpc_ping(NodeInfo) returns (NodeInfo);
+  // find out recipient's DHTID and possibly update its routing table
+  rpc rpc_ping(NodeInfo) returns (NodeInfo);
 
 
-    // request a node to store one or multiple data items (key - value - expiration)
-    rpc rpc_store(StoreRequest) returns (StoreResponse);
+  // request a node to store one or multiple data items (key - value - expiration)
+  rpc rpc_store(StoreRequest) returns (StoreResponse);
 
 
-    // for given keys, request values (if stored) or a list of peers that are likely to have them
-    rpc rpc_find(FindRequest) returns (FindResponse);
+  // for given keys, request values (if stored) or a list of peers that are likely to have them
+  rpc rpc_find(FindRequest) returns (FindResponse);
 }
 }
 
 
 message NodeInfo {
 message NodeInfo {
-    // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
-    // if either node_id or port is absent, simply request recipient info (for client-only mode)
-    bytes node_id = 1;                // sender's own node id serialized with DHTID.to_bytes()
-    int32 rpc_port = 2;               // port to which sender listens for DHT RPCs
+  // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
+  // if either node_id or port is absent, simply request recipient info (for client-only mode)
+  bytes node_id = 1;                // sender's own node id serialized with DHTID.to_bytes()
+  int32 rpc_port = 2;               // port to which sender listens for DHT RPCs
 }
 }
 
 
 message StoreRequest {
 message StoreRequest {
-    // three lists of the same length representing dht keys, dht values and expiration
-    repeated bytes keys = 1;          // keys in the form of DHTID.generate(raw_key).to_bytes()
-    repeated bytes values = 2;        // binary-encoded value for i-th key
-    repeated double expiration = 3;   // expirations for i-th key (type = DHTExpiration)
-    repeated bool in_cache = 4;       // if in_cache[i], store i-th key in cache, else store normally
-    NodeInfo peer = 5;                // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  // three lists of the same length representing dht keys, dht values and expiration
+  repeated bytes keys = 1;          // keys in the form of DHTID.generate(raw_key).to_bytes()
+  repeated bytes values = 2;        // binary-encoded value for i-th key
+  repeated double expiration = 3;   // expirations for i-th key (type = DHTExpiration)
+  repeated bool in_cache = 4;       // if in_cache[i], store i-th key in cache, else store normally
+  NodeInfo peer = 5;                // (optional) sender's own node info, same behavior as in DHT.rpc_ping
 }
 }
 
 
 message StoreResponse {
 message StoreResponse {
-    repeated bool store_ok = 1;       // for every key, True means store accepted, False means store rejected/failed
-    NodeInfo peer = 2;                // respondent's node id, for you to update routing table
+  repeated bool store_ok = 1;       // for every key, True means store accepted, False means store rejected/failed
+  NodeInfo peer = 2;                // respondent's node id, for you to update routing table
 }
 }
 
 
 message FindRequest {
 message FindRequest {
-    repeated bytes keys = 1;          // a list of DHTID search keys encoded as bytes
-    NodeInfo peer = 2;                // optional, same behavior as in DHT.ping
+  repeated bytes keys = 1;          // a list of DHTID search keys encoded as bytes
+  NodeInfo peer = 2;                // optional, same behavior as in DHT.ping
 }
 }
 
 
 message Peers {
 message Peers {
-   // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
-   repeated bytes node_ids = 1;       // DHTID serialized with node_id.to_bytes()
-   repeated string endpoints = 2;     // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+  // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
+  repeated bytes node_ids = 1;       // DHTID serialized with node_id.to_bytes()
+  repeated string endpoints = 2;     // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
 }
 }
 
 
 message FindResponse {
 message FindResponse {
-    repeated bytes values = 1;        // value for i-th key, b'' means not found locally
-    repeated double expiration = 2;   // expiration time for i-th value, only valid value is found
-    repeated Peers nearest = 3;       // peers ordered from nearest to farthest based on distance to i-th key
-    NodeInfo peer = 4;                // respondent's node id, for you to update routing table
+  repeated bytes values = 1;        // value for i-th key, b'' means not found locally
+  repeated double expiration = 2;   // expiration time for i-th value, only valid value is found
+  repeated Peers nearest = 3;       // peers ordered from nearest to farthest based on distance to i-th key
+  NodeInfo peer = 4;                // respondent's node id, for you to update routing table
 }
 }
 
 

+ 5 - 4
hivemind/dht/node.py

@@ -1,14 +1,15 @@
 from __future__ import annotations
 from __future__ import annotations
+
 import asyncio
 import asyncio
 import random
 import random
 from collections import namedtuple
 from collections import namedtuple
 from typing import Optional, Tuple, List, Dict, Collection, Union, Set
 from typing import Optional, Tuple, List, Dict, Collection, Union, Set
 from warnings import warn
 from warnings import warn
 
 
-from .protocol import DHTProtocol
-from .routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue
-from .traverse import traverse_dht
-from ..utils import Endpoint, LOCALHOST, MSGPackSerializer
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue
+from hivemind.dht.traverse import traverse_dht
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer
 
 
 
 
 class DHTNode:
 class DHTNode:

+ 2 - 2
hivemind/dht/protocol.py

@@ -11,8 +11,8 @@ from warnings import warn
 import grpc
 import grpc
 import grpc.experimental.aio
 import grpc.experimental.aio
 
 
-from .routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
-from ..utils import Endpoint, compile_grpc, get_logger
+from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
+from hivemind.utils import Endpoint, compile_grpc, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 7 - 8
hivemind/dht/routing.py

@@ -2,16 +2,18 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import hashlib
 import hashlib
+import heapq
 import os
 import os
 import random
 import random
-
 import time
 import time
-import heapq
 from collections.abc import Iterable
 from collections.abc import Iterable
 from itertools import chain
 from itertools import chain
-from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence, Iterator
+from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
 
 
-from ..utils import Endpoint, PickleSerializer
+from hivemind.utils import Endpoint, PickleSerializer
+
+DHTKey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, float, bytes, bytes  # flavour types
+get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
 
 
 
 
 class RoutingTable:
 class RoutingTable:
@@ -160,6 +162,7 @@ class KBucket:
     A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
     A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
     Maps DHT node ids to their endpoints (hostname, addr)
     Maps DHT node ids to their endpoints (hostname, addr)
     """
     """
+
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
         assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
         assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
         self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
         self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
@@ -288,7 +291,3 @@ class DHTID(int):
 
 
     def __bytes__(self):
     def __bytes__(self):
         return self.to_bytes()
         return self.to_bytes()
-
-
-DHTKey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, float, bytes, bytes  # flavour types
-get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time

+ 8 - 10
hivemind/dht/traverse.py

@@ -2,9 +2,9 @@
 import asyncio
 import asyncio
 import heapq
 import heapq
 from collections import Counter
 from collections import Counter
-from warnings import warn
 from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
 from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
-from .routing import DHTID
+
+from hivemind.dht.routing import DHTID
 
 
 ROOT = 0  # alias for heap root
 ROOT = 0  # alias for heap root
 
 
@@ -107,13 +107,13 @@ async def traverse_dht(
     if len(queries) == 0:
     if len(queries) == 0:
         return {}, dict(visited_nodes)
         return {}, dict(visited_nodes)
 
 
-    unfinished_queries = set(queries)                             # all queries that haven't triggered finish_search yet
-    candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}    # heap: unvisited nodes, ordered nearest-to-farthest
-    nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}      # heap: top-k nearest nodes, farthest-to-nearest
-    known_nodes: Dict[DHTID, Set[DHTID]] = {}                     # all nodes ever added to the heap (for deduplication)
+    unfinished_queries = set(queries)  # all queries that haven't triggered finish_search yet
+    candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: unvisited nodes, ordered nearest-to-farthest
+    nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: top-k nearest nodes, farthest-to-nearest
+    known_nodes: Dict[DHTID, Set[DHTID]] = {}  # all nodes ever added to the heap (for deduplication)
     visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes)  # where we requested get_neighbors for a given query
     visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes)  # where we requested get_neighbors for a given query
-    pending_tasks = set()                                         # all active tasks (get_neighbors and found_callback)
-    active_workers = Counter({q: 0 for q in queries})             # count workers that search for this query
+    pending_tasks = set()  # all active tasks (get_neighbors and found_callback)
+    active_workers = Counter({q: 0 for q in queries})  # count workers that search for this query
 
 
     search_finished_event = asyncio.Event()  # used to immediately stop all workers when the search is finished
     search_finished_event = asyncio.Event()  # used to immediately stop all workers when the search is finished
     heap_updated_event = asyncio.Event()  # if a worker has no nodes to explore, it will await other workers
     heap_updated_event = asyncio.Event()  # if a worker has no nodes to explore, it will await other workers
@@ -228,5 +228,3 @@ async def traverse_dht(
         for query in queries
         for query in queries
     }
     }
     return nearest_neighbors_per_query, visited_nodes
     return nearest_neighbors_per_query, visited_nodes
-
-

+ 2 - 2
hivemind/runtime/__init__.py

@@ -7,8 +7,8 @@ from typing import Dict
 import torch
 import torch
 from prefetch_generator import BackgroundGenerator
 from prefetch_generator import BackgroundGenerator
 
 
-from .expert_backend import ExpertBackend
-from .task_pool import TaskPool, TaskPoolBase
+from hivemind.runtime.expert_backend import ExpertBackend
+from hivemind.runtime.task_pool import TaskPool, TaskPoolBase
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 2 - 2
hivemind/runtime/expert_backend.py

@@ -3,8 +3,8 @@ from typing import Dict, Sequence, Any, Tuple, Union
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
-from .task_pool import TaskPool
-from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorDescriptor, DUMMY_BATCH_SIZE, nested_map
+from hivemind.runtime.task_pool import TaskPool
+from hivemind.utils import nested_flatten, nested_pack, nested_compare, BatchTensorDescriptor, DUMMY_BATCH_SIZE, nested_map
 
 
 
 
 class ExpertBackend(nn.Module):
 class ExpertBackend(nn.Module):

+ 11 - 38
hivemind/server/__init__.py

@@ -1,16 +1,12 @@
 import multiprocessing as mp
 import multiprocessing as mp
-import os
 import threading
 import threading
-from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
 from typing import Dict, Optional
 from typing import Dict, Optional
 
 
-import torch
-
-from .connection_handler import handle_connection
-from .dht_handler import DHTHandlerThread
-from .checkpoint_saver import CheckpointSaver
-from ..dht import DHT
-from ..runtime import Runtime, ExpertBackend
+from hivemind.dht import DHT
+from hivemind.runtime import Runtime, ExpertBackend
+from hivemind.server.checkpoint_saver import CheckpointSaver
+from hivemind.server.connection_handler import ConnectionHandler
+from hivemind.server.dht_handler import DHTHandlerThread
 
 
 
 
 class Server(threading.Thread):
 class Server(threading.Thread):
@@ -37,12 +33,12 @@ class Server(threading.Thread):
     """
     """
 
 
     def __init__(self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
     def __init__(self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
-                 port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
-                 checkpoint_dir=None, **kwargs):
+                 port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False, checkpoint_dir=None, **kwargs):
         super().__init__()
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.addr, self.port = addr, port
         self.addr, self.port = addr, port
-        self.conn_handlers = self._create_connection_handlers(conn_handler_processes)
+        self.conn_handlers = [ConnectionHandler(f"{self.addr}:{port}", self.experts)
+                              for _ in range(conn_handler_processes)]
         if checkpoint_dir is not None:
         if checkpoint_dir is not None:
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
         else:
         else:
@@ -71,6 +67,9 @@ class Server(threading.Thread):
             if not process.is_alive():
             if not process.is_alive():
                 process.start()
                 process.start()
 
 
+        for process in self.conn_handlers:
+            process.ready.wait()
+
         self.runtime.run()
         self.runtime.run()
 
 
         for process in self.conn_handlers:
         for process in self.conn_handlers:
@@ -104,18 +103,6 @@ class Server(threading.Thread):
         """
         """
         return self.runtime.ready  # mp.Event that is true if self is ready to process batches
         return self.runtime.ready  # mp.Event that is true if self is ready to process batches
 
 
-    def _create_connection_handlers(self, num_handlers):
-        sock = socket(AF_INET, SOCK_STREAM)
-        sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
-        sock.bind(('', self.port))
-        sock.listen(1024)
-        sock.settimeout(self.update_period)
-
-        processes = [mp.context.ForkProcess(
-            target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts), daemon=True)
-            for i in range(num_handlers)]
-        return processes
-
     def shutdown(self):
     def shutdown(self):
         """
         """
         Gracefully terminate a hivemind server, process-safe.
         Gracefully terminate a hivemind server, process-safe.
@@ -130,17 +117,3 @@ class Server(threading.Thread):
             self.dht.shutdown()
             self.dht.shutdown()
 
 
         self.runtime.shutdown()
         self.runtime.shutdown()
-
-
-def socket_loop(sock, experts):
-    """ catch connections, send tasks to processing, respond with results """
-    torch.set_num_threads(1)
-    print(f'Spawned connection handler pid={os.getpid()}')
-    while True:
-        try:
-            handle_connection(sock.accept(), experts)
-        except KeyboardInterrupt as e:
-            print(f'Socket loop has caught {type(e)}, exiting')
-            break
-        except (timeout, BrokenPipeError, ConnectionResetError, NotImplementedError):
-            continue

+ 1 - 1
hivemind/server/checkpoint_saver.py

@@ -8,7 +8,7 @@ from typing import Dict
 
 
 import torch
 import torch
 
 
-from ..runtime import ExpertBackend
+from hivemind.runtime import ExpertBackend
 
 
 
 
 class CheckpointSaver(threading.Thread):
 class CheckpointSaver(threading.Thread):

+ 35 - 0
hivemind/server/connection_handler.proto

@@ -0,0 +1,35 @@
+syntax = "proto3";
+
+
+service ConnectionHandler {
+  // Listens to incoming requests for expert computation
+  rpc info(ExpertUID) returns (ExpertInfo);
+  rpc forward(ExpertRequest) returns (ExpertResponse);
+  rpc backward(ExpertRequest) returns (ExpertResponse);
+}
+
+
+message ExpertUID {
+  string uid = 1;
+}
+
+message ExpertInfo {
+  bytes serialized_info = 1;
+}
+
+message ExpertRequest {
+  string uid = 1;
+  repeated Tensor tensors = 2;
+}
+
+message ExpertResponse {
+  repeated Tensor tensors = 2;
+}
+
+message Tensor {
+  bytes buffer = 1;
+  repeated uint32 size = 2;
+  bool requires_grad = 3;
+  string dtype = 4;
+}
+

+ 69 - 27
hivemind/server/connection_handler.py

@@ -1,29 +1,71 @@
-from socket import socket
-from typing import Tuple, Dict
+import asyncio
+import multiprocessing as mp
+import os
+import pickle
+from typing import Dict
+
+import grpc.experimental.aio
+import torch
+import uvloop
 
 
 from hivemind.runtime.expert_backend import ExpertBackend
 from hivemind.runtime.expert_backend import ExpertBackend
-from hivemind.utils import PytorchSerializer, Connection
-
-
-def handle_connection(connection_tuple: Tuple[socket, str], experts: Dict[str, ExpertBackend]):
-    with Connection(*connection_tuple) as connection:
-        try:
-            header = connection.recv_header()
-            payload = PytorchSerializer.loads(connection.recv_raw())
-
-            if header == 'fwd_':
-                uid, inputs = payload
-                response = experts[uid].forward_pool.submit_task(*inputs).result()
-            elif header == 'bwd_':
-                uid, inputs_and_grad_outputs = payload
-                response = experts[uid].backward_pool.submit_task(*inputs_and_grad_outputs).result()
-            elif header == 'info':
-                uid = payload
-                response = experts[uid].get_info()
-            else:
-                raise NotImplementedError(f"Unknown header: {header}")
-
-            connection.send_raw('rest', PytorchSerializer.dumps(response))
-        except RuntimeError:
-            # socket connection broken
-            pass
+from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, runtime_pb2, runtime_grpc
+
+logger = get_logger(__name__)
+
+
+class ConnectionHandler(mp.Process):
+    """
+    A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
+
+    :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port.
+    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+    :param experts: a dict [UID -> ExpertBackend] with all active experts
+    """
+
+    def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
+        super().__init__()
+        self.listen_on, self.experts = listen_on, experts
+        self.ready = mp.Event()
+
+    def run(self):
+        torch.set_num_threads(1)
+        uvloop.install()
+        loop = asyncio.new_event_loop()
+
+        async def _run():
+            grpc.experimental.aio.init_grpc_aio()
+            logger.debug(f'Starting, pid {os.getpid()}')
+            server = grpc.experimental.aio.server(options=[
+                ('grpc.so_reuseport', 1),
+                ('grpc.max_send_message_length', -1),
+                ('grpc.max_receive_message_length', -1)
+            ])
+            runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
+
+            found_port = server.add_insecure_port(self.listen_on)
+            assert found_port != 0, f"Failed to listen to {self.listen_on}"
+
+            await server.start()
+            self.ready.set()
+            await server.wait_for_termination()
+            logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
+
+        loop.run_until_complete(_run())
+
+    async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
+        return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
+
+    async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
+        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        future = self.experts[request.uid].forward_pool.submit_task(*inputs)
+        response = await future.async_result()
+        serialized_response = [serialize_torch_tensor(tensor) for tensor in response]
+        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+
+    async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
+        inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
+        response = await future.async_result()
+        serialized_response = [serialize_torch_tensor(tensor) for tensor in response]
+        return runtime_pb2.ExpertResponse(tensors=serialized_response)

+ 2 - 3
hivemind/server/dht_handler.py

@@ -1,12 +1,11 @@
 import threading
 import threading
 import time
 import time
 
 
-from ..dht import DHT
+from hivemind.dht import DHT
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT,
-                 update_period: int = 5, addr: str = '127.0.0.1', port: int = 8080):
+    def __init__(self, experts, dht: DHT, update_period: int = 5, addr: str = '127.0.0.1', port: int = 8080):
         super(DHTHandlerThread, self).__init__()
         super(DHTHandlerThread, self).__init__()
         self.port = port
         self.port = port
         self.addr = addr
         self.addr = addr

+ 10 - 10
hivemind/utils/__init__.py

@@ -1,10 +1,10 @@
-from .connection import *
-from .data import *
-from .nested import *
-from .tensor_descr import *
-from .serializer import *
-from .shared_future import *
-from .threading import *
-from .autograd import *
-from .grpc import *
-from .logging import get_logger
+from hivemind.utils.connection import *
+from hivemind.utils.data import *
+from hivemind.utils.nested import *
+from hivemind.utils.tensor_descr import *
+from hivemind.utils.serializer import *
+from hivemind.utils.shared_future import *
+from hivemind.utils.threading import *
+from hivemind.utils.autograd import *
+from hivemind.utils.grpc import *
+from hivemind.utils.logging import get_logger

+ 3 - 1
hivemind/utils/autograd.py

@@ -10,7 +10,7 @@ import numpy as np
 import torch
 import torch
 import torch.autograd.function
 import torch.autograd.function
 
 
-from .threading import run_in_background
+from hivemind.utils.threading import run_in_background
 
 
 
 
 class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
 class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
@@ -19,6 +19,7 @@ class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
     such as running several parallel backwards or transferring backward to a separate device.
     such as running several parallel backwards or transferring backward to a separate device.
     This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
     This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
     """
     """
+
     @property
     @property
     def saved_tensors(self):
     def saved_tensors(self):
         return tuple(self.to_save)
         return tuple(self.to_save)
@@ -71,6 +72,7 @@ class _ParallelApplyFunction(torch.autograd.Function):
     Please do not call this function directly. Use apply_with_parallel_backward instead.
     Please do not call this function directly. Use apply_with_parallel_backward instead.
     Unlike default pytorch behavior, the backward pass for each function will also happen in parallel.
     Unlike default pytorch behavior, the backward pass for each function will also happen in parallel.
     """
     """
+
     @staticmethod
     @staticmethod
     def forward(ctx, func: torch.autograd.Function, num_calls: int, num_args_per_call: int,
     def forward(ctx, func: torch.autograd.Function, num_calls: int, num_args_per_call: int,
                 output_strides_ph: Future, *args_flat) -> Tuple[torch.Tensor, ...]:
                 output_strides_ph: Future, *args_flat) -> Tuple[torch.Tensor, ...]:

+ 1 - 53
hivemind/utils/connection.py

@@ -1,63 +1,11 @@
 import socket
 import socket
-from contextlib import AbstractContextManager, closing
-from typing import Tuple
+from contextlib import closing
 
 
 Hostname, Port = str, int  # flavour types
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = '127.0.0.1'
 LOCALHOST = '127.0.0.1'
 
 
 
 
-class Connection(AbstractContextManager):
-    header_size = 4  # number of characters in all headers
-    payload_length_size = 8  # number of bytes used to encode payload length
-
-    __slots__ = ('conn', 'addr')
-
-    def __init__(self, conn: socket, addr: Tuple[Hostname, Port]):
-        self.conn, self.addr = conn, addr
-
-    @staticmethod
-    def create(host: str, port: int):
-        sock = socket.socket()
-        addr = (host, port)
-        sock.connect(addr)
-        return Connection(sock, addr)
-
-    def send_raw(self, header: str, content: bytes):
-        self.conn.send(header.encode())
-        self.conn.send(len(content).to_bytes(self.payload_length_size, byteorder='big'))
-
-        total_sent = 0
-        while total_sent < len(content):
-            sent = self.conn.send(content[total_sent:])
-            if sent == 0:
-                raise RuntimeError("socket connection broken")
-            total_sent = total_sent + sent
-
-    def recv_header(self) -> str:
-        return self.conn.recv(self.header_size).decode()
-
-    def recv_raw(self, max_package: int = 2048) -> bytes:
-        length = int.from_bytes(self.conn.recv(self.payload_length_size), byteorder='big')
-        chunks = []
-        bytes_recd = 0
-        while bytes_recd < length:
-            chunk = self.conn.recv(min(length - bytes_recd, max_package))
-            if chunk == b'':
-                raise RuntimeError("socket connection broken")
-            chunks.append(chunk)
-            bytes_recd = bytes_recd + len(chunk)
-        ret = b''.join(chunks)
-        assert len(ret) == length
-        return ret
-
-    def recv_message(self) -> Tuple[str, bytes]:
-        return self.recv_header(), self.recv_raw()
-
-    def __exit__(self, *exc_info):
-        self.conn.close()
-
-
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     try:
     try:

+ 0 - 10
hivemind/utils/data.py

@@ -1,13 +1,3 @@
-import numpy as np
 import torch
 import torch
 
 
-
-def check_numpy(x):
-    """ Makes sure x is a numpy array """
-    if isinstance(x, torch.Tensor):
-        return x.detach().cpu().numpy()
-    else:
-        return np.asarray(x)
-
-
 DUMMY = torch.empty(0, requires_grad=True)
 DUMMY = torch.empty(0, requires_grad=True)

+ 24 - 1
hivemind/utils/grpc.py

@@ -5,9 +5,12 @@ import functools
 import os
 import os
 import sys
 import sys
 import tempfile
 import tempfile
-from typing import Tuple
 from argparse import Namespace
 from argparse import Namespace
+from typing import Tuple
+
 import grpc_tools.protoc
 import grpc_tools.protoc
+import numpy as np
+import torch
 
 
 
 
 @functools.lru_cache(maxsize=None)
 @functools.lru_cache(maxsize=None)
@@ -42,3 +45,23 @@ def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]:
         finally:
         finally:
             if sys.path.pop() != build_dir:
             if sys.path.pop() != build_dir:
                 raise ImportError("Something changed sys.path while compile_grpc was in progress.")
                 raise ImportError("Something changed sys.path while compile_grpc was in progress.")
+
+
+with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'server', 'connection_handler.proto')) as f_proto:
+    runtime_pb2, runtime_grpc = compile_grpc(f_proto.read())
+
+
+def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor:
+    array = tensor.numpy()
+    proto = runtime_pb2.Tensor(
+        buffer=array.tobytes(),
+        size=array.shape,
+        dtype=array.dtype.name,
+        requires_grad=tensor.requires_grad)
+    return proto
+
+
+def deserialize_torch_tensor(tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    # TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
+    array = np.frombuffer(tensor.buffer, dtype=np.dtype(tensor.dtype)).copy()
+    return torch.as_tensor(array).view(tuple(tensor.size)).requires_grad_(tensor.requires_grad)

+ 1 - 1
hivemind/utils/serializer.py

@@ -57,7 +57,7 @@ class MSGPackSerializer(SerializerBase):
 
 
     @staticmethod
     @staticmethod
     def dumps(obj: object) -> bytes:
     def dumps(obj: object) -> bytes:
-        return umsgpack.dumps(obj, use_bin_type=False) # TODO strict https://github.com/msgpack/msgpack-python/pull/158
+        return umsgpack.dumps(obj, use_bin_type=False)  # TODO strict https://github.com/msgpack/msgpack-python/pull/158
 
 
     @staticmethod
     @staticmethod
     def loads(buf: bytes) -> object:
     def loads(buf: bytes) -> object:

+ 48 - 6
hivemind/utils/shared_future.py

@@ -2,6 +2,7 @@ import multiprocessing as mp
 import multiprocessing.connection
 import multiprocessing.connection
 from concurrent.futures import Future, CancelledError
 from concurrent.futures import Future, CancelledError
 from warnings import warn
 from warnings import warn
+import asyncio
 
 
 
 
 class SharedFuture(Future):
 class SharedFuture(Future):
@@ -22,14 +23,21 @@ class SharedFuture(Future):
         connection1, connection2 = mp.Pipe()
         connection1, connection2 = mp.Pipe()
         return cls(connection1), cls(connection2)
         return cls(connection1), cls(connection2)
 
 
+    def poll_and_recv(self, timeout):
+        available = self.connection.poll(timeout)
+        if not available:
+            raise TimeoutError
+        try:
+            status, payload = self.connection.recv()
+            self.connection.close()
+        except BrokenPipeError as e:
+            status, payload = self.STATE_EXCEPTION, e
+        return status, payload
+
     def _recv(self, timeout):
     def _recv(self, timeout):
+
         if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
         if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
-            if not self.connection.poll(timeout):
-                raise TimeoutError()
-            try:
-                status, payload = self.connection.recv()
-            except BrokenPipeError as e:
-                status, payload = self.STATE_EXCEPTION, e
+            status, payload = self.poll_and_recv(timeout)
 
 
             assert status in self.STATES
             assert status in self.STATES
             self.state = status
             self.state = status
@@ -47,6 +55,7 @@ class SharedFuture(Future):
         try:
         try:
             self.state, self._result = self.STATE_FINISHED, result
             self.state, self._result = self.STATE_FINISHED, result
             self.connection.send((self.STATE_FINISHED, result))
             self.connection.send((self.STATE_FINISHED, result))
+            self.connection.close()
             return True
             return True
         except BrokenPipeError:
         except BrokenPipeError:
             return False
             return False
@@ -55,6 +64,7 @@ class SharedFuture(Future):
         try:
         try:
             self.state, self._exception = self.STATE_EXCEPTION, exception
             self.state, self._exception = self.STATE_EXCEPTION, exception
             self.connection.send((self.STATE_EXCEPTION, exception))
             self.connection.send((self.STATE_EXCEPTION, exception))
+            self.connection.close()
             return True
             return True
         except BrokenPipeError:
         except BrokenPipeError:
             return False
             return False
@@ -103,3 +113,35 @@ class SharedFuture(Future):
             return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
             return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
         else:
         else:
             return "<MPFuture at 0x{:x} state={}>".format(id(self), self.state)
             return "<MPFuture at 0x{:x} state={}>".format(id(self), self.state)
+
+    async def _async_recv(self, timeout):
+        loop = asyncio.get_running_loop()
+
+        if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
+            status, payload = await loop.run_in_executor(None, self.poll_and_recv, timeout)
+
+            assert status in self.STATES
+            self.state = status
+
+            if status == self.STATE_FINISHED:
+                self._result = payload
+            elif status == self.STATE_EXCEPTION:
+                self._exception = payload
+            elif status in (self.STATE_RUNNING, self.STATE_CANCELLED):
+                pass  # only update self.state
+            else:
+                raise ValueError("Result status should not be self.STATE_PENDING")
+
+    async def async_result(self, timeout=None):
+        await self._async_recv(timeout)
+        if self.state == self.STATE_FINISHED:
+            return self._result
+        elif self.state == self.STATE_EXCEPTION:
+            raise self._exception
+        else:
+            assert self.state == self.STATE_CANCELLED
+            raise CancelledError()
+
+    async def async_exception(self, timeout=None):
+        await self._async_recv(timeout)
+        return self._exception

+ 2 - 1
hivemind/utils/threading.py

@@ -14,9 +14,11 @@ def run_in_background(func: callable, *args, **kwargs) -> Future:
 
 
 def run_forever(func: callable, *args, **kwargs):
 def run_forever(func: callable, *args, **kwargs):
     """ A function that runs a :func: in background forever. Returns a future that catches exceptions """
     """ A function that runs a :func: in background forever. Returns a future that catches exceptions """
+
     def repeat():
     def repeat():
         while True:
         while True:
             func(*args, **kwargs)
             func(*args, **kwargs)
+
     return run_in_background(repeat)
     return run_in_background(repeat)
 
 
 
 
@@ -65,4 +67,3 @@ def run_and_await_k(jobs: List[callable], k: int,
             future.cancel()
             future.cancel()
             outputs[index] = future.result() if not future.exception() else future.exception()
             outputs[index] = future.result() if not future.exception() else future.exception()
     return outputs
     return outputs
-

+ 0 - 1
requirements.txt

@@ -2,7 +2,6 @@ torch>=1.3.0
 joblib>=0.13
 joblib>=0.13
 numpy>=1.17
 numpy>=1.17
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1
-pytest
 umsgpack
 umsgpack
 uvloop>=0.14.0
 uvloop>=0.14.0
 grpcio
 grpcio

+ 3 - 2
tests/benchmark_throughput.py

@@ -6,9 +6,9 @@ import time
 
 
 import torch
 import torch
 from test_utils import layers, print_device_info, increase_file_limit
 from test_utils import layers, print_device_info, increase_file_limit
-from hivemind import find_open_port
 
 
 import hivemind
 import hivemind
+from hivemind import find_open_port
 
 
 
 
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
@@ -142,7 +142,8 @@ if __name__ == "__main__":
         benchmark_throughput(backprop=False, num_clients=512, batch_size=512,
         benchmark_throughput(backprop=False, num_clients=512, batch_size=512,
                              max_batch_size=8192, num_batches_per_client=args.num_batches_per_client)
                              max_batch_size=8192, num_batches_per_client=args.num_batches_per_client)
     elif args.preset == 'minimalistic':
     elif args.preset == 'minimalistic':
-        benchmark_throughput(num_experts=1, num_clients=1, num_handlers=1)
+        benchmark_throughput(num_experts=1, num_clients=1, num_handlers=1,
+                             num_batches_per_client=args.num_batches_per_client)
     elif args.preset == 'nop':
     elif args.preset == 'nop':
         benchmark_throughput(expert_cls='nop', backprop=False, num_batches_per_client=args.num_batches_per_client)
         benchmark_throughput(expert_cls='nop', backprop=False, num_batches_per_client=args.num_batches_per_client)
     else:
     else:

+ 8 - 7
tests/test_utils/run_server.py

@@ -1,11 +1,12 @@
-import resource
-from contextlib import contextmanager
-import multiprocessing as mp
 import argparse
 import argparse
+import multiprocessing as mp
+from contextlib import contextmanager
 
 
+import resource
 import torch
 import torch
+
 import hivemind
 import hivemind
-from .layers import name_to_block, name_to_input
+from tests.test_utils.layers import name_to_block, name_to_input
 
 
 
 
 def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024,
 def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024,
@@ -147,12 +148,12 @@ if __name__ == '__main__':
     parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
     parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
     parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
-                        ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
+                                                                                        ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
     parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
     parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
     parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
     parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
-                        ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
+                                                                                    ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
     parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
     parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
-                        ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
+                                                                           ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
 
 
     args = vars(parser.parse_args())
     args = vars(parser.parse_args())