Эх сурвалжийг харах

Convert hivemind.server to libp2p backend (#470)

Switch hivemind MoE from gRPC to libp2p.
This allows serving experts from behind NAT / firewalls and improves performance under latency.

Changes:
 - RemoteExpert (and MoEs) now communicate to servers via libp2p
 - Got rid of listen_on parameters in hivemind.Server and CLI tools
 - ConnectionHandlers now use load balancing for better performance (see benchmarks in the corresponding PR)
 - updated docs & tests

Co-authored-by: Denis Mazur <denismazur8@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
GreenFatGuy 3 жил өмнө
parent
commit
724cdfe5e2
40 өөрчлөгдсөн 915 нэмэгдсэн , 603 устгасан
  1. 2 2
      .github/workflows/run-tests.yml
  2. 35 17
      benchmarks/benchmark_throughput.py
  3. 9 12
      docs/user/moe.md
  4. 2 1
      hivemind/averaging/averager.py
  5. 5 1
      hivemind/compression/__init__.py
  6. 2 2
      hivemind/compression/adaptive.py
  7. 1 1
      hivemind/compression/base.py
  8. 25 1
      hivemind/compression/serialization.py
  9. 3 2
      hivemind/dht/dht.py
  10. 0 1
      hivemind/hivemind_cli/config.yml
  11. 6 3
      hivemind/hivemind_cli/run_server.py
  12. 23 10
      hivemind/moe/client/beam_search.py
  13. 165 37
      hivemind/moe/client/expert.py
  14. 28 24
      hivemind/moe/client/moe.py
  15. 48 0
      hivemind/moe/client/remote_expert_worker.py
  16. 2 2
      hivemind/moe/client/switch_moe.py
  17. 102 48
      hivemind/moe/server/connection_handler.py
  18. 22 20
      hivemind/moe/server/dht_handler.py
  19. 2 2
      hivemind/moe/server/expert_uid.py
  20. 26 40
      hivemind/moe/server/server.py
  21. 13 6
      hivemind/p2p/p2p_daemon.py
  22. 6 4
      hivemind/p2p/p2p_daemon_bindings/control.py
  23. 7 1
      hivemind/p2p/p2p_daemon_bindings/datastructures.py
  24. 5 4
      hivemind/p2p/p2p_daemon_bindings/p2pclient.py
  25. 4 2
      hivemind/p2p/servicer.py
  26. 2 0
      hivemind/proto/p2pd.proto
  27. 1 1
      hivemind/utils/__init__.py
  28. 7 1
      hivemind/utils/asyncio.py
  29. 0 210
      hivemind/utils/grpc.py
  30. 2 24
      hivemind/utils/networking.py
  31. 49 0
      hivemind/utils/streaming.py
  32. 2 2
      setup.py
  33. 3 2
      tests/test_compression.py
  34. 192 0
      tests/test_connection_handler.py
  35. 20 9
      tests/test_custom_experts.py
  36. 28 23
      tests/test_dht_experts.py
  37. 40 28
      tests/test_moe.py
  38. 8 2
      tests/test_p2p_daemon_bindings.py
  39. 17 11
      tests/test_training.py
  40. 1 47
      tests/test_util_modules.py

+ 2 - 2
.github/workflows/run-tests.yml

@@ -12,7 +12,7 @@ jobs:
     strategy:
     strategy:
       matrix:
       matrix:
         python-version: [ 3.7, 3.8, 3.9 ]
         python-version: [ 3.7, 3.8, 3.9 ]
-    timeout-minutes: 12
+    timeout-minutes: 15
     steps:
     steps:
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
       - name: Set up Python
       - name: Set up Python
@@ -71,7 +71,7 @@ jobs:
   codecov_in_develop_mode:
   codecov_in_develop_mode:
 
 
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
-    timeout-minutes: 12
+    timeout-minutes: 15
     steps:
     steps:
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
       - name: Set up Python
       - name: Set up Python

+ 35 - 17
benchmarks/benchmark_throughput.py

@@ -6,12 +6,14 @@ import time
 
 
 import torch
 import torch
 
 
-from hivemind.moe.client import RemoteExpert
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server.layers import name_to_block
 from hivemind.moe.server.layers import name_to_block
+from hivemind.p2p import P2P, PeerInfo
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.networking import LOCALHOST, get_free_port
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
@@ -31,14 +33,30 @@ def print_device_info(device=None):
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
 
 
 
 
-def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+def client_process(
+    can_start,
+    benchmarking_failed,
+    server_maddrs,
+    server_peer_id,
+    num_experts,
+    batch_size,
+    hid_dim,
+    num_batches,
+    backprop=True,
+) -> None:
     torch.set_num_threads(1)
     torch.set_num_threads(1)
     can_start.wait()
     can_start.wait()
-    experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
+
+    p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    peer_info = PeerInfo(server_peer_id, server_maddrs)
+    experts = [
+        RemoteExpert(expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=peer_info), p2p=p2p)
+        for i in range(num_experts)
+    ]
 
 
     try:
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
         dummy_batch = torch.randn(batch_size, hid_dim)
-        for batch_i in range(num_batches):
+        for _ in range(num_batches):
             expert = random.choice(experts)
             expert = random.choice(experts)
             out = expert(dummy_batch)
             out = expert(dummy_batch)
             if backprop:
             if backprop:
@@ -59,7 +77,6 @@ def benchmark_throughput(
     max_batch_size=None,
     max_batch_size=None,
     backprop=True,
     backprop=True,
     device=None,
     device=None,
-    port=None,
 ):
 ):
     assert (
     assert (
         not hasattr(torch.cuda, "is_initialized")
         not hasattr(torch.cuda, "is_initialized")
@@ -67,7 +84,6 @@ def benchmark_throughput(
         or torch.device(device) == torch.device("cpu")
         or torch.device(device) == torch.device("cpu")
     )
     )
     assert expert_cls in name_to_block
     assert expert_cls in name_to_block
-    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
     benchmarking_failed = mp.Event()
@@ -75,8 +91,7 @@ def benchmark_throughput(
     timestamps = dict(started=time.perf_counter())
     timestamps = dict(started=time.perf_counter())
 
 
     try:
     try:
-        # start clients and await server
-        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        server_dht = DHT(start=True)
         clients = [
         clients = [
             mp.Process(
             mp.Process(
                 target=client_process,
                 target=client_process,
@@ -84,30 +99,30 @@ def benchmark_throughput(
                 args=(
                 args=(
                     can_start,
                     can_start,
                     benchmarking_failed,
                     benchmarking_failed,
-                    port,
+                    server_dht.get_visible_maddrs(),
+                    server_dht.peer_id,
                     num_experts,
                     num_experts,
                     batch_size,
                     batch_size,
                     hid_dim,
                     hid_dim,
                     num_batches_per_client,
                     num_batches_per_client,
                     backprop,
                     backprop,
                 ),
                 ),
+                daemon=True,
             )
             )
             for i in range(num_clients)
             for i in range(num_clients)
         ]
         ]
 
 
         for client in clients:
         for client in clients:
-            client.daemon = True
             client.start()
             client.start()
 
 
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
 
-        # start server
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         experts = {}
         for i in range(num_experts):
         for i in range(num_experts):
             expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
             expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = ExpertBackend(
-                name=f"expert{i}",
+            experts[f"expert.{i}"] = ExpertBackend(
+                name=f"expert.{i}",
                 expert=expert,
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
                 optimizer=torch.optim.Adam(expert.parameters()),
                 args_schema=(BatchTensorDescriptor(hid_dim),),
                 args_schema=(BatchTensorDescriptor(hid_dim),),
@@ -115,21 +130,24 @@ def benchmark_throughput(
                 max_batch_size=max_batch_size,
                 max_batch_size=max_batch_size,
             )
             )
         timestamps["created_experts"] = time.perf_counter()
         timestamps["created_experts"] = time.perf_counter()
+
         server = Server(
         server = Server(
-            None,
-            experts,
-            listen_on=f"{LOCALHOST}:{port}",
+            dht=server_dht,
+            expert_backends=experts,
             num_connection_handlers=num_handlers,
             num_connection_handlers=num_handlers,
             device=device,
             device=device,
         )
         )
         server.start()
         server.start()
         server.ready.wait()
         server.ready.wait()
+
         timestamps["server_ready"] = time.perf_counter()
         timestamps["server_ready"] = time.perf_counter()
         can_start.set()
         can_start.set()
 
 
         for client in clients:
         for client in clients:
             client.join()
             client.join()
+
         timestamps["clients_finished"] = time.perf_counter()
         timestamps["clients_finished"] = time.perf_counter()
+
     except BaseException as e:
     except BaseException as e:
         benchmarking_failed.set()
         benchmarking_failed.set()
         raise e
         raise e

+ 9 - 12
docs/user/moe.md

@@ -1,7 +1,7 @@
 # Mixture-of-Experts
 # Mixture-of-Experts
 
 
 This tutorial covers the basics of Decentralized Mixture-of-Experts (DMoE).
 This tutorial covers the basics of Decentralized Mixture-of-Experts (DMoE).
-From the infrastructure standpoint, DMoE consists of two parts: experts hosted on peer devices, and a gating/routing function that assigns input to one of these experts.
+From the infrastructure standpoint, DMoE consists of two parts: experts hosted on peer devices, and client-side modules to access those experts.
 
 
 ## Host experts with a server
 ## Host experts with a server
 
 
@@ -11,9 +11,8 @@ most of the model parameters and computation. The server can be started using ei
 for now. To host a server with default experts, run this in your shell:
 for now. To host a server with default experts, run this in your shell:
 
 
 ```sh
 ```sh
-hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_pattern "expert.[0:5]" \
-                --listen_on 0.0.0.0:1337
-# note: if you omit listen_on and/or dht_port, they will be chosen automatically and printed to stdout.
+hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_pattern "expert.[0:5]"
+# note: server will listen to a random port. To specify interface & port, add --host_maddrs and --announce_maddrs
 ```
 ```
 
 
 <details style="margin-top:-24px; margin-bottom: 16px;">
 <details style="margin-top:-24px; margin-bottom: 16px;">
@@ -22,8 +21,7 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_patte
 ```sh
 ```sh
 [2021/07/15 18:52:01.424][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL'], initial peers = []
 [2021/07/15 18:52:01.424][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL'], initial peers = []
 [2021/07/15 18:52:01.424][INFO][moe.server.create:181] Generating 5 expert uids from pattern expert.[0:5]
 [2021/07/15 18:52:01.424][INFO][moe.server.create:181] Generating 5 expert uids from pattern expert.[0:5]
-[2021/07/15 18:52:01.658][INFO][moe.server.run:233] Server started at 0.0.0.0:1337
-[2021/07/15 18:52:01.658][INFO][moe.server.run:234] Got 5 experts:
+[2021/07/15 18:52:01.658][INFO][moe.server.run:233] Server started with 5 experts
 [2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.4: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.4: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.0: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.0: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:52:01.659][INFO][moe.server.run:237] expert.3: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:52:01.659][INFO][moe.server.run:237] expert.3: FeedforwardBlock, 2100736 parameters
@@ -67,8 +65,7 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 10 --expert_patt
 ```sh
 ```sh
 [2021/07/15 18:53:41.700][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/34487/p2p/QmcJ3jgbdwphLAiwGjvwrjimJJrdMyhLHf6tFj9viCFFGn'], initial peers = ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL']
 [2021/07/15 18:53:41.700][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/34487/p2p/QmcJ3jgbdwphLAiwGjvwrjimJJrdMyhLHf6tFj9viCFFGn'], initial peers = ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL']
 [2021/07/15 18:53:41.700][INFO][moe.server.create:181] Generating 10 expert uids from pattern expert.[5:250]
 [2021/07/15 18:53:41.700][INFO][moe.server.create:181] Generating 10 expert uids from pattern expert.[5:250]
-[2021/07/15 18:53:42.085][INFO][moe.server.run:233] Server started at 0.0.0.0:36389
-[2021/07/15 18:53:42.086][INFO][moe.server.run:234] Got 10 experts:
+[2021/07/15 18:53:42.085][INFO][moe.server.run:233] Server started with 10 experts:
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.55: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.55: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.173: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.173: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.164: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.164: FeedforwardBlock, 2100736 parameters
@@ -104,10 +101,10 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 10 --expert_patt
 
 
 </details>
 </details>
 
 
-By default, the server will only accept connections from your local machine. To access it globally, you should replace
-`127.0.0.1` part from initial peers with server's IP address. Hivemind supports both ipv4 and ipv6 protocols and uses the same notation
-as [libp2p](https://docs.libp2p.io/concepts/addressing/). You can find more details on multiaddresses in the 
-[DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html).
+By default, the server will only accept connections from your local network. 
+To enable training over the Internet (or some other network), you should set `--host_maddrs` and `--announce_maddrs`.
+These options also allow you to select IPv4/IPv6 network protocols and TCP and QUIC transport protocols.
+You can find more details in the [DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html).
 
 
 ## Train the experts
 ## Train the experts
 
 

+ 2 - 1
hivemind/averaging/averager.py

@@ -37,8 +37,8 @@ from hivemind.utils.asyncio import (
     enter_asynchronously,
     enter_asynchronously,
     switch_to_uvloop,
     switch_to_uvloop,
 )
 )
-from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 
 
 # flavour types
 # flavour types
@@ -709,6 +709,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         current_tensor_parts, tensors = [], []
 
 
+                        # TODO merge this with hivemind.compression.deserialize_tensor_stream
                         async for message in aiter_with_timeout(stream, timeout=timeout):
                         async for message in aiter_with_timeout(stream, timeout=timeout):
                             if message.metadata:
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                                 metadata = self.serializer.loads(message.metadata)

+ 5 - 1
hivemind/compression/__init__.py

@@ -6,4 +6,8 @@ from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveComp
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
-from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.compression.serialization import (
+    deserialize_tensor_stream,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)

+ 2 - 2
hivemind/compression/adaptive.py

@@ -3,8 +3,8 @@ from typing import Mapping, Sequence, Union
 
 
 import torch
 import torch
 
 
-import hivemind
 from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole
 from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole
+from hivemind.compression.serialization import deserialize_torch_tensor
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 
 
 
 
@@ -20,7 +20,7 @@ class AdaptiveCompressionBase(CompressionBase, ABC):
         return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace)
         return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace)
 
 
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-        return hivemind.compression.deserialize_torch_tensor(serialized_tensor)
+        return deserialize_torch_tensor(serialized_tensor)
 
 
 
 
 class SizeAdaptiveCompression(AdaptiveCompressionBase):
 class SizeAdaptiveCompression(AdaptiveCompressionBase):

+ 1 - 1
hivemind/compression/base.py

@@ -80,7 +80,7 @@ class NoCompression(CompressionBase):
     compression_type = runtime_pb2.CompressionType.NONE
     compression_type = runtime_pb2.CompressionType.NONE
 
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
-        array = tensor.numpy()
+        array = tensor.detach().numpy()
         return runtime_pb2.Tensor(
         return runtime_pb2.Tensor(
             compression=self.compression_type,
             compression=self.compression_type,
             buffer=array.tobytes(),
             buffer=array.tobytes(),

+ 25 - 1
hivemind/compression/serialization.py

@@ -1,4 +1,6 @@
-from typing import Dict, Optional
+from __future__ import annotations
+
+from typing import AsyncIterator, Dict, Iterable, List, Optional
 
 
 import torch
 import torch
 
 
@@ -6,6 +8,7 @@ from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompre
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
+from hivemind.utils.streaming import combine_from_streaming
 
 
 BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
 BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
     NONE=NoCompression(),
     NONE=NoCompression(),
@@ -41,3 +44,24 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
     """Restore a pytorch tensor from a protobuf message"""
     """Restore a pytorch tensor from a protobuf message"""
     compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
     compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
     return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
     return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
+
+
+async def deserialize_tensor_stream(
+    stream: AsyncIterator[Iterable[runtime_pb2.Tensor]],
+) -> List[torch.Tensor]:
+    """Async wrapper of combine_from_streaming that combines tensors from a stream of parts and deserializes them"""
+
+    tensors = []
+    tensor_parts = []
+
+    async for parts in stream:
+        for part in parts:
+            if part.dtype and tensor_parts:
+                tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts)))
+                tensor_parts = []
+
+            tensor_parts.append(part)
+    if tensor_parts:
+        tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts)))
+
+    return tensors

