Преглед изворни кода

GRPC connection handlers (#61)

* Add connection handlers with grpc

* Implement grpc for client/server

* Cache channel, get rid of warnings

* Awaitable interactions with TaskPool

* [personal] parallel gRPC handlers (#69)

* spawn multiple connection handlers

* remove reserve_port

* minor: make preset_minimalistic actually care about num_batches_per_client

* RemoteExpert can no longer be pickled

* fix broken moe.py after changes in RemoteModuleCall

* write-through

* fix moe.py (broken by changes in _RemoteModuleCall

* fix a bug where connection handlers set ready flag prematurely

* fix wrong gradient type if not all experts survived

* connection_handler now sets ready only when actually ready

* create stub in a lazy manner

* rollback changes to DHT

* Update TODO, remove message limits

* Connection is gone 🦀🦀🦀

* Cleanup

* Switch to absolute imports (#70)

Co-authored-by: xtinkt <ant.sinitsin@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Max Ryabinin пре 5 година
родитељ
комит
f1565ef7af

+ 1 - 1
.circleci/config.yml

@@ -27,7 +27,7 @@ jobs:
           command: sudo python setup.py develop
           name: setup
       - run:
-          command: pytest ./tests --full-trace
+          command: for test_file in tests/test*.py; do pytest $test_file --full-trace; done
           name: tests
       - run:
           command: python tests/benchmark_throughput.py --preset minimalistic

+ 5 - 5
hivemind/__init__.py

@@ -1,7 +1,7 @@
-from .client import *
-from .dht import *
-from .server import Server
-from .utils import *
-from .runtime import *
+from hivemind.client import *
+from hivemind.dht import *
+from hivemind.server import Server
+from hivemind.utils import *
+from hivemind.runtime import *
 
 __version__ = '0.7.1'

+ 2 - 2
hivemind/client/__init__.py

@@ -1,2 +1,2 @@
-from .moe import RemoteMixtureOfExperts
-from .expert import RemoteExpert
+from hivemind.client.expert import RemoteExpert
+from hivemind.client.moe import RemoteMixtureOfExperts

+ 41 - 19
hivemind/client/expert.py

@@ -1,10 +1,14 @@
+import pickle
 from typing import Tuple, Optional
 
+import grpc
+import grpc.experimental.aio
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from ..utils import nested_flatten, DUMMY, PytorchSerializer, nested_pack, nested_compare, Connection
+from hivemind.utils import nested_flatten, DUMMY, nested_pack, nested_compare
+from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, runtime_pb2, runtime_grpc
 
 
 class RemoteExpert(nn.Module):
@@ -23,12 +27,29 @@ class RemoteExpert(nn.Module):
     def __init__(self, uid, host='127.0.0.1', port=8080):
         super().__init__()
         self.uid, self.host, self.port = uid, host, port
+        self._channel, self._stub = None, None
         self._info = None
 
+    @property
+    def stub(self):
+        if self._channel is None:
+            self._channel = grpc.insecure_channel(f'{self.host}:{self.port}', options=[
+                ('grpc.max_send_message_length', -1),
+                ('grpc.max_receive_message_length', -1)
+            ])
+        if self._stub is None:
+            self._stub = runtime_grpc.ConnectionHandlerStub(self._channel)
+        return self._stub
+
+    def __del__(self):
+        if self._channel is not None:
+            self._channel.close()
+
     def forward(self, *args, **kwargs):
         """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """
         assert len(kwargs) == len(self.info['keyword_names']), f"Keyword args should be {self.info['keyword_names']}"
         kwargs = {key: kwargs[key] for key in self.info['keyword_names']}
+
         # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
 
         forward_inputs = (args, kwargs)
@@ -36,16 +57,16 @@ class RemoteExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info['forward_schema']):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
-        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, *nested_flatten(forward_inputs))
+        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, self.stub,
+                                               *nested_flatten(forward_inputs))
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
 
     @property
     def info(self):
         if self._info is None:
-            connection = Connection.create(self.host, self.port)
-            connection.send_raw('info', PytorchSerializer.dumps(self.uid))
-            self._info = PytorchSerializer.loads(connection.recv_message()[1])
+            outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
+            self._info = pickle.loads(outputs.serialized_info)
         return self._info
 
     def extra_repr(self):
@@ -56,26 +77,27 @@ class _RemoteModuleCall(torch.autograd.Function):
     """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
 
     @staticmethod
-    def forward(ctx, dummy: torch.Tensor,
-                uid: str, host: str, port: int, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+    def forward(ctx, dummy: torch.Tensor, uid: str, host: str, port: int, stub: runtime_grpc.ConnectionHandlerStub,
+                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
-        ctx.uid, ctx.host, ctx.port = uid, host, port
+        ctx.uid, ctx.host, ctx.port, ctx.stub = uid, host, port, stub
         ctx.save_for_backward(*inputs)
 
-        connection = Connection.create(ctx.host, ctx.port)
-        connection.send_raw('fwd_', PytorchSerializer.dumps((ctx.uid, inputs)))
-        rtype, msg = connection.recv_message()
-        assert len(msg) != 0, "ExpertBackend.forward did not respond"
-        return tuple(PytorchSerializer.loads(msg))  # flattened expert outputs
+        outputs = stub.forward(
+            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))
+
+        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+
+        return tuple(deserialized_outputs)
 
     @staticmethod
     @once_differentiable
     def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
-        connection = Connection.create(ctx.host, ctx.port)
         payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))
-        connection.send_raw('bwd_', PytorchSerializer.dumps((ctx.uid, payload)))
-        rtype, msg = connection.recv_message()
-        assert len(msg) != 0, "ExpertBackend.backward did not respond"
-        grad_inputs = PytorchSerializer.loads(msg)
-        return (DUMMY, None, None, None, *grad_inputs)
+
+        grad_inputs = ctx.stub.backward(
+            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
+
+        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
+        return (DUMMY, None, None, None, None, *deserialized_grad_inputs)

+ 12 - 11
hivemind/client/moe.py

@@ -1,16 +1,14 @@
-import multiprocessing as mp
-import multiprocessing.pool
 from functools import partial
-from typing import Tuple, List, Dict, Optional
+from typing import Tuple, List, Optional
 
 import numpy as np
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from .expert import RemoteExpert, _RemoteModuleCall
-from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background
-from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
+from hivemind.client.expert import RemoteExpert, _RemoteModuleCall
+from hivemind.utils import nested_map, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background, \
+    run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
 
 
 class RemoteMixtureOfExperts(nn.Module):
@@ -37,6 +35,7 @@ class RemoteMixtureOfExperts(nn.Module):
      allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
      allow_broadcasting=False will raise an error
     """
+
     def __init__(self, *, in_features, grid_size: Tuple[int], dht, k_best, k_min=1,
                  forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
                  uid_prefix='', expert_padding=None, allow_broadcasting=True):
@@ -107,7 +106,7 @@ class RemoteMixtureOfExperts(nn.Module):
         delimeters = np.array(self.dht.UID_DELIMETER)[None, None, None]  # pre-compute numpy array for fast concat
 
         for dim_index, dim_scores in enumerate(grid_scores):
-            dim_scores = check_numpy(dim_scores)
+            dim_scores = dim_scores.detach().cpu().numpy()
             assert dim_scores.shape[-1] == self.grid_size[dim_index]
 
             # create all possible successsors from current beam
@@ -194,6 +193,7 @@ class _RemoteMoECall(torch.autograd.Function):
     This function that can recover from individual failures during forward and/or backward passes.
     For user-friendly version of this function, use RemoteMixtureOfExperts module.
     """
+
     @classmethod
     def forward(cls, ctx, expert_logits: torch.Tensor, experts: List[RemoteExpert],
                 k_min: int, timeout_after_k_min: float, backward_k_min: int, timeout_total: Optional[float],
@@ -250,18 +250,19 @@ class _RemoteMoECall(torch.autograd.Function):
             for grad_out, stacked_avive_out in zip(grad_outputs_flat, stacked_alive_outputs)
         ))
         softmax_jacobian = torch.diagflat(survived_probas) - torch.ger(survived_probas, survived_probas)
-        grad_wrt_logits = grad_wrt_probs @ softmax_jacobian
+        grad_wrt_survived_logits = grad_wrt_probs @ softmax_jacobian
+        grad_wrt_logits = torch.zeros_like(expert_logits).scatter(0, backward_survivors_ix, grad_wrt_survived_logits)
 
         return (grad_wrt_logits, None, None, None, None, None, None, None, *flat_grad_inputs)
 
     @staticmethod
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
         """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
-        flat_inputs = nested_flatten((args, kwargs))
-        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, *flat_inputs)
+        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, expert.stub,
+                                    *nested_flatten((args, kwargs)))
 
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
         backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
