Просмотр исходного кода

Convert hivemind.server to libp2p backend (#470)

Switch hivemind MoE from gRPC to libp2p.
This allows serving experts from behind NAT / firewalls and improves performance under latency.

Changes:
 - RemoteExpert (and MoEs) now communicate to servers via libp2p
 - Got rid of listen_on parameters in hivemind.Server and CLI tools
 - ConnectionHandlers now use load balancing for better performance (see benchmarks in the corresponding PR)
 - updated docs & tests

Co-authored-by: Denis Mazur <denismazur8@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
GreenFatGuy 3 лет назад
Родитель
Сommit
724cdfe5e2
40 измененных файлов с 915 добавлено и 603 удалено
  1. 2 2
      .github/workflows/run-tests.yml
  2. 35 17
      benchmarks/benchmark_throughput.py
  3. 9 12
      docs/user/moe.md
  4. 2 1
      hivemind/averaging/averager.py
  5. 5 1
      hivemind/compression/__init__.py
  6. 2 2
      hivemind/compression/adaptive.py
  7. 1 1
      hivemind/compression/base.py
  8. 25 1
      hivemind/compression/serialization.py
  9. 3 2
      hivemind/dht/dht.py
  10. 0 1
      hivemind/hivemind_cli/config.yml
  11. 6 3
      hivemind/hivemind_cli/run_server.py
  12. 23 10
      hivemind/moe/client/beam_search.py
  13. 165 37
      hivemind/moe/client/expert.py
  14. 28 24
      hivemind/moe/client/moe.py
  15. 48 0
      hivemind/moe/client/remote_expert_worker.py
  16. 2 2
      hivemind/moe/client/switch_moe.py
  17. 102 48
      hivemind/moe/server/connection_handler.py
  18. 22 20
      hivemind/moe/server/dht_handler.py
  19. 2 2
      hivemind/moe/server/expert_uid.py
  20. 26 40
      hivemind/moe/server/server.py
  21. 13 6
      hivemind/p2p/p2p_daemon.py
  22. 6 4
      hivemind/p2p/p2p_daemon_bindings/control.py
  23. 7 1
      hivemind/p2p/p2p_daemon_bindings/datastructures.py
  24. 5 4
      hivemind/p2p/p2p_daemon_bindings/p2pclient.py
  25. 4 2
      hivemind/p2p/servicer.py
  26. 2 0
      hivemind/proto/p2pd.proto
  27. 1 1
      hivemind/utils/__init__.py
  28. 7 1
      hivemind/utils/asyncio.py
  29. 0 210
      hivemind/utils/grpc.py
  30. 2 24
      hivemind/utils/networking.py
  31. 49 0
      hivemind/utils/streaming.py
  32. 2 2
      setup.py
  33. 3 2
      tests/test_compression.py
  34. 192 0
      tests/test_connection_handler.py
  35. 20 9
      tests/test_custom_experts.py
  36. 28 23
      tests/test_dht_experts.py
  37. 40 28
      tests/test_moe.py
  38. 8 2
      tests/test_p2p_daemon_bindings.py
  39. 17 11
      tests/test_training.py
  40. 1 47
      tests/test_util_modules.py

+ 2 - 2
.github/workflows/run-tests.yml

@@ -12,7 +12,7 @@ jobs:
     strategy:
       matrix:
         python-version: [ 3.7, 3.8, 3.9 ]
-    timeout-minutes: 12
+    timeout-minutes: 15
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python
@@ -71,7 +71,7 @@ jobs:
   codecov_in_develop_mode:
 
     runs-on: ubuntu-latest
-    timeout-minutes: 12
+    timeout-minutes: 15
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python

+ 35 - 17
benchmarks/benchmark_throughput.py

@@ -6,12 +6,14 @@ import time
 
 import torch
 
-from hivemind.moe.client import RemoteExpert
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server.layers import name_to_block
+from hivemind.p2p import P2P, PeerInfo
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.networking import LOCALHOST, get_free_port
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 use_hivemind_log_handler("in_root_logger")
@@ -31,14 +33,30 @@ def print_device_info(device=None):
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
 
 
-def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+def client_process(
+    can_start,
+    benchmarking_failed,
+    server_maddrs,
+    server_peer_id,
+    num_experts,
+    batch_size,
+    hid_dim,
+    num_batches,
+    backprop=True,
+) -> None:
     torch.set_num_threads(1)
     can_start.wait()
-    experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
+
+    p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    peer_info = PeerInfo(server_peer_id, server_maddrs)
+    experts = [
+        RemoteExpert(expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=peer_info), p2p=p2p)
+        for i in range(num_experts)
+    ]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
-        for batch_i in range(num_batches):
+        for _ in range(num_batches):
             expert = random.choice(experts)
             out = expert(dummy_batch)
             if backprop:
@@ -59,7 +77,6 @@ def benchmark_throughput(
     max_batch_size=None,
     backprop=True,
     device=None,
-    port=None,
 ):
     assert (
         not hasattr(torch.cuda, "is_initialized")
@@ -67,7 +84,6 @@ def benchmark_throughput(
         or torch.device(device) == torch.device("cpu")
     )
     assert expert_cls in name_to_block
-    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
@@ -75,8 +91,7 @@ def benchmark_throughput(
     timestamps = dict(started=time.perf_counter())
 
     try:
-        # start clients and await server
-        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        server_dht = DHT(start=True)
         clients = [
             mp.Process(
                 target=client_process,
@@ -84,30 +99,30 @@ def benchmark_throughput(
                 args=(
                     can_start,
                     benchmarking_failed,
-                    port,
+                    server_dht.get_visible_maddrs(),
+                    server_dht.peer_id,
                     num_experts,
                     batch_size,
                     hid_dim,
                     num_batches_per_client,
                     backprop,
                 ),
+                daemon=True,
             )
             for i in range(num_clients)
         ]
 
         for client in clients:
-            client.daemon = True
             client.start()
 
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
-        # start server
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         for i in range(num_experts):
             expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = ExpertBackend(
-                name=f"expert{i}",
+            experts[f"expert.{i}"] = ExpertBackend(
+                name=f"expert.{i}",
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
                 args_schema=(BatchTensorDescriptor(hid_dim),),
@@ -115,21 +130,24 @@ def benchmark_throughput(
                 max_batch_size=max_batch_size,
             )
         timestamps["created_experts"] = time.perf_counter()
+
         server = Server(
-            None,
-            experts,
-            listen_on=f"{LOCALHOST}:{port}",
+            dht=server_dht,
+            expert_backends=experts,
             num_connection_handlers=num_handlers,
             device=device,
         )
         server.start()
         server.ready.wait()
+
         timestamps["server_ready"] = time.perf_counter()
         can_start.set()
 
         for client in clients:
             client.join()
+
         timestamps["clients_finished"] = time.perf_counter()
+
     except BaseException as e:
         benchmarking_failed.set()
         raise e

+ 9 - 12
docs/user/moe.md

@@ -1,7 +1,7 @@
 # Mixture-of-Experts
 
 This tutorial covers the basics of Decentralized Mixture-of-Experts (DMoE).
-From the infrastructure standpoint, DMoE consists of two parts: experts hosted on peer devices, and a gating/routing function that assigns input to one of these experts.
+From the infrastructure standpoint, DMoE consists of two parts: experts hosted on peer devices, and client-side modules to access those experts.
 
 ## Host experts with a server
 
@@ -11,9 +11,8 @@ most of the model parameters and computation. The server can be started using ei
 for now. To host a server with default experts, run this in your shell:
 
 ```sh
-hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_pattern "expert.[0:5]" \
-                --listen_on 0.0.0.0:1337
-# note: if you omit listen_on and/or dht_port, they will be chosen automatically and printed to stdout.
+hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_pattern "expert.[0:5]"
+# note: server will listen to a random port. To specify interface & port, add --host_maddrs and --announce_maddrs
 ```
 
 <details style="margin-top:-24px; margin-bottom: 16px;">
@@ -22,8 +21,7 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_patte
 ```sh
 [2021/07/15 18:52:01.424][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL'], initial peers = []
 [2021/07/15 18:52:01.424][INFO][moe.server.create:181] Generating 5 expert uids from pattern expert.[0:5]
-[2021/07/15 18:52:01.658][INFO][moe.server.run:233] Server started at 0.0.0.0:1337
-[2021/07/15 18:52:01.658][INFO][moe.server.run:234] Got 5 experts:
+[2021/07/15 18:52:01.658][INFO][moe.server.run:233] Server started with 5 experts
 [2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.4: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.0: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:52:01.659][INFO][moe.server.run:237] expert.3: FeedforwardBlock, 2100736 parameters
@@ -67,8 +65,7 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 10 --expert_patt
 ```sh
 [2021/07/15 18:53:41.700][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/34487/p2p/QmcJ3jgbdwphLAiwGjvwrjimJJrdMyhLHf6tFj9viCFFGn'], initial peers = ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL']
 [2021/07/15 18:53:41.700][INFO][moe.server.create:181] Generating 10 expert uids from pattern expert.[5:250]
-[2021/07/15 18:53:42.085][INFO][moe.server.run:233] Server started at 0.0.0.0:36389
-[2021/07/15 18:53:42.086][INFO][moe.server.run:234] Got 10 experts:
+[2021/07/15 18:53:42.085][INFO][moe.server.run:233] Server started with 10 experts:
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.55: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.173: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.164: FeedforwardBlock, 2100736 parameters
@@ -104,10 +101,10 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 10 --expert_patt
 
 </details>
 
-By default, the server will only accept connections from your local machine. To access it globally, you should replace
-`127.0.0.1` part from initial peers with server's IP address. Hivemind supports both ipv4 and ipv6 protocols and uses the same notation
-as [libp2p](https://docs.libp2p.io/concepts/addressing/). You can find more details on multiaddresses in the 
-[DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html).
+By default, the server will only accept connections from your local network. 
+To enable training over the Internet (or some other network), you should set `--host_maddrs` and `--announce_maddrs`.
+These options also allow you to select IPv4/IPv6 network protocols and TCP and QUIC transport protocols.
+You can find more details in the [DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html).
 
 ## Train the experts
 

+ 2 - 1
hivemind/averaging/averager.py

@@ -37,8 +37,8 @@ from hivemind.utils.asyncio import (
     enter_asynchronously,
     switch_to_uvloop,
 )
-from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 
 # flavour types
@@ -709,6 +709,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
 
+                        # TODO merge this with hivemind.compression.deserialize_tensor_stream
                         async for message in aiter_with_timeout(stream, timeout=timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)

+ 5 - 1
hivemind/compression/__init__.py

@@ -6,4 +6,8 @@ from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveComp
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
-from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.compression.serialization import (
+    deserialize_tensor_stream,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)

+ 2 - 2
hivemind/compression/adaptive.py

@@ -3,8 +3,8 @@ from typing import Mapping, Sequence, Union
 
 import torch
 
-import hivemind
 from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole
+from hivemind.compression.serialization import deserialize_torch_tensor
 from hivemind.proto import runtime_pb2
 
 
@@ -20,7 +20,7 @@ class AdaptiveCompressionBase(CompressionBase, ABC):
         return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace)
 
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-        return hivemind.compression.deserialize_torch_tensor(serialized_tensor)
+        return deserialize_torch_tensor(serialized_tensor)
 
 
 class SizeAdaptiveCompression(AdaptiveCompressionBase):

+ 1 - 1
hivemind/compression/base.py

@@ -80,7 +80,7 @@ class NoCompression(CompressionBase):
     compression_type = runtime_pb2.CompressionType.NONE
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
-        array = tensor.numpy()
+        array = tensor.detach().numpy()
         return runtime_pb2.Tensor(
             compression=self.compression_type,
             buffer=array.tobytes(),

+ 25 - 1
hivemind/compression/serialization.py

@@ -1,4 +1,6 @@
-from typing import Dict, Optional
+from __future__ import annotations
+
+from typing import AsyncIterator, Dict, Iterable, List, Optional
 
 import torch
 
@@ -6,6 +8,7 @@ from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompre
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.proto import runtime_pb2
+from hivemind.utils.streaming import combine_from_streaming
 
 BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
     NONE=NoCompression(),
@@ -41,3 +44,24 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
     """Restore a pytorch tensor from a protobuf message"""
     compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
     return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
+
+
+async def deserialize_tensor_stream(
+    stream: AsyncIterator[Iterable[runtime_pb2.Tensor]],
+) -> List[torch.Tensor]:
+    """Async wrapper of combine_from_streaming that combines tensors from a stream of parts and deserializes them"""
+
+    tensors = []
+    tensor_parts = []
+
+    async for parts in stream:
+        for part in parts:
+            if part.dtype and tensor_parts:
+                tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts)))
+                tensor_parts = []
+
+            tensor_parts.append(part)
+    if tensor_parts:
+        tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts)))
+
+    return tensors