+ 3 - 2
hivemind/dht/dht.py

@@ -55,6 +55,7 @@ class DHT(mp.Process):
         **kwargs,
         **kwargs,
     ):
     ):
         self._parent_pid = os.getpid()
         self._parent_pid = os.getpid()
+        self._origin_pid = os.getpid()
         super().__init__()
         super().__init__()
 
 
         if not (
         if not (
@@ -309,8 +310,8 @@ class DHT(mp.Process):
         Get a replica of a P2P instance used in the DHT process internally.
         Get a replica of a P2P instance used in the DHT process internally.
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         """
         """
-
-        if self._p2p_replica is None:
+        if self._p2p_replica is None or self._origin_pid != os.getpid():
+            self._origin_pid = os.getpid()
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
             self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
             self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
         return self._p2p_replica
         return self._p2p_replica

+ 0 - 1
hivemind/hivemind_cli/config.yml

@@ -1,4 +1,3 @@
-listen_on: 0.0.0.0:*
 num_experts: 16
 num_experts: 16
 expert_cls: ffn
 expert_cls: ffn
 hidden_dim: 1024
 hidden_dim: 1024

+ 6 - 3
hivemind/hivemind_cli/run_server.py

@@ -18,8 +18,7 @@ def main():
     # fmt:off
     # fmt:off
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
-    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
-                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
+
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
                         help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will'
                         help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will'
@@ -32,6 +31,11 @@ def main():
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'")
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'")
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
 
 
+    parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
+                        help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
+    parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
+                        help='Visible multiaddrs the host announces for external connections from other p2p instances')
+
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--min_batch_size', type=int, default=1,
     parser.add_argument('--min_batch_size', type=int, default=1,
@@ -49,7 +53,6 @@ def main():
     parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
     parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
 
-    parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
     parser.add_argument('--increase_file_limit', action='store_true',
     parser.add_argument('--increase_file_limit', action='store_true',

+ 23 - 10
hivemind/moe/client/beam_search.py

@@ -5,7 +5,12 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import (
+    RemoteExpert,
+    RemoteExpertInfo,
+    batch_create_remote_experts,
+    create_remote_experts,
+)
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     PREFIX_PATTERN,
     PREFIX_PATTERN,
@@ -17,6 +22,7 @@ from hivemind.moe.server.expert_uid import (
     UidEndpoint,
     UidEndpoint,
     is_valid_prefix,
     is_valid_prefix,
 )
 )
+from hivemind.p2p import PeerInfo
 from hivemind.utils import MPFuture, get_dht_time, get_logger
 from hivemind.utils import MPFuture, get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -145,7 +151,7 @@ class MoEBeamSearcher:
                 maybe_prefix_data = await pending_task
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                     successors = {
                     successors = {
-                        coord: UidEndpoint(*match.value)
+                        coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
                         for coord, match in maybe_prefix_data.value.items()
                         for coord, match in maybe_prefix_data.value.items()
                         if isinstance(coord, Coordinate)
                         if isinstance(coord, Coordinate)
                         and isinstance(getattr(match, "value", None), list)
                         and isinstance(getattr(match, "value", None), list)
@@ -212,7 +218,7 @@ class MoEBeamSearcher:
         for prefix, found in dht_responses.items():
         for prefix, found in dht_responses.items():
             if found and isinstance(found.value, dict):
             if found and isinstance(found.value, dict):
                 successors[prefix] = {
                 successors[prefix] = {
-                    coord: UidEndpoint(*match.value)
+                    coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
                     for coord, match in found.value.items()
                     for coord, match in found.value.items()
                     if isinstance(coord, Coordinate)
                     if isinstance(coord, Coordinate)
                     and 0 <= coord < grid_size
                     and 0 <= coord < grid_size
@@ -230,7 +236,7 @@ class MoEBeamSearcher:
 
 
     def find_best_experts(
     def find_best_experts(
         self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
         self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
-    ) -> Union[List[RemoteExpert], MPFuture[RemoteExpert]]:
+    ) -> Union[List[RemoteExpert], MPFuture[List[RemoteExpert]]]:
         """
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
 
@@ -245,7 +251,7 @@ class MoEBeamSearcher:
         :returns: a list that contains *up to* k_best RemoteExpert instances
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         """
         assert len(grid_scores) == len(self.grid_size) and beam_size > 0
         assert len(grid_scores) == len(self.grid_size) and beam_size > 0
-        return self.dht.run_coroutine(
+        result = self.dht.run_coroutine(
             partial(
             partial(
                 self._find_best_experts,
                 self._find_best_experts,
                 prefix=self.uid_prefix,
                 prefix=self.uid_prefix,
@@ -258,6 +264,8 @@ class MoEBeamSearcher:
             return_future,
             return_future,
         )
         )
 
 
+        return create_remote_experts(result, self.dht, return_future)
+
     @classmethod
     @classmethod
     async def _find_best_experts(
     async def _find_best_experts(
         cls,
         cls,
@@ -269,7 +277,7 @@ class MoEBeamSearcher:
         negative_caching: bool,
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
-    ) -> List[RemoteExpert]:
+    ) -> List[RemoteExpertInfo]:
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
@@ -322,7 +330,10 @@ class MoEBeamSearcher:
                 push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
                 push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
                 unique_experts.add(uid_endpoint.uid)
                 unique_experts.add(uid_endpoint.uid)
 
 
-        best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
+        best_experts = [
+            RemoteExpertInfo(uid_endpoint.uid, uid_endpoint.peer_info)
+            for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
+        ]
         return best_experts
         return best_experts
 
 
     @staticmethod
     @staticmethod
@@ -351,7 +362,7 @@ class MoEBeamSearcher:
 
 
     def batch_find_best_experts(
     def batch_find_best_experts(
         self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
         self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
-    ) -> Union[List[List[RemoteExpert]], MPFuture]:
+    ) -> Union[List[List[RemoteExpert]], MPFuture[List[List[RemoteExpert]]]]:
         """
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
 
@@ -364,7 +375,7 @@ class MoEBeamSearcher:
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         """
-        return self.dht.run_coroutine(
+        result = self.dht.run_coroutine(
             partial(
             partial(
                 self._batch_find_best_experts,
                 self._batch_find_best_experts,
                 prefix=self.uid_prefix,
                 prefix=self.uid_prefix,
@@ -376,6 +387,8 @@ class MoEBeamSearcher:
             return_future,
             return_future,
         )
         )
 
 
+        return batch_create_remote_experts(result, self.dht, return_future)
+
     @classmethod
     @classmethod
     async def _batch_find_best_experts(
     async def _batch_find_best_experts(
         cls,
         cls,
@@ -386,7 +399,7 @@ class MoEBeamSearcher:
         beam_size: int,
         beam_size: int,
         negative_caching: bool,
         negative_caching: bool,
         num_workers: Optional[int],
         num_workers: Optional[int],
-    ) -> Sequence[Sequence[RemoteExpert]]:
+    ) -> Sequence[Sequence[RemoteExpertInfo]]:
         batch_grid_scores = [
         batch_grid_scores = [
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
         ]
         ]

+ 165 - 37
hivemind/moe/client/expert.py

@@ -1,43 +1,68 @@
-from typing import Any, Dict, Optional, Tuple
+from __future__ import annotations
+
+from concurrent.futures import Future
+from dataclasses import dataclass
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, MSGPackSerializer, nested_compare, nested_flatten, nested_pack
-from hivemind.utils.grpc import ChannelCache
+from hivemind import moe
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import P2P, PeerInfo, StubBase
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.mpfuture import MPFuture
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
+from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.streaming import split_for_streaming
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
 
 
-def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
-    """Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache"""
-    channel_options = (("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)) + extra_options
-    return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
+def get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandlerStub":
+    return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
+
+
+@dataclass(frozen=True)
+class RemoteExpertInfo:
+    """A simple data class containing uid of expert and server PeerInfo"""
+
+    uid: str
+    peer_info: PeerInfo
 
 
 
 
 class RemoteExpert(nn.Module):
 class RemoteExpert(nn.Module):
     """
     """
     A simple module that runs forward/backward of an expert hosted on a remote machine.
     A simple module that runs forward/backward of an expert hosted on a remote machine.
     Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
     Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
-
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
 
 
-    :param uid: unique expert identifier
-    :param endpoint: network endpoint of a server that services that expert, e.g. "201.123.321.99:1337" or "[::]:8080"
+    :param expert_info: RemoteExpertInfo with uid and server PeerInfo
+    :param p2p: P2P instance connected to the running p2pd
     """
     """
 
 
-    def __init__(self, uid, endpoint: Endpoint):
+    def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
         super().__init__()
         super().__init__()
-        self.uid, self.endpoint = uid, endpoint
-        self._info = None
+        self._info, self.p2p = expert_info, p2p
+        self._rpc_info = None
 
 
     @property
     @property
-    def stub(self):
-        return _get_expert_stub(self.endpoint)
+    def uid(self):
+        return self._info.uid
+
+    @property
+    def server_peer_info(self):
+        return self._info.peer_info
+
+    @property
+    def stub(self) -> StubBase:
+        return get_expert_stub(self.p2p, self.server_peer_info)
 
 
     def forward(self, *args, **kwargs):
     def forward(self, *args, **kwargs):
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
@@ -52,18 +77,125 @@ class RemoteExpert(nn.Module):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
 
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
+
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
 
 
     @property
     @property
     def info(self):
     def info(self):
-        if self._info is None:
-            outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
-            self._info = MSGPackSerializer.loads(outputs.serialized_info)
-        return self._info
+        if self._rpc_info is None:
+            outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
+            self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+        return self._rpc_info
 
 
     def extra_repr(self):
     def extra_repr(self):
-        return f"uid={self.uid}, endpoint={self.endpoint}"
+        return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
+
+
+def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
+    experts: List[Optional[RemoteExpert]] = []
+    for info in infos:
+        if info is not None:
+            experts.append(RemoteExpert(info, p2p))
+        else:
+            experts.append(None)
+    return experts
+
+
+def create_remote_experts(
+    infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+) -> Union[List[Optional[RemoteExpert]], Future]:
+    if return_future:
+
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return _create_remote_experts(await infos_future, p2p)
+
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+
+    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+    return _create_remote_experts(infos, p2p)
+
+
+def batch_create_remote_experts(
+    infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
+    dht: DHT,
+    return_future: bool = False,
+) -> Union[List[List[Optional[RemoteExpert]]], Future]:
+    if return_future:
+
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return [_create_remote_experts(i, p2p) for i in await infos_future]
+
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+
+    return [create_remote_experts(exps, dht) for exps in infos]
+
+
+async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
+
+    grad_inputs = await stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
+            iter_as_aiter(split),
+        ),
+    )
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+
+
+async def expert_backward(
+    uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
+    size = 0
+    for t in inputs_and_grads:
+        size += t.element_size() * t.nelement()
+        if size > DEFAULT_MAX_MSG_SIZE:
+            return await _backward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _backward_unary(uid, serialized_tensors, stub)
+
+
+async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
+
+    outputs = await stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
+            iter_as_aiter(split),
+        ),
+    )
+
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def expert_forward(
+    uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
+    size = 0
+    for t in inputs:
+        size += t.element_size() * t.nelement()
+        if size > DEFAULT_MAX_MSG_SIZE:
+            return await _forward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _forward_unary(uid, serialized_tensors, stub)
 
 
 
 
 class _RemoteModuleCall(torch.autograd.Function):
 class _RemoteModuleCall(torch.autograd.Function):
@@ -74,7 +206,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx,
         ctx,
         dummy: torch.Tensor,
         dummy: torch.Tensor,
         uid: str,
         uid: str,
-        stub: runtime_grpc.ConnectionHandlerStub,
+        stub: "ConnectionHandlerStub",
         info: Dict[str, Any],
         info: Dict[str, Any],
         *inputs: torch.Tensor,
         *inputs: torch.Tensor,
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:
@@ -83,15 +215,11 @@ class _RemoteModuleCall(torch.autograd.Function):
         inputs = tuple(tensor.cpu().detach() for tensor in inputs)
         inputs = tuple(tensor.cpu().detach() for tensor in inputs)
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
         ctx.save_for_backward(*inputs)
-
-        serialized_tensors = [
-            serialize_torch_tensor(inp, proto.compression)
-            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
-        ]
-
-        outputs = stub.forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
-
-        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+        serialized_tensors = (
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        )
+        deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(uid, inputs, serialized_tensors, stub))
 
 
         return tuple(deserialized_outputs)
         return tuple(deserialized_outputs)
 
 
@@ -101,12 +229,12 @@ class _RemoteModuleCall(torch.autograd.Function):
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = [
+        serialized_tensors = (
             serialize_torch_tensor(tensor, proto.compression)
             serialize_torch_tensor(tensor, proto.compression)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        ]
-
-        grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        )
+        deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
+            expert_backward(ctx.uid, inputs_and_grad_outputs, serialized_tensors, ctx.stub)
+        )
 
 
-        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         return (DUMMY, None, None, None, *deserialized_grad_inputs)
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 28 - 24
hivemind/moe/client/moe.py

@@ -1,20 +1,21 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import time
 import time
+from concurrent.futures import Future
 from queue import Empty, Queue
 from queue import Empty, Queue
 from typing import Any, Dict, List, Optional, Tuple
 from typing import Any, Dict, List, Optional, Tuple
 
 
-import grpc
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.compression import serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, expert_backward, expert_forward, get_expert_stub
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -104,7 +105,7 @@ class RemoteMixtureOfExperts(nn.Module):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                     "the grid size are consistent with running Server instances."
                 )
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
@@ -178,7 +179,7 @@ class RemoteMixtureOfExperts(nn.Module):
             # grab some expert to set ensemble output shape
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
             proj_device = self.proj.weight.device
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
-            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
+            dummy_scores = dummy_scores_concat.cpu().detach().split_with_sizes(self.beam_search.grid_size, dim=-1)
             dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             self._expert_info = dummy_experts[0].info
             self._expert_info = dummy_experts[0].info
         return self._expert_info
         return self._expert_info
@@ -223,15 +224,18 @@ class _RemoteCallMany(torch.autograd.Function):
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
 
 
         # dispatch tasks to all remote experts collect responses
         # dispatch tasks to all remote experts collect responses
-        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
+        pending_tasks: Dict[Future, Tuple[int, int]] = {}
         for i in range(num_samples):
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
             for j, expert in enumerate(experts_per_sample[i]):
-                input_tensors = [
+                stub = get_expert_stub(expert.p2p, expert.server_peer_info)
+                serialized_tensors = (
                     serialize_torch_tensor(tensor, proto.compression)
                     serialize_torch_tensor(tensor, proto.compression)
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
-                ]
-                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
-                new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
+                )
+                new_task = RemoteExpertWorker.run_coroutine(
+                    expert_forward(expert.uid, flat_inputs_per_sample[i], serialized_tensors, stub),
+                    return_future=True,
+                )
                 pending_tasks[new_task] = (i, j)
                 pending_tasks[new_task] = (i, j)
 
 
         responded_inds, alive_flat_outputs = cls._collect_responses(
         responded_inds, alive_flat_outputs = cls._collect_responses(
@@ -316,14 +320,16 @@ class _RemoteCallMany(torch.autograd.Function):
         for i, j, inputs_ij, grad_outputs_ij in zip(
         for i, j, inputs_ij, grad_outputs_ij in zip(
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
         ):
         ):
-            expert = expert_per_sample[i.item()][j.item()]
-            stub = _get_expert_stub(expert.endpoint)
+            expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
+            stub = get_expert_stub(expert.p2p, expert.server_peer_info)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
-            tensors_serialized = [
+            serialized_tensors = (
                 serialize_torch_tensor(tensor, proto.compression)
                 serialize_torch_tensor(tensor, proto.compression)
                 for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
                 for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-            ]
-            new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
+            )
+            new_task = RemoteExpertWorker.run_coroutine(
+                expert_backward(expert.uid, inputs_and_grad_outputs, serialized_tensors, stub), return_future=True
+            )
             pending_tasks[new_task] = (i, j)
             pending_tasks[new_task] = (i, j)
 
 
         survivor_inds, survivor_grad_inputs = cls._collect_responses(
         survivor_inds, survivor_grad_inputs = cls._collect_responses(
@@ -358,7 +364,7 @@ class _RemoteCallMany(torch.autograd.Function):
 
 
     @staticmethod
     @staticmethod
     def _collect_responses(
     def _collect_responses(
-        task_to_indices: Dict[grpc.Future, Tuple[int, int]],
+        task_to_indices: Dict[Future, Tuple[int, int]],
         num_samples: int,
         num_samples: int,
         k_min: int,
         k_min: int,
         timeout_total: Optional[float],
         timeout_total: Optional[float],
@@ -408,17 +414,15 @@ class _RemoteCallMany(torch.autograd.Function):
         return finished_indices, finished_outputs
         return finished_indices, finished_outputs
 
 
 
 
-def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
+def _process_dispatched_task(task: Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
     if task.exception() or task.cancelled():
     if task.exception() or task.cancelled():
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         return None
         return None
 
 
-    deserialized_outputs = []
-    for tensor in task.result().tensors:
-        deserialized_tensor = deserialize_torch_tensor(tensor)
-        if detect_anomalies and not deserialized_tensor.isfinite().all():
+    outputs = task.result()
+    for tensor in outputs:
+        if detect_anomalies and not tensor.isfinite().all():
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")
             return None
             return None
-        deserialized_outputs.append(deserialized_tensor)
 
 
-    return tuple(deserialized_outputs)
+    return outputs

+ 48 - 0
hivemind/moe/client/remote_expert_worker.py

@@ -0,0 +1,48 @@
+import os
+from concurrent.futures import Future
+from queue import Queue
+from threading import Thread
+from typing import Awaitable, Optional
+
+from hivemind.utils import switch_to_uvloop
+
+
+class RemoteExpertWorker:
+    """Local thread for managing async tasks related to RemoteExpert"""
+
+    _task_queue: Queue = Queue()
+    _event_thread: Optional[Thread] = None
+    _pid: int = -1
+
+    @classmethod
+    def _run(cls):
+        loop = switch_to_uvloop()
+
+        async def receive_tasks():
+            while True:
+                cor, future = cls._task_queue.get()
+                try:
+                    result = await cor
+                except Exception as e:
+                    future.set_exception(e)
+                    continue
+                if not future.cancelled():
+                    future.set_result(result)
+
+        loop.run_until_complete(receive_tasks())
+
+    @classmethod
+    def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
+        if cls._event_thread is None or cls._pid != os.getpid():
+            cls._pid = os.getpid()
+            cls._event_thread = Thread(target=cls._run, daemon=True)
+            cls._event_thread.start()
+
+        future = Future()
+        cls._task_queue.put((coro, future))
+
+        if return_future:
+            return future
+
+        result = future.result()
+        return result

+ 2 - 2
hivemind/moe/client/switch_moe.py

@@ -2,12 +2,12 @@ from __future__ import annotations
 
 
 from typing import List, Tuple
 from typing import List, Tuple
 
 
-import grpc
 import torch
 import torch
 
 
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -110,7 +110,7 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                     "the grid size are consistent with running Server instances."
                 )
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
                 logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
 
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
         expert_mask, *expert_outputs = _RemoteCallMany.apply(

+ 102 - 48
hivemind/moe/server/connection_handler.py

@@ -1,82 +1,136 @@
+import asyncio
 import multiprocessing as mp
 import multiprocessing as mp
-import os
-from typing import Dict
+from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Union
 
 
-import grpc
 import torch
 import torch
 
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, MSGPackSerializer, get_logger, nested_flatten
-from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
+from hivemind.moe.server.task_pool import TaskPool
+from hivemind.p2p import P2PContext, ServicerBase
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P
+from hivemind.proto import runtime_pb2
+from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
+from hivemind.utils.asyncio import amap_in_executor, switch_to_uvloop
+from hivemind.utils.streaming import split_for_streaming
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class ConnectionHandler(mp.context.ForkProcess):
+class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     """
     """
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
 
 
-    :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port.
-    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+    :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
+    :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
     :param experts: a dict [UID -> ExpertBackend] with all active experts
     :param experts: a dict [UID -> ExpertBackend] with all active experts
     """
     """
 
 
-    def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
+    def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]):
         super().__init__()
         super().__init__()
-        self.listen_on, self.experts = listen_on, experts
-        self.ready = mp.Event()
+        self.dht, self.experts = dht, experts
+        self._p2p: Optional[P2P] = None
+
+        self.ready = MPFuture()
 
 
     def run(self):
     def run(self):
         torch.set_num_threads(1)
         torch.set_num_threads(1)
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
 
 
         async def _run():
         async def _run():
-            grpc.aio.init_grpc_aio()
-            logger.debug(f"Starting, pid {os.getpid()}")
-            server = grpc.aio.server(
-                options=GRPC_KEEPALIVE_OPTIONS
-                + (
-                    ("grpc.so_reuseport", 1),
-                    ("grpc.max_send_message_length", -1),
-                    ("grpc.max_receive_message_length", -1),
-                )
-            )
-            runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
-
-            found_port = server.add_insecure_port(self.listen_on)
-            assert found_port != 0, f"Failed to listen to {self.listen_on}"
-
-            await server.start()
-            self.ready.set()
-            await server.wait_for_termination()
-            logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
+            try:
+                self._p2p = await self.dht.replicate_p2p()
+                await self.add_p2p_handlers(self._p2p, balanced=True)
+
+                # wait forever
+                await asyncio.Future()
+
+            except Exception as e:
+                self.ready.set_exception(e)
+                return
+
+        self.ready.set_result(None)
 
 
         try:
         try:
             loop.run_until_complete(_run())
             loop.run_until_complete(_run())
         except KeyboardInterrupt:
         except KeyboardInterrupt:
             logger.debug("Caught KeyboardInterrupt, shutting down")
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
 
-    async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
+    async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
 
 
-    async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
-        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        future = self.experts[request.uid].forward_pool.submit_task(*inputs)
-        serialized_response = [
-            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
-            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
+    async def _gather_inputs(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> Tuple[str, List[torch.Tensor]]:
+        expert_uid = None
+
+        def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+            nonlocal expert_uid
+
+            if expert_uid is None:
+                expert_uid = req.uid
+            elif expert_uid != req.uid:
+                raise ValueError("Expert uids differ in one request")
+
+            return req.tensors
+
+        tensors_stream = amap_in_executor(_unpack, requests)
+        inputs = await deserialize_tensor_stream(tensors_stream)
+        return expert_uid, inputs
+
+    async def _process_inputs(
+        self,
+        inputs: List[torch.Tensor],
+        pool: TaskPool,
+        schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]],
+    ) -> List[runtime_pb2.Tensor]:
+        return [
+            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            for result, proto in zip(await pool.submit_task(*inputs), nested_flatten(schema))
         ]
         ]
 
 
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        expert = self.experts[request.uid]
+        return runtime_pb2.ExpertResponse(
+            tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+        )
+
+    async def rpc_forward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        uid, inputs = await self._gather_inputs(requests, context)
+        expert = self.experts[uid]
+        output_split = [
+            part
+            for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+        ]
 
 
-    async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
-        inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
-        serialized_response = [
-            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
-            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part])
+
+    async def rpc_backward(
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> runtime_pb2.ExpertResponse:
+        inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        expert = self.experts[request.uid]
+        return runtime_pb2.ExpertResponse(
+            tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
+        )
+
+    async def rpc_backward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+        uid, inputs_and_grads = await self._gather_inputs(requests, context)
+        expert = self.experts[uid]
+        output_split = [
+            part
+            for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
         ]
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part])

+ 22 - 20
hivemind/moe/server/dht_handler.py

@@ -1,9 +1,9 @@
 import threading
 import threading
 from functools import partial
 from functools import partial
-from typing import Dict, List, Optional, Sequence, Tuple
+from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     UID_DELIMITER,
     UID_DELIMITER,
@@ -14,33 +14,31 @@ from hivemind.moe.server.expert_uid import (
     is_valid_uid,
     is_valid_uid,
     split_uid,
     split_uid,
 )
 )
-from hivemind.utils import Endpoint, get_dht_time, get_port
+from hivemind.p2p import PeerID, PeerInfo
+from hivemind.utils import MPFuture, get_dht_time
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5, **kwargs):
+    def __init__(self, experts, dht: DHT, update_period: int = 5, **kwargs):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
-        assert get_port(endpoint) is not None
-        self.endpoint = endpoint
         self.experts = experts
         self.experts = experts
         self.dht = dht
         self.dht = dht
         self.update_period = update_period
         self.update_period = update_period
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def run(self) -> None:
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys(), self.endpoint)
+        declare_experts(self.dht, self.experts.keys())
         while not self.stop.wait(self.update_period):
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys(), self.endpoint)
+            declare_experts(self.dht, self.experts.keys())
 
 
 
 
 def declare_experts(
 def declare_experts(
-    dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration = 300, wait: bool = True
-) -> Dict[ExpertUID, bool]:
+    dht: DHT, uids: Sequence[ExpertUID], expiration: DHTExpiration = 300, wait: bool = True
+) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]:
     """
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
     Make experts visible to all DHT peers; update timestamps if declared previously.
 
 
     :param uids: a list of expert ids to update
     :param uids: a list of expert ids to update
-    :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
     :param expiration: experts will be visible for this many seconds
     :param expiration: experts will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
@@ -48,23 +46,25 @@ def declare_experts(
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     for uid in uids:
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
+    addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in dht.get_visible_maddrs())
     return dht.run_coroutine(
     return dht.run_coroutine(
-        partial(_declare_experts, uids=list(uids), endpoint=endpoint, expiration=expiration), return_future=not wait
+        partial(_declare_experts, uids=list(uids), peer_id=dht.peer_id, addrs=addrs, expiration=expiration),
+        return_future=not wait,
     )
     )
 
 
 
 
 async def _declare_experts(
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, addrs: Tuple[str], expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
     for uid in uids:
-        data_to_store[uid, None] = endpoint
+        data_to_store[uid, None] = (peer_id.to_base58(), addrs)
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         for i in range(prefix.count(UID_DELIMITER) - 1):
         for i in range(prefix.count(UID_DELIMITER) - 1):
             prefix, last_coord = split_uid(prefix)
             prefix, last_coord = split_uid(prefix)
-            data_to_store[prefix, last_coord] = [uid, endpoint]
+            data_to_store[prefix, last_coord] = [uid, (peer_id.to_base58(), addrs)]
 
 
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
@@ -73,7 +73,7 @@ async def _declare_experts(
 
 
 def get_experts(
 def get_experts(
     dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
     dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
-) -> List[Optional[RemoteExpert]]:
+) -> Union[List[Optional[RemoteExpert]], MPFuture[List[Optional[RemoteExpert]]]]:
     """
     """
     :param uids: find experts with these ids from across the DHT
     :param uids: find experts with these ids from across the DHT
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
@@ -81,12 +81,13 @@ def get_experts(
     :returns: a list of [RemoteExpert if found else None]
     :returns: a list of [RemoteExpert if found else None]
     """
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-    return dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+    result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+    return create_remote_experts(result, dht, return_future)
 
 
 
 
 async def _get_experts(
 async def _get_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[RemoteExpert]]:
+) -> List[Optional[RemoteExpertInfo]]:
     if expiration_time is None:
     if expiration_time is None:
         expiration_time = get_dht_time()
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
@@ -94,6 +95,7 @@ async def _get_experts(
 
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     for i, uid in enumerate(uids):
     for i, uid in enumerate(uids):
-        if found[uid] is not None and isinstance(found[uid].value, Endpoint):
-            experts[i] = RemoteExpert(uid, found[uid].value)
+        expert_info_for_uid = found[uid]
+        if expert_info_for_uid is not None and isinstance(expert_info_for_uid.value, tuple):
+            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(expert_info_for_uid.value))
     return experts
     return experts

+ 2 - 2
hivemind/moe/server/expert_uid.py

@@ -1,10 +1,10 @@
 import re
 import re
 from typing import NamedTuple, Tuple, Union
 from typing import NamedTuple, Tuple, Union
 
 
-from hivemind.utils import Endpoint
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerInfo
 
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
+UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("peer_info", PeerInfo)])
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims

+ 26 - 40
hivemind/moe/server/server.py

@@ -24,9 +24,9 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
     schedule_name_to_scheduler,
 )
 )
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
+from hivemind.p2p import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import Endpoint, get_free_port, get_port, replace_port
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -41,10 +41,8 @@ class Server(threading.Thread):
      - processes incoming forward/backward requests via Runtime (created by the server)
      - processes incoming forward/backward requests via Runtime (created by the server)
      - publishes updates to expert status every :update_period: seconds
      - publishes updates to expert status every :update_period: seconds
 
 
