Эх сурвалжийг харах

Merge branch 'master' into hive-rl

foksly 3 жил өмнө
parent
commit
4bb482c87f
97 өөрчлөгдсөн 2084 нэмэгдсэн , 2162 устгасан
  1. 1 1
      .github/workflows/check-style.yml
  2. 6 2
      .github/workflows/run-tests.yml
  3. 1 1
      README.md
  4. 5 6
      benchmarks/benchmark_averaging.py
  5. 0 1
      benchmarks/benchmark_dht.py
  6. 38 20
      benchmarks/benchmark_throughput.py
  7. 1 17
      docs/modules/optim.rst
  8. 5 5
      docs/modules/server.rst
  9. 9 12
      docs/user/moe.md
  10. 14 5
      examples/albert/README.md
  11. 1 1
      examples/albert/arguments.py
  12. 2 1
      examples/albert/run_trainer.py
  13. 2 1
      examples/albert/run_training_monitor.py
  14. 1 23
      examples/albert/utils.py
  15. 2 12
      hivemind/__init__.py
  16. 2 4
      hivemind/averaging/allreduce.py
  17. 69 71
      hivemind/averaging/averager.py
  18. 1 1
      hivemind/averaging/key_manager.py
  19. 2 2
      hivemind/averaging/matchmaking.py
  20. 5 1
      hivemind/compression/__init__.py
  21. 2 2
      hivemind/compression/adaptive.py
  22. 1 1
      hivemind/compression/base.py
  23. 25 1
      hivemind/compression/serialization.py
  24. 1 1
      hivemind/dht/__init__.py
  25. 12 4
      hivemind/dht/dht.py
  26. 2 2
      hivemind/dht/node.py
  27. 1 1
      hivemind/dht/routing.py
  28. 0 1
      hivemind/hivemind_cli/config.yml
  29. 76 0
      hivemind/hivemind_cli/run_dht.py
  30. 13 4
      hivemind/hivemind_cli/run_server.py
  31. 1 1
      hivemind/moe/__init__.py
  32. 59 55
      hivemind/moe/client/beam_search.py
  33. 158 37
      hivemind/moe/client/expert.py
  34. 29 25
      hivemind/moe/client/moe.py
  35. 48 0
      hivemind/moe/client/remote_expert_worker.py
  36. 3 3
      hivemind/moe/client/switch_moe.py
  37. 4 2
      hivemind/moe/expert_uid.py
  38. 1 1
      hivemind/moe/server/__init__.py
  39. 9 9
      hivemind/moe/server/checkpoints.py
  40. 106 51
      hivemind/moe/server/connection_handler.py
  41. 35 26
      hivemind/moe/server/dht_handler.py
  42. 1 1
      hivemind/moe/server/layers/dropout.py
  43. 58 0
      hivemind/moe/server/layers/optim.py
  44. 46 84
      hivemind/moe/server/module_backend.py
  45. 11 11
      hivemind/moe/server/runtime.py
  46. 66 69
      hivemind/moe/server/server.py
  47. 0 4
      hivemind/optim/__init__.py
  48. 0 34
      hivemind/optim/adaptive.py
  49. 0 44
      hivemind/optim/base.py
  50. 0 558
      hivemind/optim/collaborative.py
  51. 20 7
      hivemind/optim/grad_averager.py
  52. 1 1
      hivemind/optim/grad_scaler.py
  53. 12 4
      hivemind/optim/optimizer.py
  54. 222 0
      hivemind/optim/power_sgd_averager.py
  55. 0 229
      hivemind/optim/simple.py
  56. 1 1
      hivemind/optim/state_averager.py
  57. 52 19
      hivemind/p2p/p2p_daemon.py
  58. 6 4
      hivemind/p2p/p2p_daemon_bindings/control.py
  59. 5 4
      hivemind/p2p/p2p_daemon_bindings/p2pclient.py
  60. 4 2
      hivemind/p2p/servicer.py
  61. 24 0
      hivemind/proto/crypto.proto
  62. 0 11
      hivemind/proto/dht.proto
  63. 6 4
      hivemind/proto/p2pd.proto
  64. 0 8
      hivemind/proto/runtime.proto
  65. 2 2
      hivemind/utils/__init__.py
  66. 7 8
      hivemind/utils/asyncio.py
  67. 9 6
      hivemind/utils/crypto.py
  68. 0 210
      hivemind/utils/grpc.py
  69. 24 0
      hivemind/utils/math.py
  70. 1 1
      hivemind/utils/mpfuture.py
  71. 25 41
      hivemind/utils/networking.py
  72. 46 0
      hivemind/utils/streaming.py
  73. 1 1
      pyproject.toml
  74. 1 1
      requirements-dev.txt
  75. 0 1
      requirements.txt
  76. 38 30
      setup.py
  77. 1 2
      tests/test_allreduce.py
  78. 14 16
      tests/test_allreduce_fault_tolerance.py
  79. 16 80
      tests/test_averaging.py
  80. 63 0
      tests/test_cli_scripts.py
  81. 3 2
      tests/test_compression.py
  82. 192 0
      tests/test_connection_handler.py
  83. 21 9
      tests/test_custom_experts.py
  84. 1 1
      tests/test_dht.py
  85. 29 29
      tests/test_dht_experts.py
  86. 10 8
      tests/test_expert_backend.py
  87. 46 33
      tests/test_moe.py
  88. 84 12
      tests/test_optimizer.py
  89. 31 3
      tests/test_p2p_daemon.py
  90. 8 2
      tests/test_p2p_daemon_bindings.py
  91. 1 1
      tests/test_routing.py
  92. 83 0
      tests/test_start_server.py
  93. 18 96
      tests/test_training.py
  94. 1 53
      tests/test_util_modules.py
  95. 0 0
      tests/test_utils/__init__.py
  96. 18 0
      tests/test_utils/networking.py
  97. 2 1
      tests/test_utils/p2p_daemon.py

+ 1 - 1
.github/workflows/check-style.yml

@@ -13,7 +13,7 @@ jobs:
       - uses: psf/black@stable
         with:
           options: "--check --diff"
-          version: "22.1.0"
+          version: "22.3.0"
   isort:
     runs-on: ubuntu-latest
     steps:

+ 6 - 2
.github/workflows/run-tests.yml

@@ -12,7 +12,7 @@ jobs:
     strategy:
       matrix:
         python-version: [ 3.7, 3.8, 3.9 ]
-    timeout-minutes: 10
+    timeout-minutes: 15
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python
@@ -42,6 +42,10 @@ jobs:
     timeout-minutes: 10
     steps:
       - uses: actions/checkout@v2
+      - uses: actions/setup-go@v3
+        with:
+          go-version: '1.16'
+          check-latest: true
       - name: Set up Python
         uses: actions/setup-python@v2
         with:
@@ -67,7 +71,7 @@ jobs:
   codecov_in_develop_mode:
 
     runs-on: ubuntu-latest
-    timeout-minutes: 10
+    timeout-minutes: 15
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python

+ 1 - 1
README.md