+ 3 - 2
hivemind/dht/dht.py

@@ -55,6 +55,7 @@ class DHT(mp.Process):
         **kwargs,
     ):
         self._parent_pid = os.getpid()
+        self._origin_pid = os.getpid()
         super().__init__()
 
         if not (
@@ -309,8 +310,8 @@ class DHT(mp.Process):
         Get a replica of a P2P instance used in the DHT process internally.
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         """
-
-        if self._p2p_replica is None:
+        if self._p2p_replica is None or self._origin_pid != os.getpid():
+            self._origin_pid = os.getpid()
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
             self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
         return self._p2p_replica

+ 0 - 1
hivemind/hivemind_cli/config.yml

@@ -1,4 +1,3 @@
-listen_on: 0.0.0.0:*
 num_experts: 16
 expert_cls: ffn
 hidden_dim: 1024

+ 6 - 3
hivemind/hivemind_cli/run_server.py

@@ -18,8 +18,7 @@ def main():
     # fmt:off
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
-    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
-                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
+
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
                         help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will'
@@ -32,6 +31,11 @@ def main():
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'")
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
 
+    parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
+                        help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
+    parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
+                        help='Visible multiaddrs the host announces for external connections from other p2p instances')
+
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--min_batch_size', type=int, default=1,
@@ -49,7 +53,6 @@ def main():
     parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
-    parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
     parser.add_argument('--increase_file_limit', action='store_true',

+ 23 - 10
hivemind/moe/client/beam_search.py

@@ -5,7 +5,12 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import (
+    RemoteExpert,
+    RemoteExpertInfo,
+    batch_create_remote_experts,
+    create_remote_experts,
+)
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     PREFIX_PATTERN,
@@ -17,6 +22,7 @@ from hivemind.moe.server.expert_uid import (
     UidEndpoint,
     is_valid_prefix,
 )
+from hivemind.p2p import PeerInfo
 from hivemind.utils import MPFuture, get_dht_time, get_logger
 
 logger = get_logger(__name__)
@@ -145,7 +151,7 @@ class MoEBeamSearcher:
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                     successors = {
-                        coord: UidEndpoint(*match.value)
+                        coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
                         for coord, match in maybe_prefix_data.value.items()
                         if isinstance(coord, Coordinate)
                         and isinstance(getattr(match, "value", None), list)
@@ -212,7 +218,7 @@ class MoEBeamSearcher:
         for prefix, found in dht_responses.items():
             if found and isinstance(found.value, dict):
                 successors[prefix] = {
-                    coord: UidEndpoint(*match.value)
+                    coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
                     for coord, match in found.value.items()
                     if isinstance(coord, Coordinate)
                     and 0 <= coord < grid_size
@@ -230,7 +236,7 @@ class MoEBeamSearcher:
 
     def find_best_experts(
         self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
-    ) -> Union[List[RemoteExpert], MPFuture[RemoteExpert]]:
+    ) -> Union[List[RemoteExpert], MPFuture[List[RemoteExpert]]]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -245,7 +251,7 @@ class MoEBeamSearcher:
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         assert len(grid_scores) == len(self.grid_size) and beam_size > 0
-        return self.dht.run_coroutine(
+        result = self.dht.run_coroutine(
             partial(
                 self._find_best_experts,
                 prefix=self.uid_prefix,
@@ -258,6 +264,8 @@ class MoEBeamSearcher:
             return_future,
         )
 
+        return create_remote_experts(result, self.dht, return_future)
+
     @classmethod
     async def _find_best_experts(
         cls,
@@ -269,7 +277,7 @@ class MoEBeamSearcher:
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
-    ) -> List[RemoteExpert]:
+    ) -> List[RemoteExpertInfo]:
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
@@ -322,7 +330,10 @@ class MoEBeamSearcher:
                 push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
                 unique_experts.add(uid_endpoint.uid)
 
-        best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
+        best_experts = [
+            RemoteExpertInfo(uid_endpoint.uid, uid_endpoint.peer_info)
+            for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
+        ]
         return best_experts
 
     @staticmethod
@@ -351,7 +362,7 @@ class MoEBeamSearcher:
 
     def batch_find_best_experts(
         self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
-    ) -> Union[List[List[RemoteExpert]], MPFuture]:
+    ) -> Union[List[List[RemoteExpert]], MPFuture[List[List[RemoteExpert]]]]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -364,7 +375,7 @@ class MoEBeamSearcher:
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
-        return self.dht.run_coroutine(
+        result = self.dht.run_coroutine(
             partial(
                 self._batch_find_best_experts,
                 prefix=self.uid_prefix,
@@ -376,6 +387,8 @@ class MoEBeamSearcher:
             return_future,
         )
 
+        return batch_create_remote_experts(result, self.dht, return_future)
+
     @classmethod
     async def _batch_find_best_experts(
         cls,
@@ -386,7 +399,7 @@ class MoEBeamSearcher:
         beam_size: int,
         negative_caching: bool,
         num_workers: Optional[int],
-    ) -> Sequence[Sequence[RemoteExpert]]:
+    ) -> Sequence[Sequence[RemoteExpertInfo]]:
         batch_grid_scores = [
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
         ]

+ 165 - 37
hivemind/moe/client/expert.py

@@ -1,43 +1,68 @@
-from typing import Any, Dict, Optional, Tuple
+from __future__ import annotations
+
+from concurrent.futures import Future
+from dataclasses import dataclass
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, MSGPackSerializer, nested_compare, nested_flatten, nested_pack
-from hivemind.utils.grpc import ChannelCache
+from hivemind import moe
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import P2P, PeerInfo, StubBase
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.mpfuture import MPFuture
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
+from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.streaming import split_for_streaming
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
-def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
-    """Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache"""
-    channel_options = (("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)) + extra_options
-    return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
+def get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandlerStub":
+    return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
+
+
+@dataclass(frozen=True)
+class RemoteExpertInfo:
+    """A simple data class containing uid of expert and server PeerInfo"""
+
+    uid: str
+    peer_info: PeerInfo
 
 
 class RemoteExpert(nn.Module):
     """
     A simple module that runs forward/backward of an expert hosted on a remote machine.
     Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
-
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
 
-    :param uid: unique expert identifier
-    :param endpoint: network endpoint of a server that services that expert, e.g. "201.123.321.99:1337" or "[::]:8080"
+    :param expert_info: RemoteExpertInfo with uid and server PeerInfo
+    :param p2p: P2P instance connected to the running p2pd
     """
 
-    def __init__(self, uid, endpoint: Endpoint):
+    def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
         super().__init__()
-        self.uid, self.endpoint = uid, endpoint
-        self._info = None
+        self._info, self.p2p = expert_info, p2p
+        self._rpc_info = None
 
     @property
-    def stub(self):
-        return _get_expert_stub(self.endpoint)
+    def uid(self):
+        return self._info.uid
+
+    @property
+    def server_peer_info(self):
+        return self._info.peer_info
+
+    @property
+    def stub(self) -> StubBase:
+        return get_expert_stub(self.p2p, self.server_peer_info)
 
     def forward(self, *args, **kwargs):
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
@@ -52,18 +77,125 @@ class RemoteExpert(nn.Module):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
+
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
 
     @property
     def info(self):
-        if self._info is None:
-            outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
-            self._info = MSGPackSerializer.loads(outputs.serialized_info)
-        return self._info
+        if self._rpc_info is None:
+            outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
+            self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+        return self._rpc_info
 
     def extra_repr(self):
-        return f"uid={self.uid}, endpoint={self.endpoint}"
+        return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
+
+
+def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
+    experts: List[Optional[RemoteExpert]] = []
+    for info in infos:
+        if info is not None:
+            experts.append(RemoteExpert(info, p2p))
+        else:
+            experts.append(None)
+    return experts
+
+
+def create_remote_experts(
+    infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+) -> Union[List[Optional[RemoteExpert]], Future]:
+    if return_future:
+
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return _create_remote_experts(await infos_future, p2p)
+
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+
+    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+    return _create_remote_experts(infos, p2p)
+
+
+def batch_create_remote_experts(
+    infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
+    dht: DHT,
+    return_future: bool = False,
+) -> Union[List[List[Optional[RemoteExpert]]], Future]:
+    if return_future:
+
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return [_create_remote_experts(i, p2p) for i in await infos_future]
+
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+
+    return [create_remote_experts(exps, dht) for exps in infos]
+
+
+async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
+
+    grad_inputs = await stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
+            iter_as_aiter(split),
+        ),
+    )
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+
+
+async def expert_backward(
+    uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
+    size = 0
+    for t in inputs_and_grads:
+        size += t.element_size() * t.nelement()
+        if size > DEFAULT_MAX_MSG_SIZE:
+            return await _backward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _backward_unary(uid, serialized_tensors, stub)
+
+
+async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
+
+    outputs = await stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
+            iter_as_aiter(split),
+        ),
+    )
+
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def expert_forward(
+    uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
+    size = 0
+    for t in inputs:
+        size += t.element_size() * t.nelement()
+        if size > DEFAULT_MAX_MSG_SIZE:
+            return await _forward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _forward_unary(uid, serialized_tensors, stub)
 
 
 class _RemoteModuleCall(torch.autograd.Function):
@@ -74,7 +206,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx,
         dummy: torch.Tensor,
         uid: str,
-        stub: runtime_grpc.ConnectionHandlerStub,
+        stub: "ConnectionHandlerStub",
         info: Dict[str, Any],
         *inputs: torch.Tensor,
     ) -> Tuple[torch.Tensor, ...]:
@@ -83,15 +215,11 @@ class _RemoteModuleCall(torch.autograd.Function):
         inputs = tuple(tensor.cpu().detach() for tensor in inputs)
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
-
-        serialized_tensors = [
-            serialize_torch_tensor(inp, proto.compression)
-            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
-        ]
-
-        outputs = stub.forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
-
-        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+        serialized_tensors = (
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        )
+        deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(uid, inputs, serialized_tensors, stub))
 
         return tuple(deserialized_outputs)
 
@@ -101,12 +229,12 @@ class _RemoteModuleCall(torch.autograd.Function):
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = [
+        serialized_tensors = (
             serialize_torch_tensor(tensor, proto.compression)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        ]
-
-        grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        )
+        deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
+            expert_backward(ctx.uid, inputs_and_grad_outputs, serialized_tensors, ctx.stub)
+        )
 
-        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 28 - 24
hivemind/moe/client/moe.py

@@ -1,20 +1,21 @@
 from __future__ import annotations
 
 import time
+from concurrent.futures import Future
 from queue import Empty, Queue
 from typing import Any, Dict, List, Optional, Tuple
 
-import grpc
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.compression import serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, expert_backward, expert_forward, get_expert_stub
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils.logging import get_logger
 
@@ -104,7 +105,7 @@ class RemoteMixtureOfExperts(nn.Module):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
@@ -178,7 +179,7 @@ class RemoteMixtureOfExperts(nn.Module):
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
-            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
+            dummy_scores = dummy_scores_concat.cpu().detach().split_with_sizes(self.beam_search.grid_size, dim=-1)
             dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             self._expert_info = dummy_experts[0].info
         return self._expert_info
@@ -223,15 +224,18 @@ class _RemoteCallMany(torch.autograd.Function):
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
 
         # dispatch tasks to all remote experts collect responses
-        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
+        pending_tasks: Dict[Future, Tuple[int, int]] = {}
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
-                input_tensors = [
+                stub = get_expert_stub(expert.p2p, expert.server_peer_info)
+                serialized_tensors = (
                     serialize_torch_tensor(tensor, proto.compression)
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
-                ]
-                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
-                new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
+                )
+                new_task = RemoteExpertWorker.run_coroutine(
+                    expert_forward(expert.uid, flat_inputs_per_sample[i], serialized_tensors, stub),
+                    return_future=True,
+                )
                 pending_tasks[new_task] = (i, j)
 
         responded_inds, alive_flat_outputs = cls._collect_responses(
@@ -316,14 +320,16 @@ class _RemoteCallMany(torch.autograd.Function):
         for i, j, inputs_ij, grad_outputs_ij in zip(
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
         ):
-            expert = expert_per_sample[i.item()][j.item()]
-            stub = _get_expert_stub(expert.endpoint)
+            expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
+            stub = get_expert_stub(expert.p2p, expert.server_peer_info)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
-            tensors_serialized = [
+            serialized_tensors = (
                 serialize_torch_tensor(tensor, proto.compression)
                 for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-            ]
-            new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
+            )
+            new_task = RemoteExpertWorker.run_coroutine(
+                expert_backward(expert.uid, inputs_and_grad_outputs, serialized_tensors, stub), return_future=True
+            )
             pending_tasks[new_task] = (i, j)
 
         survivor_inds, survivor_grad_inputs = cls._collect_responses(
@@ -358,7 +364,7 @@ class _RemoteCallMany(torch.autograd.Function):
 
     @staticmethod
     def _collect_responses(
-        task_to_indices: Dict[grpc.Future, Tuple[int, int]],
+        task_to_indices: Dict[Future, Tuple[int, int]],
         num_samples: int,
         k_min: int,
         timeout_total: Optional[float],
@@ -408,17 +414,15 @@ class _RemoteCallMany(torch.autograd.Function):
         return finished_indices, finished_outputs
 
 
-def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
+def _process_dispatched_task(task: Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
     if task.exception() or task.cancelled():
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         return None
 
-    deserialized_outputs = []
-    for tensor in task.result().tensors:
-        deserialized_tensor = deserialize_torch_tensor(tensor)
-        if detect_anomalies and not deserialized_tensor.isfinite().all():
+    outputs = task.result()
+    for tensor in outputs:
+        if detect_anomalies and not tensor.isfinite().all():
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")
             return None
-        deserialized_outputs.append(deserialized_tensor)
 
-    return tuple(deserialized_outputs)
+    return outputs

+ 48 - 0
hivemind/moe/client/remote_expert_worker.py

@@ -0,0 +1,48 @@
+import os
+from concurrent.futures import Future
+from queue import Queue
+from threading import Thread
+from typing import Awaitable, Optional
+
+from hivemind.utils import switch_to_uvloop
+
+
+class RemoteExpertWorker:
+    """Local thread for managing async tasks related to RemoteExpert"""
+
+    _task_queue: Queue = Queue()
+    _event_thread: Optional[Thread] = None
+    _pid: int = -1
+
+    @classmethod
+    def _run(cls):
+        loop = switch_to_uvloop()
+
+        async def receive_tasks():
+            while True:
+                cor, future = cls._task_queue.get()
+                try:
+                    result = await cor
+                except Exception as e:
+                    future.set_exception(e)
+                    continue
+                if not future.cancelled():
+                    future.set_result(result)
+
+        loop.run_until_complete(receive_tasks())
+
+    @classmethod
+    def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
+        if cls._event_thread is None or cls._pid != os.getpid():
+            cls._pid = os.getpid()
+            cls._event_thread = Thread(target=cls._run, daemon=True)
+            cls._event_thread.start()
+
+        future = Future()
+        cls._task_queue.put((coro, future))
+
+        if return_future:
+            return future
+
+        result = future.result()
+        return result

+ 2 - 2
hivemind/moe/client/switch_moe.py

@@ -2,12 +2,12 @@ from __future__ import annotations
 
 from typing import List, Tuple
 
-import grpc
 import torch
 
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 
@@ -110,7 +110,7 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(

+ 102 - 48
hivemind/moe/server/connection_handler.py

@@ -1,82 +1,136 @@
+import asyncio
 import multiprocessing as mp
-import os
-from typing import Dict
+from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Union
 
-import grpc
 import torch
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, MSGPackSerializer, get_logger, nested_flatten
-from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
+from hivemind.moe.server.task_pool import TaskPool
+from hivemind.p2p import P2PContext, ServicerBase
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P
+from hivemind.proto import runtime_pb2
+from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
+from hivemind.utils.asyncio import amap_in_executor, switch_to_uvloop
+from hivemind.utils.streaming import split_for_streaming
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 logger = get_logger(__name__)
 
 
-class ConnectionHandler(mp.context.ForkProcess):
+class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     """
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
 
-    :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port.
-    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+    :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
+    :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
     :param experts: a dict [UID -> ExpertBackend] with all active experts
     """
 
-    def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
+    def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]):
         super().__init__()
-        self.listen_on, self.experts = listen_on, experts
-        self.ready = mp.Event()
+        self.dht, self.experts = dht, experts
+        self._p2p: Optional[P2P] = None
+
+        self.ready = MPFuture()
 
     def run(self):
         torch.set_num_threads(1)
         loop = switch_to_uvloop()
 
         async def _run():
-            grpc.aio.init_grpc_aio()
-            logger.debug(f"Starting, pid {os.getpid()}")
-            server = grpc.aio.server(
-                options=GRPC_KEEPALIVE_OPTIONS
-                + (
-                    ("grpc.so_reuseport", 1),
-                    ("grpc.max_send_message_length", -1),
-                    ("grpc.max_receive_message_length", -1),
-                )
-            )
-            runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
-
-            found_port = server.add_insecure_port(self.listen_on)
-            assert found_port != 0, f"Failed to listen to {self.listen_on}"
-
-            await server.start()
-            self.ready.set()
-            await server.wait_for_termination()
-            logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
+            try:
+                self._p2p = await self.dht.replicate_p2p()
+                await self.add_p2p_handlers(self._p2p, balanced=True)
+
+                # wait forever
+                await asyncio.Future()
+
+            except Exception as e:
+                self.ready.set_exception(e)
+                return
+
+        self.ready.set_result(None)
 
         try:
             loop.run_until_complete(_run())
         except KeyboardInterrupt:
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
-    async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
+    async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
 
-    async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
-        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        future = self.experts[request.uid].forward_pool.submit_task(*inputs)
-        serialized_response = [
-            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
-            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
+    async def _gather_inputs(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> Tuple[str, List[torch.Tensor]]:
+        expert_uid = None
+
+        def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+            nonlocal expert_uid
+
+            if expert_uid is None:
+                expert_uid = req.uid
+            elif expert_uid != req.uid:
+                raise ValueError("Expert uids differ in one request")
+
+            return req.tensors
+
+        tensors_stream = amap_in_executor(_unpack, requests)
+        inputs = await deserialize_tensor_stream(tensors_stream)
+        return expert_uid, inputs
+
+    async def _process_inputs(
+        self,
+        inputs: List[torch.Tensor],
+        pool: TaskPool,
+        schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]],
+    ) -> List[runtime_pb2.Tensor]:
+        return [
+            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            for result, proto in zip(await pool.submit_task(*inputs), nested_flatten(schema))
         ]
 
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        expert = self.experts[request.uid]
+        return runtime_pb2.ExpertResponse(
+            tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+        )
+
+    async def rpc_forward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        uid, inputs = await self._gather_inputs(requests, context)
+        expert = self.experts[uid]
+        output_split = [
+            part
+            for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+        ]
 
-    async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
-        inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
-        serialized_response = [
-            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
-            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part])
+
+    async def rpc_backward(
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> runtime_pb2.ExpertResponse:
+        inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        expert = self.experts[request.uid]
+        return runtime_pb2.ExpertResponse(
+            tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
+        )
+
+    async def rpc_backward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+        uid, inputs_and_grads = await self._gather_inputs(requests, context)
+        expert = self.experts[uid]
+        output_split = [
+            part
+            for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part])

+ 22 - 20
hivemind/moe/server/dht_handler.py

@@ -1,9 +1,9 @@
 import threading
 from functools import partial
-from typing import Dict, List, Optional, Sequence, Tuple
+from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     UID_DELIMITER,
@@ -14,33 +14,31 @@ from hivemind.moe.server.expert_uid import (
     is_valid_uid,
     split_uid,
 )
-from hivemind.utils import Endpoint, get_dht_time, get_port
+from hivemind.p2p import PeerID, PeerInfo
+from hivemind.utils import MPFuture, get_dht_time
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5, **kwargs):
+    def __init__(self, experts, dht: DHT, update_period: int = 5, **kwargs):
         super().__init__(**kwargs)
-        assert get_port(endpoint) is not None
-        self.endpoint = endpoint
         self.experts = experts
         self.dht = dht
         self.update_period = update_period
         self.stop = threading.Event()
 
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys(), self.endpoint)
+        declare_experts(self.dht, self.experts.keys())
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys(), self.endpoint)
+            declare_experts(self.dht, self.experts.keys())
 
 
 def declare_experts(
-    dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration = 300, wait: bool = True
-) -> Dict[ExpertUID, bool]:
+    dht: DHT, uids: Sequence[ExpertUID], expiration: DHTExpiration = 300, wait: bool = True
+) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]:
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
 
     :param uids: a list of expert ids to update
-    :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
     :param expiration: experts will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
@@ -48,23 +46,25 @@ def declare_experts(
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
+    addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in dht.get_visible_maddrs())
     return dht.run_coroutine(
-        partial(_declare_experts, uids=list(uids), endpoint=endpoint, expiration=expiration), return_future=not wait
+        partial(_declare_experts, uids=list(uids), peer_id=dht.peer_id, addrs=addrs, expiration=expiration),
+        return_future=not wait,
     )
 
 
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, addrs: Tuple[str], expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
-        data_to_store[uid, None] = endpoint
+        data_to_store[uid, None] = (peer_id.to_base58(), addrs)
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         for i in range(prefix.count(UID_DELIMITER) - 1):
             prefix, last_coord = split_uid(prefix)
-            data_to_store[prefix, last_coord] = [uid, endpoint]
+            data_to_store[prefix, last_coord] = [uid, (peer_id.to_base58(), addrs)]
 
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
@@ -73,7 +73,7 @@ async def _declare_experts(
 
 def get_experts(
     dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
-) -> List[Optional[RemoteExpert]]:
+) -> Union[List[Optional[RemoteExpert]], MPFuture[List[Optional[RemoteExpert]]]]:
     """
     :param uids: find experts with these ids from across the DHT
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
@@ -81,12 +81,13 @@ def get_experts(
     :returns: a list of [RemoteExpert if found else None]
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-    return dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+    result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+    return create_remote_experts(result, dht, return_future)
 
 
 async def _get_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[RemoteExpert]]:
+) -> List[Optional[RemoteExpertInfo]]:
     if expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