-    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
+    :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
-    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
         if too small for normal functioning, we recommend 4 handlers per expert backend.
         if too small for normal functioning, we recommend 4 handlers per expert backend.
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
@@ -55,9 +53,8 @@ class Server(threading.Thread):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        dht: Optional[DHT],
+        dht: DHT,
         expert_backends: Dict[str, ExpertBackend],
         expert_backends: Dict[str, ExpertBackend],
-        listen_on: Endpoint = "0.0.0.0:*",
         num_connection_handlers: int = 1,
         num_connection_handlers: int = 1,
         update_period: int = 30,
         update_period: int = 30,
         start=False,
         start=False,
@@ -66,22 +63,18 @@ class Server(threading.Thread):
     ):
     ):
         super().__init__()
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=get_free_port())
-        self.listen_on, self.port = listen_on, get_port(listen_on)
 
 
-        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
+        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:
         if checkpoint_dir is not None:
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
         else:
         else:
             self.checkpoint_saver = None
             self.checkpoint_saver = None
         self.runtime = Runtime(self.experts, **kwargs)
         self.runtime = Runtime(self.experts, **kwargs)
 
 
-        if self.dht and self.experts:
+        if self.experts:
             self.dht_handler_thread = DHTHandlerThread(
             self.dht_handler_thread = DHTHandlerThread(
                 experts=self.experts,
                 experts=self.experts,
                 dht=self.dht,
                 dht=self.dht,
-                endpoint=self.listen_on,
                 update_period=self.update_period,
                 update_period=self.update_period,
                 daemon=True,
                 daemon=True,
             )
             )