-        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
+        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, no_grad_stub, *grad_inputs = backward_result
         return grad_inputs

+ 5 - 5
hivemind/dht/__init__.py

@@ -17,13 +17,13 @@ import ctypes
 import multiprocessing as mp
 import warnings
 from typing import List, Optional
-import uvloop
 
-from .node import DHTNode, DHTID, DHTExpiration
-from .routing import get_dht_time
+import uvloop
 
-from ..client import RemoteExpert
-from ..utils import SharedFuture, find_open_port, Endpoint, Port, run_in_background, LOCALHOST
+from hivemind.client import RemoteExpert
+from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
+from hivemind.dht.routing import get_dht_time
+from hivemind.utils import SharedFuture, Endpoint, run_in_background
 
 
 class DHT(mp.Process):

+ 27 - 27
hivemind/dht/dht.proto

@@ -4,52 +4,52 @@ syntax = "proto3";
 // For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode)
 
 service DHT {
-    // find out recipient's DHTID and possibly update its routing table
-    rpc rpc_ping(NodeInfo) returns (NodeInfo);
+  // find out recipient's DHTID and possibly update its routing table
+  rpc rpc_ping(NodeInfo) returns (NodeInfo);
 
-    // request a node to store one or multiple data items (key - value - expiration)
-    rpc rpc_store(StoreRequest) returns (StoreResponse);
+  // request a node to store one or multiple data items (key - value - expiration)
+  rpc rpc_store(StoreRequest) returns (StoreResponse);
 
-    // for given keys, request values (if stored) or a list of peers that are likely to have them
-    rpc rpc_find(FindRequest) returns (FindResponse);
+  // for given keys, request values (if stored) or a list of peers that are likely to have them
+  rpc rpc_find(FindRequest) returns (FindResponse);
 }
 
 message NodeInfo {
-    // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
-    // if either node_id or port is absent, simply request recipient info (for client-only mode)
-    bytes node_id = 1;                // sender's own node id serialized with DHTID.to_bytes()
-    int32 rpc_port = 2;               // port to which sender listens for DHT RPCs
+  // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
+  // if either node_id or port is absent, simply request recipient info (for client-only mode)
+  bytes node_id = 1;                // sender's own node id serialized with DHTID.to_bytes()
+  int32 rpc_port = 2;               // port to which sender listens for DHT RPCs
 }
 
 message StoreRequest {
-    // three lists of the same length representing dht keys, dht values and expiration
-    repeated bytes keys = 1;          // keys in the form of DHTID.generate(raw_key).to_bytes()
-    repeated bytes values = 2;        // binary-encoded value for i-th key
-    repeated double expiration = 3;   // expirations for i-th key (type = DHTExpiration)
-    repeated bool in_cache = 4;       // if in_cache[i], store i-th key in cache, else store normally
-    NodeInfo peer = 5;                // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  // three lists of the same length representing dht keys, dht values and expiration
+  repeated bytes keys = 1;          // keys in the form of DHTID.generate(raw_key).to_bytes()
+  repeated bytes values = 2;        // binary-encoded value for i-th key
+  repeated double expiration = 3;   // expirations for i-th key (type = DHTExpiration)
+  repeated bool in_cache = 4;       // if in_cache[i], store i-th key in cache, else store normally
+  NodeInfo peer = 5;                // (optional) sender's own node info, same behavior as in DHT.rpc_ping
 }
 
 message StoreResponse {
-    repeated bool store_ok = 1;       // for every key, True means store accepted, False means store rejected/failed
-    NodeInfo peer = 2;                // respondent's node id, for you to update routing table
+  repeated bool store_ok = 1;       // for every key, True means store accepted, False means store rejected/failed
+  NodeInfo peer = 2;                // respondent's node id, for you to update routing table
 }
 
 message FindRequest {
-    repeated bytes keys = 1;          // a list of DHTID search keys encoded as bytes
-    NodeInfo peer = 2;                // optional, same behavior as in DHT.ping
+  repeated bytes keys = 1;          // a list of DHTID search keys encoded as bytes
+  NodeInfo peer = 2;                // optional, same behavior as in DHT.ping
 }
 
 message Peers {
-   // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
-   repeated bytes node_ids = 1;       // DHTID serialized with node_id.to_bytes()
-   repeated string endpoints = 2;     // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+  // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
+  repeated bytes node_ids = 1;       // DHTID serialized with node_id.to_bytes()
+  repeated string endpoints = 2;     // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
 }
 
 message FindResponse {
-    repeated bytes values = 1;        // value for i-th key, b'' means not found locally
-    repeated double expiration = 2;   // expiration time for i-th value, only valid value is found
-    repeated Peers nearest = 3;       // peers ordered from nearest to farthest based on distance to i-th key
-    NodeInfo peer = 4;                // respondent's node id, for you to update routing table
+  repeated bytes values = 1;        // value for i-th key, b'' means not found locally
+  repeated double expiration = 2;   // expiration time for i-th value, only valid value is found
+  repeated Peers nearest = 3;       // peers ordered from nearest to farthest based on distance to i-th key
+  NodeInfo peer = 4;                // respondent's node id, for you to update routing table
 }
 