@@ -94,6 +95,7 @@ async def _get_experts(
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     for i, uid in enumerate(uids):
-        if found[uid] is not None and isinstance(found[uid].value, Endpoint):
-            experts[i] = RemoteExpert(uid, found[uid].value)
+        expert_info_for_uid = found[uid]
+        if expert_info_for_uid is not None and isinstance(expert_info_for_uid.value, tuple):
+            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(expert_info_for_uid.value))
     return experts

+ 2 - 2
hivemind/moe/server/expert_uid.py

@@ -1,10 +1,10 @@
 import re
 from typing import NamedTuple, Tuple, Union
 
-from hivemind.utils import Endpoint
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerInfo
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
+UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("peer_info", PeerInfo)])
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims

+ 26 - 40
hivemind/moe/server/server.py

@@ -24,9 +24,9 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
 )
 from hivemind.moe.server.runtime import Runtime
+from hivemind.p2p import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import Endpoint, get_free_port, get_port, replace_port
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 logger = get_logger(__name__)
@@ -41,10 +41,8 @@ class Server(threading.Thread):
      - processes incoming forward/backward requests via Runtime (created by the server)
      - publishes updates to expert status every :update_period: seconds
 
-    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
+    :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
-    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
         if too small for normal functioning, we recommend 4 handlers per expert backend.
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
@@ -55,9 +53,8 @@ class Server(threading.Thread):
 
     def __init__(
         self,
-        dht: Optional[DHT],
+        dht: DHT,
         expert_backends: Dict[str, ExpertBackend],
-        listen_on: Endpoint = "0.0.0.0:*",
         num_connection_handlers: int = 1,
         update_period: int = 30,
         start=False,
@@ -66,22 +63,18 @@ class Server(threading.Thread):
     ):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=get_free_port())