@@ -92,7 +85,6 @@ class Server(threading.Thread):
     @classmethod
     @classmethod
     def create(
     def create(
         cls,
         cls,
-        listen_on="0.0.0.0:*",
         num_experts: int = None,
         num_experts: int = None,
         expert_uids: str = None,
         expert_uids: str = None,
         expert_pattern: str = None,
         expert_pattern: str = None,
@@ -107,7 +99,6 @@ class Server(threading.Thread):
         min_batch_size=1,
         min_batch_size=1,
         max_batch_size=4096,
         max_batch_size=4096,
         device=None,
         device=None,
-        no_dht=False,
         initial_peers=(),
         initial_peers=(),
         checkpoint_dir: Optional[Path] = None,
         checkpoint_dir: Optional[Path] = None,
         compression=CompressionType.NONE,
         compression=CompressionType.NONE,
@@ -115,10 +106,11 @@ class Server(threading.Thread):
         custom_module_path=None,
         custom_module_path=None,
         *,
         *,
         start: bool,
         start: bool,
+        **kwargs,
     ) -> Server:
     ) -> Server:
         """
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         Instantiate a server with several identical experts. See argparse comments below for details
-        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
+
         :param num_experts: run this many identical experts
         :param num_experts: run this many identical experts
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
            means "sample random experts between myprefix.0.0 and myprefix.255.255;
            means "sample random experts between myprefix.0.0 and myprefix.255.255;
@@ -136,7 +128,6 @@ class Server(threading.Thread):
         :param num_total_steps: the total number of steps for LR schedule
         :param num_total_steps: the total number of steps for LR schedule
         :param clip_grad_norm: maximum gradient norm used for clipping
         :param clip_grad_norm: maximum gradient norm used for clipping
 
 
-        :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
 
 
         :param checkpoint_dir: directory to save and load expert checkpoints
         :param checkpoint_dir: directory to save and load expert checkpoints
@@ -147,17 +138,15 @@ class Server(threading.Thread):
 
 
         :param start: if True, starts server right away and returns when server is ready for requests
         :param start: if True, starts server right away and returns when server is ready for requests
         :param stats_report_interval: interval between two reports of batch processing performance statistics
         :param stats_report_interval: interval between two reports of batch processing performance statistics
+        :param kwargs: any other params will be forwarded to DHT upon creation
         """
         """
         if custom_module_path is not None:
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
             add_custom_models_from_file(custom_module_path)
         assert expert_cls in name_to_block
         assert expert_cls in name_to_block
 
 
-        if no_dht:
-            dht = None
-        else:
-            dht = DHT(initial_peers=initial_peers, start=True)
-            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+        dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
+        visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+        logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
 
         assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
         assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
             num_experts is not None and expert_uids is None
             num_experts is not None and expert_uids is None
@@ -221,7 +210,6 @@ class Server(threading.Thread):
         return cls(
         return cls(
             dht,
             dht,
             experts,
             experts,
-            listen_on=listen_on,
             num_connection_handlers=num_handlers,
             num_connection_handlers=num_handlers,
             device=device,
             device=device,
             checkpoint_dir=checkpoint_dir,
             checkpoint_dir=checkpoint_dir,
@@ -234,25 +222,24 @@ class Server(threading.Thread):
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         runs Runtime (self.runtime) to process incoming requests.
         """
         """
-        logger.info(f"Server started at {self.listen_on}")
-        logger.info(f"Got {len(self.experts)} experts:")
+        logger.info(f"Server started with {len(self.experts)} experts:")
         for expert_name, backend in self.experts.items():
         for expert_name, backend in self.experts.items():
             num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
             num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
             logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
             logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
 
 
-        if self.dht:
-            if not self.dht.is_alive():
-                self.dht.run_in_background(await_ready=True)
+        if not self.dht.is_alive():
+            self.dht.run_in_background(await_ready=True)
+
+        if self.experts:
+            self.dht_handler_thread.start()
 
 
-            if self.experts:
-                self.dht_handler_thread.start()
         if self.checkpoint_saver is not None:
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
             self.checkpoint_saver.start()
 
 
         for process in self.conn_handlers:
         for process in self.conn_handlers:
             if not process.is_alive():
             if not process.is_alive():
                 process.start()
                 process.start()
-            process.ready.wait()
+            process.ready.result()
 
 
         try:
         try:
             self.runtime.run()
             self.runtime.run()
@@ -294,7 +281,7 @@ class Server(threading.Thread):
             process.join()
             process.join()
         logger.debug("Connection handlers terminated")
         logger.debug("Connection handlers terminated")
 
 
-        if self.dht and self.experts:
+        if self.experts:
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.join()
             self.dht_handler_thread.join()
 
 
@@ -302,9 +289,8 @@ class Server(threading.Thread):
             self.checkpoint_saver.stop.set()
             self.checkpoint_saver.stop.set()
             self.checkpoint_saver.join()
             self.checkpoint_saver.join()
 
 
-        if self.dht is not None:
-            self.dht.shutdown()
-            self.dht.join()
+        self.dht.shutdown()
+        self.dht.join()
 
 
         logger.debug(f"Shutting down runtime")
         logger.debug(f"Shutting down runtime")
 
 
@@ -313,14 +299,14 @@ class Server(threading.Thread):
 
 
 
 
 @contextmanager
 @contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, List[Multiaddr]]:
-    """A context manager that creates server in a background process, awaits .ready on entry and shuts down on exit"""
+def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
+    """A context manager that creates server in a background , awaits .ready on entry and shuts down on exit"""
     pipe, runners_pipe = mp.Pipe(duplex=True)
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     try:
     try:
         runner.start()
         runner.start()
         # once the server is ready, runner will send us
         # once the server is ready, runner will send us
-        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
+        # either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
         start_ok, data = pipe.recv()
         start_ok, data = pipe.recv()
         if start_ok:
         if start_ok:
             yield data
             yield data
@@ -344,8 +330,8 @@ def _server_runner(pipe, *args, **kwargs):
         return
         return
 
 
     try:
     try:
-        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
-        pipe.send((True, (server.listen_on, dht_maddrs)))
+        dht_maddrs = server.dht.get_visible_maddrs()
+        pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
         pipe.recv()  # wait for shutdown signal
         pipe.recv()  # wait for shutdown signal
 
 
     finally:
     finally:

+ 13 - 6
hivemind/p2p/p2p_daemon.py

@@ -341,6 +341,7 @@ class P2P:
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         input_protobuf_type: Type[Message],
         input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
         max_prefetch: int = 5,
+        balanced: bool = False,
     ) -> None:
     ) -> None:
         """
         """
         :param max_prefetch: Maximum number of items to prefetch from the request stream.
         :param max_prefetch: Maximum number of items to prefetch from the request stream.
@@ -405,7 +406,7 @@ class P2P:
                 finally:
                 finally:
                     processing_task.cancel()
                     processing_task.cancel()
 
 
-        await self.add_binary_stream_handler(name, _handle_stream)
+        await self.add_binary_stream_handler(name, _handle_stream, balanced=balanced)
 
 
     async def _iterate_protobuf_stream_handler(
     async def _iterate_protobuf_stream_handler(
         self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
         self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
@@ -447,16 +448,19 @@ class P2P:
         *,
         *,
         stream_input: bool = False,
         stream_input: bool = False,
         stream_output: bool = False,
         stream_output: bool = False,
+        balanced: bool = False,
     ) -> None:
     ) -> None:
         """
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
                              (not just ``TInputProtobuf``) as input.
                              (not just ``TInputProtobuf``) as input.
         :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
         :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
                               (not ``Awaitable[TOutputProtobuf]``).
                               (not ``Awaitable[TOutputProtobuf]``).