@@ -63,7 +63,7 @@ By default, hivemind uses the precompiled binary of
 the [go-libp2p-daemon](https://github.com/learning-at-home/go-libp2p-daemon) library. If you face compatibility issues
 or want to build the binary yourself, you can recompile it by running `pip install . --global-option="--buildgo"`.
 Before running the compilation, please ensure that your machine has a recent version
-of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
+of [Go toolchain](https://golang.org/doc/install) (1.15 or 1.16 are supported).
 
 ### System requirements
 

+ 5 - 6
benchmarks/benchmark_averaging.py

@@ -6,10 +6,9 @@ import time
 import torch
 
 import hivemind
-from hivemind.proto import runtime_pb2
+from hivemind.compression import Float16Compression
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.networking import LOCALHOST
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -38,7 +37,7 @@ def benchmark_averaging(
     num_peers: int,
     target_group_size: int,
     num_rounds: int,
-    averaging_expiration: float,
+    min_matchmaking_time: float,
     request_timeout: float,
     round_timeout: float,
     hid_size: int,
@@ -64,9 +63,9 @@ def benchmark_averaging(
             dht,
             prefix="my_tensor",
             initial_group_bits=initial_bits,
-            compression_type=runtime_pb2.CompressionType.FLOAT16,
+            compression=Float16Compression(),
             target_group_size=target_group_size,
-            averaging_expiration=averaging_expiration,
+            min_matchmaking_time=min_matchmaking_time,
             request_timeout=request_timeout,
             start=True,
         )
@@ -108,7 +107,7 @@ if __name__ == "__main__":
     parser.add_argument("--num_rounds", type=int, default=5, required=False)
     parser.add_argument("--hid_size", type=int, default=256, required=False)
     parser.add_argument("--num_layers", type=int, default=3, required=False)
-    parser.add_argument("--averaging_expiration", type=float, default=5, required=False)
+    parser.add_argument("--min_matchmaking_time", type=float, default=5, required=False)
     parser.add_argument("--round_timeout", type=float, default=15, required=False)
     parser.add_argument("--request_timeout", type=float, default=1, required=False)
     parser.add_argument("--spawn_dtime", type=float, default=0.1, required=False)

+ 0 - 1
benchmarks/benchmark_dht.py

@@ -3,7 +3,6 @@ import asyncio
 import random
 import time
 import uuid
-from logging import shutdown
 from typing import Tuple
 
 import numpy as np

+ 38 - 20
benchmarks/benchmark_throughput.py

@@ -6,12 +6,15 @@ import time
 
 import torch
 
-from hivemind.moe.client import RemoteExpert
-from hivemind.moe.server import ExpertBackend, Server
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertInfo
+from hivemind.moe.server import ModuleBackend, Server
 from hivemind.moe.server.layers import name_to_block
+from hivemind.p2p import P2P
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.networking import LOCALHOST, get_free_port
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 use_hivemind_log_handler("in_root_logger")
@@ -31,14 +34,29 @@ def print_device_info(device=None):
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
 
 
-def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+def client_process(
+    can_start,
+    benchmarking_failed,
+    server_maddrs,
+    server_peer_id,
+    num_experts,
+    batch_size,
+    hid_dim,
+    num_batches,
+    backprop=True,
+) -> None:
     torch.set_num_threads(1)
     can_start.wait()
-    experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
+
+    p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
+    experts = [
+        RemoteExpert(expert_info=ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_id), p2p=p2p)
+        for i in range(num_experts)
+    ]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
-        for batch_i in range(num_batches):
+        for _ in range(num_batches):
             expert = random.choice(experts)
             out = expert(dummy_batch)
             if backprop:
@@ -59,7 +77,6 @@ def benchmark_throughput(
     max_batch_size=None,
     backprop=True,
     device=None,
-    port=None,
 ):
     assert (
         not hasattr(torch.cuda, "is_initialized")
@@ -67,7 +84,6 @@ def benchmark_throughput(
         or torch.device(device) == torch.device("cpu")
     )
     assert expert_cls in name_to_block
-    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
@@ -75,8 +91,7 @@ def benchmark_throughput(
     timestamps = dict(started=time.perf_counter())
 
     try:
-        # start clients and await server
-        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        server_dht = DHT(start=True)
         clients = [
             mp.Process(
                 target=client_process,
@@ -84,52 +99,55 @@ def benchmark_throughput(
                 args=(
                     can_start,
                     benchmarking_failed,
-                    port,
+                    server_dht.get_visible_maddrs(),
+                    server_dht.peer_id,
                     num_experts,
                     batch_size,
                     hid_dim,
                     num_batches_per_client,
                     backprop,
                 ),
+                daemon=True,
             )
             for i in range(num_clients)
         ]
 
         for client in clients:
-            client.daemon = True
             client.start()
 
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
-        # start server
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-        experts = {}
+        module_backends = {}
         for i in range(num_experts):
             expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = ExpertBackend(
-                name=f"expert{i}",
-                expert=expert,
+            module_backends[f"expert.{i}"] = ModuleBackend(
+                name=f"expert.{i}",
+                module=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
                 args_schema=(BatchTensorDescriptor(hid_dim),),
                 outputs_schema=BatchTensorDescriptor(hid_dim),
                 max_batch_size=max_batch_size,
             )
         timestamps["created_experts"] = time.perf_counter()
+
         server = Server(
-            None,
-            experts,
-            listen_on=f"{LOCALHOST}:{port}",
+            dht=server_dht,
+            module_backends=module_backends,
             num_connection_handlers=num_handlers,
             device=device,
         )
         server.start()
         server.ready.wait()
+
         timestamps["server_ready"] = time.perf_counter()
         can_start.set()
 
         for client in clients:
             client.join()
+
         timestamps["clients_finished"] = time.perf_counter()
+
     except BaseException as e:
         benchmarking_failed.set()
         raise e

+ 1 - 17
docs/modules/optim.rst

@@ -21,20 +21,4 @@
 
 .. currentmodule:: hivemind.optim.grad_scaler
 .. autoclass:: GradScaler
-   :member-order: bysource
-
-
-**CollaborativeOptimizer**
---------------------------
-
-
-.. automodule:: hivemind.optim.collaborative
-.. currentmodule:: hivemind.optim
-
-.. autoclass:: CollaborativeOptimizer
-   :members: step
-   :member-order: bysource
-
-.. autoclass:: CollaborativeAdaptiveOptimizer
-   :members:
-   :member-order: bysource
+   :member-order: bysource

+ 5 - 5
docs/modules/server.rst

@@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo
 The hivemind.moe.server module is organized as follows:
 
 - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
-- ExpertBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
+- ModuleBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
   that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
-- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
+- Runtime_ balances the device (GPU) usage between several ModuleBackend_ instances that each service one expert.
 - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
   and offers those batches to Runtime_ for processing.
 
@@ -25,9 +25,9 @@ The hivemind.moe.server module is organized as follows:
    :members:
    :member-order: bysource
 
-.. _ExpertBackend:
-.. autoclass:: ExpertBackend
-    :members: forward, backward, apply_gradients, get_info, get_pools
+.. _ModuleBackend:
+.. autoclass:: ModuleBackend
+    :members: forward, backward, on_backward, get_info, get_pools
     :member-order: bysource
 
 .. currentmodule:: hivemind.moe.server.runtime

+ 9 - 12
docs/user/moe.md

@@ -1,7 +1,7 @@
 # Mixture-of-Experts
 
 This tutorial covers the basics of Decentralized Mixture-of-Experts (DMoE).
-From the infrastructure standpoint, DMoE consists of two parts: experts hosted on peer devices, and a gating/routing function that assigns input to one of these experts.
+From the infrastructure standpoint, DMoE consists of two parts: experts hosted on peer devices, and client-side modules to access those experts.
 
 ## Host experts with a server
 
@@ -11,9 +11,8 @@ most of the model parameters and computation. The server can be started using ei
 for now. To host a server with default experts, run this in your shell:
 
 ```sh
-hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_pattern "expert.[0:5]" \
-                --listen_on 0.0.0.0:1337
-# note: if you omit listen_on and/or dht_port, they will be chosen automatically and printed to stdout.
+hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_pattern "expert.[0:5]"
+# note: server will listen to a random port. To specify interface & port, add --host_maddrs and --announce_maddrs
 ```
 
 <details style="margin-top:-24px; margin-bottom: 16px;">
@@ -22,8 +21,7 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_patte
 ```sh
 [2021/07/15 18:52:01.424][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL'], initial peers = []
 [2021/07/15 18:52:01.424][INFO][moe.server.create:181] Generating 5 expert uids from pattern expert.[0:5]
-[2021/07/15 18:52:01.658][INFO][moe.server.run:233] Server started at 0.0.0.0:1337
-[2021/07/15 18:52:01.658][INFO][moe.server.run:234] Got 5 experts:
+[2021/07/15 18:52:01.658][INFO][moe.server.run:233] Server started with 5 experts
 [2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.4: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.0: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:52:01.659][INFO][moe.server.run:237] expert.3: FeedforwardBlock, 2100736 parameters
@@ -67,8 +65,7 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 10 --expert_patt
 ```sh
 [2021/07/15 18:53:41.700][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/34487/p2p/QmcJ3jgbdwphLAiwGjvwrjimJJrdMyhLHf6tFj9viCFFGn'], initial peers = ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL']
 [2021/07/15 18:53:41.700][INFO][moe.server.create:181] Generating 10 expert uids from pattern expert.[5:250]
-[2021/07/15 18:53:42.085][INFO][moe.server.run:233] Server started at 0.0.0.0:36389
-[2021/07/15 18:53:42.086][INFO][moe.server.run:234] Got 10 experts:
+[2021/07/15 18:53:42.085][INFO][moe.server.run:233] Server started with 10 experts:
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.55: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.173: FeedforwardBlock, 2100736 parameters
 [2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.164: FeedforwardBlock, 2100736 parameters
@@ -104,10 +101,10 @@ hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 10 --expert_patt
 
 </details>
 
-By default, the server will only accept connections from your local machine. To access it globally, you should replace
-`127.0.0.1` part from initial peers with server's IP address. Hivemind supports both ipv4 and ipv6 protocols and uses the same notation
-as [libp2p](https://docs.libp2p.io/concepts/addressing/). You can find more details on multiaddresses in the 
-[DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html).
+By default, the server will only accept connections from your local network. 
+To enable training over the Internet (or some other network), you should set `--host_maddrs` and `--announce_maddrs`.
+These options also allow you to select IPv4/IPv6 network protocols and TCP and QUIC transport protocols.
+You can find more details in the [DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html).
 
 ## Train the experts
 

+ 14 - 5
examples/albert/README.md

@@ -3,7 +3,7 @@
 This tutorial will walk you through the steps to set up collaborative training with the ALBERT-large-v2 model and the
 WikiText103 dataset. It uses Hugging Face [datasets](https://github.com/huggingface/datasets)
 and [transformers](https://github.com/huggingface/transformers/) libraries to compute local updates,
-using `hivemind.CollaborativeOptimizer` to exchange information between peers.
+using `hivemind.Optimizer` to exchange information between peers.
 
 ## Preparation
 
@@ -27,8 +27,8 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,
 
 ```
 $ ./run_training_monitor.py --wandb_project Demo-run
-Oct 14 16:26:36.083 [INFO] Running a DHT peer. To connect other peers to this one over the Internet,
-use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
+Oct 14 16:26:36.083 [INFO] Running a DHT instance. To connect other peers to this one, use
+ --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
 Oct 14 16:26:36.083 [INFO] Full list of visible multiaddresses: ...
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32
@@ -130,11 +130,20 @@ monitors on different servers and list all of them as `--initial_peers`. The sys
 as at least one externally accessible participant is available. For short- to mid-term experiments you can host the
 monitor on a [free-tier VM](https://www.quora.com/Are-there-any-free-online-virtual-machines).
 
+By default, the training monitor changes its address on restart, so you may launch two monitors on the same machine.
+If you'd like to fix the monitor's address (e.g., before sending it to your collaborators),
+you need to **(a)** make it listen a specific TCP/UDP port and **(b)** provide a path for storing the identity file
+(which allows [libp2p](https://libp2p.io/) to reuse the same peer ID after restart). You may do that like this:
+
+```bash
+./run_training_monitor.py --wandb_project YOUR_WANDB_PROJECT --host_maddrs /ip4/0.0.0.0/tcp/31337 --identity_path ./identity.key
+```
+
 ### Tuning for hardware/network
 
 The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept
 incoming connections (e.g. when in colab or behind a firewall), add `--client_mode` to the training script (see example
-below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or
+below). In case of high network latency, you may want to increase `--matchmaking_time` by a few seconds or
 set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
 be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
 recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`.
@@ -173,7 +182,7 @@ Here's an example of a full trainer script for Google Colab:
 !ulimit -n 4096 && ./hivemind/examples/albert/run_trainer.py \
     --initial_peers ONE_OR_MORE_PEERS \
     --logging_dir ./logs --logging_first_step --output_dir ./outputs --overwrite_output_dir \
-    --client_mode --averaging_expiration 10 --batch_size_lead 300 --gradient_accumulation_steps 1
+    --client_mode --matchmaking_time 10 --batch_size_lead 300 --gradient_accumulation_steps 1
 ```
 
 ### Using IPFS

+ 1 - 1
examples/albert/arguments.py

@@ -38,7 +38,7 @@ class BaseTrainingArguments:
         default=None,
         metadata={
             "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
-            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+            "If the file does not exist yet, writes a new private key to this file."
         },
     )
 

+ 2 - 1
examples/albert/run_trainer.py

@@ -19,6 +19,7 @@ from transformers.trainer_utils import is_main_process
 
 from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import log_visible_maddrs
 
 import utils
 from arguments import (
@@ -227,7 +228,7 @@ def main():
         announce_maddrs=collaboration_args.announce_maddrs,
         identity_path=collaboration_args.identity_path,
     )
-    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
+    log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
 
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
     if torch.cuda.device_count() != 0:

+ 2 - 1
examples/albert/run_training_monitor.py

@@ -14,6 +14,7 @@ from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser, g
 import hivemind
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import log_visible_maddrs
 
 import utils
 from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments
@@ -168,7 +169,7 @@ if __name__ == "__main__":
         announce_maddrs=monitor_args.announce_maddrs,
         identity_path=monitor_args.identity_path,
     )
-    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
+    log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
 
     if monitor_args.wandb_project is not None:
         wandb.init(project=monitor_args.wandb_project)

+ 1 - 23
examples/albert/utils.py

@@ -1,13 +1,11 @@
 from typing import Dict, List, Tuple
 
-from multiaddr import Multiaddr
 from pydantic import BaseModel, StrictFloat, confloat, conint
 
-from hivemind import choose_ip_address
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
-from hivemind.utils.logging import TextStyle, get_logger
+from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
 
@@ -28,23 +26,3 @@ def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
     signature_validator = RSASignatureValidator()
     validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
     return validators, signature_validator.local_public_key
-
-
-def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
-    if only_p2p:
-        unique_addrs = {addr["p2p"] for addr in visible_maddrs}
-        initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
-    else:
-        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
-        if available_ips:
-            preferred_ip = choose_ip_address(available_ips)
-            selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
-        else:
-            selected_maddrs = visible_maddrs
-        initial_peers_str = " ".join(str(addr) for addr in selected_maddrs)
-
-    logger.info(
-        f"Running a DHT peer. To connect other peers to this one over the Internet, use "
-        f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
-    )
-    logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")

+ 2 - 12
hivemind/__init__.py

@@ -2,24 +2,14 @@ from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.moe import (
-    ExpertBackend,
+    ModuleBackend,
     RemoteExpert,
     RemoteMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,
     Server,
     register_expert_class,
 )
-from hivemind.optim import (
-    CollaborativeAdaptiveOptimizer,
-    CollaborativeOptimizer,
-    DecentralizedAdam,
-    DecentralizedOptimizer,
-    DecentralizedOptimizerBase,
-    DecentralizedSGD,
-    GradScaler,
-    Optimizer,
-    TrainingAverager,
-)
+from hivemind.optim import GradScaler, Optimizer, TrainingAverager
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 

+ 2 - 4
hivemind/averaging/allreduce.py

@@ -1,6 +1,6 @@
 import asyncio
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
+from typing import AsyncIterator, Optional, Sequence, Set, Tuple, Type
 
 import torch
 
@@ -50,7 +50,6 @@ class AllReduceRunner(ServicerBase):
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param gathered: additional user-defined data collected from this group
     :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
       previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
     :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
@@ -73,7 +72,6 @@ class AllReduceRunner(ServicerBase):
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         modes: Optional[Sequence[AveragingMode]] = None,
-        gathered: Optional[Dict[PeerID, Any]] = None,
         sender_timeout: Optional[float] = None,
         reducer_timeout: Optional[float] = None,
         **kwargs,
@@ -99,7 +97,7 @@ class AllReduceRunner(ServicerBase):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
-        self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
+        self.modes, self.peer_fractions = modes, peer_fractions
 
         if weight is None:
             weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)

+ 69 - 71
hivemind/averaging/averager.py

@@ -22,13 +22,7 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
-from hivemind.compression import (
-    CompressionBase,
-    CompressionInfo,
-    NoCompression,
-    deserialize_torch_tensor,
-    serialize_torch_tensor,
-)
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
@@ -36,7 +30,6 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
     achain,
-    afirst,
     aiter_with_timeout,
     anext,
     as_aiter,
@@ -44,8 +37,8 @@ from hivemind.utils.asyncio import (
     enter_asynchronously,
     switch_to_uvloop,
 )
-from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 
 # flavour types
@@ -109,7 +102,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     """
 
     _matchmaking: Matchmaking
-    _pending_group_assembled: asyncio.Event
+    _pending_groups_registered: asyncio.Event
     _state_updated: asyncio.Event
     _p2p: P2P
     serializer = MSGPackSerializer
@@ -124,7 +117,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         target_group_size: Optional[int] = None,
         min_group_size: int = 2,
         initial_group_bits: str = "",
-        averaging_expiration: Optional[float] = None,
         min_matchmaking_time: float = 5.0,
         request_timeout: float = 3.0,
         averaging_alpha: float = 1.0,
@@ -152,11 +144,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
-        if averaging_expiration is not None:
-            logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
-            assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
-            min_matchmaking_time = averaging_expiration
-
         super().__init__()
         self.dht = dht
         self.prefix = prefix
@@ -207,7 +194,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             reducer_timeout=reducer_timeout,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
-        self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
+        self._running_groups: Dict[GroupID, asyncio.Future[AllReduceRunner]] = {}
 
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
@@ -309,8 +296,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.create_task(self._declare_for_download_periodically())
 
                 self._state_updated = asyncio.Event()
-                self._pending_group_assembled = asyncio.Event()
-                self._pending_group_assembled.set()
+                self._pending_groups_registered = asyncio.Event()
+                self._pending_groups_registered.set()
             except Exception as e:
                 # Loglevel is DEBUG since normally the exception is propagated to the caller
                 logger.debug(e, exc_info=True)
@@ -441,7 +428,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             while not step.done():
                 try:
-                    self._pending_group_assembled.clear()
+                    self._pending_groups_registered.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
@@ -458,17 +445,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group")
 
-                    step.stage = AveragingStage.RUNNING_ALLREDUCE
-
-                    step.set_result(
-                        await asyncio.wait_for(
-                            self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
-                            ),
-                            timeout=self._allreduce_timeout,
+                    with self._register_allreduce_group(group_info):
+                        step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                        step.set_result(
+                            await asyncio.wait_for(
+                                self._aggregate_with_group(
+                                    group_info,
+                                    tensor_infos=self.tensor_infos,
+                                    weight=step.weight,
+                                    **self.allreduce_kwargs,
+                                ),
+                                timeout=self._allreduce_timeout,
+                            )
                         )
-                    )
-                    # averaging is finished, loop will now exit
+                        # averaging is finished, loop will now exit
 
                 except (
                     AllreduceException,
@@ -503,8 +494,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                 )
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
-        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            self._running_groups[group_info.group_id] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            maybe_future = self._running_groups.pop(group_info.group_id, None)
+            if maybe_future is not None and not maybe_future.done():
+                logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
@@ -519,47 +523,39 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             )
 
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
-                allreduce = AllReduceRunner(
-                    p2p=self._p2p,
-                    servicer_type=type(self),
-                    prefix=self.prefix,
-                    group_id=group_info.group_id,
-                    tensors=local_tensors,
-                    ordered_peer_ids=group_info.peer_ids,
-                    peer_fractions=peer_fractions,
-                    gathered=user_gathered,
-                    modes=modes,
-                    **kwargs,
-                )
-
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        iter_results = allreduce.run()
-                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
-
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
-
-                return allreduce.gathered
+                await self._run_allreduce_inplace_(local_tensors, group_info, peer_fractions=peer_fractions, **kwargs)
+                return user_gathered
         except BaseException as e:
             if isinstance(e, Exception):
                 logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
-    @contextlib.contextmanager
-    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
-        """registers a given all-reduce runner to listen for incoming connections"""
-        try:
-            self._running_groups[group_id] = allreduce
-            self._pending_group_assembled.set()
-            yield
-        finally:
-            self._running_groups.pop(group_id, None)
-            self._pending_group_assembled.set()
+    async def _run_allreduce_inplace_(
+        self, tensors: Sequence[torch.Tensor], group_info: GroupInfo, group_id: Optional[bytes] = None, **kwargs
+    ):
+        """Run one allreduce process to average tensors inplace. Can be called more than a few times in one aggregation process"""
+        group_id = group_info.group_id if group_id is None else group_id
+
+        runner = AllReduceRunner(
+            p2p=self._p2p,
+            servicer_type=type(self),
+            prefix=self.prefix,
+            tensors=tensors,
+            group_id=group_id,
+            ordered_peer_ids=group_info.peer_ids,
+            **kwargs,
+        )
+        assert group_id in self._running_groups, f"Group id {group_id} was not registered in _register_allreduce_group"
+        self._running_groups[group_id].set_result(runner)
+
+        if runner.modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+            async for tensor, update in azip(as_aiter(*tensors), runner):
+                tensor.add_(update, alpha=self._averaging_alpha)
+                self.last_updated = get_dht_time()
+                self._state_updated.set()
+        else:
+            async for _ in runner:
+                raise ValueError("aux peers should not receive averaged tensors")
 
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -586,13 +582,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if request.group_id not in self._running_groups:
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # but his response with group_id was delayed and other peers got to us first
-            await self._pending_group_assembled.wait()
+            await self._pending_groups_registered.wait()
 
-        group = self._running_groups.get(request.group_id)
-        if group is None:
+        future = self._running_groups.get(request.group_id)
+        if future is None:
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
 
+        group = await future
         async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
 
@@ -706,6 +703,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
 
+                        # TODO merge this with hivemind.compression.deserialize_tensor_stream
                         async for message in aiter_with_timeout(stream, timeout=timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)

+ 1 - 1
hivemind/averaging/key_manager.py

@@ -7,7 +7,7 @@ import numpy as np
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.dht import DHT
 from hivemind.p2p import PeerID
-from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
+from hivemind.utils import DHTExpiration, get_logger
 
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101

+ 2 - 2
hivemind/averaging/matchmaking.py

@@ -12,11 +12,11 @@ from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
-from hivemind.dht import DHT, DHTID, DHTExpiration
+from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
-from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
+from hivemind.utils import DHTExpiration, TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils.asyncio import anext, cancel_and_wait
 
 logger = get_logger(__name__)

+ 5 - 1
hivemind/compression/__init__.py

@@ -6,4 +6,8 @@ from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveComp
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
-from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.compression.serialization import (
+    deserialize_tensor_stream,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)

+ 2 - 2
hivemind/compression/adaptive.py

@@ -3,8 +3,8 @@ from typing import Mapping, Sequence, Union
 
 import torch
 
-import hivemind
 from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole
+from hivemind.compression.serialization import deserialize_torch_tensor
 from hivemind.proto import runtime_pb2
 
 
@@ -20,7 +20,7 @@ class AdaptiveCompressionBase(CompressionBase, ABC):
         return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace)
 
     def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-        return hivemind.compression.deserialize_torch_tensor(serialized_tensor)
+        return deserialize_torch_tensor(serialized_tensor)
 
 
 class SizeAdaptiveCompression(AdaptiveCompressionBase):

+ 1 - 1
hivemind/compression/base.py

@@ -80,7 +80,7 @@ class NoCompression(CompressionBase):
     compression_type = runtime_pb2.CompressionType.NONE
 
     def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
-        array = tensor.numpy()
+        array = tensor.detach().numpy()
         return runtime_pb2.Tensor(
             compression=self.compression_type,
             buffer=array.tobytes(),

+ 25 - 1
hivemind/compression/serialization.py

@@ -1,4 +1,6 @@
-from typing import Dict, Optional
+from __future__ import annotations
+
+from typing import AsyncIterator, Dict, Iterable, List, Optional
 
 import torch
 
@@ -6,6 +8,7 @@ from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompre
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
 from hivemind.proto import runtime_pb2
+from hivemind.utils.streaming import combine_from_streaming
 
 BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
     NONE=NoCompression(),
@@ -41,3 +44,24 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
     """Restore a pytorch tensor from a protobuf message"""
     compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
     return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
+
+
+async def deserialize_tensor_stream(
+    stream: AsyncIterator[Iterable[runtime_pb2.Tensor]],
+) -> List[torch.Tensor]:
+    """Async wrapper of combine_from_streaming that combines tensors from a stream of parts and deserializes them"""
+
+    tensors = []
+    tensor_parts = []
+
+    async for parts in stream:
+        for part in parts:
+            if part.dtype and tensor_parts:
+                tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts)))
+                tensor_parts = []
+
+            tensor_parts.append(part)
+    if tensor_parts:
+        tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts)))
+
+    return tensors

+ 1 - 1
hivemind/dht/__init__.py

@@ -15,5 +15,5 @@ The code is organized as follows:
 
 from hivemind.dht.dht import DHT
 from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, DHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTValue
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase

+ 12 - 4
hivemind/dht/dht.py

@@ -55,6 +55,7 @@ class DHT(mp.Process):
         **kwargs,
     ):
         self._parent_pid = os.getpid()
+        self._origin_pid = os.getpid()
         super().__init__()
 
         if not (
@@ -168,6 +169,7 @@ class DHT(mp.Process):
         :param kwargs: parameters forwarded to DHTNode.get_many_by_id
         :returns: (value, expiration time); if value was not found, returns None
         """
+        assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
         future = MPFuture()
         self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
         return future if return_future else future.result()
@@ -201,6 +203,7 @@ class DHT(mp.Process):
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
+        assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
         future = MPFuture()
         self._outer_pipe.send(
             (
@@ -243,8 +246,9 @@ class DHT(mp.Process):
           DHT fields made by this coroutine will not be accessible from the host process.
         :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
           or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
-        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
+        :note: when run_coroutine is called with return_future=False, MPFuture can be cancelled to interrupt the task.
         """
+        assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
         future = MPFuture()
         self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
         return future if return_future else future.result()
@@ -274,7 +278,11 @@ class DHT(mp.Process):
     @property
     def peer_id(self) -> PeerID:
         if self._peer_id is None:
-            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+            if os.getpid() == self.pid:
+                self._peer_id = self._node.peer_id
+            else:
+                # note: we cannot run_coroutine from the same pid because it would deadlock the event loop
+                self._peer_id = self.run_coroutine(DHT._get_peer_id)
         return self._peer_id
 
     @staticmethod
@@ -309,8 +317,8 @@ class DHT(mp.Process):
         Get a replica of a P2P instance used in the DHT process internally.
         The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
         """
-
-        if self._p2p_replica is None:
+        if self._p2p_replica is None or self._origin_pid != os.getpid():
+            self._origin_pid = os.getpid()
             daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
             self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
         return self._p2p_replica

+ 2 - 2
hivemind/dht/node.py

@@ -146,7 +146,7 @@ class DHTNode:
         :param cache_locally: if True, caches all values (stored or found) in a node-local cache
         :param cache_on_store: if True, update cache entries for a key after storing a new item for that key
         :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
-          nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
+          nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
         :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
         :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
           if they are accessed this many seconds before expiration time.
@@ -341,7 +341,7 @@ class DHTNode:
     ) -> bool:
         """
         Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
-        :note: store is a simplified interface to store_many, all kwargs are be forwarded there
+        :note: store is a simplified interface to store_many, all kwargs are forwarded there
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)

+ 1 - 1
hivemind/dht/routing.py

@@ -10,7 +10,7 @@ from itertools import chain
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.p2p import PeerID
-from hivemind.utils import DHTExpiration, MSGPackSerializer, get_dht_time
+from hivemind.utils import MSGPackSerializer, get_dht_time
 
 DHTKey = Subkey = DHTValue = Any
 BinaryDHTID = BinaryDHTValue = bytes

+ 0 - 1
hivemind/hivemind_cli/config.yml

@@ -1,4 +1,3 @@
-listen_on: 0.0.0.0:*
 num_experts: 16
 expert_cls: ffn
 hidden_dim: 1024

+ 76 - 0
hivemind/hivemind_cli/run_dht.py

@@ -0,0 +1,76 @@
+import time
+from argparse import ArgumentParser
+
+from hivemind.dht import DHT, DHTNode
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import log_visible_maddrs
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
+
+
+async def report_status(dht: DHT, node: DHTNode):
+    logger.info(
+        f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
+        f"are in the local routing table "
+    )
+    logger.debug(f"Routing table contents: {node.protocol.routing_table}")
+    logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
+    logger.debug(f"Local storage contents: {node.protocol.storage}")
+
+
+def main():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "--initial_peers",
+        nargs="*",
+        help="Multiaddrs of the peers that will welcome you into the existing DHT. "
+        "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
+    )
+    parser.add_argument(
+        "--host_maddrs",
+        nargs="*",
+        default=["/ip4/0.0.0.0/tcp/0"],
+        help="Multiaddrs to listen for external connections from other DHT instances. "
+        "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
+    )
+    parser.add_argument(
+        "--announce_maddrs",
+        nargs="*",
+        help="Visible multiaddrs the host announces for external connections from other DHT instances",
+    )
+    parser.add_argument(
+        "--use_ipfs",
+        action="store_true",
+        help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
+        "part of the multiaddrs for the initial_peers "
+        "(no need to specify a particular IPv4/IPv6 host and port)",
+    )
+    parser.add_argument(
+        "--identity_path",
+        help="Path to a private key file. If defined, makes the peer ID deterministic. "
+        "If the file does not exist, writes a new private key to this file.",
+    )
+    parser.add_argument(
+        "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
+    )
+
+    args = parser.parse_args()
+
+    dht = DHT(
+        start=True,
+        initial_peers=args.initial_peers,
+        host_maddrs=args.host_maddrs,
+        announce_maddrs=args.announce_maddrs,
+        use_ipfs=args.use_ipfs,
+        identity_path=args.identity_path,
+    )
+    log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
+
+    while True:
+        dht.run_coroutine(report_status, return_future=False)
+        time.sleep(args.refresh_period)
+
+
+if __name__ == "__main__":
+    main()

+ 13 - 4
hivemind/hivemind_cli/run_server.py

@@ -18,8 +18,7 @@ def main():
     # fmt:off
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
-    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
-                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
+
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
                         help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will'
@@ -32,6 +31,11 @@ def main():
                         help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'")
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
 
+    parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
+                        help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
+    parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
+                        help='Visible multiaddrs the host announces for external connections from other p2p instances')
+
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--min_batch_size', type=int, default=1,
@@ -46,10 +50,14 @@ def main():
                         help='LR scheduler type to use')
     parser.add_argument('--num_warmup_steps', type=int, required=False,
                         help='The number of warmup steps for LR schedule')
-    parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
+    parser.add_argument('--update_period', type=float, required=False, default=30,
+                        help='Server will report experts to DHT once in this many seconds')
+    parser.add_argument('--expiration', type=float, required=False, default=None,
+                        help='DHT entries will expire after this many seconds')
+    parser.add_argument('--num_training_steps', type=int, required=False, help='The total number of steps for LR schedule')
+
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
-    parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
                         help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
     parser.add_argument('--increase_file_limit', action='store_true',
@@ -62,6 +70,7 @@ def main():
 
     parser.add_argument('--custom_module_path', type=str, required=False,
                         help='Path of a file with custom nn.modules, wrapped into special decorator')
+    parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
 
     # fmt:on
     args = vars(parser.parse_args())

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,6 +1,6 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import (
-    ExpertBackend,
+    ModuleBackend,
     Server,
     background_server,
     declare_experts,

+ 59 - 55
hivemind/moe/client/beam_search.py

@@ -4,20 +4,22 @@ from collections import deque
 from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
-from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import RemoteExpert
-from hivemind.moe.server.expert_uid import (
+from hivemind.dht import DHT, DHTNode
+from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
+from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
     PREFIX_PATTERN,
     UID_DELIMITER,
     Coordinate,
+    ExpertInfo,
     ExpertPrefix,
     ExpertUID,
     Score,
-    UidEndpoint,
     is_valid_prefix,
+    is_valid_uid,
 )
-from hivemind.utils import MPFuture, get_dht_time, get_logger
+from hivemind.p2p import PeerID
+from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 
@@ -94,7 +96,7 @@ class MoEBeamSearcher:
 
     def get_initial_beam(
         self, scores: Sequence[float], beam_size: int, return_future: bool = False
-    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
         """
         :param scores: prefer suffix coordinates that have highest scores
         :param beam_size: select this many active suffixes with highest scores
@@ -124,9 +126,9 @@ class MoEBeamSearcher:
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
-    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
         num_workers = num_workers or dht.num_workers or beam_size
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = []
         unattempted_indices: List[Coordinate] = sorted(
             range(len(scores)), key=scores.__getitem__
         )  # from worst to best
@@ -144,13 +146,7 @@ class MoEBeamSearcher:
             try:
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
-                    successors = {
-                        coord: UidEndpoint(*match.value)
-                        for coord, match in maybe_prefix_data.value.items()
-                        if isinstance(coord, Coordinate)
-                        and isinstance(getattr(match, "value", None), list)
-                        and len(match.value) == 2
-                    }
+                    successors = MoEBeamSearcher._select_valid_entries(maybe_prefix_data)
                     if successors:
                         beam.append((scores[pending_best_index], pending_best_prefix, successors))
                 elif maybe_prefix_data is None and negative_caching:
@@ -172,7 +168,7 @@ class MoEBeamSearcher:
 
     def get_active_successors(
         self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
-    ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+    ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
         """
         :param prefixes: a list of prefix for which to find active successor uids
         :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
@@ -195,6 +191,22 @@ class MoEBeamSearcher:
             return_future=return_future,
         )
 
+    @staticmethod
+    def _select_valid_entries(entry: ValueWithExpiration, grid_size: Optional[int] = None):
+        if not isinstance(entry, ValueWithExpiration) or not isinstance(entry.value, dict):
+            return {}
+        return {
+            coord: ExpertInfo(uid=match.value[0], peer_id=PeerID.from_base58(match.value[1]))
+            for coord, match in entry.value.items()
+            if isinstance(coord, Coordinate)
+            and (grid_size is None or 0 <= coord < grid_size)
+            and isinstance(match, ValueWithExpiration)
+            and isinstance(match.value, tuple)
+            and len(match.value) == 2
+            and is_valid_uid(match.value[0])
+            and isinstance(match.value[1], str)
+        }
+
     @staticmethod
     async def _get_active_successors(
         dht: DHT,
@@ -204,33 +216,23 @@ class MoEBeamSearcher:
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
-    ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+    ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
         grid_size = grid_size or float("inf")
         num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
-        successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
+        successors: Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]] = {}
         for prefix, found in dht_responses.items():
-            if found and isinstance(found.value, dict):
-                successors[prefix] = {
-                    coord: UidEndpoint(*match.value)
-                    for coord, match in found.value.items()
-                    if isinstance(coord, Coordinate)
-                    and 0 <= coord < grid_size
-                    and isinstance(getattr(match, "value", None), list)
-                    and len(match.value) == 2
-                }
-            else:
-                successors[prefix] = {}
-                if found is None and negative_caching:
-                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
-                    asyncio.create_task(
-                        node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
-                    )
+            successors[prefix] = MoEBeamSearcher._select_valid_entries(found, grid_size)
+            if not successors[prefix] and negative_caching:
+                logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
+                asyncio.create_task(
+                    node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
+                )
         return successors
 
     def find_best_experts(
         self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
-    ) -> Union[List[RemoteExpert], MPFuture[RemoteExpert]]:
+    ) -> Union[List[RemoteExpert], MPFuture[List[RemoteExpert]]]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -240,12 +242,11 @@ class MoEBeamSearcher:
          After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
          Please note that any queries that fall outside the budget will still be performed in background and cached
          for subsequent iterations as long as DHTNode.cache_locally is True
-        :param num_workers: use up to this many concurrent workers to search DHT
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         assert len(grid_scores) == len(self.grid_size) and beam_size > 0
-        return self.dht.run_coroutine(
+        result = self.dht.run_coroutine(
             partial(
                 self._find_best_experts,
                 prefix=self.uid_prefix,
@@ -257,6 +258,7 @@ class MoEBeamSearcher:
             ),
             return_future,
         )
+        return create_remote_experts(result, self.dht, return_future)
 
     @classmethod
     async def _find_best_experts(
@@ -269,23 +271,23 @@ class MoEBeamSearcher:
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
-    ) -> List[RemoteExpert]:
+    ) -> List[ExpertInfo]:
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = await cls._get_initial_beam(
             dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
         )
 
-        best_experts_heap: List[Tuple[Score, UidEndpoint]] = []  # max-heap of expert uids/endpoints ordered by scores
+        best_experts_heap: List[Tuple[Score, ExpertInfo]] = []  # max-heap of expert infos ordered by scores
         unique_experts: Set[ExpertUID] = set()
 
         for dim_index in range(1, len(grid_scores) - 1):
-            for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
-                if uid_endpoint.uid not in unique_experts:
+            for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
+                if expert_info.uid not in unique_experts:
                     push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                    push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                    unique_experts.add(uid_endpoint.uid)
+                    push_and_maybe_pop(best_experts_heap, (score, expert_info))
+                    unique_experts.add(expert_info.uid)
 
             # form new beam using successors from the current beam
             dim_scores = grid_scores[dim_index]
@@ -298,6 +300,7 @@ class MoEBeamSearcher:
                     if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
                 ),
             )
+
             _, best_uid_prefixes = zip(*best_active_pairs)
 
             # search DHT for next step suffixes
@@ -316,19 +319,18 @@ class MoEBeamSearcher:
                 break
 
         # add best experts from the final beam
-        for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
-            if uid_endpoint.uid not in unique_experts:
+        for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
+            if expert_info.uid not in unique_experts:
                 push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                unique_experts.add(uid_endpoint.uid)
+                push_and_maybe_pop(best_experts_heap, (score, expert_info))
+                unique_experts.add(expert_info.uid)
 
-        best_experts = [RemoteExpert(*uid_endpoint) for score, uid_endpoint in sorted(best_experts_heap, reverse=True)]
-        return best_experts
+        return [expert_info for _, expert_info in sorted(best_experts_heap, reverse=True)]
 
     @staticmethod
     def _iterate_matching_experts(
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]], grid_scores: Sequence[Sequence[float]]
-    ) -> Iterator[Tuple[Score, UidEndpoint]]:
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]], grid_scores: Sequence[Sequence[float]]
+    ) -> Iterator[Tuple[Score, ExpertInfo]]:
         """iterate over all exemplar experts attached to current beam"""
         for score, prefix, suffixes in beam:
             for next_coord, match in suffixes.items():
@@ -351,7 +353,7 @@ class MoEBeamSearcher:
 
     def batch_find_best_experts(
         self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
-    ) -> Union[List[List[RemoteExpert]], MPFuture]:
+    ) -> Union[List[List[RemoteExpert]], MPFuture[List[List[RemoteExpert]]]]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -364,7 +366,7 @@ class MoEBeamSearcher:
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
-        return self.dht.run_coroutine(
+        result = self.dht.run_coroutine(
             partial(
                 self._batch_find_best_experts,
                 prefix=self.uid_prefix,
@@ -376,6 +378,8 @@ class MoEBeamSearcher:
             return_future,
         )
 
+        return batch_create_remote_experts(result, self.dht, return_future)
+
     @classmethod
     async def _batch_find_best_experts(
         cls,
@@ -386,7 +390,7 @@ class MoEBeamSearcher:
         beam_size: int,
         negative_caching: bool,
         num_workers: Optional[int],
-    ) -> Sequence[Sequence[RemoteExpert]]:
+    ) -> Sequence[Sequence[ExpertInfo]]:
         batch_grid_scores = [
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
         ]

+ 158 - 37
hivemind/moe/client/expert.py

@@ -1,43 +1,61 @@
-from typing import Any, Dict, Optional, Tuple
+from __future__ import annotations
+
+from concurrent.futures import Future
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, MSGPackSerializer, nested_compare, nested_flatten, nested_pack
-from hivemind.utils.grpc import ChannelCache
+from hivemind import moe
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertInfo
+from hivemind.p2p import P2P, PeerID, StubBase
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.mpfuture import MPFuture
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
+from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.streaming import split_for_streaming
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
-def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
-    """Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache"""
-    channel_options = (("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)) + extra_options
-    return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
+def get_server_stub(p2p: P2P, server_peer_id: PeerID) -> "ConnectionHandlerStub":
+    """Create an RPC stub that can send requests to any expert on the specified remote server"""
+    return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_id)
 
 
 class RemoteExpert(nn.Module):
     """
     A simple module that runs forward/backward of an expert hosted on a remote machine.
     Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
-
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
 
-    :param uid: unique expert identifier
-    :param endpoint: network endpoint of a server that services that expert, e.g. "201.123.321.99:1337" or "[::]:8080"
+    :param expert_info: RemoteExpertInfo with uid and server PeerInfo
+    :param p2p: P2P instance connected to the running p2pd
     """
 
-    def __init__(self, uid, endpoint: Endpoint):
+    def __init__(self, expert_info: ExpertInfo, p2p: P2P):
         super().__init__()
-        self.uid, self.endpoint = uid, endpoint
-        self._info = None
+        self._info, self.p2p = expert_info, p2p
+        self._rpc_info = None
+
+    @property
+    def uid(self):
+        return self._info.uid
+
+    @property
+    def peer_id(self) -> PeerID:
+        return self._info.peer_id
 
     @property
-    def stub(self):
-        return _get_expert_stub(self.endpoint)
+    def stub(self) -> StubBase:
+        return get_server_stub(self.p2p, self.peer_id)
 
     def forward(self, *args, **kwargs):
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
@@ -52,18 +70,125 @@ class RemoteExpert(nn.Module):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
+
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
 
     @property
     def info(self):
-        if self._info is None:
-            outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
-            self._info = MSGPackSerializer.loads(outputs.serialized_info)
-        return self._info
+        if self._rpc_info is None:
+            outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
+            self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+        return self._rpc_info
 
     def extra_repr(self):
-        return f"uid={self.uid}, endpoint={self.endpoint}"
+        return f"uid={self.uid}, server_peer_id={self.peer_id}"
+
+
+def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
+    experts: List[Optional[RemoteExpert]] = []
+    for info in infos:
+        if info is not None:
+            experts.append(RemoteExpert(info, p2p))
+        else:
+            experts.append(None)
+    return experts
+
+
+def create_remote_experts(
+    infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+) -> Union[List[Optional[RemoteExpert]], Future]:
+    if return_future:
+
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return _create_remote_experts(await infos_future, p2p)
+
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+
+    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+    return _create_remote_experts(infos, p2p)
+
+
+def batch_create_remote_experts(
+    infos: Union[Sequence[Sequence[Optional[ExpertInfo]]], MPFuture],
+    dht: DHT,
+    return_future: bool = False,
+) -> Union[List[List[Optional[RemoteExpert]]], Future]:
+    if return_future:
+
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            return [_create_remote_experts(i, p2p) for i in await infos_future]
+
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+
+    return [create_remote_experts(exps, dht) for exps in infos]
+
+
+async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
+
+    grad_inputs = await stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
+            iter_as_aiter(split),
+        ),
+    )
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+
+
+async def expert_backward(
+    uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
+    size = 0
+    for t in inputs_and_grads:
+        size += t.element_size() * t.nelement()
+        if size > DEFAULT_MAX_MSG_SIZE:
+            return await _backward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _backward_unary(uid, serialized_tensors, stub)
+
+
+async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
+
+    outputs = await stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
+            iter_as_aiter(split),
+        ),
+    )
+
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def expert_forward(
+    uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
+    size = 0
+    for t in inputs:
+        size += t.element_size() * t.nelement()
+        if size > DEFAULT_MAX_MSG_SIZE:
+            return await _forward_stream(uid, serialized_tensors, stub)
+    else:
+        return await _forward_unary(uid, serialized_tensors, stub)
 
 
 class _RemoteModuleCall(torch.autograd.Function):
@@ -74,7 +199,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx,
         dummy: torch.Tensor,
         uid: str,
-        stub: runtime_grpc.ConnectionHandlerStub,
+        stub: "ConnectionHandlerStub",
         info: Dict[str, Any],
         *inputs: torch.Tensor,
     ) -> Tuple[torch.Tensor, ...]:
@@ -83,15 +208,11 @@ class _RemoteModuleCall(torch.autograd.Function):
         inputs = tuple(tensor.cpu().detach() for tensor in inputs)
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
-
-        serialized_tensors = [
-            serialize_torch_tensor(inp, proto.compression)
-            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
-        ]
-
-        outputs = stub.forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
-
-        deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
+        serialized_tensors = (
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        )
+        deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(uid, inputs, serialized_tensors, stub))
 
         return tuple(deserialized_outputs)
 
@@ -101,12 +222,12 @@ class _RemoteModuleCall(torch.autograd.Function):
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = [
+        serialized_tensors = (
             serialize_torch_tensor(tensor, proto.compression)
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        ]
-
-        grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        )
+        deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
+            expert_backward(ctx.uid, inputs_and_grad_outputs, serialized_tensors, ctx.stub)
+        )
 
-        deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 29 - 25
hivemind/moe/client/moe.py

@@ -1,20 +1,21 @@
 from __future__ import annotations
 
 import time
+from concurrent.futures import Future
 from queue import Empty, Queue
 from typing import Any, Dict, List, Optional, Tuple
 
-import grpc
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.compression import serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
-from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, expert_backward, expert_forward, get_server_stub
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import UID_DELIMITER
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils.logging import get_logger
 
@@ -104,7 +105,7 @@ class RemoteMixtureOfExperts(nn.Module):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
@@ -178,7 +179,7 @@ class RemoteMixtureOfExperts(nn.Module):
             # grab some expert to set ensemble output shape
             proj_device = self.proj.weight.device
             dummy_scores_concat = self.proj(torch.randn(1, self.proj.in_features, device=proj_device))
-            dummy_scores = dummy_scores_concat.cpu().split_with_sizes(self.beam_search.grid_size, dim=-1)
+            dummy_scores = dummy_scores_concat.cpu().detach().split_with_sizes(self.beam_search.grid_size, dim=-1)
             dummy_experts = self.beam_search.find_best_experts(dummy_scores, beam_size=1)
             self._expert_info = dummy_experts[0].info
         return self._expert_info
@@ -223,15 +224,18 @@ class _RemoteCallMany(torch.autograd.Function):
         assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
 
         # dispatch tasks to all remote experts collect responses
-        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
+        pending_tasks: Dict[Future, Tuple[int, int]] = {}
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
-                input_tensors = [
+                stub = get_server_stub(expert.p2p, expert.peer_id)
+                serialized_tensors = (
                     serialize_torch_tensor(tensor, proto.compression)
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
-                ]
-                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
-                new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
+                )
+                new_task = RemoteExpertWorker.run_coroutine(
+                    expert_forward(expert.uid, flat_inputs_per_sample[i], serialized_tensors, stub),
+                    return_future=True,
+                )
                 pending_tasks[new_task] = (i, j)
 
         responded_inds, alive_flat_outputs = cls._collect_responses(
@@ -316,14 +320,16 @@ class _RemoteCallMany(torch.autograd.Function):
         for i, j, inputs_ij, grad_outputs_ij in zip(
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
         ):
-            expert = expert_per_sample[i.item()][j.item()]
-            stub = _get_expert_stub(expert.endpoint)
+            expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
+            stub = get_server_stub(expert.p2p, expert.peer_id)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
-            tensors_serialized = [
+            serialized_tensors = (
                 serialize_torch_tensor(tensor, proto.compression)
                 for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-            ]
-            new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
+            )
+            new_task = RemoteExpertWorker.run_coroutine(
+                expert_backward(expert.uid, inputs_and_grad_outputs, serialized_tensors, stub), return_future=True
+            )
             pending_tasks[new_task] = (i, j)
 
         survivor_inds, survivor_grad_inputs = cls._collect_responses(
@@ -358,7 +364,7 @@ class _RemoteCallMany(torch.autograd.Function):
 
     @staticmethod
     def _collect_responses(
-        task_to_indices: Dict[grpc.Future, Tuple[int, int]],
+        task_to_indices: Dict[Future, Tuple[int, int]],
         num_samples: int,
         k_min: int,
         timeout_total: Optional[float],
@@ -408,17 +414,15 @@ class _RemoteCallMany(torch.autograd.Function):
         return finished_indices, finished_outputs
 
 
-def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
+def _process_dispatched_task(task: Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
     if task.exception() or task.cancelled():
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         return None
 
-    deserialized_outputs = []
-    for tensor in task.result().tensors:
-        deserialized_tensor = deserialize_torch_tensor(tensor)
-        if detect_anomalies and not deserialized_tensor.isfinite().all():
+    outputs = task.result()
+    for tensor in outputs:
+        if detect_anomalies and not tensor.isfinite().all():
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")
             return None
-        deserialized_outputs.append(deserialized_tensor)
 
-    return tuple(deserialized_outputs)
+    return outputs

+ 48 - 0
hivemind/moe/client/remote_expert_worker.py

@@ -0,0 +1,48 @@
+import os
+from concurrent.futures import Future
+from queue import Queue
+from threading import Thread
+from typing import Awaitable, Optional
+
+from hivemind.utils import switch_to_uvloop
+
+
+class RemoteExpertWorker:
+    """Local thread for managing async tasks related to RemoteExpert"""
+
+    _task_queue: Queue = Queue()
+    _event_thread: Optional[Thread] = None
+    _pid: int = -1
+
+    @classmethod
+    def _run(cls):
+        loop = switch_to_uvloop()
+
+        async def receive_tasks():
+            while True:
+                cor, future = cls._task_queue.get()
+                try:
+                    result = await cor
+                except Exception as e:
+                    future.set_exception(e)
+                    continue
+                if not future.cancelled():
+                    future.set_result(result)
+
+        loop.run_until_complete(receive_tasks())
+
+    @classmethod
+    def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
+        if cls._event_thread is None or cls._pid != os.getpid():
+            cls._pid = os.getpid()
+            cls._event_thread = Thread(target=cls._run, daemon=True)
+            cls._event_thread.start()
+
+        future = Future()
+        cls._task_queue.put((coro, future))
+
+        if return_future:
+            return future
+
+        result = future.result()
+        return result

+ 3 - 3
hivemind/moe/client/switch_moe.py

@@ -2,12 +2,12 @@ from __future__ import annotations
 
 from typing import List, Tuple
 
-import grpc
 import torch
 
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
-from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.moe.expert_uid import UID_DELIMITER
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 
@@ -110,7 +110,7 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
                     "No responding experts found during beam search. Check that UID prefixes and "
                     "the grid size are consistent with running Server instances."
                 )
-            except grpc.RpcError as e:
+            except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(

+ 4 - 2
hivemind/moe/server/expert_uid.py → hivemind/moe/expert_uid.py

@@ -1,10 +1,12 @@
+from __future__ import annotations
+
 import re
 from typing import NamedTuple, Tuple, Union
 
-from hivemind.utils import Endpoint
+from hivemind.p2p import PeerID
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
+ExpertInfo = NamedTuple("ExpertInfo", [("uid", ExpertUID), ("peer_id", PeerID)])
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims

+ 1 - 1
hivemind/moe/server/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.moe.server.dht_handler import declare_experts, get_experts
-from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.layers import register_expert_class
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.server import Server, background_server

+ 9 - 9
hivemind/moe/server/checkpoints.py

@@ -8,7 +8,7 @@ from typing import Dict
 
 import torch
 
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -34,23 +34,23 @@ def copy_tree(src: str, dst: str):
 
 
 class CheckpointSaver(threading.Thread):
-    def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
+    def __init__(self, module_backends: Dict[str, ModuleBackend], checkpoint_dir: Path, update_period: float):
         super().__init__()
         assert is_directory(checkpoint_dir)
-        self.expert_backends = expert_backends
+        self.module_backends = module_backends
         self.update_period = update_period
         self.checkpoint_dir = checkpoint_dir
         self.stop = threading.Event()
 
         # create expert directories to ensure that the directory is writable and checkpoints can be loaded
-        store_experts(self.expert_backends, self.checkpoint_dir)
+        store_experts(self.module_backends, self.checkpoint_dir)
 
     def run(self) -> None:
         while not self.stop.wait(self.update_period):
-            store_experts(self.expert_backends, self.checkpoint_dir)
+            store_experts(self.module_backends, self.checkpoint_dir)
 
 
-def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+def store_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path):
     logger.debug(f"Storing experts at {checkpoint_dir.absolute()}")
     assert is_directory(checkpoint_dir)
     timestamp = datetime.now().isoformat(sep="_")
@@ -59,17 +59,17 @@ def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir.mkdir()
             checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt"
-            torch.save(expert_backend.get_full_state(), checkpoint_name)
+            torch.save(expert_backend.state_dict(), checkpoint_name)
             os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt")
         copy_tree(tmpdirname, str(checkpoint_dir))
 
 
-def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
+def load_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path):
     assert is_directory(checkpoint_dir)
     for expert_name, expert in experts.items():
         checkpoints_folder = checkpoint_dir / expert_name
         latest_checkpoint = checkpoints_folder / "checkpoint_last.pt"
         if latest_checkpoint.exists():
-            expert.load_full_state(torch.load(latest_checkpoint))
+            expert.load_state_dict(torch.load(latest_checkpoint))
         else:
             logger.warning(f"Failed to load checkpoint for expert {expert_name}")

+ 106 - 51
hivemind/moe/server/connection_handler.py

@@ -1,82 +1,137 @@
+import asyncio
 import multiprocessing as mp
-import os
-from typing import Dict
+from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Union
 
-import grpc
 import torch
 
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, MSGPackSerializer, get_logger, nested_flatten
-from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
+from hivemind.moe.server.module_backend import ModuleBackend
+from hivemind.moe.server.task_pool import TaskPool
+from hivemind.p2p import P2PContext, ServicerBase
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P
+from hivemind.proto import runtime_pb2
+from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
+from hivemind.utils.asyncio import amap_in_executor, switch_to_uvloop
+from hivemind.utils.streaming import split_for_streaming
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 logger = get_logger(__name__)
 
 
-class ConnectionHandler(mp.context.ForkProcess):
+class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     """
     A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
 
-    :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port.
-    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
-    :param experts: a dict [UID -> ExpertBackend] with all active experts
+    :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
+    :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
+    :param module_backends: a dict [UID -> ModuleBackend] with all active experts
     """
 
-    def __init__(self, listen_on: Endpoint, experts: Dict[str, ExpertBackend]):
+    def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend]):
         super().__init__()
-        self.listen_on, self.experts = listen_on, experts
-        self.ready = mp.Event()
+        self.dht, self.module_backends = dht, module_backends
+        self._p2p: Optional[P2P] = None
+
+        self.ready = MPFuture()
 
     def run(self):
         torch.set_num_threads(1)
         loop = switch_to_uvloop()
 
         async def _run():
-            grpc.aio.init_grpc_aio()
-            logger.debug(f"Starting, pid {os.getpid()}")
-            server = grpc.aio.server(
-                options=GRPC_KEEPALIVE_OPTIONS
-                + (
-                    ("grpc.so_reuseport", 1),
-                    ("grpc.max_send_message_length", -1),
-                    ("grpc.max_receive_message_length", -1),
-                )
-            )
-            runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
-
-            found_port = server.add_insecure_port(self.listen_on)
-            assert found_port != 0, f"Failed to listen to {self.listen_on}"
-
-            await server.start()
-            self.ready.set()
-            await server.wait_for_termination()
-            logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")
+            try:
+                self._p2p = await self.dht.replicate_p2p()
+                await self.add_p2p_handlers(self._p2p, balanced=True)
+
+                # wait forever
+                await asyncio.Future()
+
+            except Exception as e:
+                self.ready.set_exception(e)
+                return
+
+        self.ready.set_result(None)
 
         try:
             loop.run_until_complete(_run())
         except KeyboardInterrupt:
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
-    async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
-        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
+    async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
+        module_info = self.module_backends[request.uid].get_info()
+        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))
+
+    async def _gather_inputs(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> Tuple[str, List[torch.Tensor]]:
+        expert_uid = None
+
+        def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+            nonlocal expert_uid
+
+            if expert_uid is None:
+                expert_uid = req.uid
+            elif expert_uid != req.uid:
+                raise ValueError("Expert uids differ in one request")
+
+            return req.tensors
+
+        tensors_stream = amap_in_executor(_unpack, requests)
+        inputs = await deserialize_tensor_stream(tensors_stream)
+        return expert_uid, inputs
+
+    async def _process_inputs(
+        self,
+        inputs: List[torch.Tensor],
+        pool: TaskPool,
+        schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]],
+    ) -> List[runtime_pb2.Tensor]:
+        return [
+            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            for result, proto in zip(await pool.submit_task(*inputs), nested_flatten(schema))
+        ]
 
-    async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
+    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        future = self.experts[request.uid].forward_pool.submit_task(*inputs)
-        serialized_response = [
-            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
-            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
+        expert = self.module_backends[request.uid]
+        return runtime_pb2.ExpertResponse(
+            tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+        )
+
+    async def rpc_forward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        uid, inputs = await self._gather_inputs(requests, context)
+        expert = self.module_backends[uid]
+        output_split = [
+            part
+            for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
 
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
-
-    async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
-        inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
-        serialized_response = [
-            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
-            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part])
+
+    async def rpc_backward(
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> runtime_pb2.ExpertResponse:
+        inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        expert = self.module_backends[request.uid]
+        return runtime_pb2.ExpertResponse(
+            tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
+        )
+
+    async def rpc_backward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+        uid, inputs_and_grads = await self._gather_inputs(requests, context)
+        expert = self.module_backends[uid]
+        output_split = [
+            part
+            for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part])

+ 35 - 26
hivemind/moe/server/dht_handler.py

@@ -1,70 +1,77 @@
 import threading
 from functools import partial
-from typing import Dict, List, Optional, Sequence, Tuple
+from typing import Dict, List, Optional, Sequence, Tuple, Union
 
-from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert
-from hivemind.moe.server.expert_uid import (
+from hivemind.dht import DHT, DHTNode, DHTValue
+from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
+from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
     UID_DELIMITER,
     UID_PATTERN,
     Coordinate,
+    ExpertInfo,
     ExpertPrefix,
     ExpertUID,
     is_valid_uid,
     split_uid,
 )
-from hivemind.utils import Endpoint, get_dht_time, get_port
+from hivemind.p2p import PeerID
+from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, DHTExpiration, MPFuture, get_dht_time
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5, **kwargs):
+    def __init__(
+        self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
+    ):
         super().__init__(**kwargs)
-        assert get_port(endpoint) is not None
-        self.endpoint = endpoint
-        self.experts = experts
+        if expiration is None:
+            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+        self.module_backends = module_backends
         self.dht = dht
         self.update_period = update_period
+        self.expiration = expiration
         self.stop = threading.Event()
 
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys(), self.endpoint)
+        declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys(), self.endpoint)
+            declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)
 
 
 def declare_experts(
-    dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration = 300, wait: bool = True
-) -> Dict[ExpertUID, bool]:
+    dht: DHT, uids: Sequence[ExpertUID], expiration_time: DHTExpiration, wait: bool = True
+) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]:
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
 
     :param uids: a list of expert ids to update
-    :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
-    :param expiration: experts will be visible for this many seconds
+    :param expiration_time: experts will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    if not isinstance(uids, list):
+        uids = list(uids)
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
     return dht.run_coroutine(
-        partial(_declare_experts, uids=list(uids), endpoint=endpoint, expiration=expiration), return_future=not wait
+        partial(_declare_experts, uids=uids, expiration_time=expiration_time), return_future=not wait
     )
 
 
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
-    expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
+    peer_id_base58 = dht.peer_id.to_base58()
+
     for uid in uids:
-        data_to_store[uid, None] = endpoint
+        data_to_store[uid, None] = peer_id_base58
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         for i in range(prefix.count(UID_DELIMITER) - 1):
             prefix, last_coord = split_uid(prefix)
-            data_to_store[prefix, last_coord] = [uid, endpoint]
+            data_to_store[prefix, last_coord] = (uid, peer_id_base58)
 
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
@@ -73,7 +80,7 @@ async def _declare_experts(
 
 def get_experts(
     dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
-) -> List[Optional[RemoteExpert]]:
+) -> Union[List[Optional[RemoteExpert]], MPFuture[List[Optional[RemoteExpert]]]]:
     """
     :param uids: find experts with these ids from across the DHT
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
@@ -81,19 +88,21 @@ def get_experts(
     :returns: a list of [RemoteExpert if found else None]
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-    return dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+    result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
+    return create_remote_experts(result, dht, return_future)
 
 
 async def _get_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[RemoteExpert]]:
+) -> List[Optional[ExpertInfo]]:
     if expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
-    experts: List[Optional[RemoteExpert]] = [None] * len(uids)
+    experts: List[Optional[ExpertInfo]] = [None] * len(uids)
     for i, uid in enumerate(uids):
-        if found[uid] is not None and isinstance(found[uid].value, Endpoint):
-            experts[i] = RemoteExpert(uid, found[uid].value)
+        server_peer_id = found[uid]
+        if server_peer_id is not None and isinstance(server_peer_id.value, str):
+            experts[i] = ExpertInfo(uid, PeerID.from_base58(server_peer_id.value))
     return experts

+ 1 - 1
hivemind/moe/server/layers/dropout.py

@@ -19,7 +19,7 @@ class DeterministicDropoutFunction(torch.autograd.Function):
 class DeterministicDropout(nn.Module):
     """
     Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations).
-    Can be used with RemoteExpert/ExpertBackend to ensure that dropout mask is the same at forward and backward steps
+    Can be used with RemoteExpert/ModuleBackend to ensure that dropout mask is the same at forward and backward steps
     """
 
     def __init__(self, drop_prob):

+ 58 - 0
hivemind/moe/server/layers/optim.py

@@ -0,0 +1,58 @@
+import torch
+
+
+class OptimizerWrapper(torch.optim.Optimizer):
+    """A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""
+
+    def __init__(self, optim: torch.optim.Optimizer):
+        super().__init__(optim.param_groups, optim.defaults)
+        self.optim = optim
+
+    @property
+    def defaults(self):
+        return self.optim.defaults
+
+    @property
+    def state(self):
+        return self.optim.state
+
+    def __getstate__(self):
+        return self.optim.__getstate__()
+
+    def __setstate__(self, state):
+        self.optim.__setstate__(state)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({repr(self.optim)})"
+
+    def state_dict(self):
+        return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        return self.optim.load_state_dict(state_dict)
+
+    def step(self, *args, **kwargs):
+        return self.optim.step(*args, **kwargs)
+
+    def zero_grad(self, *args, **kwargs):
+        return self.optim.zero_grad(*args, **kwargs)
+
+    @property
+    def param_groups(self):
+        return self.optim.param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        return self.optim.add_param_group(param_group)
+
+
+class ClippingWrapper(OptimizerWrapper):
+    """A wrapper of torch.Optimizer that clips gradients by global norm before each step"""
+
+    def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float):
+        super().__init__(optim)
+        self.clip_grad_norm = clip_grad_norm
+
+    def step(self, *args, **kwargs):
+        parameters = tuple(param for group in self.param_groups for param in group["params"])
+        torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm)
+        return super().step(*args, **kwargs)

+ 46 - 84
hivemind/moe/server/expert_backend.py → hivemind/moe/server/module_backend.py

@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Sequence, Tuple, Union
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
 
 import torch
 from torch import nn
@@ -8,19 +8,20 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 logger = get_logger(__name__)
 
 
-class ExpertBackend:
+class ModuleBackend:
     """
-    ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
-    By default, ExpertBackend handles three types of requests:
+    ModuleBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
+    By default, ModuleBackend handles three types of requests:
 
      - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
      - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
      - get_info - return expert metadata. Not batched.
 
-    :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
+    :param module: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
 
      - Experts must always receive the same set of args and kwargs and produce output tensors of same type
      - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
@@ -34,49 +35,37 @@ class ExpertBackend:
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
     :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
     :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
-    :param num_warmup_steps: the number of warmup steps for LR schedule
-    :param num_total_steps: the total number of steps for LR schedule
-    :param clip_grad_norm: maximum gradient norm used for clipping
     :param kwargs: extra parameters to be forwarded into TaskPool.__init__
     """
 
     def __init__(
         self,
         name: str,
-        expert: nn.Module,
-        optimizer: torch.optim.Optimizer,
+        module: nn.Module,
         *,
-        scheduler: Callable = None,
+        optimizer: Optional[torch.optim.Optimizer] = None,
+        scheduler: Optional[LRSchedulerBase] = None,
         args_schema: Tuple[BatchTensorDescriptor, ...] = None,
         kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
         outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
-        num_warmup_steps: int = None,
-        num_total_steps: int = None,
-        clip_grad_norm: float = None,
         **kwargs,
     ):
         super().__init__()
-        self.expert, self.optimizer, self.name = expert, optimizer, name
-
-        if scheduler is None:
-            self.scheduler = None
-        else:
-            assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None
-            self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps)
-        self.clip_grad_norm = clip_grad_norm
+        self.name, self.module, self.optimizer, self.scheduler = name, module, optimizer, scheduler
 
         self.args_schema = args_schema = tuple(args_schema or ())
         self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
         assert args_schema or kwargs_schema, (
-            "expert must receive at least one positional or keyword input."
+            f"Module must take at least one positional or keyword input."
             " Did you forget to provide args_schema/kwargs_schema?"
         )
+        assert optimizer is not None or scheduler is None, "scheduler should only be used if optimizer is not None"
 
         if outputs_schema is None:
             # run expert once to get outputs schema
             dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
             dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
-            dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
+            dummy_outputs = self.module(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 
         self.forward_schema = (self.args_schema, self.kwargs_schema)  # inputs for forward
@@ -87,22 +76,17 @@ class ExpertBackend:
         self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
         self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
 
-        self.update_count = 0
-        self.examples_processed = 0
-
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
-        To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.
+        To submit a request for asynchronous processing, please use ``ModuleBackend.forward_pool.submit_task``.
+
+        .. warning: if the underlying module performs non-gradient updates (e.g. batchnorm), it will be updated twice:
+           once during forward pass, and again during backward. This behavior is similar to gradient checkpointing.
 
         Subclassing:
            This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
-
            It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
-
-           .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
-           .. For now, either register all buffers as outputs or avoid stateful experts
-
         """
         args, kwargs = nested_pack(inputs, structure=self.forward_schema)
 
@@ -110,7 +94,7 @@ class ExpertBackend:
             raise RuntimeError("Batch should contain more than 0 samples")
 
         with torch.no_grad():
-            outputs = self.expert(*args, **kwargs)
+            outputs = self.module(*args, **kwargs)
 
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         return tuple(nested_flatten(outputs))
@@ -118,7 +102,7 @@ class ExpertBackend:
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
-        To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.
+        To submit a request for asynchronous processing, please use ``ModuleBackend.backward_pool.submit_task``.
 
         Subclassing:
            This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
@@ -128,9 +112,7 @@ class ExpertBackend:
            Runtime doesn't guarantee that backward will be performed in the same order and for the same data
            as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
 
-           .. todo correct state handling (see forward)
-
-           Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
+           Please make sure to call ``ModuleBackend.on_backward`` after each call to backward
         """
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
@@ -148,7 +130,7 @@ class ExpertBackend:
 
             batch_size = args[0].size(0)
 
-            outputs = self.expert(*args, **kwargs)
+            outputs = self.module(*args, **kwargs)
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
 
             outputs_flat = tuple(nested_flatten(outputs))
@@ -163,65 +145,45 @@ class ExpertBackend:
             torch.autograd.backward(
                 outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
             )
-            self.apply_gradients(batch_size)
+            self.on_backward(batch_size)
 
         return tuple(
             x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs))
         )
 
-    def apply_gradients(self, batch_size) -> None:
+    def on_backward(self, batch_size: int) -> None:
         """
-        Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
+        Train the expert for one step. This method is called by ``ModuleBackend.backward`` after computing gradients.
         """
-        if self.clip_grad_norm is not None:
-            torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
-
-        self.optimizer.step()
-        self.optimizer.zero_grad()
+        if self.optimizer is not None:
+            self.optimizer.step()
+            self.optimizer.zero_grad()
 
         if self.scheduler is not None:
             self.scheduler.step()
 
-        self.update_count += 1
-        self.examples_processed += batch_size
-
-    def get_stats(self) -> Dict:
-        """
-        Return current expert training statistics (number of updates, number of processed examples after
-        last optimizer step)
-        """
-        return {"updates": self.update_count, "examples_processed": self.examples_processed}
-
-    def get_full_state(self) -> Dict:
-        """
-        Return the current state of the expert (including batch processing statistics)
-        """
-        full_state = {
-            "stats": self.get_stats(),
-            "model": self.expert.state_dict(),
-            "optimizer": self.optimizer.state_dict(),
-            "scheduler": {} if self.scheduler is None else self.scheduler.state_dict(),
-        }
+    def state_dict(self) -> Dict:
+        """Return the current state of the module, optimizer, and scheduler"""
+        full_state = dict(module=self.module.state_dict())
+        if self.optimizer is not None:
+            full_state["optimizer"] = self.optimizer.state_dict()
+        if self.scheduler is not None:
+            full_state["scheduler"] = self.scheduler.state_dict()
         return full_state
 
-    def load_full_state(self, state_dict: Dict):
-        if "stats" in state_dict:
-            self.update_count = state_dict["stats"]["updates"]
-            self.examples_processed = state_dict["stats"]["examples_processed"]
-        else:
-            logger.warning(f"Batch processing stats missing for expert {self.name}")
-
-        self.expert.load_state_dict(state_dict["model"])
+    def load_state_dict(self, state_dict: Dict):
+        self.module.load_state_dict(state_dict["module"])
+        if self.optimizer is not None:
+            if "optimizer" in state_dict:
+                self.optimizer.load_state_dict(state_dict["optimizer"])
+            else:
+                logger.warning(f"Optimizer state missing for {self.name}")
 
-        if "optimizer" in state_dict:
-            self.optimizer.load_state_dict(state_dict["optimizer"])
-        else:
-            logger.warning(f"Optimizer state missing for expert {self.name}")
-
-        if self.scheduler is not None and "scheduler" in state_dict:
-            self.scheduler.load_state_dict(state_dict["scheduler"])
-        else:
-            logger.warning(f"Learning rate scheduler state missing for expert {self.name}")
+        if self.scheduler is not None:
+            if "scheduler" in state_dict:
+                self.scheduler.load_state_dict(state_dict["scheduler"])
+            else:
+                logger.warning(f"Learning rate scheduler state missing for {self.name}")
 
     def get_info(self) -> Dict[str, Any]:
         """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""

+ 11 - 11
hivemind/moe/server/runtime.py

@@ -12,7 +12,7 @@ from typing import Dict, NamedTuple, Optional
 import torch
 from prefetch_generator import BackgroundGenerator
 
-from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 
 logger = get_logger(__name__)
@@ -20,20 +20,20 @@ logger = get_logger(__name__)
 
 class Runtime(threading.Thread):
     """
-    A group of processes that processes incoming requests for multiple experts on a shared device.
+    A group of processes that processes incoming requests for multiple module backends on a shared device.
     Runtime is usually created and managed by Server, humans need not apply.
 
     For debugging, you can start runtime manually with .start() or .run()
 
-    >>> expert_backends = {'expert_name': ExpertBackend(**kwargs)}
-    >>> runtime = Runtime(expert_backends)
+    >>> module_backends = {'expert_name': ModuleBackend(**kwargs)}
+    >>> runtime = Runtime(module_backends)
     >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
     >>> runtime.ready.wait()  # await for runtime to load all experts on device and create request pools
-    >>> future = runtime.expert_backends['expert_name'].forward_pool.submit_task(*expert_inputs)
+    >>> future = runtime.module_backends['expert_name'].forward_pool.submit_task(*module_inputs)
     >>> print("Returned:", future.result())
     >>> runtime.shutdown()
 
-    :param expert_backends: a dict [expert uid -> ExpertBackend]
+    :param module_backends: a dict [expert uid -> ModuleBackend]
     :param prefetch_batches: form up to this many batches in advance
     :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
     :param device: if specified, moves all experts and data to this device via .to(device=device).
@@ -46,15 +46,15 @@ class Runtime(threading.Thread):
 
     def __init__(
         self,
-        expert_backends: Dict[str, ExpertBackend],
+        module_backends: Dict[str, ModuleBackend],
         prefetch_batches=64,
         sender_threads: int = 1,
         device: torch.device = None,
         stats_report_interval: Optional[int] = None,
     ):
         super().__init__()
-        self.expert_backends = expert_backends
-        self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
+        self.module_backends = module_backends
+        self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
         self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
         self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.shutdown_trigger = mp.Event()
@@ -69,8 +69,8 @@ class Runtime(threading.Thread):
             if not pool.is_alive():
                 pool.start()
         if self.device is not None:
-            for expert_backend in self.expert_backends.values():
-                expert_backend.expert.to(self.device)
+            for backend in self.module_backends.values():
+                backend.module.to(self.device)
 
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:

+ 66 - 69
hivemind/moe/server/server.py

@@ -6,27 +6,27 @@ import threading
 from contextlib import contextmanager
 from functools import partial
 from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
 
 import torch
-from multiaddr import Multiaddr
 
 from hivemind.dht import DHT
+from hivemind.moe.expert_uid import UID_DELIMITER
 from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
-from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.layers import (
     add_custom_models_from_file,
     name_to_block,
     name_to_input,
     schedule_name_to_scheduler,
 )
+from hivemind.moe.server.layers.optim import ClippingWrapper
+from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.runtime import Runtime
+from hivemind.p2p import PeerInfo
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import Endpoint, get_free_port, get_port, replace_port
 from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 logger = get_logger(__name__)
@@ -34,55 +34,51 @@ logger = get_logger(__name__)
 
 class Server(threading.Thread):
     """
-    Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts.
+    Server allows you to host "experts" - pytorch subnetworks that can be accessed remotely by peers.
     After creation, a server should be started: see Server.run or Server.run_in_background.
 
     A working server does two things:
      - processes incoming forward/backward requests via Runtime (created by the server)
      - publishes updates to expert status every :update_period: seconds
 
-    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
-    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
-    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
+    :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions.
+    :param module_backends: dict{expert uid (str) : ModuleBackend} for all expert hosted by this server.
     :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
         if too small for normal functioning, we recommend 4 handlers per expert backend.
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
         if dht is None, this parameter is ignored.
+    :param expiration: when server declares its experts to the DHT, these entries will expire after this many seconds
     :param start: if True, the server will immediately start as a background thread and returns control after server
         is ready (see .ready below)
     """
 
     def __init__(
         self,
-        dht: Optional[DHT],
-        expert_backends: Dict[str, ExpertBackend],
-        listen_on: Endpoint = "0.0.0.0:*",
+        dht: DHT,
+        module_backends: Dict[str, ModuleBackend],
         num_connection_handlers: int = 1,
-        update_period: int = 30,
+        update_period: float = 30,
+        expiration: Optional[float] = None,
         start=False,
         checkpoint_dir=None,
         **kwargs,
     ):
         super().__init__()
-        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=get_free_port())
-        self.listen_on, self.port = listen_on, get_port(listen_on)
+        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
 
-        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
+        self.conn_handlers = [ConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:
-            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
+            self.checkpoint_saver = CheckpointSaver(module_backends, checkpoint_dir, update_period)
         else:
             self.checkpoint_saver = None
-        self.runtime = Runtime(self.experts, **kwargs)
+        self.runtime = Runtime(self.module_backends, **kwargs)
 
-        if self.dht and self.experts:
+        if self.module_backends:
             self.dht_handler_thread = DHTHandlerThread(
-                experts=self.experts,
+                module_backends=self.module_backends,
                 dht=self.dht,
-                endpoint=self.listen_on,
                 update_period=self.update_period,
+                expiration=expiration,
                 daemon=True,
             )
 
@@ -92,7 +88,6 @@ class Server(threading.Thread):
     @classmethod
     def create(
         cls,
-        listen_on="0.0.0.0:*",
         num_experts: int = None,
         expert_uids: str = None,
         expert_pattern: str = None,
@@ -101,24 +96,26 @@ class Server(threading.Thread):
         optim_cls=torch.optim.Adam,
         scheduler: str = "none",
         num_warmup_steps=None,
-        num_total_steps=None,
+        num_training_steps=None,
         clip_grad_norm=None,
         num_handlers=None,
         min_batch_size=1,
         max_batch_size=4096,
         device=None,
-        no_dht=False,
         initial_peers=(),
         checkpoint_dir: Optional[Path] = None,
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         custom_module_path=None,
+        update_period: float = 30,
+        expiration: Optional[float] = None,
         *,
         start: bool,
+        **kwargs,
     ) -> Server:
         """
-        Instantiate a server with several identical experts. See argparse comments below for details
-        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
+        Instantiate a server with several identical modules. See argparse comments below for details
+
         :param num_experts: run this many identical experts
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
            means "sample random experts between myprefix.0.0 and myprefix.255.255;
@@ -133,31 +130,28 @@ class Server(threading.Thread):
         :param optim_cls: uses this optimizer to train all experts
         :param scheduler: if not `none`, the name of the expert LR scheduler
         :param num_warmup_steps: the number of warmup steps for LR schedule
-        :param num_total_steps: the total number of steps for LR schedule
+        :param num_training_steps: the total number of steps for LR schedule
         :param clip_grad_norm: maximum gradient norm used for clipping
 
-        :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
 
         :param checkpoint_dir: directory to save and load expert checkpoints
 
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
-            for each BatchTensorProto in ExpertBackend for the respective experts.
+            for each BatchTensorProto in ModuleBackend for the respective experts.
 
         :param start: if True, starts server right away and returns when server is ready for requests
         :param stats_report_interval: interval between two reports of batch processing performance statistics
+        :param kwargs: any other params will be forwarded to DHT upon creation
         """
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
         assert expert_cls in name_to_block
 
-        if no_dht:
-            dht = None
-        else:
-            dht = DHT(initial_peers=initial_peers, start=True)
-            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+        dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
+        visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+        logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
         assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
             num_experts is not None and expert_uids is None
@@ -187,7 +181,6 @@ class Server(threading.Thread):
 
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
         sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim)
@@ -196,21 +189,26 @@ class Server(threading.Thread):
         else:
             args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
-        scheduler = schedule_name_to_scheduler[scheduler]
+        scheduler_cls = schedule_name_to_scheduler[scheduler]
+        if scheduler_cls is not None:
+            scheduler_cls = partial(
+                scheduler_cls, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
+            )
 
         # initialize experts
         experts = {}
         for expert_uid in expert_uids:
             expert = name_to_block[expert_cls](hidden_dim)
-            experts[expert_uid] = ExpertBackend(
+            optimizer = optim_cls(expert.parameters()) if optim_cls is not None else None
+            scheduler = scheduler_cls(optimizer) if scheduler_cls is not None else None
+            if clip_grad_norm is not None:
+                optimizer = ClippingWrapper(optimizer, clip_grad_norm)
+            experts[expert_uid] = ModuleBackend(
                 name=expert_uid,
-                expert=expert,
+                module=expert,
                 args_schema=args_schema,
-                optimizer=optim_cls(expert.parameters()),
+                optimizer=optimizer,
                 scheduler=scheduler,
-                num_warmup_steps=num_warmup_steps,
-                num_total_steps=num_total_steps,
-                clip_grad_norm=clip_grad_norm,
                 min_batch_size=min_batch_size,
                 max_batch_size=max_batch_size,
             )
@@ -221,11 +219,12 @@ class Server(threading.Thread):
         return cls(
             dht,
             experts,
-            listen_on=listen_on,
             num_connection_handlers=num_handlers,
             device=device,
             checkpoint_dir=checkpoint_dir,
             stats_report_interval=stats_report_interval,
+            update_period=update_period,
+            expiration=expiration,
             start=start,
         )
 
@@ -234,25 +233,24 @@ class Server(threading.Thread):
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         """
-        logger.info(f"Server started at {self.listen_on}")
-        logger.info(f"Got {len(self.experts)} experts:")
-        for expert_name, backend in self.experts.items():
-            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
-
-        if self.dht:
-            if not self.dht.is_alive():
-                self.dht.run_in_background(await_ready=True)
-
-            if self.experts:
-                self.dht_handler_thread.start()
+        logger.info(f"Server started with {len(self.module_backends)} modules:")
+        for expert_name, backend in self.module_backends.items():
+            num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
+
+        if not self.dht.is_alive():
+            self.dht.run_in_background(await_ready=True)
+
+        if self.module_backends:
+            self.dht_handler_thread.start()
+
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
 
         for process in self.conn_handlers:
             if not process.is_alive():
                 process.start()
-            process.ready.wait()
+            process.ready.result()
 
         try:
             self.runtime.run()
@@ -294,7 +292,7 @@ class Server(threading.Thread):
             process.join()
         logger.debug("Connection handlers terminated")
 
-        if self.dht and self.experts:
+        if self.module_backends:
             self.dht_handler_thread.stop.set()
             self.dht_handler_thread.join()
 
@@ -302,9 +300,8 @@ class Server(threading.Thread):
             self.checkpoint_saver.stop.set()
             self.checkpoint_saver.join()
 
-        if self.dht is not None:
-            self.dht.shutdown()
-            self.dht.join()
+        self.dht.shutdown()
+        self.dht.join()
 
         logger.debug(f"Shutting down runtime")
 
@@ -313,14 +310,14 @@ class Server(threading.Thread):
 
 
 @contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, List[Multiaddr]]:
-    """A context manager that creates server in a background process, awaits .ready on entry and shuts down on exit"""
+def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
+    """A context manager that creates server in a background , awaits .ready on entry and shuts down on exit"""
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     try:
         runner.start()
         # once the server is ready, runner will send us
-        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
+        # either (False, exception) or (True, PeerInfo(dht_peer_id, dht_maddrs))
         start_ok, data = pipe.recv()
         if start_ok:
             yield data
@@ -344,8 +341,8 @@ def _server_runner(pipe, *args, **kwargs):
         return
 
     try:
-        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
-        pipe.send((True, (server.listen_on, dht_maddrs)))
+        dht_maddrs = server.dht.get_visible_maddrs()
+        pipe.send((True, PeerInfo(server.dht.peer_id, dht_maddrs)))
         pipe.recv()  # wait for shutdown signal
 
     finally:

+ 0 - 4
hivemind/optim/__init__.py

@@ -1,7 +1,3 @@
-from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
 from hivemind.optim.optimizer import Optimizer
-from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.training_averager import TrainingAverager

+ 0 - 34
hivemind/optim/adaptive.py

@@ -1,34 +0,0 @@
-from typing import Sequence
-
-import torch.optim
-
-from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.training_averager import TrainingAverager
-
-
-class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
-    """
-    Behaves exactly as CollaborativeOptimizer except:
-
-    * averages adaptive learning rates of an optimizer
-    * doesn't average gradients
-
-    :param average_opt_statistics: average optimizer statistics with corresponding names in statedict
-    :param kwargs: options for CollaborativeOptimizer
-    """
-
-    def __init__(self, opt: torch.optim.Optimizer, average_opt_statistics: Sequence[str], **kwargs):
-        super().__init__(opt, average_opt_statistics=average_opt_statistics, **kwargs)
-
-    def _make_averager(self, average_opt_statistics, **kwargs):
-        return TrainingAverager(
-            self.opt,
-            dht=self.dht,
-            average_parameters=True,
-            average_gradients=False,
-            average_opt_statistics=average_opt_statistics,
-            prefix=f"{self.prefix}_averaging",
-            allreduce_timeout=self.averaging_timeout,
-            client_mode=self.client_mode,
-            **kwargs,
-        )

+ 0 - 44
hivemind/optim/base.py

@@ -1,44 +0,0 @@
-from warnings import warn
-
-import torch
-
-from hivemind.dht import DHT
-
-
-class DecentralizedOptimizerBase(torch.optim.Optimizer):
-    """A shared interface for all hivemind optimizers. Cooperates with DHT peers to train a shared model"""
-
-    def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
-        self.opt, self.dht = opt, dht
-        warn(
-            "DecentralizedOptimizerBase and its subclasses have been deprecated and will be removed "
-            "in hivemind 1.1.0. Use hivemind.Optimizer instead",
-            FutureWarning,
-            stacklevel=2,
-        )
-
-    @property
-    def state(self):
-        return self.opt.state
-
-    @property
-    def param_groups(self):
-        return self.opt.param_groups
-
-    def add_param_group(self, param_group: dict) -> None:
-        raise ValueError(
-            f"{self.__class__.__name__} does not support calling add_param_group after creation."
-            f"Please provide all parameter groups at init."
-        )
-
-    def state_dict(self) -> dict:
-        return self.opt.state_dict()
-
-    def load_state_dict(self, state_dict: dict):
-        return self.opt.load_state_dict(state_dict)
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}(opt={repr(self.opt)}, dht={repr(self.dht)})"
-
-    def shutdown(self):
-        raise NotImplementedError()

+ 0 - 558
hivemind/optim/collaborative.py

@@ -1,558 +0,0 @@
-from __future__ import annotations
-
-import logging
-from dataclasses import dataclass
-from threading import Event, Lock, Thread
-from typing import Dict, Iterator, Optional
-
-import numpy as np
-import torch
-from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
-
-from hivemind.dht import DHT
-from hivemind.dht.crypto import RSASignatureValidator
-from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.grad_scaler import HivemindGradScaler
-from hivemind.optim.training_averager import TrainingAverager
-from hivemind.utils import get_dht_time, get_logger
-from hivemind.utils.performance_ema import PerformanceEMA
-
-logger = get_logger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
-
-
-@dataclass(frozen=False)
-class CollaborationState:
-    optimizer_step: int
-    samples_accumulated: int
-    target_batch_size: int
-    num_peers: int
-    num_clients: int
-    eta_next_step: float
-    next_fetch_time: float
-
-    @property
-    def ready_for_step(self):
-        return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
-
-    def register_step(self, local_step: int):
-        self.optimizer_step = max(local_step, self.optimizer_step)
-        self.samples_accumulated = 0
-        self.eta_next_step = float("inf")
-
-
-class TrainingState(BaseModel):
-    peer_id: bytes
-    step: conint(ge=0, strict=True)
-    samples_accumulated: conint(ge=0, strict=True)
-    samples_per_second: confloat(ge=0.0, strict=True)
-    time: StrictFloat
-    client_mode: StrictBool
-
-
-class TrainingProgressSchema(BaseModel):
-    progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
-
-
-class CollaborativeOptimizer(DecentralizedOptimizerBase):
-    """
-    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers.
-
-    These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
-    Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
-
-    :note: **For new projects, please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
-      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and many advanced ones.
-      CollaborativeOptimizer will still be supported for a while, but it will be deprecated in v1.1.0.
-
-    :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
-
-      * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
-      * it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
-
-
-    :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
-    :param dht: a running hivemind.DHT daemon connected to other peers
-    :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
-    :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
-    :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
-    :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
-    :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
-    :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
-    :param expected_drift_peers: assume that this many new peers can join between steps
-    :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between steps
-    :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
-      refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
-    :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
-    :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
-    :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
-    :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
-    :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
-    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
-    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
-    :param scheduler: if specified, use this scheduler to update optimizer learning rate
-    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
-      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
-    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
-     device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
-     the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
-    :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
-    :param kwargs: additional parameters forwarded to DecentralizedAverager
-    :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
-      explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
-    """
-
-    def __init__(
-        self,
-        opt: torch.optim.Optimizer,
-        *,
-        dht: DHT,
-        prefix: str,
-        target_batch_size: int,
-        batch_size_per_step: Optional[int] = None,
-        scheduler: Optional[LRSchedulerBase] = None,
-        min_refresh_period: float = 0.5,
-        max_refresh_period: float = 30,
-        default_refresh_period: float = 3,
-        expected_drift_peers: float = 3,
-        expected_drift_rate: float = 0.2,
-        performance_ema_alpha: float = 0.1,
-        metadata_expiration: float = 60.0,
-        averaging_timeout: Optional[float] = None,
-        load_state_timeout: float = 600.0,
-        step_tolerance: int = 1,
-        reuse_grad_buffers: bool = False,
-        accumulate_grads_on: Optional[torch.device] = None,
-        client_mode: bool = False,
-        verbose: bool = False,
-        **kwargs,
-    ):
-        super().__init__(opt, dht)
-
-        signature_validator = RSASignatureValidator()
-        self._local_public_key = signature_validator.local_public_key
-        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
-
-        if reuse_grad_buffers and accumulate_grads_on is not None:
-            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
-        self.prefix, self.scheduler = prefix, scheduler
-        self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
-        self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = (
-            min_refresh_period,
-            max_refresh_period,
-            default_refresh_period,
-        )
-        self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
-        self.averaging_timeout = averaging_timeout
-        self.load_state_timeout = load_state_timeout
-        self.metadata_expiration = metadata_expiration
-        self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
-        self.client_mode, self.step_tolerance = client_mode, step_tolerance
-        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
-        self.averager = self._make_averager(**kwargs)
-
-        self._step_supports_amp_scaling = self.reuse_grad_buffers  # enable custom execution with torch GradScaler
-
-        self.training_progress_key = f"{self.prefix}_progress"
-        self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
-        self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
-        self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
-        self.last_step_time = None
-
-        self.collaboration_state = self._fetch_state()
-        self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
-        self.lock_local_progress, self.should_report_progress = Lock(), Event()
-        self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
-        self.progress_reporter.start()
-        self.collaboration_state_updater = Thread(
-            target=self.check_collaboration_state_periodically, daemon=True, name=f"{self}.collaboration_state_updater"
-        )
-        self.collaboration_state_updater.start()
-
-    def _make_averager(self, **kwargs):
-        return TrainingAverager(
-            self.opt,
-            dht=self.dht,
-            average_parameters=True,
-            average_gradients=True,
-            prefix=f"{self.prefix}_averaging",
-            allreduce_timeout=self.averaging_timeout,
-            client_mode=self.client_mode,
-            **kwargs,
-        )
-
-    @property
-    def local_step(self) -> int:
-        return self.averager.local_step
-
-    @property
-    def is_synchronized(self) -> bool:
-        return self.local_step >= self.collaboration_state.optimizer_step
-
-    @property
-    def is_within_tolerance(self) -> bool:
-        return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
-
-    def is_alive(self) -> bool:
-        return self.averager.is_alive()
-
-    def load_state_from_peers(self, **kwargs):
-        """Attempt to fetch the newest collaboration state from other peers"""
-        with self.lock_collaboration_state:
-            while True:
-                try:
-                    self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
-                    break
-                except KeyboardInterrupt:
-                    raise
-                except BaseException as e:
-                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
-                    continue
-
-            self.local_samples_accumulated = self.local_updates_accumulated = 0
-            self.reset_accumulated_grads_()
-            self.update_scheduler()
-
-    def state_dict(self) -> dict:
-        state_dict = super().state_dict()
-        state_dict["state"]["collaborative_step"] = self.local_step
-        return state_dict
-
-    def load_state_dict(self, state_dict: dict):
-        if "collaborative_step" in state_dict["state"]:
-            self.averager.local_step = state_dict["state"].pop("collaborative_step")
-        return super().load_state_dict(state_dict)
-
-    def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
-        """
-        Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
-
-        :param batch_size: optional override for batch_size_per_step from init
-        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
-        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
-        """
-        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
-            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler)")
-        if self.batch_size_per_step is None:
-            if batch_size is None:
-                raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
-            logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
-            self.batch_size_per_step = batch_size
-        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
-
-        if not self.is_synchronized and not self.is_within_tolerance:
-            logger.log(self.status_loglevel, "Peer is out of sync")
-            self.load_state_from_peers()
-            return
-        elif not self.is_synchronized and self.is_within_tolerance:
-            self.averager.local_step = self.collaboration_state.optimizer_step
-            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}")
-
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
-            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
-            self.reset_accumulated_grads_()
-            self.should_report_progress.set()
-            return
-
-        if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
-            logger.warning(
-                f"Training step took {get_dht_time() - self.last_step_time}, "
-                f"but metadata expired in {self.metadata_expiration} s."
-            )
-
-        self.accumulate_grads_(batch_size)
-
-        with self.lock_local_progress:
-            self.local_samples_accumulated += batch_size
-            self.local_updates_accumulated += 1
-            self.performance_ema.update(task_size=batch_size)
-            self.should_report_progress.set()
-
-        if not self.collaboration_state.ready_for_step:
-            return
-
-        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        with self.performance_ema.pause(), self.lock_collaboration_state:
-            self.collaboration_state = self._fetch_state()
-            self.collaboration_state_updated.set()
-
-            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.unscale_(self)
-
-            current_step, group_info = self.averager.local_step, None
-
-            if self.collaboration_state.num_peers > 1:
-                mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
-                weight = self.local_samples_accumulated / mean_samples_per_worker
-                try:
-                    group_info = self.averager.step(
-                        weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
-                    )
-                    if group_info:
-                        logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
-
-                        # update our current step if we averaged with another peer that was at a more recent step
-                        for peer, peer_step in group_info.items():
-                            if isinstance(peer_step, int):
-                                current_step = max(current_step, peer_step)
-                            else:
-                                logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
-
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
-
-            else:
-                logger.log(
-                    self.status_loglevel,
-                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s)",
-                )
-
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.step(self)
-            else:
-                self.opt.step()
-
-            self.reset_accumulated_grads_()
-            self.local_samples_accumulated = self.local_updates_accumulated = 0
-            self.collaboration_state.register_step(current_step + 1)
-            self.averager.local_step = current_step + 1
-            self.collaboration_state_updated.set()
-            self.update_scheduler()
-
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.update()
-
-            if not self.averager.client_mode:
-                self.averager.state_sharing_priority = self.local_step
-
-        logger.log(self.status_loglevel, f"Optimizer step: done!")
-
-        return group_info
-
-    def step_aux(self, **kwargs):
-        """
-        Find and assist other peers in averaging without sending local gradients.
-
-        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
-        """
-
-        if not self.collaboration_state.ready_for_step:
-            return
-
-        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self._fetch_state()
-        self.collaboration_state_updated.set()
-
-        with self.lock_collaboration_state:
-            current_step, group_info = self.averager.local_step, None
-
-            try:
-                group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
-                if group_info:
-                    logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
-
-                    # update our current step if we averaged with another peer that was at a more recent step
-                    for peer, peer_step in group_info.items():
-                        if isinstance(peer_step, int):
-                            current_step = max(current_step, peer_step)
-                        else:
-                            logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
-            except BaseException as e:
-                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
-
-            self.collaboration_state.register_step(current_step + 1)
-            self.averager.local_step = current_step + 1
-            self.collaboration_state_updated.set()
-
-        logger.log(self.status_loglevel, f"Optimizer step: done!")
-
-        return group_info
-
-    def _grad_buffers(self) -> Iterator[torch.Tensor]:
-        """pytorch-internal gradient buffers"""
-        for param_group in self.opt.param_groups:
-            for param in param_group["params"]:
-                if param.grad is None:
-                    yield torch.zeros_like(param)
-                else:
-                    yield param.grad
-
-    @torch.no_grad()
-    def accumulated_grads(self) -> Iterator[torch.Tensor]:
-        """local gradient accumulators"""
-        if self.reuse_grad_buffers:
-            yield from self._grad_buffers()
-            return
-
-        if self._grads is None:
-            self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
-        yield from self._grads
-
-    @torch.no_grad()
-    def accumulate_grads_(self, batch_size: int):
-        """add current gradients to grad accumulators (if any)"""
-        if self.reuse_grad_buffers:
-            # user is responsible for accumulating gradients in .grad buffers
-            assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
-        else:
-            alpha = float(batch_size) / self.batch_size_per_step
-            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
-
-    @torch.no_grad()
-    def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
-        if not self.reuse_grad_buffers:
-            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-                grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
-        if scale_by is not None:
-            for grad_buf in self._grad_buffers():
-                grad_buf.mul_(scale_by)
-
-    @torch.no_grad()
-    def reset_accumulated_grads_(self):
-        for grad_buf in self.accumulated_grads():
-            grad_buf.zero_()
-
-    def report_training_progress(self):
-        """Periodically publish metadata and the current number of samples accumulated towards the next step"""
-        while self.is_alive():
-            self.should_report_progress.wait()
-            self.should_report_progress.clear()
-            with self.lock_local_progress:
-                current_time = get_dht_time()
-                local_state_info = TrainingState(
-                    peer_id=self.averager.peer_id.to_bytes(),
-                    step=self.local_step,
-                    samples_accumulated=self.local_samples_accumulated,
-                    samples_per_second=self.performance_ema.samples_per_second,
-                    time=current_time,
-                    client_mode=self.averager.client_mode,
-                )
-
-            self.dht.store(
-                key=self.training_progress_key,
-                subkey=self._local_public_key,
-                value=local_state_info.dict(),
-                expiration_time=current_time + self.metadata_expiration,
-                return_future=True,
-            )
-
-    def check_collaboration_state_periodically(self):
-        """
-        Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
-        """
-        while self.is_alive():
-            time_to_next_update = max(0.0, self.collaboration_state.next_fetch_time - get_dht_time())
-            if self.collaboration_state_updated.wait(time_to_next_update):
-                self.collaboration_state_updated.clear()
-                continue  # if state was updated externally, reset timer
-
-            with self.lock_collaboration_state:
-                self.collaboration_state = self._fetch_state()
-
-    def _fetch_state(self) -> CollaborationState:
-        """Read performance statistics reported by peers, estimate progress towards next batch"""
-        response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
-        current_time = get_dht_time()
-
-        if not isinstance(response, dict) or len(response) == 0:
-            logger.log(self.status_loglevel, f"Found no active peers: {response}")
-            samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
-            local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
-
-            return CollaborationState(
-                self.local_step,
-                self.local_samples_accumulated,
-                self.target_batch_size,
-                num_peers=0,
-                num_clients=0,
-                eta_next_step=current_time + local_eta_next_step,
-                next_fetch_time=current_time + self.default_refresh_period,
-            )
-
-        valid_peer_states = [
-            TrainingState.parse_obj(peer_state.value)
-            for peer_state in response.values()
-            if peer_state.value is not None
-        ]
-
-        num_peers = len(valid_peer_states)
-        num_clients = sum(state.client_mode for state in valid_peer_states)
-        global_optimizer_step = self.local_step
-        for state in valid_peer_states:
-            if not state.client_mode:
-                global_optimizer_step = max(global_optimizer_step, state.step)
-
-        total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
-
-        for state in valid_peer_states:
-            total_samples_per_second += state.samples_per_second
-            if state.step == global_optimizer_step:
-                total_samples_accumulated += state.samples_accumulated
-                estimated_current_samples += (
-                    state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
-                )
-            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
-            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
-
-        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
-        estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
-
-        expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
-        time_to_next_fetch = float(
-            np.clip(
-                a=estimated_time_to_next_step * num_peers / expected_max_peers,
-                a_min=self.min_refresh_period,
-                a_max=self.max_refresh_period,
-            )
-        )
-        logger.log(
-            self.status_loglevel,
-            f"{self.prefix} accumulated {total_samples_accumulated} samples from "
-            f"{num_peers} peers for step #{global_optimizer_step}. "
-            f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
-        )
-        return CollaborationState(
-            global_optimizer_step,
-            total_samples_accumulated,
-            target_batch_size=self.target_batch_size,
-            num_peers=num_peers,
-            num_clients=num_clients,
-            eta_next_step=current_time + estimated_time_to_next_step,
-            next_fetch_time=current_time + time_to_next_fetch,
-        )
-
-    def zero_grad(self, *args, **kwargs):
-        if self.reuse_grad_buffers:
-            raise ValueError(
-                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
-                f"call zero_grad manually. Gradients will be refreshed internally."
-            )
-        return self.opt.zero_grad(*args, **kwargs)
-
-    def update_scheduler(self):
-        if self.scheduler:
-            while self.scheduler._step_count < self.local_step:
-                self.scheduler.step()
-
-    def shutdown(self):
-        logger.debug("Shutting down averager...")
-        self.averager.shutdown()
-        logger.debug("Sending goodbye to peers...")
-        self.dht.store(
-            self.training_progress_key,
-            subkey=self._local_public_key,
-            value=None,
-            expiration_time=get_dht_time() + self.metadata_expiration,
-        )
-        logger.debug(f"{self.__class__.__name__} is shut down")
-
-    def __del__(self):
-        self.shutdown()

+ 20 - 7
hivemind/optim/grad_averager.py

@@ -1,16 +1,20 @@
 import contextlib
-from typing import Iterable, Iterator, Optional
+from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar
 
 import torch
 
-import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
-from hivemind.utils import DHTExpiration, get_dht_time, get_logger
+from hivemind.dht import DHT
+from hivemind.utils import DHTExpiration, get_logger
 
 logger = get_logger(__name__)
 
 
+TGradientAverager = TypeVar("TGradientAverager", bound="GradientAverager")
+GradientAveragerFactory = Callable[..., TGradientAverager]
+
+
 class GradientAverager(DecentralizedAverager):
     """
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
@@ -36,6 +40,7 @@ class GradientAverager(DecentralizedAverager):
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
 
 
@@ -69,12 +74,13 @@ class GradientAverager(DecentralizedAverager):
         self,
         parameters: Iterable[torch.nn.Parameter],
         *,
-        dht: hivemind.DHT,
+        dht: DHT,
         prefix: str,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         warn: bool = True,
+        averaged_grads: Sequence[torch.Tensor] = (),
         **kwargs,
     ):
         if reuse_grad_buffers and accumulate_grads_on is not None:
@@ -95,9 +101,16 @@ class GradientAverager(DecentralizedAverager):
         self._new_averaged_grads = False
 
         with torch.no_grad():
-            averaged_grads = tuple(
-                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
-            )
+            if not averaged_grads:
+                averaged_grads = tuple(
+                    grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+                )
+            else:
+                if any(
+                    param_grad.size() != grad.size()
+                    for param_grad, grad in zip(self._grads_from_parameters(), averaged_grads)
+                ):
+                    raise ValueError("Averaged gradients don't have same shape as gradients from parameters")
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:

+ 1 - 1
hivemind/optim/grad_scaler.py

@@ -50,7 +50,7 @@ class GradScaler(TorchGradScaler):
 
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
         with self._lock:
-            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            assert isinstance(optimizer, hivemind.Optimizer)
             if self._is_running_global_step:
                 super().unscale_(optimizer)
                 self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])

+ 12 - 4
hivemind/optim/optimizer.py

@@ -11,7 +11,7 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
@@ -147,6 +147,7 @@ class Optimizer(torch.optim.Optimizer):
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
 
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param grad_averager_factory: if provided, creates gradient averager with required averaging strategy
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
     :param load_state_compression: compression strategy for loading state from peers, default = no compression
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
@@ -187,6 +188,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode: bool = None,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
+        grad_averager_factory: Optional[GradientAveragerFactory] = None,
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
@@ -226,6 +228,9 @@ class Optimizer(torch.optim.Optimizer):
         if use_local_updates:
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+            assert (
+                grad_averager_factory is None
+            ), "if local_updates is True, provided grad_averager_factory will not be used"
 
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
@@ -258,7 +263,7 @@ class Optimizer(torch.optim.Optimizer):
         )
         if not use_local_updates:
             self.grad_averager = self._make_gradient_averager(
-                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+                grad_averager_factory, reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression
             )
         else:
             self.grad_averager = None
@@ -291,9 +296,10 @@ class Optimizer(torch.optim.Optimizer):
             **kwargs,
         )
 
-    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+    def _make_gradient_averager(self, factory: Optional[GradientAveragerFactory], **kwargs) -> GradientAverager:
         assert hasattr(self, "state_averager"), "must initialize state averager first"
-        grad_averager = GradientAverager(
+        factory = factory if factory is not None else GradientAverager
+        grad_averager = factory(
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
@@ -685,6 +691,8 @@ class Optimizer(torch.optim.Optimizer):
             while True:
                 try:
                     self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    if self.grad_averager is not None:
+                        self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     break
                 except KeyboardInterrupt:
                     raise

+ 222 - 0
hivemind/optim/power_sgd_averager.py

@@ -0,0 +1,222 @@
+import asyncio
+import contextlib
+from enum import Enum
+from typing import Any, Iterable, Optional, Sequence
+
+import torch
+
+from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import MatchmakingException
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.dht import DHT
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import enter_asynchronously
+from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
+
+GatheredData = Any
+logger = get_logger(__name__)
+
+
+class AllReducePhases(Enum):
+    PHASE_P = 1
+    PHASE_Q = 2
+
+
+class PowerSGDGradientAverager(GradientAverager):
+    """
+    A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
+    For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
+    Put simply, this method approximates large gradient tensors (m,n) with a product of two
+    smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
+
+    As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
+    High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
+    Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
+
+    To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
+    if some part of the gradient is "lost in compression", it will be added to the next iteration.
+    This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
+    and (b) if devices stay alive only for one step, training with small rank may converge slower.
+    This is because error feedback takes multiple steps to kick in.
+
+    Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
+    If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
+    normally, i.e. without powerSGD. See min_compression_ratio for details.
+
+    :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
+     matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param averager_rank: rank of compressed gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param min_compression_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor,
+      otherwise aggregate tensors as is
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        averager_rank: int,
+        *,
+        dht: DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        min_compression_ratio: float = 0.5,
+        averaged_grads: Optional[Sequence[torch.Tensor]] = None,
+        **kwargs,
+    ):
+        self.rank = averager_rank
+        self.parameters = tuple(parameters)
+        self._uncompressed_gradients_indexes = set(
+            i
+            for i, grad in enumerate(self._grads_from_parameters())
+            if grad.ndim <= 1
+            or (1 - self.rank * sum(get_flatten_greedy_dims(grad)) / grad.numel()) < min_compression_ratio
+            # compute how much parameters are left after factorization
+        )
+        self._ms = [
+            torch.zeros_like(grad, device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+        self._qs = [
+            torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+
+        super().__init__(
+            self.parameters,
+            dht=dht,
+            prefix=prefix,
+            reuse_grad_buffers=reuse_grad_buffers,
+            accumulate_grads_on=accumulate_grads_on,
+            client_mode=client_mode,
+            warn=warn,
+            averaged_grads=averaged_grads,
+            **kwargs,
+        )
+
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            for phase in list(AllReducePhases):
+                self._running_groups[group_info.group_id + phase.name.encode()] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            for phase in list(AllReducePhases):
+                maybe_future = self._running_groups.pop(group_info.group_id + phase.name.encode(), None)
+                if maybe_future and not maybe_future.done():
+                    logger.warning(f"All-reduce group {group_info.group_id + phase.name.encode()} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            async with enter_asynchronously(self.get_tensors()) as averaged_grads:
+                averaged_grads_via_sgd = [
+                    grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes
+                ]
+                for grad, m in zip(averaged_grads_via_sgd, self._ms):
+                    m.add_(grad.to(m.device))
+
+                ps = [
+                    torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
+                    for idx, grad in enumerate(averaged_grads_via_sgd)
+                ]
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    # we use reshape for all matrixes because PowerSGD works only with 2d tensors
+                    torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
+
+                p_group_id = group_info.group_id + AllReducePhases.PHASE_P.name.encode()
+                q_groud_id = group_info.group_id + AllReducePhases.PHASE_Q.name.encode()
+
+                await self._run_allreduce_inplace_(ps, group_info, p_group_id, peer_fractions=peer_fractions, **kwargs)
+
+                for p in ps:
+                    orthogonalize_(p)
+
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
+
+                phase_q_tensors = self._qs + [
+                    grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
+                ]
+
+                await self._run_allreduce_inplace_(
+                    phase_q_tensors, group_info, q_groud_id, peer_fractions=peer_fractions, **kwargs
+                )
+
+                for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grads_via_sgd):
+                    new_m = torch.matmul(p, q.t()).reshape(m.size())
+                    m.sub_(new_m)
+                    grad.copy_(new_m)
+
+                return user_gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+    def get_current_state(self):
+        """
+        Get current gradient averager state and when requested by a newbie peer.
+        """
+        with torch.no_grad(), self.lock_averaged_tensors:
+            grad_averager_buffers = [q for q in self._qs]
+            grad_averager_buffers_infos = [
+                CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
+                for buffer, key in zip(grad_averager_buffers, enumerate(grad_averager_buffers))
+            ]
+
+        metadata = dict(group_bits=self.get_group_bits())
+        return metadata, grad_averager_buffers, grad_averager_buffers_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update gradient averager buffers.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        logger.info("Starting loading gradient averager buffers from peers")
+
+        if len(flat_tensors) != len(self._qs):
+            logger.error("Failed to load state from peer, received invalid parameters, extras or metadata")
+            return
+
+        with torch.no_grad(), self.lock_averaged_tensors:
+            for local_q, loaded_q in zip(self._qs, flat_tensors):
+                local_q.copy_(loaded_q, non_blocking=True)

+ 0 - 229
hivemind/optim/simple.py

@@ -1,229 +0,0 @@
-import time
-from threading import Event, Lock, Thread
-from typing import Optional, Sequence, Tuple
-
-import torch
-
-from hivemind.dht import DHT
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.training_averager import TrainingAverager
-from hivemind.utils import get_dht_time, get_logger
-
-logger = get_logger(__name__)
-
-
-class DecentralizedOptimizer(DecentralizedOptimizerBase):
-    """
-    A simple optimizer that trains a shared model by averaging with peers in variety of ways. Supports
-    parameter/gradient averaging and syncing adaptive learning rates or any other internal statistics of optimizer.
-
-    :param opt: a pytorch optimizer configured to update model parameters.
-    :param dht: a running hivemind DHT daemon connected to other peers
-    :param average_parameters: whether to average model parameters
-    :param average_gradients: whether to average gradients
-    :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
-    :param averaging_steps_period: performs averaging after this many optimizer steps
-    :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
-      many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
-    :param kwargs: additional parameters passed to TrainingAverager
-    :note: if you're using an optimizer with adaptive learning rates (such as Adam), make sure to specify
-      necessary fields' names in `average_opt_statistics`. Otherwise you may encounter poor convergence.
-    :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
-    """
-
-    def __init__(
-        self,
-        opt: torch.optim.Optimizer,
-        dht: DHT,
-        *,
-        prefix: str,
-        target_group_size: int,
-        average_parameters: bool,
-        average_gradients: bool,
-        average_opt_statistics: Sequence[str] = (),
-        averaging_steps_period: int = 1,
-        averaging_time_period: float = 0,
-        timeout: Optional[float] = None,
-        verbose: bool = False,
-        **kwargs,
-    ):
-        super().__init__(opt, dht)
-        assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
-        self.local_step, self.averaging_step_period = 0, averaging_steps_period
-
-        self.averager = TrainingAverager(
-            opt,
-            average_parameters=average_parameters,
-            average_gradients=average_gradients,
-            average_opt_statistics=average_opt_statistics,
-            dht=dht,
-            start=True,
-            prefix=prefix,
-            target_group_size=target_group_size,
-            **kwargs,
-        )
-        self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
-        self.lock_parameters.acquire()  # this lock is only released when averager can modify tensors in background
-
-        self.background_averaging_thread = Thread(
-            name=f"{self.__class__.__name__}",
-            daemon=True,
-            target=self._average_parameters_in_background,
-            args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
-            kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose),
-        )
-        self.background_averaging_thread.start()
-
-    def step(self, *args, **kwargs):
-        loss = self.opt.step(*args, **kwargs)
-        if self.lock_parameters.locked():
-            self.lock_parameters.release()
-        try:
-            self.local_step += 1
-            if self.local_step % self.averaging_step_period == 0:
-                self.update_event.set()
-            self.averager.pending_updates_done.wait()
-
-            if not self.averager.client_mode:
-                self.averager.state_sharing_priority = get_dht_time()
-
-            return loss
-        finally:
-            self.lock_parameters.acquire()
-
-    def zero_grad(self, *args, **kwargs):
-        return self.opt.zero_grad(*args, **kwargs)
-
-    def __del__(self):
-        self.stop_event.set()
-        self.update_event.set()
-
-    def shutdown(self):
-        self.stop_event.set()
-        self.update_event.set()
-        self.averager.shutdown()
-
-    @staticmethod
-    @torch.no_grad()
-    def _average_parameters_in_background(
-        lock_parameters: Lock,
-        update_event: Event,
-        stop_event: Event,
-        averager: TrainingAverager,
-        averaging_period: float,
-        verbose: bool,
-        **kwargs,
-    ):
-        """Iteratively find groups of peers, average parameters with these peers and update local model parameters."""
-        while not stop_event.is_set():
-            update_event.wait()
-            update_event.clear()
-            if stop_event.is_set():
-                break
-
-            if averaging_period:
-                current_time = get_dht_time()
-                # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups)
-                time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
-                time.sleep(time_to_nearest_interval)
-
-            if verbose:
-                logger.info(f"Starting a new averaging round with current parameters")
-            try:
-                group_info = averager.step(lock_parameters, **kwargs)
-                if verbose:
-                    if group_info is not None:
-                        logger.info(f"Finished averaging round in with {len(group_info)} peers")
-                    else:
-                        logger.warning(f"Averaging round failed: could not find group")
-            except Exception as e:
-                logger.error(f"Averaging round failed: caught {e}")
-
-
-class DecentralizedSGD(DecentralizedOptimizer):
-    """
-    Decentralized Stochastic Gradient Descent algorithm like in Lian et al (2017) [1] based on Moshpit All-Reduce [2].
-
-    :param dht: a running hivemind DHT daemon connected to other peers
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param kwargs: additional parameters passed to DecentralizedOptimizer
-
-    - [1] Can Decentralized Algorithms Outperform Centralized Algorithms? A Case Study for Parallel Stochastic Gradient
-     Descent - https://proceedings.neurips.cc/paper/2017/hash/f75526659f31040afeb61cb7133e4e6d-Abstract.html
-    - [2] Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices
-     https://arxiv.org/abs/2103.03239
-    """
-
-    def __init__(
-        self,
-        params,
-        lr: float,
-        *,
-        dht: DHT,
-        prefix: str,
-        target_group_size: int,
-        momentum: float = 0,
-        dampening: float = 0,
-        weight_decay: float = 0,
-        nesterov: bool = False,
-        **kwargs,
-    ):
-        opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
-        super().__init__(
-            opt,
-            dht,
-            prefix=prefix,
-            target_group_size=target_group_size,
-            average_parameters=True,
-            average_gradients=False,
-            **kwargs,
-        )
-
-
-class DecentralizedAdam(DecentralizedOptimizer):
-    """
-    Decentralized Adam/AmsGrad as proposed in [1], [2]
-
-    :param dht: a running hivemind DHT daemon connected to other peers
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param averaging_steps_period: performs averaging after this many optimizer steps
-    :param kwargs: additional parameters passed to DecentralizedOptimizer
-
-    - [1] On the Convergence of Decentralized Adaptive Gradient Methods
-    - [2] Toward Communication Efficient Adaptive Gradient Method - https://dl.acm.org/doi/abs/10.1145/3412815.3416891
-    """
-
-    def __init__(
-        self,
-        params,
-        lr: float,
-        *,
-        dht: DHT,
-        prefix: str,
-        target_group_size: int,
-        averaging_steps_period: int,
-        betas: Tuple[float, float] = (0.9, 0.999),
-        eps: float = 1e-8,
-        weight_decay: float = 0,
-        amsgrad: bool = False,
-        **kwargs,
-    ):
-        opt = torch.optim.Adam(params, lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
-        opt_statistics = ("max_exp_avg_sq",) if amsgrad else ("exp_avg_sq",)
-
-        super().__init__(
-            opt,
-            dht,
-            prefix=prefix,
-            target_group_size=target_group_size,
-            average_parameters=True,
-            average_gradients=False,
-            average_opt_statistics=opt_statistics,
-            averaging_steps_period=averaging_steps_period,
-            **kwargs,
-        )

+ 1 - 1
hivemind/optim/state_averager.py

@@ -14,7 +14,7 @@ from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
 

+ 52 - 19
hivemind/p2p/p2p_daemon.py

@@ -3,12 +3,13 @@ import json
 import logging
 import os
 import secrets
+import warnings
 from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from datetime import datetime
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
 from google.protobuf.message import Message
 from multiaddr import Multiaddr
@@ -17,8 +18,10 @@ import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.proto import crypto_pb2
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import as_aiter, asingle
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger, golog_level_to_python, loglevel, python_level_to_golog
 
 logger = get_logger(__name__)
@@ -89,16 +92,16 @@ class P2P:
         identity_path: Optional[str] = None,
         idle_timeout: float = 30,
         nat_port_map: bool = True,
-        quic: bool = False,
         relay_hop_limit: int = 0,
         startup_timeout: float = 15,
         tls: bool = True,
         use_auto_relay: bool = False,
         use_ipfs: bool = False,
         use_relay: bool = True,
-        use_relay_hop: bool = False,
-        use_relay_discovery: bool = False,
         persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
+        quic: Optional[bool] = None,
+        use_relay_hop: Optional[bool] = None,
+        use_relay_discovery: Optional[bool] = None,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
@@ -112,20 +115,20 @@ class P2P:
                          Details: https://pkg.go.dev/github.com/libp2p/go-libp2p-kad-dht#ModeOpt
         :param force_reachability: Force reachability mode (public/private)
         :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
-        :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
-                              May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``.
+        :param identity_path: Path to a private key file. If defined, makes the peer ID deterministic.
+                              If the file does not exist yet, writes a new private key to this file.
         :param idle_timeout: kill daemon if client has been idle for a given number of
                              seconds before opening persistent streams
         :param nat_port_map: Enables NAT port mapping
-        :param quic: Enables the QUIC transport
         :param relay_hop_limit: sets the hop limit for hop relays
         :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :param tls: Enables TLS1.3 channel security protocol
         :param use_auto_relay: enables autorelay
         :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
         :param use_relay: enables circuit relay
-        :param use_relay_hop: enables hop for relay
-        :param use_relay_discovery: enables passive discovery for relay
+        :param quic: Deprecated, has no effect since libp2p 0.17.0
+        :param use_relay_hop: Deprecated, has no effect since libp2p 0.17.0
+        :param use_relay_discovery: Deprecated, has no effect since libp2p 0.17.0
         :return: a wrapper for the p2p daemon
         """
 
@@ -133,6 +136,14 @@ class P2P:
             initial_peers and use_ipfs
         ), "User-defined initial_peers and use_ipfs=True are incompatible, please choose one option"
 
+        if not all(arg is None for arg in [quic, use_relay_hop, use_relay_discovery]):
+            warnings.warn(
+                "Parameters `quic`, `use_relay_hop`, and `use_relay_discovery` of hivemind.P2P "
+                "have no effect since libp2p 0.17.0 and will be removed in hivemind 1.2.0+",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+
         self = cls()
         with path(cli, P2PD_FILENAME) as p:
             p2pd_path = p
@@ -147,7 +158,7 @@ class P2P:
                     raise ValueError("Please specify an explicit port in announce_maddrs: port 0 is not supported")
 
         need_bootstrap = bool(initial_peers) or use_ipfs
-        process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})
+        process_kwargs = cls.DHT_MODE_MAPPING[dht_mode].copy()
         process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
         for param, value in [
             ("bootstrapPeers", initial_peers),
@@ -156,7 +167,11 @@ class P2P:
         ]:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
+
         if identity_path is not None:
+            if not os.path.isfile(identity_path):
+                logger.info(f"Generating new identity (libp2p private key) in `{identity_path}`")
+                self.generate_identity(identity_path)
             process_kwargs["id"] = identity_path
 
         proc_args = self._make_process_args(
@@ -168,10 +183,7 @@ class P2P:
             idleTimeout=f"{idle_timeout}s",
             listen=self._daemon_listen_maddr,
             natPortMap=nat_port_map,
-            quic=quic,
             relay=use_relay,
-            relayDiscovery=use_relay_discovery,
-            relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
             tls=tls,
             persistentConnMaxMsgSize=persistent_conn_max_msg_size,
@@ -205,6 +217,20 @@ class P2P:
         await self._ping_daemon()
         return self
 
+    @staticmethod
+    def generate_identity(identity_path: str) -> None:
+        private_key = RSAPrivateKey()
+        protobuf = crypto_pb2.PrivateKey(key_type=crypto_pb2.KeyType.RSA, data=private_key.to_bytes())
+
+        try:
+            with open(identity_path, "wb") as f:
+                f.write(protobuf.SerializeToString())
+        except FileNotFoundError:
+            raise FileNotFoundError(
+                f"The directory `{os.path.dirname(identity_path)}` for saving the identity does not exist"
+            )
+        os.chmod(identity_path, 0o400)
+
     @classmethod
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """
@@ -315,6 +341,7 @@ class P2P:
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
+        balanced: bool = False,
     ) -> None:
         """
         :param max_prefetch: Maximum number of items to prefetch from the request stream.
@@ -379,7 +406,7 @@ class P2P:
                 finally:
                     processing_task.cancel()
 
-        await self.add_binary_stream_handler(name, _handle_stream)
+        await self.add_binary_stream_handler(name, _handle_stream, balanced=balanced)
 
     async def _iterate_protobuf_stream_handler(
         self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
@@ -421,16 +448,19 @@ class P2P:
         *,
         stream_input: bool = False,
         stream_output: bool = False,
+        balanced: bool = False,
     ) -> None:
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
                              (not just ``TInputProtobuf``) as input.
         :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
                               (not ``Awaitable[TOutputProtobuf]``).
+        :param balanced: If True, handler will be balanced on p2pd side between all handlers in python.
+                         Default: False
         """
 
         if not stream_input and not stream_output:
-            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type, balanced=balanced)
             return
 
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
@@ -443,13 +473,14 @@ class P2P:
             else:
                 yield await output
 
-        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
+        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type, balanced=balanced)
 
     async def _add_protobuf_unary_handler(
         self,
         handle_name: str,
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         input_protobuf_type: Type[Message],
+        balanced: bool = False,
     ) -> None:
         """
         Register a request-response (unary) handler. Unary requests and responses
@@ -471,7 +502,7 @@ class P2P:
             response = await handler(input_serialized, context)
             return response.SerializeToString()
 
-        await self._client.add_unary_handler(handle_name, _unary_handler)
+        await self._client.add_unary_handler(handle_name, _unary_handler, balanced=balanced)
 
     async def call_protobuf_handler(
         self,
@@ -515,10 +546,12 @@ class P2P:
 
         self._listen_task = asyncio.create_task(listen())
 
-    async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
+    async def add_binary_stream_handler(
+        self, name: str, handler: p2pclient.StreamHandler, balanced: bool = False
+    ) -> None:
         if self._listen_task is None:
             self._start_listening()
-        await self._client.stream_handler(name, handler)
+        await self._client.stream_handler(name, handler, balanced)
 
     async def call_binary_stream_handler(
         self, peer_id: PeerID, handler_name: str

+ 6 - 4
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -246,10 +246,10 @@ class ControlClient:
         self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
         self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
         call_id = uuid4()
 
-        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto, balanced=balanced)
         req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
 
         if self.unary_handlers.get(proto):
@@ -358,11 +358,13 @@ class ControlClient:
 
         return stream_info, reader, writer
 
-    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
         reader, writer = await self.daemon_connector.open_connection()
 
         listen_path_maddr_bytes = self.listen_maddr.to_bytes()
-        stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto])
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(
+            addr=listen_path_maddr_bytes, proto=[proto], balanced=balanced
+        )
         req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
         await write_pbmsg(writer, req)
 

+ 5 - 4
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -61,8 +61,8 @@ class Client:
         async with self.control.listen():
             yield self
 
-    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
-        await self.control.add_unary_handler(proto, handler)
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False):
+        await self.control.add_unary_handler(proto, handler, balanced=balanced)
 
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         return await self.control.call_unary_handler(peer_id, proto, data)
@@ -105,11 +105,12 @@ class Client:
         """
         return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
 
-    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None:
         """
         Register a stream handler
         :param proto: protocols that handler serves
         :param handler_cb: handler callback
+        :param balanced: flag if stream handler should be balanced on p2pd side. Default: False.
         :return:
         """
-        await self.control.stream_handler(proto=proto, handler_cb=handler_cb)
+        await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced)

+ 4 - 2
hivemind/p2p/servicer.py

@@ -104,11 +104,12 @@ class ServicerBase:
         caller.__name__ = handler.method_name
         return caller
 
-    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None) -> None:
+    async def add_p2p_handlers(
+        self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None, balanced: bool = False
+    ) -> None:
         self._collect_rpc_handlers()
 
         servicer = self if wrapper is None else wrapper
-
         await asyncio.gather(
             *[
                 p2p.add_protobuf_handler(
@@ -117,6 +118,7 @@ class ServicerBase:
                     handler.request_type,
                     stream_input=handler.stream_input,
                     stream_output=handler.stream_output,
+                    balanced=balanced,
                 )
                 for handler in self._rpc_handlers
             ]

+ 24 - 0
hivemind/proto/crypto.proto

@@ -0,0 +1,24 @@
+// Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+// Licence: MIT
+// Author: Kevin Mai-Husan Chia
+
+syntax = "proto2";
+
+package crypto.pb;
+
+enum KeyType {
+  RSA = 0;
+  Ed25519 = 1;
+  Secp256k1 = 2;
+  ECDSA = 3;
+}
+
+message PublicKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}
+
+message PrivateKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}

+ 0 - 11
hivemind/proto/dht.proto

@@ -4,17 +4,6 @@ import "auth.proto";
 // this protocol defines how Hivemind nodes form a distributed hash table.
 // For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode)
 
-service DHT {
-  // find out recipient's DHTID and possibly update its routing table
-  rpc rpc_ping(PingRequest) returns (PingResponse);
-
-  // request a node to store one or multiple data items (key - value - expiration)
-  rpc rpc_store(StoreRequest) returns (StoreResponse);
-
-  // for given keys, request values (if stored) or a list of peers that are likely to have them
-  rpc rpc_find(FindRequest) returns (FindResponse);
-}
-
 message NodeInfo {
   // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
   // if either node_id or port is absent, simply request recipient info (for client-only mode)

+ 6 - 4
hivemind/proto/p2pd.proto

@@ -1,6 +1,6 @@
-//Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
-//Licence: MIT
-//Author: Kevin Mai-Husan Chia
+// Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+// Licence: MIT
+// Author: Kevin Mai-Husan Chia
 
 syntax = "proto2";
 
@@ -15,7 +15,7 @@ message Request {
     DHT                      = 4;
     LIST_PEERS               = 5;
     CONNMANAGER              = 6;
-    DISCONNECT               = 7;      
+    DISCONNECT               = 7;
     PUBSUB                   = 8;
 
     PERSISTENT_CONN_UPGRADE  = 9;
@@ -90,6 +90,7 @@ message StreamOpenRequest {
 message StreamHandlerRequest {
   required bytes addr = 1;
   repeated string proto = 2;
+  required bool balanced = 3;
 }
 
 message ErrorResponse {
@@ -201,6 +202,7 @@ message CallUnaryResponse {
 
 message AddUnaryHandlerRequest {
   required string proto = 1;
+  required bool balanced = 2;
 }
 
 message DaemonError {

+ 0 - 8
hivemind/proto/runtime.proto

@@ -1,14 +1,6 @@
 syntax = "proto3";
 
 
-service ConnectionHandler {
-  // Listens to incoming requests for expert computation
-  rpc info(ExpertUID) returns (ExpertInfo);
-  rpc forward(ExpertRequest) returns (ExpertResponse);
-  rpc backward(ExpertRequest) returns (ExpertResponse);
-}
-
-
 message ExpertUID {
   string uid = 1;
 }

+ 2 - 2
hivemind/utils/__init__.py

@@ -1,11 +1,11 @@
 from hivemind.utils.asyncio import *
-from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
-from hivemind.utils.networking import *
+from hivemind.utils.networking import log_visible_maddrs
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 7 - 8
hivemind/utils/asyncio.py

@@ -2,7 +2,7 @@ import asyncio
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar, Union
 
 import uvloop
 
@@ -29,6 +29,12 @@ async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
     return await aiter.__anext__()
 
 
+async def iter_as_aiter(iterable: Iterable[T]) -> AsyncIterator[T]:
+    """create an asynchronous iterator from single iterable"""
+    for elem in iterable:
+        yield elem
+
+
 async def as_aiter(*args: T) -> AsyncIterator[T]:
     """create an asynchronous iterator from a sequence of values"""
     for arg in args:
@@ -72,13 +78,6 @@ async def asingle(aiter: AsyncIterable[T]) -> T:
     return item
 
 
-async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
-    """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
-    async for item in aiter:
-        return item
-    return default
-
-
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable

+ 9 - 6
hivemind/utils/crypto.py

@@ -60,19 +60,22 @@ class RSAPrivateKey(PrivateKey):
     def get_public_key(self) -> RSAPublicKey:
         return RSAPublicKey(self._private_key.public_key())
 
+    def to_bytes(self) -> bytes:
+        return self._private_key.private_bytes(
+            encoding=serialization.Encoding.DER,
+            format=serialization.PrivateFormat.TraditionalOpenSSL,
+            encryption_algorithm=serialization.NoEncryption(),
+        )
+
     def __getstate__(self):
         state = self.__dict__.copy()
         # Serializes the private key to make the class instances picklable
-        state["_private_key"] = self._private_key.private_bytes(
-            encoding=serialization.Encoding.PEM,
-            format=serialization.PrivateFormat.OpenSSH,
-            encryption_algorithm=serialization.NoEncryption(),
-        )
+        state["_private_key"] = self.to_bytes()
         return state
 
     def __setstate__(self, state):
         self.__dict__.update(state)
-        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
+        self._private_key = serialization.load_der_private_key(self._private_key, password=None)
 
 
 class RSAPublicKey(PublicKey):

+ 0 - 210
hivemind/utils/grpc.py

@@ -1,210 +0,0 @@
-"""
-Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
-"""
-
-from __future__ import annotations
-
-import os
-import threading
-from typing import Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
-
-import grpc
-
-from hivemind.proto import runtime_pb2
-from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import Endpoint
-from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
-
-logger = get_logger(__name__)
-
-Stub = TypeVar("Stub")
-
-GRPC_KEEPALIVE_OPTIONS = (
-    ("grpc.keepalive_time_ms", 60 * 1000),
-    ("grpc.keepalive_timeout_ms", 60 * 1000),
-    ("grpc.keepalive_permit_without_calls", True),
-    ("grpc.http2.max_pings_without_data", 0),
-    ("grpc.http2.min_time_between_pings_ms", 30 * 1000),
-    ("grpc.http2.min_ping_interval_without_data_ms", 10 * 1000),
-)
-
-
-class ChannelInfo(NamedTuple):
-    target: Endpoint
-    aio: bool
-    options: Tuple[Tuple[str, str], ...]
-    credentials: Optional[grpc.ChannelCredentials]
-    compression: Optional[grpc.Compression]
-
-
-class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.Channel], Dict]]):
-    """
-    A process-wide cache of gRPC channels, supports both normal and aio channels, secure/insecure channels, etc
-    Based on grpcio internal channel cache by Richard Belleville and Lidi Zheng (thanks!)
-    Unlike TimedStorage, ChannelCache actively evicts stale channels even if the cache is not accessed
-    Unlike grpc._simple_stubs.ChannelCache, this implementation supports aio and does not forcibly close active channels
-    """
-
-    MAXIMUM_CHANNELS = int(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM", 4096))
-    EVICTION_PERIOD_SECONDS = float(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS", 10 * 60))
-    logger.debug(f"Eviction period = {EVICTION_PERIOD_SECONDS}s, max channels = {MAXIMUM_CHANNELS}")
-
-    _singleton: Optional[ChannelCache] = None
-    _singleton_pid: int = os.getpid()
-    _lock: threading.RLock = threading.RLock()
-    _update_eviction_evt: threading.Event = threading.Event()
-
-    def __init__(self, _created_as_singleton=False):
-        assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
-        super().__init__(maxsize=self.MAXIMUM_CHANNELS)
-        self._is_active = True
-        self._nearest_expiration_time = float("inf")
-        self._eviction_thread = threading.Thread(target=self._evict_stale_channels_in_background, daemon=True)
-        self._eviction_thread.start()
-
-    @classmethod
-    def get_singleton(cls):
-        """Get or create the channel cache for the current process"""
-        with cls._lock:
-            if cls._singleton is None or cls._singleton_pid != os.getpid():
-                if cls._singleton is not None:
-                    cls._singleton._stop_background_thread()
-                cls._singleton, cls._singleton_pid = cls(_created_as_singleton=True), os.getpid()
-            return cls._singleton
-
-    @classmethod
-    def get_stub(
-        cls,
-        target: Endpoint,
-        stub_type: Type[Stub],
-        *,
-        aio: bool,
-        options: Tuple[Tuple[str, Any]] = (),
-        channel_credentials: Optional[grpc.ChannelCredentials] = None,
-        compression: Optional[grpc.Compression] = None,
-    ) -> Stub:
-        """
-        Create a grpc channel with given options or reuse pre-existing one
-
-        :param target: the recipient's address and port
-        :param stub_type: a gRPC stub (client) to be instantiated
-        :param aio: if True, returns grpc.Channel, otherwise returns grpc.aio.Channel
-        :param options: see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
-        :param channel_credentials: if specified, create a secure channel usin these credentials (default = insecure)
-        :param compression: see https://github.com/grpc/grpc/tree/master/examples/python/compression
-        """
-        cache = cls.get_singleton()
-        with cls._lock:
-            key = ChannelInfo(target, aio, tuple(options), channel_credentials, compression)
-            entry: ValueWithExpiration = super(cls, cache).get(key)
-
-            if entry is not None:
-                channel, stubs = entry.value
-            else:
-                channel = cls._create_channel(*key)
-                stubs = {}
-
-            channel._channel.check_connectivity_state(True)
-
-            if stub_type not in stubs:
-                stubs[stub_type] = stub_type(channel)
-
-            # either cache channel or update expiration of an existing channel
-            expiration_time = get_dht_time() + cls.EVICTION_PERIOD_SECONDS
-            super(cls, cache).store(key, (channel, stubs), expiration_time)
-
-            if expiration_time < cache._nearest_expiration_time:
-                cache._nearest_expiration_time = expiration_time
-                cls._update_eviction_evt.set()
-
-            return stubs[stub_type]
-
-    @classmethod
-    def _create_channel(
-        cls,
-        target: Endpoint,
-        aio: bool,
-        extra_options: Tuple[Tuple[str, Any], ...],
-        channel_credentials: Optional[grpc.ChannelCredentials],
-        compression: Optional[grpc.Compression],
-    ) -> Union[grpc.Channel, grpc.aio.Channel]:
-        namespace = grpc.aio if aio else grpc
-
-        options = extra_options + GRPC_KEEPALIVE_OPTIONS
-
-        if channel_credentials is None:
-            logger.debug(
-                f"Creating insecure {namespace} channel with options '{options}' " f"and compression '{compression}'"
-            )
-            return namespace.insecure_channel(target, options=options, compression=compression)
-        else:
-            logger.debug(
-                f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
-                f"options '{options}' and compression '{compression}'"
-            )
-            return namespace.secure_channel(
-                target, credentials=channel_credentials, options=options, compression=compression
-            )
-
-    def _evict_stale_channels_in_background(self):
-        while self._is_active:
-            now = get_dht_time()
-            time_to_wait = max(0.0, self._nearest_expiration_time - now)
-            interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float("inf") else None)
-            if interrupted_early:
-                self._update_eviction_evt.clear()
-                continue
-
-            with self._lock:
-                self._remove_outdated()
-                _, entry = super().top()
-                self._nearest_expiration_time = entry.expiration_time if entry is not None else float("inf")
-
-    def _stop_background_thread(self):
-        with self._lock:
-            self._is_active = False
-            self._update_eviction_evt.set()
-
-    def store(self, *args, **kwargs) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-    def get(self, *args, **kwargs) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-    def top(self) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-
-STREAMING_CHUNK_SIZE_BYTES = 2**16
-
-
-def split_for_streaming(
-    serialized_tensor: runtime_pb2.Tensor,
-    chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
-) -> Iterator[runtime_pb2.Tensor]:
-    """Split serialized_tensor into multiple chunks for gRPC streaming"""
-    buffer = memoryview(serialized_tensor.buffer)
-    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
-    yield runtime_pb2.Tensor(
-        compression=serialized_tensor.compression,
-        buffer=buffer[:chunk_size_bytes].tobytes(),
-        chunks=num_chunks,
-        size=serialized_tensor.size,
-        dtype=serialized_tensor.dtype,
-        requires_grad=serialized_tensor.requires_grad,
-    )
-    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
-        yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
-
-
-def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
-    """Restore a result of split_into_chunks into a single serialized tensor"""
-    stream = iter(stream)
-    first_chunk = next(stream)
-    serialized_tensor = runtime_pb2.Tensor()
-    serialized_tensor.CopyFrom(first_chunk)
-    buffer_chunks = [first_chunk.buffer]
-    for tensor_part in stream:
-        buffer_chunks.append(tensor_part.buffer)
-    serialized_tensor.buffer = b"".join(buffer_chunks)
-    return serialized_tensor

+ 24 - 0
hivemind/utils/math.py

@@ -0,0 +1,24 @@
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def orthogonalize_(matrix, eps: float = 1e-8):
+    """Orthogonalize a 2d tensor in-place over the last dimension"""
+    n, m = matrix.shape
+    for i in range(m):
+        col = matrix[:, i]
+        F.normalize(col, dim=0, eps=eps, out=col)
+        if i + 1 < m:
+            rest = matrix[:, i + 1 :]
+            rest.addmm_(col[:, None], (col @ rest)[None, :], alpha=-1)
+
+
+def get_flatten_greedy_dims(tensor: torch.Tensor, max_ndim: int = 2):
+    """get dims to flatten tensor upto max_ndim dimensions by merging small axes together"""
+    dims = list(tensor.shape)
+    while len(dims) > max_ndim:
+        squeeze_ix = min(range(len(dims) - 1), key=lambda i: dims[i] * dims[i + 1])
+        squeezed_dim = dims.pop(squeeze_ix)
+        dims[squeeze_ix] *= squeezed_dim
+    return dims

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -9,7 +9,7 @@ import threading
 import uuid
 from contextlib import nullcontext
 from enum import Enum, auto
-from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar
 from weakref import ref
 
 import torch  # used for py3.7-compatible shared memory

+ 25 - 41
hivemind/utils/networking.py

@@ -1,54 +1,18 @@
-import socket
-from contextlib import closing
 from ipaddress import ip_address
-from typing import Optional, Sequence
+from typing import List, Sequence
 
 from multiaddr import Multiaddr
 
-Hostname, Port = str, int  # flavour types
-Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
-LOCALHOST = "127.0.0.1"
-
-
-def get_port(endpoint: Endpoint) -> Optional[Port]:
-    """get port or None if port is undefined"""
-    # TODO: find a standard way to get port, make sure it works in malformed ports
-    try:
-        return int(endpoint[endpoint.rindex(":") + 1 :], base=10)
-    except ValueError:  # :* or not specified
-        return None
-
-
-def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
-    assert endpoint.endswith(":*") or get_port(endpoint) is not None, endpoint
-    return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
-
-
-def strip_port(endpoint: Endpoint) -> Hostname:
-    """Removes port from the end of endpoint. If port is not specified, does nothing"""
-    maybe_port = endpoint[endpoint.rindex(":") + 1 :]
-    return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
+from hivemind.utils.logging import TextStyle, get_logger
 
+LOCALHOST = "127.0.0.1"
 
-def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """
-    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
-
-    :note: Using this function is discouraged since it often leads to a race condition
-           with the "Address is already in use" error if the code is run in parallel.
-    """
-    try:
-        with closing(socket.socket(*params)) as sock:
-            sock.bind(("", 0))
-            sock.setsockopt(*opt)
-            return sock.getsockname()[1]
-    except Exception as e:
-        raise e
+logger = get_logger(__name__)
 
 
 def choose_ip_address(
     maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
-) -> Hostname:
+) -> str:
     """
     Currently, some components of hivemind are not converted to work over libp2p and use classical networking.
     To allow other peers reach a server when needed, these components announce a machine's IP address.
@@ -74,3 +38,23 @@ def choose_ip_address(
                         return value_for_protocol
 
     raise ValueError(f"No IP address found among given multiaddrs: {maddrs}")
+
+
+def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
+    if only_p2p:
+        unique_addrs = {addr["p2p"] for addr in visible_maddrs}
+        initial_peers = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
+    else:
+        available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
+        if available_ips:
+            preferred_ip = choose_ip_address(available_ips)
+            selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
+        else:
+            selected_maddrs = visible_maddrs
+        initial_peers = " ".join(str(addr) for addr in selected_maddrs)
+
+    logger.info(
+        f"Running a DHT instance. To connect other peers to this one, use "
+        f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers}{TextStyle.RESET}"
+    )
+    logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")

+ 46 - 0
hivemind/utils/streaming.py

@@ -0,0 +1,46 @@
+"""
+Utilities for streaming tensors
+"""
+
+from __future__ import annotations
+
+from typing import Iterable, Iterator
+
+from hivemind.proto import runtime_pb2
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+STREAMING_CHUNK_SIZE_BYTES = 2**16
+
+
+def split_for_streaming(
+    serialized_tensor: runtime_pb2.Tensor,
+    chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
+) -> Iterator[runtime_pb2.Tensor]:
+    """Split serialized_tensor into multiple chunks for streaming"""
+    buffer = memoryview(serialized_tensor.buffer)
+    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
+    yield runtime_pb2.Tensor(
+        compression=serialized_tensor.compression,
+        buffer=buffer[:chunk_size_bytes].tobytes(),
+        chunks=num_chunks,
+        size=serialized_tensor.size,
+        dtype=serialized_tensor.dtype,
+        requires_grad=serialized_tensor.requires_grad,
+    )
+    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
+        yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
+
+
+def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
+    """Restore a result of split_into_chunks into a single serialized tensor"""
+    stream = iter(stream)
+    first_chunk = next(stream)
+    serialized_tensor = runtime_pb2.Tensor()
+    serialized_tensor.CopyFrom(first_chunk)
+    buffer_chunks = [first_chunk.buffer]
+    for tensor_part in stream:
+        buffer_chunks.append(tensor_part.buffer)
+    serialized_tensor.buffer = b"".join(buffer_chunks)
+    return serialized_tensor

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.black]
 line-length = 119
-required-version = "22.1.0"
+required-version = "22.3.0"
 
 [tool.isort]
 profile = "black"

+ 1 - 1
requirements-dev.txt

@@ -6,6 +6,6 @@ coverage==6.0.2  # see https://github.com/pytest-dev/pytest-cov/issues/520
 tqdm
 scikit-learn
 torchvision
-black==22.1.0
+black==22.3.0
 isort==5.10.1
 psutil

+ 0 - 1
requirements.txt

@@ -6,7 +6,6 @@ prefetch_generator>=1.0.1
 msgpack>=0.5.6
 sortedcontainers
 uvloop>=0.14.0
-grpcio>=1.33.2
 grpcio-tools>=1.33.2
 protobuf>=3.12.2
 configargparse>=1.2.3

+ 38 - 30
setup.py

@@ -3,7 +3,6 @@ import glob
 import hashlib
 import os
 import re
-import shlex
 import subprocess
 import tarfile
 import tempfile
@@ -14,20 +13,25 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.6"
-P2PD_CHECKSUM = "627d0c3b475a29331fdfd1667e828f6d"
-LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
-P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
+P2PD_VERSION = "v0.3.9"
+
+P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
+P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
+
+# The value is sha256 of the binary from the release page
+EXECUTABLES = {
+    "p2pd": "8f9434f4717f6e851430f75f07e283d5ddeb2c7cde1b3648e677d813703f4e40",
+}
+
 
 here = os.path.abspath(os.path.dirname(__file__))
 
 
-def md5(fname, chunk_size=4096):
-    hash_md5 = hashlib.md5()
-    with open(fname, "rb") as f:
-        for chunk in iter(lambda: f.read(chunk_size), b""):
-            hash_md5.update(chunk)
-    return hash_md5.hexdigest()
+def sha256(path):
+    if not os.path.exists(path):
+        return None
+    with open(path, "rb") as f:
+        return hashlib.sha256(f.read()).hexdigest()
 
 
 def proto_compile(output_path):
@@ -37,7 +41,6 @@ def proto_compile(output_path):
         "grpc_tools.protoc",
         "--proto_path=hivemind/proto",
         f"--python_out={output_path}",
-        f"--grpc_python_out={output_path}",
     ] + glob.glob("hivemind/proto/*.proto")
 
     code = grpc_tools.protoc.main(cli_args)
@@ -64,32 +67,36 @@ def build_p2p_daemon():
 
     with tempfile.TemporaryDirectory() as tempdir:
         dest = os.path.join(tempdir, "libp2p-daemon.tar.gz")
-        urllib.request.urlretrieve(LIBP2P_TAR_URL, dest)
+        urllib.request.urlretrieve(P2PD_SOURCE_URL, dest)
 
         with tarfile.open(dest, "r:gz") as tar:
             tar.extractall(tempdir)
 
-        result = subprocess.run(
-            f'go build -o {shlex.quote(os.path.join(here, "hivemind", "hivemind_cli", "p2pd"))}',
-            cwd=os.path.join(tempdir, f"go-libp2p-daemon-{P2PD_VERSION[1:]}", "p2pd"),
-            shell=True,
-        )
-
-        if result.returncode:
-            raise RuntimeError(
-                "Failed to build or install libp2p-daemon:" f" exited with status code: {result.returncode}"
+        for executable in EXECUTABLES:
+            result = subprocess.run(
+                ["go", "build", "-o", os.path.join(here, "hivemind", "hivemind_cli", executable)],
+                cwd=os.path.join(tempdir, f"go-libp2p-daemon-{P2PD_VERSION.lstrip('v')}", executable),
             )
+            if result.returncode != 0:
+                raise RuntimeError(f"Failed to build {executable}: exited with status code: {result.returncode}")
 
 
 def download_p2p_daemon():
-    install_path = os.path.join(here, "hivemind", "hivemind_cli")
-    binary_path = os.path.join(install_path, "p2pd")
-    if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
-        print("Downloading Peer to Peer Daemon")
-        urllib.request.urlretrieve(P2PD_BINARY_URL, binary_path)
-        os.chmod(binary_path, 0o777)
-        if md5(binary_path) != P2PD_CHECKSUM:
-            raise RuntimeError(f"Downloaded p2pd binary from {P2PD_BINARY_URL} does not match with md5 checksum")
+    for executable, expected_hash in EXECUTABLES.items():
+        binary_path = os.path.join(here, "hivemind", "hivemind_cli", executable)
+
+        if sha256(binary_path) != expected_hash:
+            binary_url = os.path.join(P2PD_BINARY_URL, executable)
+            print(f"Downloading {binary_url}")
+
+            urllib.request.urlretrieve(binary_url, binary_path)
+            os.chmod(binary_path, 0o777)
+
+            actual_hash = sha256(binary_path)
+            if actual_hash != expected_hash:
+                raise RuntimeError(
+                    f"The sha256 checksum for {executable} does not match (expected: {expected_hash}, actual: {actual_hash})"
+                )
 
 
 class BuildPy(build_py):
@@ -170,6 +177,7 @@ setup(
     ],
     entry_points={
         "console_scripts": [
+            "hivemind-dht = hivemind.hivemind_cli.run_dht:main",
             "hivemind-server = hivemind.hivemind_cli.run_server:main",
         ]
     },

+ 1 - 2
tests/test_allreduce.py

@@ -10,8 +10,7 @@ from hivemind import Quantile8BitQuantization, aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor
-from hivemind.p2p import P2P, StubBase
-from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.p2p import P2P
 
 
 @pytest.mark.forked

+ 14 - 16
tests/test_allreduce_fault_tolerance.py

@@ -1,14 +1,10 @@
 from __future__ import annotations
 
-import asyncio
 from enum import Enum, auto
-from typing import AsyncIterator
 
 import pytest
-import torch
 
 import hivemind
-from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.averager import *
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
@@ -35,7 +31,7 @@ class FaultyAverager(hivemind.DecentralizedAverager):
         self.fault = fault
         super().__init__(*args, **kwargs)
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
@@ -60,24 +56,26 @@ class FaultyAverager(hivemind.DecentralizedAverager):
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
-                    gathered=user_gathered,
                     modes=modes,
                     fault=self.fault,
                     **kwargs,
                 )
 
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
+                self._running_groups[group_info.group_id].set_result(allreduce)
+                # TODO maybe this can be extracted into a method that checks if register_... context is active.
 
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
+                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                    iter_results = allreduce.run()
+                    async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
+                        # all-reduce is performed asynchronously while iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                    self._state_updated.set()
+
+                else:
+                    async for _ in allreduce:  # trigger all-reduce by iterating
+                        raise ValueError("aux peers should not receive averaged tensors")
 
-                return allreduce.gathered
+                return user_gathered
         except BaseException as e:
             logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")

+ 16 - 80
tests/test_averaging.py

@@ -6,7 +6,7 @@ import pytest
 import torch
 
 import hivemind
-import hivemind.averaging.averager
+from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.key_manager import GroupKeyManager
@@ -78,11 +78,11 @@ def _test_allreduce_once(n_clients, n_aux):
 
     dht_instances = launch_dht_instances(len(peer_tensors))
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             tensors,
             dht=dht,
             target_group_size=4,
-            averaging_expiration=15,
+            min_matchmaking_time=15,
             prefix="mygroup",
             client_mode=mode == AveragingMode.CLIENT,
             auxiliary=mode == AveragingMode.AUX,
@@ -135,11 +135,11 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
 
     dht_instances = launch_dht_instances(4)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             tensors,
             dht=dht,
             target_group_size=4,
-            averaging_expiration=15,
+            min_matchmaking_time=15,
             prefix="mygroup",
             client_mode=client_mode,
             start=True,
@@ -185,7 +185,7 @@ def compute_mean_std(averagers, unbiased=True):
 def test_allreduce_grid():
     dht_instances = launch_dht_instances(8)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             target_group_size=2,
@@ -221,11 +221,11 @@ def test_allreduce_grid():
 def test_allgather(n_averagers=8, target_group_size=4):
     dht_instances = launch_dht_instances(n_averagers)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             [torch.ones(1)],
             dht=dht,
             target_group_size=target_group_size,
-            averaging_expiration=15,
+            min_matchmaking_time=15,
             prefix="mygroup",
             initial_group_bits="000",
             start=True,
@@ -297,11 +297,11 @@ def test_load_balancing():
 def test_too_few_peers():
     dht_instances = launch_dht_instances(4)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             target_group_size=2,
-            averaging_expiration=1,
+            min_matchmaking_time=1,
             request_timeout=0.5,
             prefix="mygroup",
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
@@ -327,11 +327,11 @@ def test_too_few_peers():
 def test_overcrowded(num_peers=16):
     dht_instances = launch_dht_instances(num_peers)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             target_group_size=2,
-            averaging_expiration=1,
+            min_matchmaking_time=1,
             request_timeout=0.5,
             prefix="mygroup",
             initial_group_bits="",
@@ -353,7 +353,7 @@ def test_load_state_from_peers():
     super_metadata = dict(x=123)
     super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
 
-    class TestAverager(hivemind.averaging.DecentralizedAverager):
+    class TestAverager(DecentralizedAverager):
         def get_current_state(self):
             """
             Get current state and send it to a peer. executed in the host process. Meant to be overriden.
@@ -455,7 +455,7 @@ def test_load_state_priority():
 @pytest.mark.forked
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
-    averager = hivemind.averaging.DecentralizedAverager(
+    averager = DecentralizedAverager(
         [torch.randn(3)],
         dht=dht,
         start=True,
@@ -469,7 +469,7 @@ def test_getset_bits():
 @pytest.mark.forked
 def test_averaging_trigger():
     averagers = tuple(
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             min_matchmaking_time=0.5,
@@ -514,7 +514,7 @@ def test_averaging_trigger():
 @pytest.mark.forked
 def test_averaging_cancel():
     averagers = tuple(
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             min_matchmaking_time=0.5,
@@ -540,67 +540,3 @@ def test_averaging_cancel():
 
     for averager in averagers:
         averager.shutdown()
-
-
-@pytest.mark.forked
-def test_training_averager(n_steps: int = 10, n_dims: int = 16):
-    torch.manual_seed(42)
-
-    dht_instances = launch_dht_instances(2)
-    common_kwargs = {
-        "start": True,
-        "prefix": "demo-run",
-        "target_group_size": 2,
-    }
-
-    x1 = torch.randn(n_dims, requires_grad=True)
-    opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.TrainingAverager(
-        opt1,
-        average_gradients=True,
-        average_parameters=True,
-        average_opt_statistics=["exp_avg_sq"],
-        dht=dht_instances[0],
-        **common_kwargs
-    )
-
-    x2 = torch.randn(n_dims, requires_grad=True)
-    opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.TrainingAverager(
-        opt2,
-        average_gradients=True,
-        average_parameters=True,
-        average_opt_statistics=["exp_avg_sq"],
-        dht=dht_instances[1],
-        **common_kwargs
-    )
-    a = torch.ones(n_dims)
-
-    for i in range(n_steps):
-        opt1.zero_grad()
-        opt2.zero_grad()
-        (x1 - a).pow(2).sum().backward()
-        (x2 - a).pow(2).sum().backward()
-        opt1.step()
-        opt2.step()
-
-        with torch.no_grad():
-            x_avg = 0.5 * (x1 + x2)
-            grad_avg = 0.5 * (x1.grad + x2.grad)
-            stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
-
-        # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
-        f1 = averager1.step(wait=False)
-        f2 = averager2.step(wait=False)
-        f1.result()
-        f2.result()
-
-        assert torch.allclose(x1, x_avg)
-        assert torch.allclose(x2, x_avg)
-        assert torch.allclose(x1.grad, grad_avg)
-        assert torch.allclose(x2.grad, grad_avg)
-        assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
-        assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
-
-    for instance in [averager1, averager2] + dht_instances:
-        instance.shutdown()

+ 63 - 0
tests/test_cli_scripts.py

@@ -0,0 +1,63 @@
+import re
+from subprocess import PIPE, Popen
+from time import sleep
+
+DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$")
+
+
+def test_dht_connection_successful():
+    dht_refresh_period = 1
+
+    dht_proc = Popen(
+        ["hivemind-dht", "--host_maddrs", "/ip4/127.0.0.1/tcp/0", "--refresh_period", str(dht_refresh_period)],
+        stderr=PIPE,
+        text=True,
+        encoding="utf-8",
+    )
+
+    first_line = dht_proc.stderr.readline()
+    second_line = dht_proc.stderr.readline()
+    dht_pattern_match = DHT_START_PATTERN.search(first_line)
+    assert dht_pattern_match is not None, first_line
+    assert "Full list of visible multiaddresses:" in second_line, second_line
+
+    initial_peers = dht_pattern_match.group(1).split(" ")
+
+    dht_client_proc = Popen(
+        ["hivemind-dht", *initial_peers, "--host_maddrs", "/ip4/127.0.0.1/tcp/0"],
+        stderr=PIPE,
+        text=True,
+        encoding="utf-8",
+    )
+
+    # skip first two lines with connectivity info
+    for _ in range(2):
+        dht_client_proc.stderr.readline()
+    first_report_msg = dht_client_proc.stderr.readline()
+
+    assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
+
+    # ensure we get the output of dht_proc after the start of dht_client_proc
+    sleep(dht_refresh_period)
+
+    # expect that one of the next logging outputs from the first peer shows a new connection
+    for _ in range(5):
+        first_report_msg = dht_proc.stderr.readline()
+        second_report_msg = dht_proc.stderr.readline()
+
+        if (
+            "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
+            and "Local storage contains 0 keys" in second_report_msg
+        ):
+            break
+    else:
+        assert (
+            "2 DHT nodes (including this one) are in the local routing table" in first_report_msg
+            and "Local storage contains 0 keys" in second_report_msg
+        )
+
+    dht_proc.terminate()
+    dht_client_proc.terminate()
+
+    dht_proc.wait()
+    dht_client_proc.wait()

+ 3 - 2
tests/test_compression.py

@@ -20,6 +20,7 @@ from hivemind.compression import (
 )
 from hivemind.compression.adaptive import AdaptiveCompressionBase
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 
 from test_utils.dht_swarms import launch_dht_instances
 
@@ -47,9 +48,9 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
 def test_serialize_tensor():
     def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
         serialized_tensor = serialize_torch_tensor(tensor, compression)
-        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+        chunks = list(split_for_streaming(serialized_tensor, chunk_size))
         assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-        restored = hivemind.combine_from_streaming(chunks)
+        restored = combine_from_streaming(chunks)
         assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
 
     tensor = torch.randn(512, 12288)

+ 192 - 0
tests/test_connection_handler.py

@@ -0,0 +1,192 @@
+from __future__ import annotations
+
+import asyncio
+import math
+from typing import Any, Dict
+
+import pytest
+import torch
+
+from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.moe.server.module_backend import ModuleBackend
+from hivemind.moe.server.task_pool import TaskPool
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.streaming import split_for_streaming
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_info():
+    handler = ConnectionHandler(
+        DHT(start=True),
+        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    # info
+    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
+    assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
+
+    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
+    assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
+
+    with pytest.raises(P2PHandlerError):
+        await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_forward():
+    handler = ConnectionHandler(
+        DHT(start=True),
+        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    inputs = torch.randn(1, 2)
+    inputs_long = torch.randn(2**21, 2)
+
+    # forward unary
+    response = await client_stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * 1)
+
+    response = await client_stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * 2)
+
+    # forward streaming
+    split = (
+        p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
+    )
+    output_generator = await client_stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
+            iter_as_aiter(split),
+        ),
+    )
+    outputs_list = [part async for part in output_generator]
+    assert len(outputs_list) == math.ceil(inputs_long.numel() * 4 / DEFAULT_MAX_MSG_SIZE)
+
+    results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
+    assert len(results) == 1
+    assert torch.allclose(results[0], inputs_long * 2)
+
+    # forward errors
+    with pytest.raises(P2PHandlerError):
+        # no such expert: fails with P2PHandlerError KeyError('expert3')
+        await client_stub.rpc_forward(
+            runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
+        )
+
+    with pytest.raises(P2PHandlerError):
+        # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
+        await client_stub.rpc_forward(
+            runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
+        )
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler_backward():
+    handler = ConnectionHandler(
+        DHT(start=True),
+        dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    inputs = torch.randn(1, 2)
+    inputs_long = torch.randn(2**21, 2)
+
+    # backward unary
+    response = await client_stub.rpc_backward(
+        runtime_pb2.ExpertRequest(
+            uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
+        )
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * -2)
+
+    # backward streaming
+    split = (
+        p
+        for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
+        for p in split_for_streaming(t, chunk_size_bytes=2**16)
+    )
+    output_generator = await client_stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
+            iter_as_aiter(split),
+        ),
+    )
+    results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
+    assert len(results) == 1
+    assert torch.allclose(results[0], inputs_long * 3)
+
+    # backward errors
+    with pytest.raises(P2PHandlerError):
+        # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
+        await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
+
+    with pytest.raises(P2PHandlerError):
+        # backward fails: empty stream
+        output_generator = await client_stub.rpc_backward_stream(
+            amap_in_executor(
+                lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
+                iter_as_aiter([]),
+            ),
+        )
+        results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, output_generator))
+        assert len(results) == 1
+        assert torch.allclose(results[0], inputs_long * 3)
+
+    # check that handler did not crash after failed request
+    await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
+
+    handler.terminate()
+    handler.join()
+
+
+class DummyPool(TaskPool):
+    def __init__(self, k: float):
+        self.k = k
+
+    async def submit_task(self, *inputs: torch.Tensor):
+        await asyncio.sleep(0.01)
+        assert inputs[0].shape[-1] == 2
+        return [inputs[0] * self.k]
+
+
+class DummyModuleBackend(ModuleBackend):
+    def __init__(self, name: str, k: float):
+        self.name = name
+        self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.forward_pool = DummyPool(k)
+        self.backward_pool = DummyPool(k)
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(name=self.name)

+ 21 - 9
tests/test_custom_experts.py

@@ -3,7 +3,9 @@ import os
 import pytest
 import torch
 
-from hivemind import RemoteExpert
+from hivemind.dht import DHT
+from hivemind.moe.client.expert import create_remote_experts
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import background_server
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@@ -17,11 +19,16 @@ def test_custom_expert(hid_dim=16):
         device="cpu",
         hidden_dim=hid_dim,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = create_remote_experts(
+            [
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
+            ],
+            dht=dht,
+        )
 
         for batch_size in (1, 4):
             batch = torch.randn(batch_size, hid_dim)
@@ -43,11 +50,16 @@ def test_multihead_expert(hid_dim=16):
         device="cpu",
         hidden_dim=hid_dim,
         num_handlers=2,
-        no_dht=True,
         custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert0, expert1 = create_remote_experts(
+            [
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
+            ],
+            dht=dht,
+        )
 
         for batch_size in (1, 4):
             batch = (

+ 1 - 1
tests/test_dht.py

@@ -7,9 +7,9 @@ import pytest
 from multiaddr import Multiaddr
 
 import hivemind
-from hivemind.utils.networking import get_free_port
 
 from test_utils.dht_swarms import launch_dht_instances
+from test_utils.networking import get_free_port
 
 
 @pytest.mark.asyncio

+ 29 - 29
tests/test_dht_experts.py

@@ -6,11 +6,11 @@ import numpy as np
 import pytest
 
 import hivemind
-from hivemind import LOCALHOST
-from hivemind.dht import DHTNode
+from hivemind import get_dht_time
+from hivemind.dht.node import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.server import declare_experts, get_experts
-from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
+from hivemind.moe.expert_uid import ExpertInfo, is_valid_prefix, is_valid_uid, split_uid
+from hivemind.moe.server.dht_handler import declare_experts, get_experts
 
 
 @pytest.mark.forked
@@ -25,17 +25,18 @@ def test_store_get_experts(n_peers=10):
     expert_uids = [f"my_expert.{i}" for i in range(50)]
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
-        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], "localhost:1234")
+        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], get_dht_time() + 30)
 
     found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
-    other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
-    declare_experts(other_peer, [other_expert], f"that_host:{other_port}")
+    other_expert = "my_other_expert.1337"
+    declare_experts(other_peer, [other_expert], get_dht_time() + 30)
     first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
-    assert first_found.endpoint == f"that_host:{other_port}"
+    assert first_found.peer_id == other_peer.peer_id
+    assert first_notfound is None
 
     # test graceful shutdown
     first_peer.shutdown()
@@ -43,30 +44,28 @@ def test_store_get_experts(n_peers=10):
     time.sleep(1.0)
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
-    assert all(declare_experts(remaining_peer1, ["new_expert.1"], "dummy"))
-    assert get_experts(remaining_peer2, ["new_expert.1"])[0].endpoint == "dummy"
+    assert all(declare_experts(remaining_peer1, ["new_expert.1"], expiration_time=get_dht_time() + 30))
+    assert get_experts(remaining_peer2, ["new_expert.1"])[0].peer_id == remaining_peer1.peer_id
 
 
 @pytest.mark.forked
 def test_beam_search(
     n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
 ):
-    dht = [hivemind.DHT(start=True)]
-    initial_peers = dht[0].get_visible_maddrs()
-    dht += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
+    dht_instances = [hivemind.DHT(start=True)]
+    initial_peers = dht_instances[0].get_visible_maddrs()
+    dht_instances += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
 
     real_experts = sorted(
         {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
     )
     for batch_start in range(0, len(real_experts), batch_size):
-        declare_experts(
-            random.choice(dht),
-            real_experts[batch_start : batch_start + batch_size],
-            wait=True,
-            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}",
-        )
+        dht = random.choice(dht_instances)
+        declare_experts(dht, real_experts[batch_start : batch_start + batch_size], get_dht_time() + 30)
 
-    neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
+    neighbors = sum(
+        [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], []
+    )
     you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
     beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
 
@@ -89,22 +88,23 @@ def test_dht_single_node():
     node = hivemind.DHT(start=True)
     beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
 
-    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], f"{hivemind.LOCALHOST}:1337").values())
-    assert len(declare_experts(node, ["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], f"{hivemind.LOCALHOST}:42")) == 7
+    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], get_dht_time() + 30).values())
+    assert len(declare_experts(node, ["ffn.1", "ffn.2"], get_dht_time() + 30)) == 4
+    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], get_dht_time() + 30)) == 7
 
     for expert in get_experts(node, ["expert.3", "expert.2"]):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
+        assert expert.peer_id == node.peer_id
 
-    assert all(declare_experts(node, ["expert.5", "expert.2"], f"{hivemind.LOCALHOST}:1337").values())
+    assert all(declare_experts(node, ["expert.5", "expert.2"], get_dht_time() + 30).values())
     found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
     assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
 
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     assert len(successors["e.1.2."]) == 2
-    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", f"{LOCALHOST}:42")
-    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", f"{LOCALHOST}:42")
-    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", f"{LOCALHOST}:42")
+
+    assert successors["e.1.2."][3] == ExpertInfo("e.1.2.3", node.peer_id)
+    assert successors["e.1.2."][5] == ExpertInfo("e.1.2.5", node.peer_id)
+    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == ExpertInfo("e.2.0", node.peer_id)
     assert successors["e.4.5."] == {}
 
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
@@ -194,7 +194,7 @@ async def test_negative_caching(n_peers=10):
     peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
 
     writer_peer = random.choice(peers)
-    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], "myaddr:1234").values())
+    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], get_dht_time() + 30).values())
 
     neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)

+ 10 - 8
tests/test_expert_backend.py

@@ -5,7 +5,7 @@ import pytest
 import torch
 from torch.nn import Linear
 
-from hivemind import BatchTensorDescriptor, ExpertBackend
+from hivemind import BatchTensorDescriptor, ModuleBackend
 from hivemind.moe.server.checkpoints import load_experts, store_experts
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
@@ -22,13 +22,15 @@ def example_experts():
     opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
 
     args_schema = (BatchTensorDescriptor(1),)
-    expert_backend = ExpertBackend(
+    expert_backend = ModuleBackend(
         name=EXPERT_NAME,
-        expert=expert,
+        module=expert,
         optimizer=opt,
-        scheduler=get_linear_schedule_with_warmup,
-        num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
-        num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+        scheduler=get_linear_schedule_with_warmup(
+            opt,
+            num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
+            num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+        ),
         args_schema=args_schema,
         outputs_schema=BatchTensorDescriptor(1),
         max_batch_size=1,
@@ -39,7 +41,7 @@ def example_experts():
 
 @pytest.mark.forked
 def test_save_load_checkpoints(example_experts):
-    expert = example_experts[EXPERT_NAME].expert
+    expert = example_experts[EXPERT_NAME].module
 
     with TemporaryDirectory() as tmpdir:
         tmp_path = Path(tmpdir)
@@ -79,7 +81,7 @@ def test_restore_update_count(example_experts):
             expert_backend.backward(batch, loss_grad)
 
         load_experts(example_experts, tmp_path)
-        assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE
+        assert expert_backend.scheduler._step_count == BACKWARD_PASSES_BEFORE_SAVE + 1
 
 
 @pytest.mark.forked

+ 46 - 33
tests/test_moe.py

@@ -1,14 +1,16 @@
-import grpc
 import numpy as np
 import pytest
 import torch
 
 from hivemind.dht import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
-from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
+from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
+from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
+from hivemind.moe.expert_uid import ExpertInfo
+from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
-from hivemind.utils.tensor_descr import BatchTensorDescriptor
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
+from hivemind.utils import BatchTensorDescriptor, get_dht_time
 
 
 @pytest.mark.forked
@@ -18,8 +20,8 @@ def test_moe():
     ]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
         dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
 
@@ -35,9 +37,8 @@ def test_no_experts():
     ]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
-
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
         dmoe = RemoteSwitchMixtureOfExperts(
             in_features=16,
             grid_size=(4, 4, 4),
@@ -71,12 +72,16 @@ def test_call_many(hidden_dim=16):
         num_handlers=1,
         hidden_dim=hidden_dim,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
+    ) as server_peer_info:
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
+
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        e0, e1, e2, e3, e4 = create_remote_experts(
+            [ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_info.peer_id) for i in range(5)],
+            dht,
+        )
+        e5 = RemoteExpert(ExpertInfo(f"thisshouldnotexist", server_peer_info), None)
 
         mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
@@ -129,11 +134,15 @@ def test_remote_module_call(hidden_dim=16):
         num_handlers=1,
         hidden_dim=hidden_dim,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        real_expert = RemoteExpert("expert.0", server_endpoint)
-        fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
-
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        real_expert, fake_expert = create_remote_experts(
+            [
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="oiasfjiasjf", peer_id=server_peer_info.peer_id),
+            ],
+            dht=dht,
+        )
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
         dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
@@ -144,9 +153,9 @@ def test_remote_module_call(hidden_dim=16):
         out3_again.norm().backward()
         assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
 
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             real_expert(torch.randn(3, 11))
-        with pytest.raises(grpc.RpcError):
+        with pytest.raises(P2PDaemonError):
             fake_expert(dummy_x)
 
 
@@ -154,11 +163,11 @@ def test_remote_module_call(hidden_dim=16):
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
     dht = DHT(start=True)
-    assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
+    assert all(declare_experts(dht, all_expert_uids, expiration_time=get_dht_time() + 30))
 
     dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
 
-    for i in range(25):
+    for _ in range(25):
         input = torch.randn(32)
         grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
 
@@ -173,7 +182,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[RemoteExpert(uid, "") for uid in all_expert_uids]],
+            [[RemoteExpert(ExpertInfo(uid, None), None) for uid in all_expert_uids]],
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
@@ -194,9 +203,12 @@ def test_determinism(hidden_dim=16):
         num_handlers=1,
         hidden_dim=hidden_dim,
         optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert = create_remote_experts(
+            [ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id)],
+            dht=dht,
+        )[0]
 
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -220,7 +232,7 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
             [
-                RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
+                RemoteExpert(ExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
                 for expert_i in range(len(ii[batch_i]))
             ]
             for batch_i in range(len(ii))
@@ -245,25 +257,26 @@ def test_client_anomaly_detection():
     experts = {}
     for i in range(4):
         expert = name_to_block["ffn"](HID_DIM)
-        experts[f"expert.{i}"] = ExpertBackend(
+        experts[f"expert.{i}"] = ModuleBackend(
             name=f"expert.{i}",
-            expert=expert,
+            module=expert,
             optimizer=torch.optim.Adam(expert.parameters()),
             args_schema=(BatchTensorDescriptor(HID_DIM),),
             outputs_schema=BatchTensorDescriptor(HID_DIM),
             max_batch_size=16,
         )
 
-    experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
+    experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan")
 
     dht = DHT(start=True)
     server = Server(dht, experts, num_connection_handlers=1)
     server.start()
     try:
         server.ready.wait()
+        client_side_dht = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
 
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(3,), dht=client_side_dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
 
         input = torch.randn(1, 16)
@@ -280,7 +293,7 @@ def test_client_anomaly_detection():
             inf_loss.backward()
 
         dmoe = RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
+            in_features=16, grid_size=(4,), dht=client_side_dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         output = dmoe(input)
         assert output.isfinite().all()

+ 84 - 12
tests/test_optimizer.py

@@ -11,24 +11,31 @@ import torch.nn.functional as F
 
 import hivemind
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.optimizer import Optimizer
+from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
 
 
 @pytest.mark.forked
-def test_grad_averager():
+@pytest.mark.parametrize(
+    "grad_averager_factory",
+    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
+)
+def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
+    parameter_shape = (5, 5)
+
     dht1 = hivemind.DHT(start=True)
-    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager1 = GradientAverager(
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager1 = grad_averager_factory(
         model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
     )
 
     dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
-    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager2 = GradientAverager(
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager2 = grad_averager_factory(
         model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
     )
 
@@ -38,12 +45,12 @@ def test_grad_averager():
     for i in range(10):
         time.sleep(0.1)
         if i % 3 == 0:
-            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1 = F.mse_loss(model1.w, torch.ones(parameter_shape))
             loss1.backward()
             averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
             model1.zero_grad()
         else:
-            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2 = F.mse_loss(model2.w, -torch.ones(parameter_shape))
             loss2.backward()
             averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
             # note: we do not call zero grad here because reuse_grad_buffers=True
@@ -51,11 +58,11 @@ def test_grad_averager():
     assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
     peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
     assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
-    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    ref_grads1 = torch.full(parameter_shape, -2 / np.prod(parameter_shape) * averager1.local_times_accumulated)
     assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
 
     assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
-    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    ref_grads2 = torch.full(parameter_shape, 2 / np.prod(parameter_shape) * averager2.local_times_accumulated)
     assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
 
     averager1.step(control=control1, wait=False)
@@ -76,6 +83,28 @@ def test_grad_averager():
     assert not torch.allclose(model2.w.grad, ref_average)
 
 
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "grad_averager_factory",
+    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
+)
+def test_grad_averager_wrong_shape(grad_averager_factory: GradientAveragerFactory):
+    parameter_shape = (5, 5)
+    model = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    dht = hivemind.DHT(start=True)
+
+    with pytest.raises(ValueError):
+        grad_averager_factory(
+            model.parameters(),
+            dht=dht,
+            prefix="test_fail",
+            target_group_size=2,
+            reuse_grad_buffers=False,
+            start=True,
+            averaged_grads=[torch.zeros(parameter_shape + (1,))],
+        )
+
+
 @pytest.mark.forked
 @pytest.mark.parametrize(
     "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
@@ -162,7 +191,11 @@ def test_load_state_from_peers():
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1,
+        params=model1.parameters(),
+        allow_state_sharing=False,
+        start=True,
+        **common_kwargs,
     )
 
     avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
@@ -286,12 +319,45 @@ def test_progress_tracker():
 
 
 @pytest.mark.forked
+@pytest.mark.parametrize(
+    "use_local_updates, delay_state_averaging, delay_optimizer_step, delay_grad_averaging, reuse_grad_buffers",
+    # fmt: off
+    [
+        (False, False, False, False, False),
+        (False, True, False, False, False),
+        (False, True, True, True, False),
+        (False, False, False, False, True),
+        (False, True, True, True, True),
+        (False, True, True, False, True),
+        (True, False, False, False, False),
+        (True, True, False, False, False,),
+    ],
+    # fmt: on
+)
 def test_optimizer(
+    use_local_updates: bool,
+    delay_state_averaging: bool,
+    delay_optimizer_step: bool,
+    delay_grad_averaging: bool,
+    reuse_grad_buffers: bool,
+):
+    _test_optimizer(
+        use_local_updates=use_local_updates,
+        delay_state_averaging=delay_state_averaging,
+        delay_grad_averaging=delay_grad_averaging,
+        delay_optimizer_step=delay_optimizer_step,
+        reuse_grad_buffers=reuse_grad_buffers,
+    )
+
+
+def _test_optimizer(
     num_peers: int = 1,
     num_clients: int = 0,
     target_batch_size: int = 32,
     total_epochs: int = 3,
+    use_local_updates: bool = False,
     reuse_grad_buffers: bool = True,
+    delay_state_averaging: bool = True,
     delay_grad_averaging: bool = True,
     delay_optimizer_step: bool = True,
     average_state_every: int = 1,
@@ -319,9 +385,11 @@ def test_optimizer(
             dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
             tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
             averager_opts=dict(request_timeout=0.5),
+            use_local_updates=use_local_updates,
             matchmaking_time=1.0,
             averaging_timeout=5.0,
             reuse_grad_buffers=reuse_grad_buffers,
+            delay_state_averaging=delay_state_averaging,
             delay_grad_averaging=delay_grad_averaging,
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
@@ -380,6 +448,10 @@ def test_optimizer(
     assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
 
     assert not optimizer.state_averager.is_alive()
-    assert not optimizer.grad_averager.is_alive()
     assert not optimizer.tracker.is_alive()
+    if not use_local_updates:
+        assert not optimizer.grad_averager.is_alive()
+    else:
+        assert optimizer.grad_averager is None
+
     assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()

+ 31 - 3
tests/test_p2p_daemon.py

@@ -1,6 +1,8 @@
 import asyncio
 import multiprocessing as mp
+import os
 import subprocess
+import tempfile
 from contextlib import closing
 from functools import partial
 from typing import List
@@ -11,9 +13,10 @@ from multiaddr import Multiaddr
 
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.proto import dht_pb2, test_pb2
-from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 
+from test_utils.networking import get_free_port
+
 
 def is_process_running(pid: int) -> bool:
     return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
@@ -45,6 +48,31 @@ async def test_startup_error_message():
         await P2P.create(startup_timeout=0.01)  # Test that startup_timeout works
 
 
+@pytest.mark.asyncio
+async def test_identity():
+    with tempfile.TemporaryDirectory() as tempdir:
+        id1_path = os.path.join(tempdir, "id1")
+        id2_path = os.path.join(tempdir, "id2")
+        p2ps = await asyncio.gather(*[P2P.create(identity_path=path) for path in [None, None, id1_path, id2_path]])
+
+        # We create the second daemon with id2 separately
+        # to avoid a race condition while saving a newly generated identity
+        p2ps.append(await P2P.create(identity_path=id2_path))
+
+        # Using the same identity (if any) should lead to the same peer ID
+        assert p2ps[-2].peer_id == p2ps[-1].peer_id
+
+        # The rest of peer IDs should be different
+        peer_ids = {instance.peer_id for instance in p2ps}
+        assert len(peer_ids) == 4
+
+        for instance in p2ps:
+            await instance.shutdown()
+
+    with pytest.raises(FileNotFoundError, match=r"The directory.+does not exist"):
+        P2P.generate_identity(id1_path)
+
+
 @pytest.mark.parametrize(
     "host_maddrs",
     [
@@ -55,11 +83,11 @@ async def test_startup_error_message():
 )
 @pytest.mark.asyncio
 async def test_transports(host_maddrs: List[Multiaddr]):
-    server = await P2P.create(quic=True, host_maddrs=host_maddrs)
+    server = await P2P.create(host_maddrs=host_maddrs)
     peers = await server.list_peers()
     assert len(peers) == 0
 
-    client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
+    client = await P2P.create(host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
     await client.wait_for_at_least_n_peers(1)
 
     peers = await client.list_peers()

+ 8 - 2
tests/test_p2p_daemon_bindings.py

@@ -560,13 +560,19 @@ async def test_client_stream_handler_success(p2pcs):
 
     writer.close()
 
-    # test case: registering twice can override the previous registration
+    # test case: registering twice can't override the previous registration without balanced flag
     event_third = asyncio.Event()
 
     async def handler_third(stream_info, reader, writer):
         event_third.set()
 
-    await p2pcs[1].stream_handler(another_proto, handler_third)
+    # p2p raises now for doubled stream handlers
+    with pytest.raises(ControlFailure):
+        await p2pcs[1].stream_handler(another_proto, handler_third)
+
+    # add in balanced mode: handler should be placed in round robin queue
+    # and become the next to be called
+    await p2pcs[1].stream_handler(another_proto, handler_third, balanced=True)
     assert another_proto in p2pcs[1].control.handlers
     # ensure the handler is override
     assert handler_third == p2pcs[1].control.handlers[another_proto]

+ 1 - 1
tests/test_routing.py

@@ -3,8 +3,8 @@ import operator
 import random
 from itertools import chain, zip_longest
 
-from hivemind import LOCALHOST
 from hivemind.dht.routing import DHTID, RoutingTable
+from hivemind.utils.networking import LOCALHOST
 
 
 def test_ids_basic():

+ 83 - 0
tests/test_start_server.py

@@ -0,0 +1,83 @@
+import os
+import re
+from subprocess import PIPE, Popen
+from tempfile import TemporaryDirectory
+
+from hivemind.moe.server import background_server
+
+
+def test_background_server_identity_path():
+    with TemporaryDirectory() as tempdir:
+        id_path = os.path.join(tempdir, "id")
+
+        with background_server(num_experts=1, identity_path=id_path) as server_info_1, background_server(
+            num_experts=1, identity_path=id_path
+        ) as server_info_2, background_server(num_experts=1, identity_path=None) as server_info_3:
+
+            assert server_info_1.peer_id == server_info_2.peer_id
+            assert server_info_1.peer_id != server_info_3.peer_id
+            assert server_info_3.peer_id == server_info_3.peer_id
+
+
+def test_cli_run_server_identity_path():
+    pattern = r"Running DHT node on \[(.+)\],"
+
+    with TemporaryDirectory() as tempdir:
+        id_path = os.path.join(tempdir, "id")
+
+        server_1_proc = Popen(
+            ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
+            stderr=PIPE,
+            text=True,
+            encoding="utf-8",
+        )
+
+        # Skip line "Generating new identity (libp2p private key) in {path to file}"
+        line = server_1_proc.stderr.readline()
+        line = server_1_proc.stderr.readline()
+        addrs_1 = set(re.search(pattern, line).group(1).split(", "))
+        ids_1 = set(a.split("/")[-1] for a in addrs_1)
+
+        assert len(ids_1) == 1
+
+        server_2_proc = Popen(
+            ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
+            stderr=PIPE,
+            text=True,
+            encoding="utf-8",
+        )
+
+        line = server_2_proc.stderr.readline()
+        addrs_2 = set(re.search(pattern, line).group(1).split(", "))
+        ids_2 = set(a.split("/")[-1] for a in addrs_2)
+
+        assert len(ids_2) == 1
+
+        server_3_proc = Popen(
+            ["hivemind-server", "--num_experts", "1"],
+            stderr=PIPE,
+            text=True,
+            encoding="utf-8",
+        )
+
+        line = server_3_proc.stderr.readline()
+        addrs_3 = set(re.search(pattern, line).group(1).split(", "))
+        ids_3 = set(a.split("/")[-1] for a in addrs_3)
+
+        assert len(ids_3) == 1
+
+        assert ids_1 == ids_2
+        assert ids_1 != ids_3
+        assert ids_2 != ids_3
+
+        assert addrs_1 != addrs_2
+        assert addrs_1 != addrs_3
+        assert addrs_2 != addrs_3
+
+        server_1_proc.terminate()
+        server_2_proc.terminate()
+        server_3_proc.terminate()
+
+        server_1_proc.wait()
+        server_2_proc.wait()
+        server_3_proc.wait()

+ 18 - 96
tests/test_training.py

@@ -1,4 +1,3 @@
-import time
 from functools import partial
 
 import pytest
@@ -8,9 +7,10 @@ import torch.nn.functional as F
 from sklearn.datasets import load_digits
 
 from hivemind import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.expert import create_remote_experts
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import background_server
-from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
 @pytest.mark.forked
@@ -19,12 +19,17 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     SGD = partial(torch.optim.SGD, lr=0.05)
 
-    with background_server(num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1, no_dht=True) as (
-        server_endpoint,
-        _,
-    ):
-        expert1 = RemoteExpert("expert.0", server_endpoint)
-        expert2 = RemoteExpert("expert.1", server_endpoint)
+    with background_server(
+        num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
+    ) as server_peer_info:
+        dht = DHT(initial_peers=server_peer_info.addrs, start=True)
+        expert1, expert2 = create_remote_experts(
+            [
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
+            ],
+            dht=dht,
+        )
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
 
         opt = SGD(model.parameters(), lr=0.05)
@@ -54,8 +59,8 @@ def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
         model = nn.Sequential(moe, nn.Linear(64, 2))
@@ -107,8 +112,8 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     with background_server(
         expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
+    ) as server_peer_info:
+        dht = DHT(start=True, initial_peers=server_peer_info.addrs)
 
         model = SwitchNetwork(dht, 64, 2, num_experts)
         opt = SGD(model.parameters(), lr=0.05)
@@ -126,86 +131,3 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
 
         assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
         assert accuracy >= threshold, f"too small accuracy: {accuracy}"
-
-
-@pytest.mark.forked
-def test_decentralized_optimizer_step():
-    dht_root = DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-
-    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedSGD(
-        [param1],
-        lr=0.1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedSGD(
-        [param2],
-        lr=0.05,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    assert not torch.allclose(param1, param2)
-
-    (param1.sum() + 300 * param2.sum()).backward()
-
-    for i in range(5):
-        time.sleep(0.1)
-        opt1.step()
-        opt2.step()
-        opt1.zero_grad()
-        opt2.zero_grad()
-
-    assert torch.allclose(param1, param2)
-    reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
-    assert torch.allclose(param1, torch.full_like(param1, reference))
-
-
-@pytest.mark.skip(reason="Skipped until a more stable averager implementation is ready (TODO @justheuristic)")
-@pytest.mark.forked
-def test_decentralized_optimizer_averaging():
-    dht_root = DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-
-    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedAdam(
-        [param1],
-        lr=0.1,
-        averaging_steps_period=1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedAdam(
-        [param2],
-        lr=0.05,
-        averaging_steps_period=1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    assert not torch.allclose(param1, param2, atol=1e-3, rtol=0)
-    (param1.sum() + param2.sum()).backward()
-
-    for _ in range(100):
-        time.sleep(0.1)
-        opt1.step()
-        opt2.step()
-        opt1.zero_grad()
-        opt2.zero_grad()
-
-    assert torch.allclose(param1, param2, atol=1e-3, rtol=0)
-    assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"], atol=1e-3, rtol=0)

+ 1 - 53
tests/test_util_modules.py

@@ -11,14 +11,11 @@ import torch
 
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils.asyncio import (
     achain,
     aenumerate,
-    afirst,
     aiter_with_timeout,
     amap_in_executor,
     anext,
@@ -330,50 +327,6 @@ def test_many_futures():
     p.join()
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_channel_cache():
-    hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
-    hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
-
-    c1 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-    c2 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
-    c3 = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
-    c3_again = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
-    c1_again = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-    c4 = hivemind.ChannelCache.get_stub("localhost:1339", DHTStub, aio=True)
-    c2_anew = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
-    c1_yetagain = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-
-    await asyncio.sleep(0.2)
-    c1_anew = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
-    c1_anew_again = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
-    c1_otherstub = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub)
-    await asyncio.sleep(0.05)
-    c1_otherstub_again = hivemind.ChannelCache.get_stub(
-        target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub
-    )
-    all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
-
-    assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
-    assert isinstance(all_channels[-1], ConnectionHandlerStub)
-    assert "aio" in repr(c2.rpc_find)
-    assert "aio" not in repr(c1.rpc_find)
-
-    duplicates = {
-        (c1, c1_again),
-        (c1, c1_yetagain),
-        (c1_again, c1_yetagain),
-        (c3, c3_again),
-        (c1_anew, c1_anew_again),
-        (c1_otherstub, c1_otherstub_again),
-    }
-    for i in range(len(all_channels)):
-        for j in range(i + 1, len(all_channels)):
-            ci, cj = all_channels[i], all_channels[j]
-            assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
-
-
 def test_serialize_tuple():
     test_pairs = (
         ((1, 2, 3), [1, 2, 3]),
@@ -419,7 +372,7 @@ def test_split_parts():
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
         with pytest.raises(RuntimeError):
             deserialize_torch_tensor(combined)
-            # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllreduceRunner
+            # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllReduceRunner
 
 
 def test_generic_data_classes():
@@ -476,11 +429,6 @@ async def test_asyncio_utils():
     with pytest.raises(ValueError):
         await asingle(as_aiter(1, 2, 3))
 
-    assert await afirst(as_aiter(1)) == 1
-    assert await afirst(as_aiter()) is None
-    assert await afirst(as_aiter(), -1) == -1
-    assert await afirst(as_aiter(1, 2, 3)) == 1
-
     async def iterate_with_delays(delays):
         for i, delay in enumerate(delays):
             await asyncio.sleep(delay)

+ 0 - 0
tests/test_utils/__init__.py


+ 18 - 0
tests/test_utils/networking.py

@@ -0,0 +1,18 @@
+import socket
+from contextlib import closing
+
+
+def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+    """
+    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
+
+    :note: Using this function is discouraged since it often leads to a race condition
+           with the "Address is already in use" error if the code is run in parallel.
+    """
+    try:
+        with closing(socket.socket(*params)) as sock:
+            sock.bind(("", 0))
+            sock.setsockopt(*opt)
+            return sock.getsockname()[1]
+    except Exception as e:
+        raise e

+ 2 - 1
tests/test_utils/p2p_daemon.py

@@ -10,9 +10,10 @@ from typing import NamedTuple
 from multiaddr import Multiaddr, protocols
 from pkg_resources import resource_filename
 
-from hivemind import get_free_port
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 
+from test_utils.networking import get_free_port
+
 TIMEOUT_DURATION = 30  # seconds
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")