+ 5 - 4
hivemind/dht/node.py

@@ -1,14 +1,15 @@
 from __future__ import annotations
+
 import asyncio
 import random
 from collections import namedtuple
 from typing import Optional, Tuple, List, Dict, Collection, Union, Set
 from warnings import warn
 
-from .protocol import DHTProtocol
-from .routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue
-from .traverse import traverse_dht
-from ..utils import Endpoint, LOCALHOST, MSGPackSerializer
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue
+from hivemind.dht.traverse import traverse_dht
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer
 
 
 class DHTNode:

+ 2 - 2
hivemind/dht/protocol.py

@@ -11,8 +11,8 @@ from warnings import warn
 import grpc
 import grpc.experimental.aio
 
-from .routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
-from ..utils import Endpoint, compile_grpc, get_logger
+from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
+from hivemind.utils import Endpoint, compile_grpc, get_logger
 
 logger = get_logger(__name__)
 

+ 7 - 8
hivemind/dht/routing.py

@@ -2,16 +2,18 @@
 from __future__ import annotations
 
 import hashlib
+import heapq
 import os
 import random
-
 import time
-import heapq
 from collections.abc import Iterable
 from itertools import chain
-from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence, Iterator
+from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
 
-from ..utils import Endpoint, PickleSerializer
+from hivemind.utils import Endpoint, PickleSerializer
+
+DHTKey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, float, bytes, bytes  # flavour types
+get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
 
 
 class RoutingTable:
@@ -160,6 +162,7 @@ class KBucket:
     A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
     Maps DHT node ids to their endpoints (hostname, addr)
     """
+
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
         assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
         self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
@@ -288,7 +291,3 @@ class DHTID(int):
 
     def __bytes__(self):
         return self.to_bytes()
-
-
-DHTKey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, float, bytes, bytes  # flavour types
-get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time

+ 8 - 10
hivemind/dht/traverse.py

@@ -2,9 +2,9 @@
 import asyncio
 import heapq
 from collections import Counter
-from warnings import warn
 from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
-from .routing import DHTID
+
+from hivemind.dht.routing import DHTID
 
 ROOT = 0  # alias for heap root
 
@@ -107,13 +107,13 @@ async def traverse_dht(
     if len(queries) == 0:
         return {}, dict(visited_nodes)
 
-    unfinished_queries = set(queries)                             # all queries that haven't triggered finish_search yet
-    candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}    # heap: unvisited nodes, ordered nearest-to-farthest
-    nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}      # heap: top-k nearest nodes, farthest-to-nearest
-    known_nodes: Dict[DHTID, Set[DHTID]] = {}                     # all nodes ever added to the heap (for deduplication)
+    unfinished_queries = set(queries)  # all queries that haven't triggered finish_search yet
+    candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: unvisited nodes, ordered nearest-to-farthest
+    nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: top-k nearest nodes, farthest-to-nearest
+    known_nodes: Dict[DHTID, Set[DHTID]] = {}  # all nodes ever added to the heap (for deduplication)
     visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes)  # where we requested get_neighbors for a given query
-    pending_tasks = set()                                         # all active tasks (get_neighbors and found_callback)
-    active_workers = Counter({q: 0 for q in queries})             # count workers that search for this query
+    pending_tasks = set()  # all active tasks (get_neighbors and found_callback)
+    active_workers = Counter({q: 0 for q in queries})  # count workers that search for this query
 
     search_finished_event = asyncio.Event()  # used to immediately stop all workers when the search is finished
     heap_updated_event = asyncio.Event()  # if a worker has no nodes to explore, it will await other workers
@@ -228,5 +228,3 @@ async def traverse_dht(
         for query in queries
     }
     return nearest_neighbors_per_query, visited_nodes
-
-

+ 2 - 2
hivemind/runtime/__init__.py

@@ -7,8 +7,8 @@ from typing import Dict
 import torch
 from prefetch_generator import BackgroundGenerator
 
-from .expert_backend import ExpertBackend
-from .task_pool import TaskPool, TaskPoolBase
+from hivemind.runtime.expert_backend import ExpertBackend
+from hivemind.runtime.task_pool import TaskPool, TaskPoolBase
 from hivemind.utils import get_logger
 
 logger = get_logger(__name__)

+ 2 - 2
hivemind/runtime/expert_backend.py

@@ -3,8 +3,8 @@ from typing import Dict, Sequence, Any, Tuple, Union
 import torch
 from torch import nn
 
-from .task_pool import TaskPool
-from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorDescriptor, DUMMY_BATCH_SIZE, nested_map
+from hivemind.runtime.task_pool import TaskPool
+from hivemind.utils import nested_flatten, nested_pack, nested_compare, BatchTensorDescriptor, DUMMY_BATCH_SIZE, nested_map
 
 
 class ExpertBackend(nn.Module):

+ 11 - 38
hivemind/server/__init__.py

@@ -1,16 +1,12 @@
 import multiprocessing as mp
-import os
 import threading
-from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
 from typing import Dict, Optional
 
-import torch
-
-from .connection_handler import handle_connection
-from .dht_handler import DHTHandlerThread
-from .checkpoint_saver import CheckpointSaver
-from ..dht import DHT
-from ..runtime import Runtime, ExpertBackend
+from hivemind.dht import DHT
+from hivemind.runtime import Runtime, ExpertBackend
+from hivemind.server.checkpoint_saver import CheckpointSaver
+from hivemind.server.connection_handler import ConnectionHandler
+from hivemind.server.dht_handler import DHTHandlerThread
 
 
 class Server(threading.Thread):
@@ -37,12 +33,12 @@ class Server(threading.Thread):
     """
 
     def __init__(self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
-                 port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
-                 checkpoint_dir=None, **kwargs):
+                 port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False, checkpoint_dir=None, **kwargs):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.addr, self.port = addr, port
-        self.conn_handlers = self._create_connection_handlers(conn_handler_processes)
+        self.conn_handlers = [ConnectionHandler(f"{self.addr}:{port}", self.experts)
+                              for _ in range(conn_handler_processes)]
         if checkpoint_dir is not None:
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
         else:
@@ -71,6 +67,9 @@ class Server(threading.Thread):
             if not process.is_alive():
                 process.start()
 
+        for process in self.conn_handlers:
+            process.ready.wait()
+
         self.runtime.run()
 
         for process in self.conn_handlers:
@@ -104,18 +103,6 @@ class Server(threading.Thread):
         """
         return self.runtime.ready  # mp.Event that is true if self is ready to process batches
 
-    def _create_connection_handlers(self, num_handlers):
-        sock = socket(AF_INET, SOCK_STREAM)
-        sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
-        sock.bind(('', self.port))
-        sock.listen(1024)
-        sock.settimeout(self.update_period)
-
-        processes = [mp.context.ForkProcess(
-            target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts), daemon=True)
-            for i in range(num_handlers)]
-        return processes
-
     def shutdown(self):
         """
         Gracefully terminate a hivemind server, process-safe.
@@ -130,17 +117,3 @@ class Server(threading.Thread):
             self.dht.shutdown()
 
         self.runtime.shutdown()
-
-
-def socket_loop(sock, experts):
-    """ catch connections, send tasks to processing, respond with results """
-    torch.set_num_threads(1)
-    print(f'Spawned connection handler pid={os.getpid()}')
-    while True:
-        try:
-            handle_connection(sock.accept(), experts)
-        except KeyboardInterrupt as e:
-            print(f'Socket loop has caught {type(e)}, exiting')
-            break
-        except (timeout, BrokenPipeError, ConnectionResetError, NotImplementedError):
-            continue

+ 1 - 1
hivemind/server/checkpoint_saver.py

@@ -8,7 +8,7 @@ from typing import Dict
 
 import torch
 
-from ..runtime import ExpertBackend
+from hivemind.runtime import ExpertBackend
 
 
 class CheckpointSaver(threading.Thread):

+ 35 - 0
hivemind/server/connection_handler.proto

@@ -0,0 +1,35 @@
+syntax = "proto3";
+
+
+service ConnectionHandler {
+  // Listens to incoming requests for expert computation
+  rpc info(ExpertUID) returns (ExpertInfo);
+  rpc forward(ExpertRequest) returns (ExpertResponse);
+  rpc backward(ExpertRequest) returns (ExpertResponse);
+}
+
+
+message ExpertUID {
+  string uid = 1;
+}
+
+message ExpertInfo {
+  bytes serialized_info = 1;
+}
+
+message ExpertRequest {
+  string uid = 1;
+  repeated Tensor tensors = 2;
+}
+
+message ExpertResponse {
+  repeated Tensor tensors = 2;
+}
+
+message Tensor {
+  bytes buffer = 1;
+  repeated uint32 size = 2;
+  bool requires_grad = 3;
+  string dtype = 4;
+}
+

+ 69 - 27
hivemind/server/connection_handler.py

@@ -1,29 +1,71 @@
-from socket import socket
-from typing import Tuple, Dict
+import asyncio
+import multiprocessing as mp
+import os
+import pickle
+from typing import Dict
+
+import grpc.experimental.aio
+import torch
+import uvloop
 
 from hivemind.runtime.expert_backend import ExpertBackend