+        :param balanced: If True, handler will be balanced on p2pd side between all handlers in python.
+                         Default: False
         """
         """
 
 
         if not stream_input and not stream_output:
         if not stream_input and not stream_output:
-            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type, balanced=balanced)
             return
             return
 
 
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
@@ -469,13 +473,14 @@ class P2P:
             else:
             else:
                 yield await output
                 yield await output
 
 
-        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
+        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type, balanced=balanced)
 
 
     async def _add_protobuf_unary_handler(
     async def _add_protobuf_unary_handler(
         self,
         self,
         handle_name: str,
         handle_name: str,
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         input_protobuf_type: Type[Message],
         input_protobuf_type: Type[Message],
+        balanced: bool = False,
     ) -> None:
     ) -> None:
         """
         """
         Register a request-response (unary) handler. Unary requests and responses
         Register a request-response (unary) handler. Unary requests and responses
@@ -497,7 +502,7 @@ class P2P:
             response = await handler(input_serialized, context)
             response = await handler(input_serialized, context)
             return response.SerializeToString()
             return response.SerializeToString()
 
 
-        await self._client.add_unary_handler(handle_name, _unary_handler)
+        await self._client.add_unary_handler(handle_name, _unary_handler, balanced=balanced)
 
 
     async def call_protobuf_handler(
     async def call_protobuf_handler(
         self,
         self,
@@ -541,10 +546,12 @@ class P2P:
 
 
         self._listen_task = asyncio.create_task(listen())
         self._listen_task = asyncio.create_task(listen())
 
 
-    async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
+    async def add_binary_stream_handler(
+        self, name: str, handler: p2pclient.StreamHandler, balanced: bool = False
+    ) -> None:
         if self._listen_task is None:
         if self._listen_task is None:
             self._start_listening()
             self._start_listening()
-        await self._client.stream_handler(name, handler)
+        await self._client.stream_handler(name, handler, balanced)
 
 
     async def call_binary_stream_handler(
     async def call_binary_stream_handler(
         self, peer_id: PeerID, handler_name: str
         self, peer_id: PeerID, handler_name: str

+ 6 - 4
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -246,10 +246,10 @@ class ControlClient:
         self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
         self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
         self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
         self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
 
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
         call_id = uuid4()
         call_id = uuid4()
 
 
-        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced)
         req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
         req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
 
 
         if self.unary_handlers.get(proto):
         if self.unary_handlers.get(proto):
@@ -358,11 +358,13 @@ class ControlClient:
 
 
         return stream_info, reader, writer
         return stream_info, reader, writer
 
 
-    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
         reader, writer = await self.daemon_connector.open_connection()
         reader, writer = await self.daemon_connector.open_connection()
 
 
         listen_path_maddr_bytes = self.listen_maddr.to_bytes()
         listen_path_maddr_bytes = self.listen_maddr.to_bytes()
-        stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto])
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(
+            addr=listen_path_maddr_bytes, proto=[proto], balanced=balanced
+        )
         req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
         req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
         await write_pbmsg(writer, req)
         await write_pbmsg(writer, req)
 
 

+ 7 - 1
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -5,7 +5,7 @@ Author: Kevin Mai-Husan Chia
 """
 """
 
 
 import hashlib
 import hashlib
-from typing import Any, Sequence, Union
+from typing import Any, Sequence, Tuple, Union
 
 
 import base58
 import base58
 import multihash
 import multihash
@@ -128,6 +128,12 @@ class PeerInfo:
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         return PeerInfo(peer_id, addrs)
         return PeerInfo(peer_id, addrs)
 
 
+    @classmethod
+    def from_tuple(cls, value: Tuple[str, Sequence[str]]) -> "PeerInfo":
+        peer_id = PeerID.from_base58(value[0])
+        addrs = [Multiaddr(addr) for addr in value[1]]
+        return PeerInfo(peer_id, addrs)
+
     def __str__(self):
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
 

+ 5 - 4
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -61,8 +61,8 @@ class Client:
         async with self.control.listen():
         async with self.control.listen():
             yield self
             yield self
 
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
-        await self.control.add_unary_handler(proto, handler)
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
+        await self.control.add_unary_handler(proto, handler, balanced=balanced)
 
 
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         return await self.control.call_unary_handler(peer_id, proto, data)
         return await self.control.call_unary_handler(peer_id, proto, data)
@@ -105,11 +105,12 @@ class Client:
         """
         """
         return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
         return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
 
 
-    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
         """
         """
         Register a stream handler
         Register a stream handler
         :param proto: protocols that handler serves
         :param proto: protocols that handler serves
         :param handler_cb: handler callback
         :param handler_cb: handler callback
+        :param balanced: flag if stream handler should be balanced on p2pd side. Default: False.
         :return:
         :return:
         """
         """
-        await self.control.stream_handler(proto=proto, handler_cb=handler_cb)
+        await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced)