-        self.listen_on, self.port = listen_on, get_port(listen_on)
 
-        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
+        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
         else:
             self.checkpoint_saver = None
         self.runtime = Runtime(self.experts, **kwargs)
 
-        if self.dht and self.experts:
+        if self.experts:
             self.dht_handler_thread = DHTHandlerThread(
                 experts=self.experts,
                 dht=self.dht,
-                endpoint=self.listen_on,
                 update_period=self.update_period,
                 daemon=True,
             )
@@ -92,7 +85,6 @@ class Server(threading.Thread):
     @classmethod
     def create(
         cls,
-        listen_on="0.0.0.0:*",
         num_experts: int = None,
         expert_uids: str = None,
         expert_pattern: str = None,
@@ -107,7 +99,6 @@ class Server(threading.Thread):
         min_batch_size=1,
         max_batch_size=4096,
         device=None,
-        no_dht=False,
         initial_peers=(),
         checkpoint_dir: Optional[Path] = None,
         compression=CompressionType.NONE,
@@ -115,10 +106,11 @@ class Server(threading.Thread):
         custom_module_path=None,
         *,
         start: bool,
+        **kwargs,
     ) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
-        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
+
         :param num_experts: run this many identical experts
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
            means "sample random experts between myprefix.0.0 and myprefix.255.255;