-from hivemind.utils import PytorchSerializer, Connection
-
-
-def handle_connection(connection_tuple: Tuple[socket, str], experts: Dict[str, ExpertBackend]):
-    with Connection(*connection_tuple) as connection:
-        try:
-            header = connection.recv_header()
-            payload = PytorchSerializer.loads(connection.recv_raw())
-
-            if header == 'fwd_':
-                uid, inputs = payload
-                response = experts[uid].forward_pool.submit_task(*inputs).result()
-            elif header == 'bwd_':
-                uid, inputs_and_grad_outputs = payload
-                response = experts[uid].backward_pool.submit_task(*inputs_and_grad_outputs).result()
-            elif header == 'info':
-                uid = payload
-                response = experts[uid].get_info()
-            else:
-                raise NotImplementedError(f"Unknown header: {header}")
-
-            connection.send_raw('rest', PytorchSerializer.dumps(response))
-        except RuntimeError:
-            # socket connection broken
-            pass
+from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, runtime_pb2, runtime_grpc
+
+logger = get_logger(__name__)
+
+
+class ConnectionHandler(mp.Process):
+    """
+    A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
+
+    :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port.
+    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+    :param experts: a dict [UID -> ExpertBackend] with all active experts
+    """
+
+    def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
+        super().__init__()
+        self.listen_on, self.experts = listen_on, experts
+        self.ready = mp.Event()
+
+    def run(self):
+        torch.set_num_threads(1)
+        uvloop.install()
+        loop = asyncio.new_event_loop()
+
+        async def _run():
+            grpc.experimental.aio.init_grpc_aio()
+            logger.debug(f'Starting, pid {os.getpid()}')
+            server = grpc.experimental.aio.server(options=[
+                ('grpc.so_reuseport', 1),
+                ('grpc.max_send_message_length', -1),
+                ('grpc.max_receive_message_length', -1)
+            ])
+            runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
+
+            found_port = server.add_insecure_port(self.listen_on)
+            assert found_port != 0, f"Failed to listen to {self.listen_on}"
+
+            await server.start()
+            self.ready.set()
+            await server.wait_for_termination()
+            logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
+
+        loop.run_until_complete(_run())
+
+    async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
+        return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
+
+    async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
+        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        future = self.experts[request.uid].forward_pool.submit_task(*inputs)
+        response = await future.async_result()
+        serialized_response = [serialize_torch_tensor(tensor) for tensor in response]
+        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+
+    async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
+        inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
+        response = await future.async_result()
+        serialized_response = [serialize_torch_tensor(tensor) for tensor in response]
+        return runtime_pb2.ExpertResponse(tensors=serialized_response)

+ 2 - 3
hivemind/server/dht_handler.py

@@ -1,12 +1,11 @@
 import threading
 import time
 
-from ..dht import DHT
+from hivemind.dht import DHT
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT,
-                 update_period: int = 5, addr: str = '127.0.0.1', port: int = 8080):
+    def __init__(self, experts, dht: DHT, update_period: int = 5, addr: str = '127.0.0.1', port: int = 8080):
         super(DHTHandlerThread, self).__init__()
         self.port = port
         self.addr = addr

+ 10 - 10
hivemind/utils/__init__.py

@@ -1,10 +1,10 @@
-from .connection import *
-from .data import *
-from .nested import *
-from .tensor_descr import *
-from .serializer import *
-from .shared_future import *
-from .threading import *
-from .autograd import *
-from .grpc import *
-from .logging import get_logger
+from hivemind.utils.connection import *
+from hivemind.utils.data import *
+from hivemind.utils.nested import *
+from hivemind.utils.tensor_descr import *
+from hivemind.utils.serializer import *
+from hivemind.utils.shared_future import *
+from hivemind.utils.threading import *
+from hivemind.utils.autograd import *
+from hivemind.utils.grpc import *
+from hivemind.utils.logging import get_logger

+ 3 - 1
hivemind/utils/autograd.py

@@ -10,7 +10,7 @@ import numpy as np
 import torch
 import torch.autograd.function
 
-from .threading import run_in_background
+from hivemind.utils.threading import run_in_background
 
 
 class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
@@ -19,6 +19,7 @@ class EmulatedAutogradContext(torch.autograd.function._ContextMethodMixin):
     such as running several parallel backwards or transferring backward to a separate device.
     This class is not tested outside its use cases in RemoteMixtureOfExperts and we do not recommend using it elsewhere.
     """
+
     @property
     def saved_tensors(self):
         return tuple(self.to_save)
@@ -71,6 +72,7 @@ class _ParallelApplyFunction(torch.autograd.Function):
     Please do not call this function directly. Use apply_with_parallel_backward instead.
     Unlike default pytorch behavior, the backward pass for each function will also happen in parallel.
     """
+
     @staticmethod
     def forward(ctx, func: torch.autograd.Function, num_calls: int, num_args_per_call: int,
                 output_strides_ph: Future, *args_flat) -> Tuple[torch.Tensor, ...]:

+ 1 - 53
hivemind/utils/connection.py

@@ -1,63 +1,11 @@
 import socket