+ 4 - 2
hivemind/p2p/servicer.py

@@ -104,11 +104,12 @@ class ServicerBase:
         caller.__name__ = handler.method_name
         caller.__name__ = handler.method_name
         return caller
         return caller
 
 
-    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None) -> None:
+    async def add_p2p_handlers(
+        self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None, balanced: bool = False
+    ) -> None:
         self._collect_rpc_handlers()
         self._collect_rpc_handlers()
 
 
         servicer = self if wrapper is None else wrapper
         servicer = self if wrapper is None else wrapper
-
         await asyncio.gather(
         await asyncio.gather(
             *[
             *[
                 p2p.add_protobuf_handler(
                 p2p.add_protobuf_handler(
@@ -117,6 +118,7 @@ class ServicerBase:
                     handler.request_type,
                     handler.request_type,
                     stream_input=handler.stream_input,
                     stream_input=handler.stream_input,
                     stream_output=handler.stream_output,
                     stream_output=handler.stream_output,
+                    balanced=balanced,
                 )
                 )
                 for handler in self._rpc_handlers
                 for handler in self._rpc_handlers
             ]
             ]

+ 2 - 0
hivemind/proto/p2pd.proto

@@ -90,6 +90,7 @@ message StreamOpenRequest {
 message StreamHandlerRequest {
 message StreamHandlerRequest {
   required bytes addr = 1;
   required bytes addr = 1;
   repeated string proto = 2;
   repeated string proto = 2;
+  required bool balanced = 3;
 }
 }
 
 
 message ErrorResponse {
 message ErrorResponse {
@@ -201,6 +202,7 @@ message CallUnaryResponse {
 
 
 message AddUnaryHandlerRequest {
 message AddUnaryHandlerRequest {
   required string proto = 1;
   required string proto = 1;
+  required bool balanced = 2;
 }
 }
 
 
 message DaemonError {
 message DaemonError {

+ 1 - 1
hivemind/utils/__init__.py

@@ -1,5 +1,4 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.asyncio import *
-from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
@@ -7,5 +6,6 @@ from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.networking import *
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *

+ 7 - 1
hivemind/utils/asyncio.py

@@ -2,7 +2,7 @@ import asyncio
 import concurrent.futures
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
 from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar, Union
 
 
 import uvloop
 import uvloop
 
 
@@ -29,6 +29,12 @@ async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
     return await aiter.__anext__()
     return await aiter.__anext__()
 
 
 
 
+async def iter_as_aiter(iterable: Iterable[T]) -> AsyncIterator[T]:
+    """create an asynchronous iterator from single iterable"""
+    for elem in iterable:
+        yield elem
+
+
 async def as_aiter(*args: T) -> AsyncIterator[T]:
 async def as_aiter(*args: T) -> AsyncIterator[T]:
     """create an asynchronous iterator from a sequence of values"""
     """create an asynchronous iterator from a sequence of values"""
     for arg in args:
     for arg in args:

+ 0 - 210
hivemind/utils/grpc.py

@@ -1,210 +0,0 @@
-"""
-Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
-"""
-
-from __future__ import annotations
-
-import os
-import threading
-from typing import Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
-
-import grpc
-
-from hivemind.proto import runtime_pb2
-from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import Endpoint
-from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
-
-logger = get_logger(__name__)
-
-Stub = TypeVar("Stub")
-
-GRPC_KEEPALIVE_OPTIONS = (
-    ("grpc.keepalive_time_ms", 60 * 1000),
-    ("grpc.keepalive_timeout_ms", 60 * 1000),
-    ("grpc.keepalive_permit_without_calls", True),
-    ("grpc.http2.max_pings_without_data", 0),
-    ("grpc.http2.min_time_between_pings_ms", 30 * 1000),
-    ("grpc.http2.min_ping_interval_without_data_ms", 10 * 1000),
-)
-
-
-class ChannelInfo(NamedTuple):
-    target: Endpoint
-    aio: bool
-    options: Tuple[Tuple[str, str], ...]
-    credentials: Optional[grpc.ChannelCredentials]
-    compression: Optional[grpc.Compression]
-
-
-class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.Channel], Dict]]):
-    """
-    A process-wide cache of gRPC channels, supports both normal and aio channels, secure/insecure channels, etc
-    Based on grpcio internal channel cache by Richard Belleville and Lidi Zheng (thanks!)
-    Unlike TimedStorage, ChannelCache actively evicts stale channels even if the cache is not accessed
-    Unlike grpc._simple_stubs.ChannelCache, this implementation supports aio and does not forcibly close active channels
-    """
-
-    MAXIMUM_CHANNELS = int(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM", 4096))
-    EVICTION_PERIOD_SECONDS = float(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS", 10 * 60))
-    logger.debug(f"Eviction period = {EVICTION_PERIOD_SECONDS}s, max channels = {MAXIMUM_CHANNELS}")
-
-    _singleton: Optional[ChannelCache] = None
-    _singleton_pid: int = os.getpid()
-    _lock: threading.RLock = threading.RLock()
-    _update_eviction_evt: threading.Event = threading.Event()
-
-    def __init__(self, _created_as_singleton=False):
-        assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
-        super().__init__(maxsize=self.MAXIMUM_CHANNELS)
-        self._is_active = True
-        self._nearest_expiration_time = float("inf")
-        self._eviction_thread = threading.Thread(target=self._evict_stale_channels_in_background, daemon=True)
-        self._eviction_thread.start()
-
-    @classmethod
-    def get_singleton(cls):
-        """Get or create the channel cache for the current process"""
-        with cls._lock:
-            if cls._singleton is None or cls._singleton_pid != os.getpid():
-                if cls._singleton is not None:
-                    cls._singleton._stop_background_thread()
-                cls._singleton, cls._singleton_pid = cls(_created_as_singleton=True), os.getpid()
-            return cls._singleton
-
-    @classmethod
-    def get_stub(
-        cls,
-        target: Endpoint,
-        stub_type: Type[Stub],
-        *,
-        aio: bool,
-        options: Tuple[Tuple[str, Any]] = (),
-        channel_credentials: Optional[grpc.ChannelCredentials] = None,
-        compression: Optional[grpc.Compression] = None,
-    ) -> Stub:
-        """
-        Create a grpc channel with given options or reuse pre-existing one
-
-        :param target: the recipient's address and port
-        :param stub_type: a gRPC stub (client) to be instantiated
-        :param aio: if True, returns grpc.Channel, otherwise returns grpc.aio.Channel
-        :param options: see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
-        :param channel_credentials: if specified, create a secure channel usin these credentials (default = insecure)
-        :param compression: see https://github.com/grpc/grpc/tree/master/examples/python/compression
-        """
-        cache = cls.get_singleton()
-        with cls._lock:
-            key = ChannelInfo(target, aio, tuple(options), channel_credentials, compression)
-            entry: ValueWithExpiration = super(cls, cache).get(key)
-
-            if entry is not None:
-                channel, stubs = entry.value
-            else:
-                channel = cls._create_channel(*key)
-                stubs = {}
-
-            channel._channel.check_connectivity_state(True)
-
-            if stub_type not in stubs:
-                stubs[stub_type] = stub_type(channel)
-
-            # either cache channel or update expiration of an existing channel
-            expiration_time = get_dht_time() + cls.EVICTION_PERIOD_SECONDS
-            super(cls, cache).store(key, (channel, stubs), expiration_time)
-
-            if expiration_time < cache._nearest_expiration_time:
-                cache._nearest_expiration_time = expiration_time
-                cls._update_eviction_evt.set()
-
-            return stubs[stub_type]
-
-    @classmethod
-    def _create_channel(
-        cls,
-        target: Endpoint,
-        aio: bool,
-        extra_options: Tuple[Tuple[str, Any], ...],
-        channel_credentials: Optional[grpc.ChannelCredentials],
-        compression: Optional[grpc.Compression],
-    ) -> Union[grpc.Channel, grpc.aio.Channel]:
-        namespace = grpc.aio if aio else grpc
-
-        options = extra_options + GRPC_KEEPALIVE_OPTIONS
-
-        if channel_credentials is None:
-            logger.debug(
-                f"Creating insecure {namespace} channel with options '{options}' " f"and compression '{compression}'"
-            )
-            return namespace.insecure_channel(target, options=options, compression=compression)
-        else:
-            logger.debug(
-                f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
-                f"options '{options}' and compression '{compression}'"
-            )
-            return namespace.secure_channel(
-                target, credentials=channel_credentials, options=options, compression=compression
-            )
-
-    def _evict_stale_channels_in_background(self):
-        while self._is_active:
-            now = get_dht_time()
-            time_to_wait = max(0.0, self._nearest_expiration_time - now)
-            interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float("inf") else None)
-            if interrupted_early:
-                self._update_eviction_evt.clear()
-                continue
-
-            with self._lock:
-                self._remove_outdated()
-                _, entry = super().top()
-                self._nearest_expiration_time = entry.expiration_time if entry is not None else float("inf")
-
-    def _stop_background_thread(self):
-        with self._lock:
-            self._is_active = False
-            self._update_eviction_evt.set()
-
-    def store(self, *args, **kwargs) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-    def get(self, *args, **kwargs) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-    def top(self) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-
-STREAMING_CHUNK_SIZE_BYTES = 2**16
-
-
-def split_for_streaming(
-    serialized_tensor: runtime_pb2.Tensor,
-    chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
-) -> Iterator[runtime_pb2.Tensor]:
-    """Split serialized_tensor into multiple chunks for gRPC streaming"""
-    buffer = memoryview(serialized_tensor.buffer)
-    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
-    yield runtime_pb2.Tensor(
-        compression=serialized_tensor.compression,
-        buffer=buffer[:chunk_size_bytes].tobytes(),
-        chunks=num_chunks,
-        size=serialized_tensor.size,
-        dtype=serialized_tensor.dtype,
-        requires_grad=serialized_tensor.requires_grad,
-    )
-    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
-        yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
-
-
-def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
-    """Restore a result of split_into_chunks into a single serialized tensor"""
-    stream = iter(stream)
-    first_chunk = next(stream)
-    serialized_tensor = runtime_pb2.Tensor()
-    serialized_tensor.CopyFrom(first_chunk)
-    buffer_chunks = [first_chunk.buffer]
-    for tensor_part in stream:
-        buffer_chunks.append(tensor_part.buffer)
-    serialized_tensor.buffer = b"".join(buffer_chunks)
-    return serialized_tensor

+ 2 - 24
hivemind/utils/networking.py

@@ -1,35 +1,13 @@
 import socket
 import socket
 from contextlib import closing
 from contextlib import closing
 from ipaddress import ip_address
 from ipaddress import ip_address
-from typing import Optional, Sequence
+from typing import Sequence
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-Hostname, Port = str, int  # flavour types
-Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = "127.0.0.1"
 LOCALHOST = "127.0.0.1"
 
 
 
 
-def get_port(endpoint: Endpoint) -> Optional[Port]:
-    """get port or None if port is undefined"""
-    # TODO: find a standard way to get port, make sure it works in malformed ports
-    try:
-        return int(endpoint[endpoint.rindex(":") + 1 :], base=10)
-    except ValueError:  # :* or not specified
-        return None
-
-
-def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
-    assert endpoint.endswith(":*") or get_port(endpoint) is not None, endpoint
-    return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
-
-
-def strip_port(endpoint: Endpoint) -> Hostname:
-    """Removes port from the end of endpoint. If port is not specified, does nothing"""
-    maybe_port = endpoint[endpoint.rindex(":") + 1 :]
-    return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
-
-
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """
     """
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.
@@ -48,7 +26,7 @@ def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_S
 
 
 def choose_ip_address(
 def choose_ip_address(
     maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
     maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
-) -> Hostname:
+) -> str:
     """
     """
     Currently, some components of hivemind are not converted to work over libp2p and use classical networking.
     Currently, some components of hivemind are not converted to work over libp2p and use classical networking.
     To allow other peers reach a server when needed, these components announce a machine's IP address.
     To allow other peers reach a server when needed, these components announce a machine's IP address.

+ 49 - 0
hivemind/utils/streaming.py

@@ -0,0 +1,49 @@
+"""
+Utilities for streaming tensors
+"""
+
+from __future__ import annotations
+
+from typing import Iterable, Iterator, TypeVar
+
+from hivemind.proto import runtime_pb2
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+STREAMING_CHUNK_SIZE_BYTES = 2**16
+
+
+def split_for_streaming(
+    serialized_tensor: runtime_pb2.Tensor,
+    chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
+) -> Iterator[runtime_pb2.Tensor]:
+    """Split serialized_tensor into multiple chunks for streaming"""
+    buffer = memoryview(serialized_tensor.buffer)
+    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
+    yield runtime_pb2.Tensor(
+        compression=serialized_tensor.compression,
+        buffer=buffer[:chunk_size_bytes].tobytes(),
+        chunks=num_chunks,
+        size=serialized_tensor.size,
+        dtype=serialized_tensor.dtype,
+        requires_grad=serialized_tensor.requires_grad,
+    )
+    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
+        yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
+
+
+def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
+    """Restore a result of split_into_chunks into a single serialized tensor"""
+    stream = iter(stream)
+    first_chunk = next(stream)
+    serialized_tensor = runtime_pb2.Tensor()
+    serialized_tensor.CopyFrom(first_chunk)
+    buffer_chunks = [first_chunk.buffer]
+    for tensor_part in stream:
+        buffer_chunks.append(tensor_part.buffer)
+    serialized_tensor.buffer = b"".join(buffer_chunks)
+    return serialized_tensor
+
+
+StreamMessage = TypeVar("StreamMessage")

+ 2 - 2
setup.py

@@ -13,14 +13,14 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 from setuptools.command.develop import develop
 
 
-P2PD_VERSION = "v0.3.8"
+P2PD_VERSION = "v0.3.9"
 
 
 P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
 P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
 
 
 # The value is sha256 of the binary from the release page
 # The value is sha256 of the binary from the release page
 EXECUTABLES = {
 EXECUTABLES = {
-    "p2pd": "785058526d993f699c674dc2f9b66d565a52315a18b79b629998fab3ebd8e20f",
+    "p2pd": "8f9434f4717f6e851430f75f07e283d5ddeb2c7cde1b3648e677d813703f4e40",
 }
 }
 
 
 
 

+ 3 - 2
tests/test_compression.py

@@ -20,6 +20,7 @@ from hivemind.compression import (
 )
 )
 from hivemind.compression.adaptive import AdaptiveCompressionBase
 from hivemind.compression.adaptive import AdaptiveCompressionBase
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 
 
 from test_utils.dht_swarms import launch_dht_instances
 from test_utils.dht_swarms import launch_dht_instances
 
 
@@ -47,9 +48,9 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
 def test_serialize_tensor():
 def test_serialize_tensor():
     def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
     def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
         serialized_tensor = serialize_torch_tensor(tensor, compression)
         serialized_tensor = serialize_torch_tensor(tensor, compression)
-        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+        chunks = list(split_for_streaming(serialized_tensor, chunk_size))
         assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
         assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-        restored = hivemind.combine_from_streaming(chunks)
+        restored = combine_from_streaming(chunks)
         assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
         assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
 
 
     tensor = torch.randn(512, 12288)
     tensor = torch.randn(512, 12288)

+ 192 - 0
tests/test_connection_handler.py

@@ -0,0 +1,192 @@
+from __future__ import annotations
+
+import asyncio
+import math
+from typing import Any, Dict
+
+import pytest
+import torch
+
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.task_pool import TaskPool
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.streaming import split_for_streaming
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_info():
+    handler = ConnectionHandler(
+        DHT(start=True),
+        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    # info
+    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
+    assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
+
+    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
+    assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
+
+    with pytest.raises(P2PHandlerError):
+        await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_forward():
+    handler = ConnectionHandler(
+        DHT(start=True),
+        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    inputs = torch.randn(1, 2)
+    inputs_long = torch.randn(2**21, 2)
+
+    # forward unary
+    response = await client_stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * 1)
+
+    response = await client_stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * 2)
+
+    # forward streaming
+    split = (
+        p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
+    )
+    output_generator = await client_stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
+            iter_as_aiter(split),
+        ),
+    )
+    outputs_list = [part async for part in output_generator]
+    assert len(outputs_list) == math.ceil(inputs_long.numel() * 4 / DEFAULT_MAX_MSG_SIZE)
+
+    results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
+    assert len(results) == 1
+    assert torch.allclose(results[0], inputs_long * 2)
+
+    # forward errors
+    with pytest.raises(P2PHandlerError):
+        # no such expert: fails with P2PHandlerError KeyError('expert3')
+        await client_stub.rpc_forward(
+            runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
+        )
+
+    with pytest.raises(P2PHandlerError):
+        # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
+        await client_stub.rpc_forward(
+            runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
+        )
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_backward():
+    handler = ConnectionHandler(
+        DHT(start=True),
+        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    inputs = torch.randn(1, 2)
+    inputs_long = torch.randn(2**21, 2)
+
+    # backward unary
+    response = await client_stub.rpc_backward(
+        runtime_pb2.ExpertRequest(
+            uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
+        )
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * -2)
+
+    # backward streaming
+    split = (
+        p
+        for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
+        for p in split_for_streaming(t, chunk_size_bytes=2**16)
+    )
+    output_generator = await client_stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
+            iter_as_aiter(split),
+        ),
+    )
+    results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
+    assert len(results) == 1
+    assert torch.allclose(results[0], inputs_long * 3)
+
+    # backward errors
+    with pytest.raises(P2PHandlerError):
+        # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
+        await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
+
+    with pytest.raises(P2PHandlerError):
+        # backward fails: empty stream
+        output_generator = await client_stub.rpc_backward_stream(
+            amap_in_executor(
+                lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
+                iter_as_aiter([]),
+            ),
+        )
+        results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
+        assert len(results) == 1
+        assert torch.allclose(results[0], inputs_long * 3)
+
+    # check that handler did not crash after failed request
+    await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
+
+    handler.terminate()
+    handler.join()
+
+
+class DummyPool(TaskPool):
+    def __init__(self, k: float):
+        self.k = k
+
+    async def submit_task(self, *inputs: torch.Tensor):
+        await asyncio.sleep(0.01)
+        assert inputs[0].shape[-1] == 2
+        return [inputs[0] * self.k]
+
+
+class DummyExpertBackend(ExpertBackend):
+    def __init__(self, name: str, k: float):
+        self.name = name
+        self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.forward_pool = DummyPool(k)
+        self.backward_pool = DummyPool(k)
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(name=self.name)

+ 20 - 9
tests/test_custom_experts.py

@@ -3,7 +3,8 @@ import os
 import pytest
 import pytest
 import torch
 import torch
 
 
-from hivemind import RemoteExpert
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
 
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@@ -17,11 +18,16 @@ def test_custom_expert(hid_dim=16):
         device="cpu",
         device="cpu",
         hidden_dim=hid_dim,
         hidden_dim=hid_dim,
         num_handlers=2,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = create_remote_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
 
 
         for batch_size in (1, 4):
         for batch_size in (1, 4):
             batch = torch.randn(batch_size, hid_dim)
             batch = torch.randn(batch_size, hid_dim)
@@ -43,11 +49,16 @@ def test_multihead_expert(hid_dim=16):
         device="cpu",
         device="cpu",
         hidden_dim=hid_dim,
         hidden_dim=hid_dim,
         num_handlers=2,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = create_remote_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
 
 
         for batch_size in (1, 4):
         for batch_size in (1, 4):
             batch = (
             batch = (

+ 28 - 23
tests/test_dht_experts.py

@@ -6,11 +6,11 @@ import numpy as np
 import pytest
 import pytest
 
 
 import hivemind
 import hivemind
-from hivemind import LOCALHOST
 from hivemind.dht import DHTNode
 from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
 from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
+from hivemind.p2p import PeerInfo
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -25,17 +25,18 @@ def test_store_get_experts(n_peers=10):
     expert_uids = [f"my_expert.{i}" for i in range(50)]
     expert_uids = [f"my_expert.{i}" for i in range(50)]
     batch_size = 10
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
     for batch_start in range(0, len(expert_uids), batch_size):
-        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], "localhost:1234")
+        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size])
 
 
     found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
     found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
 