@@ -136,7 +128,6 @@ class Server(threading.Thread):
         :param num_total_steps: the total number of steps for LR schedule
         :param clip_grad_norm: maximum gradient norm used for clipping
 
-        :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
 
         :param checkpoint_dir: directory to save and load expert checkpoints
@@ -147,17 +138,15 @@ class Server(threading.Thread):
 
         :param start: if True, starts server right away and returns when server is ready for requests
         :param stats_report_interval: interval between two reports of batch processing performance statistics
+        :param kwargs: any other params will be forwarded to DHT upon creation
         """
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
         assert expert_cls in name_to_block
 
-        if no_dht:
-            dht = None
-        else:
-            dht = DHT(initial_peers=initial_peers, start=True)
-            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+        dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
+        visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+        logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
         assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
             num_experts is not None and expert_uids is None
@@ -221,7 +210,6 @@ class Server(threading.Thread):
         return cls(
             dht,
             experts,
-            listen_on=listen_on,
             num_connection_handlers=num_handlers,
             device=device,
             checkpoint_dir=checkpoint_dir,
@@ -234,25 +222,24 @@ class Server(threading.Thread):
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         """
-        logger.info(f"Server started at {self.listen_on}")
-        logger.info(f"Got {len(self.experts)} experts:")
+        logger.info(f"Server started with {len(self.experts)} experts:")
         for expert_name, backend in self.experts.items():
             num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
             logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
 
-        if self.dht:
-            if not self.dht.is_alive():
-                self.dht.run_in_background(await_ready=True)
+        if not self.dht.is_alive():
+            self.dht.run_in_background(await_ready=True)
+
+        if self.experts:
+            self.dht_handler_thread.start()
 
-            if self.experts:
-                self.dht_handler_thread.start()
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
 
         for process in self.conn_handlers:
             if not process.is_alive():
                 process.start()
-            process.ready.wait()
+            process.ready.result()
 
         try:
             self.runtime.run()
@@ -294,7 +281,7 @@ class Server(threading.Thread):
             process.join()
         logger.debug("Connection handlers terminated")
 
-        if self.dht and self.experts:
+        if self.experts:
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.join()
 
@@ -302,9 +289,8 @@ class Server(threading.Thread):
             self.checkpoint_saver.stop.set()
             self.checkpoint_saver.join()
 
-        if self.dht is not None:
-            self.dht.shutdown()
-            self.dht.join()
+        self.dht.shutdown()
+        self.dht.join()
 
         logger.debug(f"Shutting down runtime")
 
@@ -313,14 +299,14 @@ class Server(threading.Thread):
 
 
 @contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, List[Multiaddr]]:
-    """A context manager that creates server in a background process, awaits .ready on entry and shuts down on exit"""
+def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
+    """A context manager that creates server in a background , awaits .ready on entry and shuts down on exit"""
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     try:
         runner.start()
         # once the server is ready, runner will send us
-        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
+        # either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
         start_ok, data = pipe.recv()
         if start_ok:
             yield data
@@ -344,8 +330,8 @@ def _server_runner(pipe, *args, **kwargs):
         return
 
     try:
-        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
-        pipe.send((True, (server.listen_on, dht_maddrs)))
+        dht_maddrs = server.dht.get_visible_maddrs()
+        pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
         pipe.recv()  # wait for shutdown signal
 
     finally:

+ 13 - 6
hivemind/p2p/p2p_daemon.py

@@ -341,6 +341,7 @@ class P2P:
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
+        balanced: bool = False,
     ) -> None:
         """
         :param max_prefetch: Maximum number of items to prefetch from the request stream.
@@ -405,7 +406,7 @@ class P2P:
                 finally:
                     processing_task.cancel()
 
-        await self.add_binary_stream_handler(name, _handle_stream)
+        await self.add_binary_stream_handler(name, _handle_stream, balanced=balanced)
 
     async def _iterate_protobuf_stream_handler(
         self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
@@ -447,16 +448,19 @@ class P2P:
         *,
         stream_input: bool = False,
         stream_output: bool = False,
+        balanced: bool = False,
     ) -> None:
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
                              (not just ``TInputProtobuf``) as input.
         :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
                               (not ``Awaitable[TOutputProtobuf]``).
+        :param balanced: If True, handler will be balanced on p2pd side between all handlers in python.
+                         Default: False
         """
 
         if not stream_input and not stream_output:
-            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type, balanced=balanced)
             return
 
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
@@ -469,13 +473,14 @@ class P2P:
             else:
                 yield await output
 
-        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
+        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type, balanced=balanced)
 
     async def _add_protobuf_unary_handler(
         self,
         handle_name: str,
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         input_protobuf_type: Type[Message],
+        balanced: bool = False,
     ) -> None:
         """
         Register a request-response (unary) handler. Unary requests and responses
@@ -497,7 +502,7 @@ class P2P:
             response = await handler(input_serialized, context)
             return response.SerializeToString()
 
-        await self._client.add_unary_handler(handle_name, _unary_handler)
+        await self._client.add_unary_handler(handle_name, _unary_handler, balanced=balanced)
 
     async def call_protobuf_handler(
         self,
@@ -541,10 +546,12 @@ class P2P:
 
         self._listen_task = asyncio.create_task(listen())
 
-    async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
+    async def add_binary_stream_handler(
+        self, name: str, handler: p2pclient.StreamHandler, balanced: bool = False
+    ) -> None:
         if self._listen_task is None:
             self._start_listening()
-        await self._client.stream_handler(name, handler)
+        await self._client.stream_handler(name, handler, balanced)
 
     async def call_binary_stream_handler(
         self, peer_id: PeerID, handler_name: str

+ 6 - 4
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -246,10 +246,10 @@ class ControlClient:
         self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
         self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
         call_id = uuid4()
 
-        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced)
         req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
 
         if self.unary_handlers.get(proto):
@@ -358,11 +358,13 @@ class ControlClient:
 
         return stream_info, reader, writer
 
-    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
         reader, writer = await self.daemon_connector.open_connection()
 
         listen_path_maddr_bytes = self.listen_maddr.to_bytes()
-        stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto])
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(
+            addr=listen_path_maddr_bytes, proto=[proto], balanced=balanced
+        )
         req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
         await write_pbmsg(writer, req)
 

+ 7 - 1
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -5,7 +5,7 @@ Author: Kevin Mai-Husan Chia
 """
 
 import hashlib
-from typing import Any, Sequence, Union
+from typing import Any, Sequence, Tuple, Union
 
 import base58
 import multihash
@@ -128,6 +128,12 @@ class PeerInfo:
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         return PeerInfo(peer_id, addrs)
 
+    @classmethod
+    def from_tuple(cls, value: Tuple[str, Sequence[str]]) -> "PeerInfo":
+        peer_id = PeerID.from_base58(value[0])
+        addrs = [Multiaddr(addr) for addr in value[1]]
+        return PeerInfo(peer_id, addrs)
+
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 

+ 5 - 4
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -61,8 +61,8 @@ class Client:
         async with self.control.listen():
             yield self
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
-        await self.control.add_unary_handler(proto, handler)
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
+        await self.control.add_unary_handler(proto, handler, balanced=balanced)
 
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         return await self.control.call_unary_handler(peer_id, proto, data)
@@ -105,11 +105,12 @@ class Client:
         """
         return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
 
-    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
         """
         Register a stream handler
         :param proto: protocols that handler serves
         :param handler_cb: handler callback