-from contextlib import AbstractContextManager, closing
-from typing import Tuple
+from contextlib import closing
 
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = '127.0.0.1'
 
 
-class Connection(AbstractContextManager):
-    header_size = 4  # number of characters in all headers
-    payload_length_size = 8  # number of bytes used to encode payload length
-
-    __slots__ = ('conn', 'addr')
-
-    def __init__(self, conn: socket, addr: Tuple[Hostname, Port]):
-        self.conn, self.addr = conn, addr
-
-    @staticmethod
-    def create(host: str, port: int):
-        sock = socket.socket()
-        addr = (host, port)
-        sock.connect(addr)
-        return Connection(sock, addr)
-
-    def send_raw(self, header: str, content: bytes):
-        self.conn.send(header.encode())
-        self.conn.send(len(content).to_bytes(self.payload_length_size, byteorder='big'))
-
-        total_sent = 0
-        while total_sent < len(content):
-            sent = self.conn.send(content[total_sent:])
-            if sent == 0:
-                raise RuntimeError("socket connection broken")
-            total_sent = total_sent + sent
-
-    def recv_header(self) -> str:
-        return self.conn.recv(self.header_size).decode()
-
-    def recv_raw(self, max_package: int = 2048) -> bytes:
-        length = int.from_bytes(self.conn.recv(self.payload_length_size), byteorder='big')
-        chunks = []
-        bytes_recd = 0
-        while bytes_recd < length:
-            chunk = self.conn.recv(min(length - bytes_recd, max_package))
-            if chunk == b'':
-                raise RuntimeError("socket connection broken")
-            chunks.append(chunk)
-            bytes_recd = bytes_recd + len(chunk)
-        ret = b''.join(chunks)
-        assert len(ret) == length
-        return ret
-
-    def recv_message(self) -> Tuple[str, bytes]:
-        return self.recv_header(), self.recv_raw()
-
-    def __exit__(self, *exc_info):
-        self.conn.close()
-
-
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     try:

+ 0 - 10
hivemind/utils/data.py

@@ -1,13 +1,3 @@
-import numpy as np
 import torch
 
-
-def check_numpy(x):
-    """ Makes sure x is a numpy array """
-    if isinstance(x, torch.Tensor):
-        return x.detach().cpu().numpy()
-    else:
-        return np.asarray(x)
-
-
 DUMMY = torch.empty(0, requires_grad=True)

+ 24 - 1
hivemind/utils/grpc.py

@@ -5,9 +5,12 @@ import functools
 import os
 import sys
 import tempfile
-from typing import Tuple
 from argparse import Namespace
+from typing import Tuple
+
 import grpc_tools.protoc
+import numpy as np
+import torch
 
 
 @functools.lru_cache(maxsize=None)
@@ -42,3 +45,23 @@ def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]:
         finally:
             if sys.path.pop() != build_dir:
                 raise ImportError("Something changed sys.path while compile_grpc was in progress.")