-    other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
-    declare_experts(other_peer, [other_expert], f"that_host:{other_port}")
+    other_expert = "my_other_expert.1337"
+    declare_experts(other_peer, [other_expert])
     first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert isinstance(first_found, hivemind.RemoteExpert)
-    assert first_found.endpoint == f"that_host:{other_port}"
+    assert first_found.server_peer_info.peer_id == other_peer.peer_id
+    assert first_notfound is None
 
 
     # test graceful shutdown
     # test graceful shutdown
     first_peer.shutdown()
     first_peer.shutdown()
@@ -43,30 +44,31 @@ def test_store_get_experts(n_peers=10):
     time.sleep(1.0)
     time.sleep(1.0)
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
-    assert all(declare_experts(remaining_peer1, ["new_expert.1"], "dummy"))
-    assert get_experts(remaining_peer2, ["new_expert.1"])[0].endpoint == "dummy"
+    assert all(declare_experts(remaining_peer1, ["new_expert.1"]))
+    assert get_experts(remaining_peer2, ["new_expert.1"])[0].server_peer_info.peer_id == remaining_peer1.peer_id
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_beam_search(
 def test_beam_search(
     n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
     n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
 ):
 ):
-    dht = [hivemind.DHT(start=True)]
-    initial_peers = dht[0].get_visible_maddrs()
-    dht += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
+    dht_instances = [hivemind.DHT(start=True)]
+    initial_peers = dht_instances[0].get_visible_maddrs()
+    dht_instances += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
 
 
     real_experts = sorted(
     real_experts = sorted(
         {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
         {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
     )
     )
     for batch_start in range(0, len(real_experts), batch_size):
     for batch_start in range(0, len(real_experts), batch_size):
+        dht = random.choice(dht_instances)
         declare_experts(
         declare_experts(
-            random.choice(dht),
+            dht,
             real_experts[batch_start : batch_start + batch_size],
             real_experts[batch_start : batch_start + batch_size],
-            wait=True,
-            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}",
         )
         )
 
 
-    neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
+    neighbors = sum(
+        [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], []
+    )
     you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
     you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
     beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
     beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
 
 
@@ -89,22 +91,25 @@ def test_dht_single_node():
     node = hivemind.DHT(start=True)
     node = hivemind.DHT(start=True)
     beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
     beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
 
 
-    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], f"{hivemind.LOCALHOST}:1337").values())
-    assert len(declare_experts(node, ["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], f"{hivemind.LOCALHOST}:42")) == 7
+    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"]).values())
+    assert len(declare_experts(node, ["ffn.1", "ffn.2"])) == 4
+    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"])) == 7
 
 
     for expert in get_experts(node, ["expert.3", "expert.2"]):
     for expert in get_experts(node, ["expert.3", "expert.2"]):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+        assert expert.server_peer_info.peer_id == node.peer_id
 
 
-    assert all(declare_experts(node, ["expert.5", "expert.2"], f"{hivemind.LOCALHOST}:1337").values())
+    assert all(declare_experts(node, ["expert.5", "expert.2"]).values())
     found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
     found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
     assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
     assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
 
 
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     assert len(successors["e.1.2."]) == 2
     assert len(successors["e.1.2."]) == 2
-    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", f"{LOCALHOST}:42")
-    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", f"{LOCALHOST}:42")
-    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", f"{LOCALHOST}:42")
+
+    peer_info = PeerInfo(node.peer_id, [a.decapsulate("/p2p/" + a.get("p2p")) for a in node.get_visible_maddrs()])
+
+    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", peer_info)
+    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", peer_info)
+    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", peer_info)
     assert successors["e.4.5."] == {}
     assert successors["e.4.5."] == {}
 
 
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
@@ -194,7 +199,7 @@ async def test_negative_caching(n_peers=10):
     peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
     peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
 
 
     writer_peer = random.choice(peers)
     writer_peer = random.choice(peers)
-    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], "myaddr:1234").values())
+    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"]).values())
 
 
     neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
     neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)