+        :param balanced: flag if stream handler should be balanced on p2pd side. Default: False.
         :return:
         """
-        await self.control.stream_handler(proto=proto, handler_cb=handler_cb)
+        await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced)

+ 4 - 2
hivemind/p2p/servicer.py

@@ -104,11 +104,12 @@ class ServicerBase:
         caller.__name__ = handler.method_name
         return caller
 
-    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None) -> None:
+    async def add_p2p_handlers(
+        self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None, balanced: bool = False
+    ) -> None:
         self._collect_rpc_handlers()
 
         servicer = self if wrapper is None else wrapper
-
         await asyncio.gather(
             *[
                 p2p.add_protobuf_handler(
@@ -117,6 +118,7 @@ class ServicerBase:
                     handler.request_type,
                     stream_input=handler.stream_input,
                     stream_output=handler.stream_output,
+                    balanced=balanced,
                 )
                 for handler in self._rpc_handlers
             ]

+ 2 - 0
hivemind/proto/p2pd.proto

@@ -90,6 +90,7 @@ message StreamOpenRequest {
 message StreamHandlerRequest {
   required bytes addr = 1;
   repeated string proto = 2;
+  required bool balanced = 3;
 }
 
 message ErrorResponse {
@@ -201,6 +202,7 @@ message CallUnaryResponse {
 
 message AddUnaryHandlerRequest {
   required string proto = 1;
+  required bool balanced = 2;
 }
 
 message DaemonError {

+ 1 - 1
hivemind/utils/__init__.py

@@ -1,5 +1,4 @@
 from hivemind.utils.asyncio import *
-from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
@@ -7,5 +6,6 @@ from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 7 - 1
hivemind/utils/asyncio.py

@@ -2,7 +2,7 @@ import asyncio
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar, Union
 
 import uvloop
 
@@ -29,6 +29,12 @@ async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
     return await aiter.__anext__()
 
 
+async def iter_as_aiter(iterable: Iterable[T]) -> AsyncIterator[T]:
+    """create an asynchronous iterator from single iterable"""
+    for elem in iterable:
+        yield elem
+
+
 async def as_aiter(*args: T) -> AsyncIterator[T]:
     """create an asynchronous iterator from a sequence of values"""
     for arg in args:

+ 0 - 210
hivemind/utils/grpc.py

@@ -1,210 +0,0 @@
-"""
-Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
-"""
-
-from __future__ import annotations
-
-import os
-import threading
-from typing import Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
-
-import grpc
-
-from hivemind.proto import runtime_pb2
-from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import Endpoint
-from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
-
-logger = get_logger(__name__)
-
-Stub = TypeVar("Stub")
-
-GRPC_KEEPALIVE_OPTIONS = (
-    ("grpc.keepalive_time_ms", 60 * 1000),
-    ("grpc.keepalive_timeout_ms", 60 * 1000),
-    ("grpc.keepalive_permit_without_calls", True),
-    ("grpc.http2.max_pings_without_data", 0),
-    ("grpc.http2.min_time_between_pings_ms", 30 * 1000),
-    ("grpc.http2.min_ping_interval_without_data_ms", 10 * 1000),
-)
-
-
-class ChannelInfo(NamedTuple):
-    target: Endpoint
-    aio: bool
-    options: Tuple[Tuple[str, str], ...]
-    credentials: Optional[grpc.ChannelCredentials]
-    compression: Optional[grpc.Compression]
-
-
-class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.Channel], Dict]]):
-    """
-    A process-wide cache of gRPC channels, supports both normal and aio channels, secure/insecure channels, etc
-    Based on grpcio internal channel cache by Richard Belleville and Lidi Zheng (thanks!)
-    Unlike TimedStorage, ChannelCache actively evicts stale channels even if the cache is not accessed
-    Unlike grpc._simple_stubs.ChannelCache, this implementation supports aio and does not forcibly close active channels
-    """
-
-    MAXIMUM_CHANNELS = int(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM", 4096))
-    EVICTION_PERIOD_SECONDS = float(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS", 10 * 60))
-    logger.debug(f"Eviction period = {EVICTION_PERIOD_SECONDS}s, max channels = {MAXIMUM_CHANNELS}")
-
-    _singleton: Optional[ChannelCache] = None
-    _singleton_pid: int = os.getpid()
-    _lock: threading.RLock = threading.RLock()
-    _update_eviction_evt: threading.Event = threading.Event()
-
-    def __init__(self, _created_as_singleton=False):
-        assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
-        super().__init__(maxsize=self.MAXIMUM_CHANNELS)
-        self._is_active = True
-        self._nearest_expiration_time = float("inf")
-        self._eviction_thread = threading.Thread(target=self._evict_stale_channels_in_background, daemon=True)
-        self._eviction_thread.start()
-
-    @classmethod
-    def get_singleton(cls):
-        """Get or create the channel cache for the current process"""
-        with cls._lock:
-            if cls._singleton is None or cls._singleton_pid != os.getpid():
-                if cls._singleton is not None:
-                    cls._singleton._stop_background_thread()
-                cls._singleton, cls._singleton_pid = cls(_created_as_singleton=True), os.getpid()
-            return cls._singleton
-
-    @classmethod
-    def get_stub(
-        cls,
-        target: Endpoint,
-        stub_type: Type[Stub],
-        *,
-        aio: bool,
-        options: Tuple[Tuple[str, Any]] = (),
-        channel_credentials: Optional[grpc.ChannelCredentials] = None,
-        compression: Optional[grpc.Compression] = None,
-    ) -> Stub:
-        """
-        Create a grpc channel with given options or reuse pre-existing one
-
-        :param target: the recipient's address and port
-        :param stub_type: a gRPC stub (client) to be instantiated
-        :param aio: if True, returns grpc.Channel, otherwise returns grpc.aio.Channel
-        :param options: see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
-        :param channel_credentials: if specified, create a secure channel usin these credentials (default = insecure)
-        :param compression: see https://github.com/grpc/grpc/tree/master/examples/python/compression
-        """
-        cache = cls.get_singleton()
-        with cls._lock:
-            key = ChannelInfo(target, aio, tuple(options), channel_credentials, compression)
-            entry: ValueWithExpiration = super(cls, cache).get(key)
-
-            if entry is not None:
-                channel, stubs = entry.value
-            else:
-                channel = cls._create_channel(*key)
-                stubs = {}
-
-            channel._channel.check_connectivity_state(True)
-
-            if stub_type not in stubs:
-                stubs[stub_type] = stub_type(channel)
-
-            # either cache channel or update expiration of an existing channel
-            expiration_time = get_dht_time() + cls.EVICTION_PERIOD_SECONDS
-            super(cls, cache).store(key, (channel, stubs), expiration_time)
-
-            if expiration_time < cache._nearest_expiration_time:
-                cache._nearest_expiration_time = expiration_time
-                cls._update_eviction_evt.set()
-
-            return stubs[stub_type]
-
-    @classmethod
-    def _create_channel(
-        cls,
-        target: Endpoint,
-        aio: bool,
-        extra_options: Tuple[Tuple[str, Any], ...],
-        channel_credentials: Optional[grpc.ChannelCredentials],
-        compression: Optional[grpc.Compression],
-    ) -> Union[grpc.Channel, grpc.aio.Channel]:
-        namespace = grpc.aio if aio else grpc
-
-        options = extra_options + GRPC_KEEPALIVE_OPTIONS
-
-        if channel_credentials is None:
-            logger.debug(
-                f"Creating insecure {namespace} channel with options '{options}' " f"and compression '{compression}'"
-            )
-            return namespace.insecure_channel(target, options=options, compression=compression)
-        else:
-            logger.debug(
-                f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
-                f"options '{options}' and compression '{compression}'"
-            )
-            return namespace.secure_channel(
-                target, credentials=channel_credentials, options=options, compression=compression
-            )
-
-    def _evict_stale_channels_in_background(self):
-        while self._is_active:
-            now = get_dht_time()
-            time_to_wait = max(0.0, self._nearest_expiration_time - now)
-            interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float("inf") else None)
-            if interrupted_early:
-                self._update_eviction_evt.clear()
-                continue
-
-            with self._lock:
-                self._remove_outdated()
-                _, entry = super().top()
-                self._nearest_expiration_time = entry.expiration_time if entry is not None else float("inf")
-
-    def _stop_background_thread(self):
-        with self._lock:
-            self._is_active = False
-            self._update_eviction_evt.set()
-
-    def store(self, *args, **kwargs) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-    def get(self, *args, **kwargs) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-    def top(self) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-
-STREAMING_CHUNK_SIZE_BYTES = 2**16
-
-
-def split_for_streaming(
-    serialized_tensor: runtime_pb2.Tensor,
-    chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
-) -> Iterator[runtime_pb2.Tensor]:
-    """Split serialized_tensor into multiple chunks for gRPC streaming"""
-    buffer = memoryview(serialized_tensor.buffer)
-    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
-    yield runtime_pb2.Tensor(
-        compression=serialized_tensor.compression,
-        buffer=buffer[:chunk_size_bytes].tobytes(),
-        chunks=num_chunks,
-        size=serialized_tensor.size,
-        dtype=serialized_tensor.dtype,
-        requires_grad=serialized_tensor.requires_grad,
-    )
-    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
-        yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
-
-
-def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
-    """Restore a result of split_into_chunks into a single serialized tensor"""
-    stream = iter(stream)
-    first_chunk = next(stream)
-    serialized_tensor = runtime_pb2.Tensor()
-    serialized_tensor.CopyFrom(first_chunk)
-    buffer_chunks = [first_chunk.buffer]
-    for tensor_part in stream:
-        buffer_chunks.append(tensor_part.buffer)
-    serialized_tensor.buffer = b"".join(buffer_chunks)
-    return serialized_tensor

+ 2 - 24
hivemind/utils/networking.py

@@ -1,35 +1,13 @@
 import socket
 from contextlib import closing
 from ipaddress import ip_address
-from typing import Optional, Sequence
+from typing import Sequence
 
 from multiaddr import Multiaddr
 
-Hostname, Port = str, int  # flavour types
-Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = "127.0.0.1"
 
 
-def get_port(endpoint: Endpoint) -> Optional[Port]:
-    """get port or None if port is undefined"""
-    # TODO: find a standard way to get port, make sure it works in malformed ports
-    try:
-        return int(endpoint[endpoint.rindex(":") + 1 :], base=10)
-    except ValueError:  # :* or not specified
-        return None
-
-
-def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
-    assert endpoint.endswith(":*") or get_port(endpoint) is not None, endpoint
-    return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
-
-
-def strip_port(endpoint: Endpoint) -> Hostname:
-    """Removes port from the end of endpoint. If port is not specified, does nothing"""
-    maybe_port = endpoint[endpoint.rindex(":") + 1 :]
-    return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
-
-
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.
@@ -48,7 +26,7 @@ def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_S
 
 def choose_ip_address(
     maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
-) -> Hostname:
+) -> str:
     """
     Currently, some components of hivemind are not converted to work over libp2p and use classical networking.
     To allow other peers reach a server when needed, these components announce a machine's IP address.