+
+
+with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'server', 'connection_handler.proto')) as f_proto:
+    runtime_pb2, runtime_grpc = compile_grpc(f_proto.read())
+
+
+def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor:
+    array = tensor.numpy()
+    proto = runtime_pb2.Tensor(
+        buffer=array.tobytes(),
+        size=array.shape,
+        dtype=array.dtype.name,
+        requires_grad=tensor.requires_grad)
+    return proto
+
+
+def deserialize_torch_tensor(tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    # TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
+    array = np.frombuffer(tensor.buffer, dtype=np.dtype(tensor.dtype)).copy()
+    return torch.as_tensor(array).view(tuple(tensor.size)).requires_grad_(tensor.requires_grad)

+ 1 - 1
hivemind/utils/serializer.py

@@ -57,7 +57,7 @@ class MSGPackSerializer(SerializerBase):
 
     @staticmethod
     def dumps(obj: object) -> bytes:
-        return umsgpack.dumps(obj, use_bin_type=False) # TODO strict https://github.com/msgpack/msgpack-python/pull/158
+        return umsgpack.dumps(obj, use_bin_type=False)  # TODO strict https://github.com/msgpack/msgpack-python/pull/158
 
     @staticmethod
     def loads(buf: bytes) -> object:

+ 48 - 6
hivemind/utils/shared_future.py

@@ -2,6 +2,7 @@ import multiprocessing as mp
 import multiprocessing.connection
 from concurrent.futures import Future, CancelledError
 from warnings import warn
+import asyncio
 
 
 class SharedFuture(Future):
@@ -22,14 +23,21 @@ class SharedFuture(Future):
         connection1, connection2 = mp.Pipe()
         return cls(connection1), cls(connection2)
 
+    def poll_and_recv(self, timeout):
+        available = self.connection.poll(timeout)
+        if not available:
+            raise TimeoutError
+        try:
+            status, payload = self.connection.recv()
+            self.connection.close()
+        except BrokenPipeError as e:
+            status, payload = self.STATE_EXCEPTION, e
+        return status, payload
+
     def _recv(self, timeout):
+
         if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
-            if not self.connection.poll(timeout):
-                raise TimeoutError()
-            try:
-                status, payload = self.connection.recv()
-            except BrokenPipeError as e:
-                status, payload = self.STATE_EXCEPTION, e
+            status, payload = self.poll_and_recv(timeout)
 
             assert status in self.STATES
             self.state = status
@@ -47,6 +55,7 @@ class SharedFuture(Future):
         try:
             self.state, self._result = self.STATE_FINISHED, result
             self.connection.send((self.STATE_FINISHED, result))
+            self.connection.close()
             return True
         except BrokenPipeError:
             return False
@@ -55,6 +64,7 @@ class SharedFuture(Future):
         try:
             self.state, self._exception = self.STATE_EXCEPTION, exception
             self.connection.send((self.STATE_EXCEPTION, exception))
+            self.connection.close()
             return True
         except BrokenPipeError:
             return False
@@ -103,3 +113,35 @@ class SharedFuture(Future):
             return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
         else:
             return "<MPFuture at 0x{:x} state={}>".format(id(self), self.state)
+
+    async def _async_recv(self, timeout):
+        loop = asyncio.get_running_loop()
+
+        if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
+            status, payload = await loop.run_in_executor(None, self.poll_and_recv, timeout)
+
+            assert status in self.STATES
+            self.state = status
+
+            if status == self.STATE_FINISHED:
+                self._result = payload
+            elif status == self.STATE_EXCEPTION:
+                self._exception = payload
+            elif status in (self.STATE_RUNNING, self.STATE_CANCELLED):
+                pass  # only update self.state
+            else:
+                raise ValueError("Result status should not be self.STATE_PENDING")
+
+    async def async_result(self, timeout=None):
+        await self._async_recv(timeout)
+        if self.state == self.STATE_FINISHED:
+            return self._result
+        elif self.state == self.STATE_EXCEPTION:
+            raise self._exception
+        else:
+            assert self.state == self.STATE_CANCELLED
+            raise CancelledError()
+
+    async def async_exception(self, timeout=None):
+        await self._async_recv(timeout)
+        return self._exception

+ 2 - 1
hivemind/utils/threading.py

@@ -14,9 +14,11 @@ def run_in_background(func: callable, *args, **kwargs) -> Future:
 
 def run_forever(func: callable, *args, **kwargs):
     """ A function that runs a :func: in background forever. Returns a future that catches exceptions """
+
     def repeat():
         while True:
             func(*args, **kwargs)
+
     return run_in_background(repeat)
 
 
@@ -65,4 +67,3 @@ def run_and_await_k(jobs: List[callable], k: int,
             future.cancel()
             outputs[index] = future.result() if not future.exception() else future.exception()
     return outputs
-

+ 0 - 1
requirements.txt

@@ -2,7 +2,6 @@ torch>=1.3.0
 joblib>=0.13
 numpy>=1.17
 prefetch_generator>=1.0.1
-pytest
 umsgpack
 uvloop>=0.14.0
 grpcio

+ 3 - 2
tests/benchmark_throughput.py

@@ -6,9 +6,9 @@ import time
 
 import torch
 from test_utils import layers, print_device_info, increase_file_limit
-from hivemind import find_open_port
 
 import hivemind
+from hivemind import find_open_port
 
 
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
@@ -142,7 +142,8 @@ if __name__ == "__main__":
         benchmark_throughput(backprop=False, num_clients=512, batch_size=512,
                              max_batch_size=8192, num_batches_per_client=args.num_batches_per_client)
     elif args.preset == 'minimalistic':
-        benchmark_throughput(num_experts=1, num_clients=1, num_handlers=1)
+        benchmark_throughput(num_experts=1, num_clients=1, num_handlers=1,
+                             num_batches_per_client=args.num_batches_per_client)
     elif args.preset == 'nop':
         benchmark_throughput(expert_cls='nop', backprop=False, num_batches_per_client=args.num_batches_per_client)
     else:

+ 8 - 7
tests/test_utils/run_server.py

@@ -1,11 +1,12 @@
-import resource
-from contextlib import contextmanager
-import multiprocessing as mp
 import argparse
+import multiprocessing as mp
+from contextlib import contextmanager
 
+import resource
 import torch
+
 import hivemind
-from .layers import name_to_block, name_to_input
+from tests.test_utils.layers import name_to_block, name_to_input
 
 
 def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024,
@@ -147,12 +148,12 @@ if __name__ == '__main__':
     parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
-                        ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
+                                                                                        ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
     parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
     parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
-                        ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
+                                                                                    ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
     parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
-                        ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
+                                                                           ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
 
     args = vars(parser.parse_args())