+ 40 - 28
tests/test_moe.py

@@ -1,13 +1,14 @@
-import grpc
 import numpy as np
 import numpy as np
 import pytest
 import pytest
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
+from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
 from hivemind.moe.server.layers import name_to_block
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 
 
@@ -18,8 +19,8 @@ def test_moe():
     ]
     ]
     with background_server(
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
 
         dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
         dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
 
 
@@ -35,9 +36,8 @@ def test_no_experts():
     ]
     ]
     with background_server(
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
-
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
         dmoe = RemoteSwitchMixtureOfExperts(
         dmoe = RemoteSwitchMixtureOfExperts(
             in_features=16,
             in_features=16,
             grid_size=(4, 4, 4),
             grid_size=(4, 4, 4),
@@ -71,12 +71,16 @@ def test_call_many(hidden_dim=16):
         num_handlers=1,
         num_handlers=1,
         hidden_dim=hidden_dim,
         hidden_dim=hidden_dim,
         optim_cls=None,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
+    ) as server_peer_info:
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
+
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        e0, e1, e2, e3, e4 = create_remote_experts(
+            [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
+            dht,
+        )
+        e5 = RemoteExpert(RemoteExpertInfo(f"thisshouldnotexist", server_peer_info), None)
 
 
         mask, expert_outputs = _RemoteCallMany.apply(
         mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             DUMMY,
@@ -129,11 +133,15 @@ def test_remote_module_call(hidden_dim=16):
         num_handlers=1,
         num_handlers=1,
         hidden_dim=hidden_dim,
         hidden_dim=hidden_dim,
         optim_cls=None,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        real_expert = RemoteExpert("expert.0", server_endpoint)
-        fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
-
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        real_expert, fake_expert = create_remote_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
         out1 = real_expert(torch.randn(1, hidden_dim))
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
         assert out1.shape == (1, hidden_dim)
         dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
         dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
@@ -144,9 +152,9 @@ def test_remote_module_call(hidden_dim=16):
         out3_again.norm().backward()
         out3_again.norm().backward()
         assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
         assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
 
 
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             real_expert(torch.randn(3, 11))
             real_expert(torch.randn(3, 11))
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             fake_expert(dummy_x)
             fake_expert(dummy_x)
 
 
 
 
@@ -154,11 +162,11 @@ def test_remote_module_call(hidden_dim=16):
 def test_beam_search_correctness():
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
     dht = DHT(start=True)
     dht = DHT(start=True)
-    assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
+    assert all(declare_experts(dht, all_expert_uids))
 
 
     dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
     dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
 
 
-    for i in range(25):
+    for _ in range(25):
         input = torch.randn(32)
         input = torch.randn(32)
         grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
         grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
 
 
@@ -173,7 +181,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[RemoteExpert(uid, "") for uid in all_expert_uids]],
+            [[RemoteExpert(RemoteExpertInfo(uid, None), None) for uid in all_expert_uids]],
         )[0]
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
 
@@ -194,9 +202,12 @@ def test_determinism(hidden_dim=16):
         num_handlers=1,
         num_handlers=1,
         hidden_dim=hidden_dim,
         hidden_dim=hidden_dim,
         optim_cls=None,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert = create_remote_experts(
+            [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
+            dht=dht,
+        )[0]
 
 
         out = expert(xx, mask)
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -220,7 +231,7 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
         batch_experts = [
             [
             [
-                RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
+                RemoteExpert(RemoteExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
                 for expert_i in range(len(ii[batch_i]))
                 for expert_i in range(len(ii[batch_i]))
             ]
             ]
             for batch_i in range(len(ii))
             for batch_i in range(len(ii))
@@ -261,9 +272,10 @@ def test_client_anomaly_detection():
     server.start()
     server.start()
     try:
     try:
         server.ready.wait()
         server.ready.wait()
+        client_side_dht = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
 
 
         dmoe = RemoteMixtureOfExperts(
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(3,), dht=client_side_dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
         )
 
 
         input = torch.randn(1, 16)
         input = torch.randn(1, 16)
@@ -280,7 +292,7 @@ def test_client_anomaly_detection():
             inf_loss.backward()
             inf_loss.backward()
 
 
         dmoe = RemoteMixtureOfExperts(
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(4,), dht=client_side_dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         )
         output = dmoe(input)
         output = dmoe(input)
         assert output.isfinite().all()
         assert output.isfinite().all()

+ 8 - 2
tests/test_p2p_daemon_bindings.py

@@ -560,13 +560,19 @@ async def test_client_stream_handler_success(p2pcs):
 
 
     writer.close()
     writer.close()
 
 
-    # test case: registering twice can override the previous registration
+    # test case: registering twice can't override the previous registration without balanced flag
     event_third = asyncio.Event()
     event_third = asyncio.Event()
 
 
     async def handler_third(stream_info, reader, writer):
     async def handler_third(stream_info, reader, writer):
         event_third.set()
         event_third.set()
 
 
-    await p2pcs[1].stream_handler(another_proto, handler_third)
+    # p2p raises now for doubled stream handlers
+    with pytest.raises(ControlFailure):
+        await p2pcs[1].stream_handler(another_proto, handler_third)
+
+    # add in balanced mode: handler should be placed in round robin queue
+    # and become the next to be called
+    await p2pcs[1].stream_handler(another_proto, handler_third, balanced=True)
     assert another_proto in p2pcs[1].control.handlers
     assert another_proto in p2pcs[1].control.handlers
     # ensure the handler is override
     # ensure the handler is override
     assert handler_third == p2pcs[1].control.handlers[another_proto]
     assert handler_third == p2pcs[1].control.handlers[another_proto]

+ 17 - 11
tests/test_training.py

@@ -8,7 +8,8 @@ import torch.nn.functional as F
 from sklearn.datasets import load_digits
 from sklearn.datasets import load_digits
 
 
 from hivemind import DHT
 from hivemind import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
@@ -19,12 +20,17 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     SGD = partial(torch.optim.SGD, lr=0.05)
     SGD = partial(torch.optim.SGD, lr=0.05)
 
 
-    with background_server(num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1, no_dht=True) as (
-        server_endpoint,
-        _,
-    ):
-        expert1 = RemoteExpert("expert.0", server_endpoint)
-        expert2 = RemoteExpert("expert.1", server_endpoint)
+    with background_server(
+        num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert1, expert2 = create_remote_experts(
+            [
+                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
+                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+            ],
+            dht=dht,
+        )
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
 
 
         opt = SGD(model.parameters(), lr=0.05)
         opt = SGD(model.parameters(), lr=0.05)
@@ -54,8 +60,8 @@ def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
 
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
         model = nn.Sequential(moe, nn.Linear(64, 2))
         model = nn.Sequential(moe, nn.Linear(64, 2))
@@ -107,8 +113,8 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
 
         model = SwitchNetwork(dht, 64, 2, num_experts)
         model = SwitchNetwork(dht, 64, 2, num_experts)
         opt = SGD(model.parameters(), lr=0.05)
         opt = SGD(model.parameters(), lr=0.05)

+ 1 - 47
tests/test_util_modules.py

@@ -11,9 +11,7 @@ import torch
 
 
 import hivemind
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
     achain,
     achain,
@@ -330,50 +328,6 @@ def test_many_futures():
     p.join()
     p.join()
 
 
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_channel_cache():
-    hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
-    hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
-
-    c1 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-    c2 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
-    c3 = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
-    c3_again = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
-    c1_again = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-    c4 = hivemind.ChannelCache.get_stub("localhost:1339", DHTStub, aio=True)
-    c2_anew = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
-    c1_yetagain = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-
-    await asyncio.sleep(0.2)
-    c1_anew = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
-    c1_anew_again = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
-    c1_otherstub = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub)
-    await asyncio.sleep(0.05)
-    c1_otherstub_again = hivemind.ChannelCache.get_stub(
-        target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub
-    )
-    all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
-
-    assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
-    assert isinstance(all_channels[-1], ConnectionHandlerStub)
-    assert "aio" in repr(c2.rpc_find)
-    assert "aio" not in repr(c1.rpc_find)
-
-    duplicates = {
-        (c1, c1_again),
-        (c1, c1_yetagain),
-        (c1_again, c1_yetagain),
-        (c3, c3_again),
-        (c1_anew, c1_anew_again),
-        (c1_otherstub, c1_otherstub_again),
-    }
-    for i in range(len(all_channels)):
-        for j in range(i + 1, len(all_channels)):
-            ci, cj = all_channels[i], all_channels[j]
-            assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
-
-
 def test_serialize_tuple():
 def test_serialize_tuple():
     test_pairs = (
     test_pairs = (
         ((1, 2, 3), [1, 2, 3]),
         ((1, 2, 3), [1, 2, 3]),
@@ -419,7 +373,7 @@ def test_split_parts():
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
         with pytest.raises(RuntimeError):
         with pytest.raises(RuntimeError):
             deserialize_torch_tensor(combined)
             deserialize_torch_tensor(combined)
-            # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllreduceRunner
+            # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllReduceRunner
 
 
 
 
 def test_generic_data_classes():
 def test_generic_data_classes():