+ 49 - 0
hivemind/utils/streaming.py

@@ -0,0 +1,49 @@
+"""
+Utilities for streaming tensors
+"""
+
+from __future__ import annotations
+
+from typing import Iterable, Iterator, TypeVar
+
+from hivemind.proto import runtime_pb2
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+STREAMING_CHUNK_SIZE_BYTES = 2**16
+
+
+def split_for_streaming(
+    serialized_tensor: runtime_pb2.Tensor,
+    chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
+) -> Iterator[runtime_pb2.Tensor]:
+    """Split serialized_tensor into multiple chunks for streaming"""
+    buffer = memoryview(serialized_tensor.buffer)
+    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
+    yield runtime_pb2.Tensor(
+        compression=serialized_tensor.compression,
+        buffer=buffer[:chunk_size_bytes].tobytes(),
+        chunks=num_chunks,
+        size=serialized_tensor.size,
+        dtype=serialized_tensor.dtype,
+        requires_grad=serialized_tensor.requires_grad,
+    )
+    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
+        yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
+
+
+def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
+    """Restore a result of split_into_chunks into a single serialized tensor"""
+    stream = iter(stream)
+    first_chunk = next(stream)
+    serialized_tensor = runtime_pb2.Tensor()
+    serialized_tensor.CopyFrom(first_chunk)
+    buffer_chunks = [first_chunk.buffer]
+    for tensor_part in stream:
+        buffer_chunks.append(tensor_part.buffer)
+    serialized_tensor.buffer = b"".join(buffer_chunks)
+    return serialized_tensor
+
+
+StreamMessage = TypeVar("StreamMessage")

+ 2 - 2
setup.py

@@ -13,14 +13,14 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.8"
+P2PD_VERSION = "v0.3.9"
 
 P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
 
 # The value is sha256 of the binary from the release page
 EXECUTABLES = {
-    "p2pd": "785058526d993f699c674dc2f9b66d565a52315a18b79b629998fab3ebd8e20f",
+    "p2pd": "8f9434f4717f6e851430f75f07e283d5ddeb2c7cde1b3648e677d813703f4e40",
 }
 
 

+ 3 - 2
tests/test_compression.py

@@ -20,6 +20,7 @@ from hivemind.compression import (
 )
 from hivemind.compression.adaptive import AdaptiveCompressionBase
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 
 from test_utils.dht_swarms import launch_dht_instances
 
@@ -47,9 +48,9 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
 def test_serialize_tensor():
     def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
         serialized_tensor = serialize_torch_tensor(tensor, compression)
-        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+        chunks = list(split_for_streaming(serialized_tensor, chunk_size))
         assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-        restored = hivemind.combine_from_streaming(chunks)
+        restored = combine_from_streaming(chunks)
         assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
 
     tensor = torch.randn(512, 12288)

+ 192 - 0
tests/test_connection_handler.py

@@ -0,0 +1,192 @@
+from __future__ import annotations
+
+import asyncio
+import math
+from typing import Any, Dict
+
+import pytest
+import torch
+
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.task_pool import TaskPool
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.streaming import split_for_streaming
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_info():
+    handler = ConnectionHandler(
+        DHT(start=True),
+        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    # info
+    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
+    assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
+
+    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
+    assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
+
+    with pytest.raises(P2PHandlerError):
+        await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_forward():
+    handler = ConnectionHandler(
+        DHT(start=True),
+        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    inputs = torch.randn(1, 2)
+    inputs_long = torch.randn(2**21, 2)
+
+    # forward unary
+    response = await client_stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * 1)
+
+    response = await client_stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * 2)
+
+    # forward streaming
+    split = (
+        p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
+    )
+    output_generator = await client_stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
+            iter_as_aiter(split),
+        ),
+    )
+    outputs_list = [part async for part in output_generator]
+    assert len(outputs_list) == math.ceil(inputs_long.numel() * 4 / DEFAULT_MAX_MSG_SIZE)
+
+    results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
+    assert len(results) == 1
+    assert torch.allclose(results[0], inputs_long * 2)
+
+    # forward errors
+    with pytest.raises(P2PHandlerError):
+        # no such expert: fails with P2PHandlerError KeyError('expert3')
+        await client_stub.rpc_forward(
+            runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
+        )
+
+    with pytest.raises(P2PHandlerError):
+        # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
+        await client_stub.rpc_forward(
+            runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
+        )
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_backward():
+    handler = ConnectionHandler(
+        DHT(start=True),
+        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    inputs = torch.randn(1, 2)
+    inputs_long = torch.randn(2**21, 2)
+
+    # backward unary
+    response = await client_stub.rpc_backward(
+        runtime_pb2.ExpertRequest(
+            uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
+        )
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * -2)
+
+    # backward streaming
+    split = (
+        p
+        for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
+        for p in split_for_streaming(t, chunk_size_bytes=2**16)
+    )
+    output_generator = await client_stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
+            iter_as_aiter(split),
+        ),
+    )
+    results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
+    assert len(results) == 1
+    assert torch.allclose(results[0], inputs_long * 3)
+
+    # backward errors
+    with pytest.raises(P2PHandlerError):
+        # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
+        await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
+
+    with pytest.raises(P2PHandlerError):
+        # backward fails: empty stream
+        output_generator = await client_stub.rpc_backward_stream(
+            amap_in_executor(
+                lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
+                iter_as_aiter([]),
+            ),
+        )
+        results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
+        assert len(results) == 1
+        assert torch.allclose(results[0], inputs_long * 3)
+
+    # check that handler did not crash after failed request
+    await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
+
+    handler.terminate()
+    handler.join()
+
+
+class DummyPool(TaskPool):
+    def __init__(self, k: float):
+        self.k = k
+
+    async def submit_task(self, *inputs: torch.Tensor):
+        await asyncio.sleep(0.01)
+        assert inputs[0].shape[-1] == 2
+        return [inputs[0] * self.k]
+
+
+class DummyExpertBackend(ExpertBackend):
+    def __init__(self, name: str, k: float):
+        self.name = name
+        self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.forward_pool = DummyPool(k)
+        self.backward_pool = DummyPool(k)
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(name=self.name)

+ 20 - 9
tests/test_custom_experts.py

@@ -3,7 +3,8 @@ import os
 import pytest
 import torch
 
-from hivemind import RemoteExpert
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server import background_server
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@@ -17,11 +18,16 @@ def test_custom_expert(hid_dim=16):
         device="cpu",
         hidden_dim=hid_dim,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = create_remote_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
 
         for batch_size in (1, 4):
             batch = torch.randn(batch_size, hid_dim)
@@ -43,11 +49,16 @@ def test_multihead_expert(hid_dim=16):
         device="cpu",
         hidden_dim=hid_dim,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = create_remote_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
 
         for batch_size in (1, 4):
             batch = (

+ 28 - 23
tests/test_dht_experts.py

@@ -6,11 +6,11 @@ import numpy as np
 import pytest
 
 import hivemind
-from hivemind import LOCALHOST
 from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
+from hivemind.p2p import PeerInfo
 
 
 @pytest.mark.forked
@@ -25,17 +25,18 @@ def test_store_get_experts(n_peers=10):
     expert_uids = [f"my_expert.{i}" for i in range(50)]
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
-        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], "localhost:1234")
+        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size])
 
     found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
-    other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
-    declare_experts(other_peer, [other_expert], f"that_host:{other_port}")
+    other_expert = "my_other_expert.1337"
+    declare_experts(other_peer, [other_expert])
     first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
-    assert first_found.endpoint == f"that_host:{other_port}"
+    assert first_found.server_peer_info.peer_id == other_peer.peer_id
+    assert first_notfound is None
 
     # test graceful shutdown
     first_peer.shutdown()
@@ -43,30 +44,31 @@ def test_store_get_experts(n_peers=10):
     time.sleep(1.0)
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
-    assert all(declare_experts(remaining_peer1, ["new_expert.1"], "dummy"))
-    assert get_experts(remaining_peer2, ["new_expert.1"])[0].endpoint == "dummy"
+    assert all(declare_experts(remaining_peer1, ["new_expert.1"]))
+    assert get_experts(remaining_peer2, ["new_expert.1"])[0].server_peer_info.peer_id == remaining_peer1.peer_id
 
 
 @pytest.mark.forked
 def test_beam_search(
     n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
 ):
-    dht = [hivemind.DHT(start=True)]
-    initial_peers = dht[0].get_visible_maddrs()
-    dht += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
+    dht_instances = [hivemind.DHT(start=True)]
+    initial_peers = dht_instances[0].get_visible_maddrs()
+    dht_instances += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
 
     real_experts = sorted(
         {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
     )
     for batch_start in range(0, len(real_experts), batch_size):
+        dht = random.choice(dht_instances)
         declare_experts(
-            random.choice(dht),
+            dht,
             real_experts[batch_start : batch_start + batch_size],
-            wait=True,
-            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}",
         )
 
-    neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
+    neighbors = sum(
+        [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], []
+    )
     you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
     beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
 
@@ -89,22 +91,25 @@ def test_dht_single_node():
     node = hivemind.DHT(start=True)
     beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
 
-    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], f"{hivemind.LOCALHOST}:1337").values())
-    assert len(declare_experts(node, ["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], f"{hivemind.LOCALHOST}:42")) == 7
+    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"]).values())
+    assert len(declare_experts(node, ["ffn.1", "ffn.2"])) == 4
+    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"])) == 7
 
     for expert in get_experts(node, ["expert.3", "expert.2"]):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+        assert expert.server_peer_info.peer_id == node.peer_id
 
-    assert all(declare_experts(node, ["expert.5", "expert.2"], f"{hivemind.LOCALHOST}:1337").values())
+    assert all(declare_experts(node, ["expert.5", "expert.2"]).values())
     found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
     assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
 
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     assert len(successors["e.1.2."]) == 2
-    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", f"{LOCALHOST}:42")
-    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", f"{LOCALHOST}:42")
-    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", f"{LOCALHOST}:42")
+
+    peer_info = PeerInfo(node.peer_id, [a.decapsulate("/p2p/" + a.get("p2p")) for a in node.get_visible_maddrs()])
+
+    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", peer_info)
+    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", peer_info)
+    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", peer_info)
     assert successors["e.4.5."] == {}
 
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
@@ -194,7 +199,7 @@ async def test_negative_caching(n_peers=10):
     peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
 
     writer_peer = random.choice(peers)
-    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], "myaddr:1234").values())
+    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"]).values())
 
     neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)

+ 40 - 28
tests/test_moe.py

@@ -1,13 +1,14 @@
-import grpc
 import numpy as np
 import pytest
 import torch
 
 from hivemind.dht import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
+from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
@@ -18,8 +19,8 @@ def test_moe():
     ]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
         dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
 
@@ -35,9 +36,8 @@ def test_no_experts():
     ]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
-
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
         dmoe = RemoteSwitchMixtureOfExperts(
             in_features=16,
             grid_size=(4, 4, 4),
@@ -71,12 +71,16 @@ def test_call_many(hidden_dim=16):
         num_handlers=1,
         hidden_dim=hidden_dim,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
+    ) as server_peer_info:
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
+
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        e0, e1, e2, e3, e4 = create_remote_experts(
+            [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
+            dht,
+        )
+        e5 = RemoteExpert(RemoteExpertInfo(f"thisshouldnotexist", server_peer_info), None)
 
         mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
@@ -129,11 +133,15 @@ def test_remote_module_call(hidden_dim=16):
         num_handlers=1,
         hidden_dim=hidden_dim,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        real_expert = RemoteExpert("expert.0", server_endpoint)
-        fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
-
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        real_expert, fake_expert = create_remote_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
         dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
@@ -144,9 +152,9 @@ def test_remote_module_call(hidden_dim=16):
         out3_again.norm().backward()
         assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
 
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             real_expert(torch.randn(3, 11))
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             fake_expert(dummy_x)
 
 
@@ -154,11 +162,11 @@ def test_remote_module_call(hidden_dim=16):
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
     dht = DHT(start=True)
-    assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
+    assert all(declare_experts(dht, all_expert_uids))
 
     dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
 
-    for i in range(25):
+    for _ in range(25):
         input = torch.randn(32)
         grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
 
@@ -173,7 +181,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[RemoteExpert(uid, "") for uid in all_expert_uids]],
+            [[RemoteExpert(RemoteExpertInfo(uid, None), None) for uid in all_expert_uids]],
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
@@ -194,9 +202,12 @@ def test_determinism(hidden_dim=16):
         num_handlers=1,
         hidden_dim=hidden_dim,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert = create_remote_experts(
+            [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
+            dht=dht,
+        )[0]
 
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -220,7 +231,7 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
             [
-                RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
+                RemoteExpert(RemoteExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
                 for expert_i in range(len(ii[batch_i]))
             ]
             for batch_i in range(len(ii))
@@ -261,9 +272,10 @@ def test_client_anomaly_detection():
     server.start()
     try:
         server.ready.wait()
+        client_side_dht = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
 
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(3,), dht=client_side_dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
 
         input = torch.randn(1, 16)
@@ -280,7 +292,7 @@ def test_client_anomaly_detection():
             inf_loss.backward()
 
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(4,), dht=client_side_dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         output = dmoe(input)
         assert output.isfinite().all()

+ 8 - 2
tests/test_p2p_daemon_bindings.py

@@ -560,13 +560,19 @@ async def test_client_stream_handler_success(p2pcs):
 
     writer.close()
 
-    # test case: registering twice can override the previous registration
+    # test case: registering twice can't override the previous registration without balanced flag
     event_third = asyncio.Event()
 
     async def handler_third(stream_info, reader, writer):
         event_third.set()
 
-    await p2pcs[1].stream_handler(another_proto, handler_third)
+    # p2p raises now for doubled stream handlers
+    with pytest.raises(ControlFailure):
+        await p2pcs[1].stream_handler(another_proto, handler_third)
+
+    # add in balanced mode: handler should be placed in round robin queue
+    # and become the next to be called
+    await p2pcs[1].stream_handler(another_proto, handler_third, balanced=True)
     assert another_proto in p2pcs[1].control.handlers
     # ensure the handler is override
     assert handler_third == p2pcs[1].control.handlers[another_proto]

+ 17 - 11
tests/test_training.py

@@ -8,7 +8,8 @@ import torch.nn.functional as F
 from sklearn.datasets import load_digits
 
 from hivemind import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server import background_server
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
@@ -19,12 +20,17 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     SGD = partial(torch.optim.SGD, lr=0.05)
 
-    with background_server(num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1, no_dht=True) as (
-        server_endpoint,
-        _,
-    ):
-        expert1 = RemoteExpert("expert.0", server_endpoint)
-        expert2 = RemoteExpert("expert.1", server_endpoint)
+    with background_server(
+        num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert1, expert2 = create_remote_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
 
         opt = SGD(model.parameters(), lr=0.05)
@@ -54,8 +60,8 @@ def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
         model = nn.Sequential(moe, nn.Linear(64, 2))
@@ -107,8 +113,8 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
         model = SwitchNetwork(dht, 64, 2, num_experts)
         opt = SGD(model.parameters(), lr=0.05)

+ 1 - 47
tests/test_util_modules.py

@@ -11,9 +11,7 @@ import torch
 
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils.asyncio import (
     achain,
@@ -330,50 +328,6 @@ def test_many_futures():
     p.join()
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_channel_cache():
-    hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
-    hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
-
-    c1 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-    c2 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
-    c3 = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
-    c3_again = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
-    c1_again = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-    c4 = hivemind.ChannelCache.get_stub("localhost:1339", DHTStub, aio=True)
-    c2_anew = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
-    c1_yetagain = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-
-    await asyncio.sleep(0.2)
-    c1_anew = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
-    c1_anew_again = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
-    c1_otherstub = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub)
-    await asyncio.sleep(0.05)
-    c1_otherstub_again = hivemind.ChannelCache.get_stub(
-        target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub
-    )
-    all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
-
-    assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
-    assert isinstance(all_channels[-1], ConnectionHandlerStub)
-    assert "aio" in repr(c2.rpc_find)
-    assert "aio" not in repr(c1.rpc_find)
-
-    duplicates = {
-        (c1, c1_again),
-        (c1, c1_yetagain),
-        (c1_again, c1_yetagain),
-        (c3, c3_again),
-        (c1_anew, c1_anew_again),
-        (c1_otherstub, c1_otherstub_again),
-    }
-    for i in range(len(all_channels)):
-        for j in range(i + 1, len(all_channels)):
-            ci, cj = all_channels[i], all_channels[j]
-            assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
-
-
 def test_serialize_tuple():
     test_pairs = (
         ((1, 2, 3), [1, 2, 3]),
@@ -419,7 +373,7 @@ def test_split_parts():
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
         with pytest.raises(RuntimeError):
             deserialize_torch_tensor(combined)
-            # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllreduceRunner
+            # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllReduceRunner
 
 
 def test_generic_data_classes():