Переглянути джерело

Merge branch 'master' into typing_fixes

Michael Diskin 3 роки тому
батько
коміт
21119a3551
100 змінених файлів з 7516 додано та 2483 видалено
  1. 4 1
      .github/workflows/check-style.yml
  2. 0 1
      .github/workflows/push-docker-image.yml
  3. 37 0
      .github/workflows/run-benchmarks.yml
  4. 5 5
      .github/workflows/run-tests.yml
  5. 1 0
      .gitignore
  6. 46 42
      README.md
  7. 5 2
      benchmarks/benchmark_averaging.py
  8. 181 78
      benchmarks/benchmark_dht.py
  9. 162 0
      benchmarks/benchmark_optimizer.py
  10. 3 2
      benchmarks/benchmark_tensor_compression.py
  11. 16 15
      benchmarks/benchmark_throughput.py
  12. BIN
      docs/_static/dht.odp
  13. BIN
      docs/_static/dht.png
  14. 1 1
      docs/conf.py
  15. 3 3
      docs/index.rst
  16. 26 4
      docs/modules/optim.rst
  17. 8 6
      docs/modules/server.rst
  18. 1 1
      docs/user/dht.md
  19. 35 32
      docs/user/quickstart.md
  20. 32 32
      examples/albert/README.md
  21. 23 19
      examples/albert/arguments.py
  22. 5 5
      examples/albert/requirements.txt
  23. 108 96
      examples/albert/run_trainer.py
  24. 23 25
      examples/albert/run_training_monitor.py
  25. 1 1
      examples/albert/tokenize_wikitext103.py
  26. 1 7
      examples/albert/utils.py
  27. 6 2
      hivemind/__init__.py
  28. 0 1
      hivemind/averaging/__init__.py
  29. 196 84
      hivemind/averaging/allreduce.py
  30. 277 139
      hivemind/averaging/averager.py
  31. 165 0
      hivemind/averaging/control.py
  32. 33 97
      hivemind/averaging/key_manager.py
  33. 1 1
      hivemind/averaging/load_balancing.py
  34. 103 95
      hivemind/averaging/matchmaking.py
  35. 92 43
      hivemind/averaging/partition.py
  36. 52 0
      hivemind/compression/__init__.py
  37. 67 0
      hivemind/compression/adaptive.py
  38. 92 0
      hivemind/compression/base.py
  39. 92 0
      hivemind/compression/floating.py
  40. 114 0
      hivemind/compression/quantization.py
  41. 4 302
      hivemind/dht/__init__.py
  42. 324 0
      hivemind/dht/dht.py
  43. 18 7
      hivemind/dht/node.py
  44. 1 1
      hivemind/dht/protocol.py
  45. 2 2
      hivemind/dht/routing.py
  46. 1 1
      hivemind/dht/schema.py
  47. 9 6
      hivemind/hivemind_cli/run_server.py
  48. 8 1
      hivemind/moe/__init__.py
  49. 3 3
      hivemind/moe/client/beam_search.py
  50. 3 4
      hivemind/moe/client/expert.py
  51. 7 7
      hivemind/moe/client/moe.py
  52. 3 355
      hivemind/moe/server/__init__.py
  53. 3 4
      hivemind/moe/server/connection_handler.py
  54. 2 2
      hivemind/moe/server/dht_handler.py
  55. 4 3
      hivemind/moe/server/expert_backend.py
  56. 2 73
      hivemind/moe/server/expert_uid.py
  57. 419 0
      hivemind/moe/server/server.py
  58. 3 0
      hivemind/optim/__init__.py
  59. 1 1
      hivemind/optim/adaptive.py
  60. 8 0
      hivemind/optim/base.py
  61. 135 61
      hivemind/optim/collaborative.py
  62. 226 0
      hivemind/optim/grad_averager.py
  63. 125 0
      hivemind/optim/grad_scaler.py
  64. 779 0
      hivemind/optim/optimizer.py
  65. 0 41
      hivemind/optim/performance_ema.py
  66. 363 0
      hivemind/optim/progress_tracker.py
  67. 9 5
      hivemind/optim/simple.py
  68. 723 0
      hivemind/optim/state_averager.py
  69. 48 13
      hivemind/optim/training_averager.py
  70. 1 1
      hivemind/p2p/__init__.py
  71. 227 112
      hivemind/p2p/p2p_daemon.py
  72. 209 3
      hivemind/p2p/p2p_daemon_bindings/control.py
  73. 9 0
      hivemind/p2p/p2p_daemon_bindings/datastructures.py
  74. 40 3
      hivemind/p2p/p2p_daemon_bindings/p2pclient.py
  75. 22 33
      hivemind/p2p/servicer.py
  76. 1 1
      hivemind/proto/averaging.proto
  77. 59 10
      hivemind/proto/p2pd.proto
  78. 2 2
      hivemind/utils/__init__.py
  79. 90 11
      hivemind/utils/asyncio.py
  80. 0 209
      hivemind/utils/compression.py
  81. 185 17
      hivemind/utils/logging.py
  82. 12 6
      hivemind/utils/mpfuture.py
  83. 7 2
      hivemind/utils/networking.py
  84. 70 0
      hivemind/utils/performance_ema.py
  85. 2 2
      hivemind/utils/serializer.py
  86. 64 8
      hivemind/utils/tensor_descr.py
  87. 2 1
      requirements-dev.txt
  88. 4 2
      requirements-docs.txt
  89. 5 5
      setup.py
  90. 24 3
      tests/conftest.py
  91. 5 5
      tests/test_allreduce.py
  92. 213 0
      tests/test_allreduce_fault_tolerance.py
  93. 129 61
      tests/test_averaging.py
  94. 213 0
      tests/test_compression.py
  95. 17 1
      tests/test_dht.py
  96. 34 215
      tests/test_dht_node.py
  97. 163 0
      tests/test_dht_protocol.py
  98. 31 34
      tests/test_moe.py
  99. 385 0
      tests/test_optimizer.py
  100. 41 4
      tests/test_p2p_daemon.py

+ 4 - 1
.github/workflows/check-style.yml

@@ -1,6 +1,9 @@
 name: Check style
 
-on: [ push ]
+on:
+  push:
+    branches: [ master ]
+  pull_request:
 
 jobs:
   black:

+ 0 - 1
.github/workflows/push-docker-image.yml

@@ -8,7 +8,6 @@ on:
   pull_request:
     branches: [ master ]
 
-
 jobs:
   build:
     runs-on: ubuntu-latest

+ 37 - 0
.github/workflows/run-benchmarks.yml

@@ -0,0 +1,37 @@
+name: Benchmarks
+
+on:
+  push:
+    branches: [ master ]
+  pull_request:
+
+jobs:
+  run_benchmarks:
+
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: 3.9
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-3.9-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install .
+      - name: Benchmark
+        run: |
+          cd benchmarks
+          python benchmark_throughput.py --preset minimalistic
+          python benchmark_tensor_compression.py
+          python benchmark_dht.py

+ 5 - 5
.github/workflows/run-tests.yml

@@ -1,7 +1,9 @@
 name: Tests
 
-on: [ push ]
-
+on:
+  push:
+    branches: [ master ]
+  pull_request:
 
 jobs:
   run_tests:
@@ -34,7 +36,6 @@ jobs:
         run: |
           cd tests
           pytest --durations=0 --durations-min=1.0 -v
-
   build_and_test_p2pd:
     runs-on: ubuntu-latest
     timeout-minutes: 10
@@ -61,7 +62,6 @@ jobs:
         run: |
           cd tests
           pytest -k "p2p" -v
-
   codecov_in_develop_mode:
 
     runs-on: ubuntu-latest
@@ -84,7 +84,7 @@ jobs:
           pip install -r requirements-dev.txt
       - name: Build hivemind
         run: |
-          pip install -e .
+          pip install -e . --no-use-pep517
       - name: Test
         run: |
           pytest --cov=hivemind -v tests

+ 1 - 0
.gitignore

@@ -54,6 +54,7 @@ coverage.xml
 .project
 .pydevproject
 .idea
+.vscode
 .ipynb_checkpoints
 
 # Rope

+ 46 - 42
README.md

@@ -1,7 +1,7 @@
 ## Hivemind: decentralized deep learning in PyTorch
 
 [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest)
-[![PyPI version](https://img.shields.io/pypi/v/hivemind.svg)](https://pypi.org/project/hivemind/)
+[![PyPI version](https://img.shields.io/pypi/v/hivemind.svg?color=blue)](https://pypi.org/project/hivemind/)
 [![Discord](https://img.shields.io/static/v1?style=default&label=Discord&logo=discord&message=join)](https://discord.gg/uGugx9zYvN)
 [![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions)
 ![Codecov](https://img.shields.io/codecov/c/github/learning-at-home/hivemind)
@@ -12,6 +12,10 @@ large model on hundreds of computers from different universities, companies, and
 
 ![img](https://i.imgur.com/GPxolxb.gif)
 
+## Live Demo
+
+Check out our NeurIPS 2021 demonstration ["Training Transformers Together"](https://training-transformers-together.github.io/) to see hivemind in action, join an ongoing collaborative experiment, and learn more about the technologies behind it!
+
 ## Key Features
 
 * Distributed training without a master node: Distributed Hash Table allows connecting computers in a decentralized
@@ -23,8 +27,8 @@ large model on hundreds of computers from different universities, companies, and
 * Train neural networks of arbitrary size: parts of their layers are distributed across the participants with the
   Decentralized Mixture-of-Experts ([paper](https://arxiv.org/abs/2002.04013)).
 
-To learn more about the ideas behind this library, see https://learning-at-home.github.io or read
-the [NeurIPS 2020 paper](https://arxiv.org/abs/2002.04013).
+To learn more about the ideas behind this library,
+see the [full list](https://github.com/learning-at-home/hivemind/tree/refer-to-discord-in-docs#citation) of our papers below.
 
 ## Installation
 
@@ -52,7 +56,7 @@ cd hivemind
 pip install .
 ```
 
-If you would like to verify that your installation is working properly, you can install with `pip install -e .[dev]`
+If you would like to verify that your installation is working properly, you can install with `pip install .[dev]`
 instead. Then, you can run the tests with `pytest tests/`.
 
 By default, hivemind uses the precompiled binary of
@@ -65,8 +69,8 @@ of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
 
 - __Linux__ is the default OS for which hivemind is developed and tested. We recommend Ubuntu 18.04+ (64-bit), but
   other 64-bit distros should work as well. Legacy 32-bit is not recommended.
-- __macOS 10.x__ mostly works but requires building hivemind from source, and some edge cases may fail. To ensure full
-  compatibility, we recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
+- __macOS 10.x__ can run hivemind using [Docker](https://docs.docker.com/desktop/mac/install/).
+  We recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
 - __Windows 10+ (experimental)__ can run hivemind
   using [WSL](https://docs.microsoft.com/ru-ru/windows/wsl/install-win10). You can configure WSL to use GPU by
   following sections 1–3 of [this guide](https://docs.nvidia.com/cuda/wsl-user-guide/index.html) by NVIDIA. After
@@ -83,17 +87,17 @@ of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
 * API reference and additional tutorials are available
   at [learning-at-home.readthedocs.io](https://learning-at-home.readthedocs.io)
 
-If you have any questions about installing and using hivemind, you can ask them in
+If you have any questions about installing and using hivemind, feel free to ask them in
 [our Discord chat](https://discord.gg/uGugx9zYvN) or file an [issue](https://github.com/learning-at-home/hivemind/issues).
 
 ## Contributing
 
 Hivemind is currently at the active development stage, and we welcome all contributions. Everything, from bug fixes and
-documentation improvements to entirely new features, is equally appreciated.
+documentation improvements to entirely new features, is appreciated.
 
 If you want to contribute to hivemind but don't know where to start, take a look at the
 unresolved [issues](https://github.com/learning-at-home/hivemind/issues). Open a new issue or
-join [our chat room](https://discord.gg/xC7ucM8j) in case you want to discuss new functionality or report a possible
+join [our chat room](https://discord.gg/uGugx9zYvN) in case you want to discuss new functionality or report a possible
 bug. Bug fixes are always welcome, but new features should be preferably discussed with maintainers beforehand.
 
 If you want to start contributing to the source code of hivemind, please see
@@ -105,9 +109,9 @@ our [guide](https://learning-at-home.readthedocs.io/en/latest/user/contributing.
 
 If you found hivemind or its underlying algorithms useful for your research, please cite the following source:
 
-```
+```bibtex
 @misc{hivemind,
-  author = {Learning@home team},
+  author = {Learning{@}home team},
   title = {{H}ivemind: a {L}ibrary for {D}ecentralized {D}eep {L}earning},
   year = 2020,
   howpublished = {\url{https://github.com/learning-at-home/hivemind}},
@@ -118,17 +122,17 @@ Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired t
 (prototype implementation of hivemind available
 at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
 
-```
+```bibtex
 @inproceedings{ryabinin2020crowdsourced,
- author = {Ryabinin, Max and Gusev, Anton},
- booktitle = {Advances in Neural Information Processing Systems},
- editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
- pages = {3659--3672},
- publisher = {Curran Associates, Inc.},
- title = {Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts},
- url = {https://proceedings.neurips.cc/paper/2020/file/25ddc0f8c9d3e22e03d3076f98d83cb2-Paper.pdf},
- volume = {33},
- year = {2020}
+  author = {Ryabinin, Max and Gusev, Anton},
+  booktitle = {Advances in Neural Information Processing Systems},
+  editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
+  pages = {3659--3672},
+  publisher = {Curran Associates, Inc.},
+  title = {Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts},
+  url = {https://proceedings.neurips.cc/paper/2020/file/25ddc0f8c9d3e22e03d3076f98d83cb2-Paper.pdf},
+  volume = {33},
+  year = {2020}
 }
 ```
 
@@ -137,40 +141,40 @@ at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
 
 ["Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices"](https://arxiv.org/abs/2103.03239)
 
-```
+```bibtex
 @misc{ryabinin2021moshpit,
-      title={Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices}, 
-      author={Max Ryabinin and Eduard Gorbunov and Vsevolod Plokhotnyuk and Gennady Pekhimenko},
-      year={2021},
-      eprint={2103.03239},
-      archivePrefix={arXiv},
-      primaryClass={cs.LG}
+  title = {Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices},
+  author = {Max Ryabinin and Eduard Gorbunov and Vsevolod Plokhotnyuk and Gennady Pekhimenko},
+  year = {2021},
+  eprint = {2103.03239},
+  archivePrefix = {arXiv},
+  primaryClass = {cs.LG}
 }
 ```
 
 ["Distributed Deep Learning in Open Collaborations"](https://arxiv.org/abs/2106.10207)
 
-```
+```bibtex
 @misc{diskin2021distributed,
-      title={Distributed Deep Learning in Open Collaborations}, 
-      author={Michael Diskin and Alexey Bukhtiyarov and Max Ryabinin and Lucile Saulnier and Quentin Lhoest and Anton Sinitsin and Dmitry Popov and Dmitry Pyrkin and Maxim Kashirin and Alexander Borzunov and Albert Villanova del Moral and Denis Mazur and Ilia Kobelev and Yacine Jernite and Thomas Wolf and Gennady Pekhimenko},
-      year={2021},
-      eprint={2106.10207},
-      archivePrefix={arXiv},
-      primaryClass={cs.LG}
+  title = {Distributed Deep Learning in Open Collaborations},
+  author = {Michael Diskin and Alexey Bukhtiyarov and Max Ryabinin and Lucile Saulnier and Quentin Lhoest and Anton Sinitsin and Dmitry Popov and Dmitry Pyrkin and Maxim Kashirin and Alexander Borzunov and Albert Villanova del Moral and Denis Mazur and Ilia Kobelev and Yacine Jernite and Thomas Wolf and Gennady Pekhimenko},
+  year = {2021},
+  eprint = {2106.10207},
+  archivePrefix = {arXiv},
+  primaryClass = {cs.LG}
 }
 ```
 
 ["Secure Distributed Training at Scale"](https://arxiv.org/abs/2106.11257)
 
-```
+```bibtex
 @misc{gorbunov2021secure,
-      title={Secure Distributed Training at Scale}, 
-      author={Eduard Gorbunov and Alexander Borzunov and Michael Diskin and Max Ryabinin},
-      year={2021},
-      eprint={2106.11257},
-      archivePrefix={arXiv},
-      primaryClass={cs.LG}
+  title = {Secure Distributed Training at Scale},
+  author = {Eduard Gorbunov and Alexander Borzunov and Michael Diskin and Max Ryabinin},
+  year = {2021},
+  eprint = {2106.11257},
+  archivePrefix = {arXiv},
+  primaryClass = {cs.LG}
 }
 ```
 

+ 5 - 2
benchmarks/benchmark_averaging.py

@@ -7,8 +7,11 @@ import torch
 
 import hivemind
 from hivemind.proto import runtime_pb2
-from hivemind.utils import LOCALHOST, get_logger, increase_file_limit
+from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import LOCALHOST
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
@@ -80,7 +83,7 @@ def benchmark_averaging(
             with lock_stats:
                 successful_steps += int(success)
                 total_steps += 1
-            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
+            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}")
         logger.info(f"Averager {index}: done.")
 
     threads = []

+ 181 - 78
benchmarks/benchmark_dht.py

@@ -1,33 +1,132 @@
 import argparse
+import asyncio
 import random
 import time
+import uuid
+from logging import shutdown
+from typing import Tuple
 
+import numpy as np
 from tqdm import trange
 
 import hivemind
-from hivemind.moe.server import declare_experts, get_experts
 from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-logger = hivemind.get_logger(__name__)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
 
 
-def random_endpoint() -> hivemind.Endpoint:
-    return (
-        f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}."
-        f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
-    )
+class NodeKiller:
+    """Auxiliary class that kills dht nodes over a pre-defined schedule"""
+
+    def __init__(self, shutdown_peers: list, shutdown_timestamps: list):
+        self.shutdown_peers = set(shutdown_peers)
+        self.shutdown_timestamps = shutdown_timestamps
+        self.current_iter = 0
+        self.timestamp_iter = 0
+        self.lock = asyncio.Lock()
+
+    async def check_and_kill(self):
+        async with self.lock:
+            if (
+                self.shutdown_timestamps != None
+                and self.timestamp_iter < len(self.shutdown_timestamps)
+                and self.current_iter == self.shutdown_timestamps[self.timestamp_iter]
+            ):
+                shutdown_peer = random.sample(self.shutdown_peers, 1)[0]
+                shutdown_peer.shutdown()
+                self.shutdown_peers.remove(shutdown_peer)
+                self.timestamp_iter += 1
+            self.current_iter += 1
+
+
+async def store_and_get_task(
+    peers: list,
+    total_num_rounds: int,
+    num_store_peers: int,
+    num_get_peers: int,
+    wait_after_iteration: float,
+    delay: float,
+    expiration: float,
+    latest: bool,
+    node_killer: NodeKiller,
+) -> Tuple[list, list, list, list, int, int]:
+    """Iteratively choose random peers to store data onto the dht, then retreive with another random subset of peers"""
+
+    total_stores = total_gets = 0
+    successful_stores = []
+    successful_gets = []
+    store_times = []
+    get_times = []
+
+    for _ in range(total_num_rounds):
+        key = uuid.uuid4().hex
+
+        store_start = time.perf_counter()
+        store_peers = random.sample(peers, min(num_store_peers, len(peers)))
+        store_subkeys = [uuid.uuid4().hex for _ in store_peers]
+        store_values = {subkey: uuid.uuid4().hex for subkey in store_subkeys}
+        store_tasks = [
+            peer.store(
+                key,
+                subkey=subkey,
+                value=store_values[subkey],
+                expiration_time=hivemind.get_dht_time() + expiration,
+                return_future=True,
+            )
+            for peer, subkey in zip(store_peers, store_subkeys)
+        ]
+        store_result = await asyncio.gather(*store_tasks)
+        await node_killer.check_and_kill()
+
+        store_times.append(time.perf_counter() - store_start)
+
+        total_stores += len(store_result)
+        successful_stores_per_iter = sum(store_result)
+        successful_stores.append(successful_stores_per_iter)
+        await asyncio.sleep(delay)
+
+        get_start = time.perf_counter()
+        get_peers = random.sample(peers, min(num_get_peers, len(peers)))
+        get_tasks = [peer.get(key, latest, return_future=True) for peer in get_peers]
+        get_result = await asyncio.gather(*get_tasks)
+        get_times.append(time.perf_counter() - get_start)
+
+        successful_gets_per_iter = 0
+
+        total_gets += len(get_result)
+        for result in get_result:
+            if result != None:
+                attendees, expiration = result
+                if len(attendees.keys()) == successful_stores_per_iter:
+                    get_ok = True
+                    for key in attendees:
+                        if attendees[key][0] != store_values[key]:
+                            get_ok = False
+                            break
+                    successful_gets_per_iter += get_ok
 
+        successful_gets.append(successful_gets_per_iter)
+        await asyncio.sleep(wait_after_iteration)
 
-def benchmark_dht(
+    return store_times, get_times, successful_stores, successful_gets, total_stores, total_gets
+
+
+async def benchmark_dht(
     num_peers: int,
     initial_peers: int,
-    num_experts: int,
-    expert_batch_size: int,
     random_seed: int,
-    wait_after_request: float,
-    wait_before_read: float,
+    num_threads: int,
+    total_num_rounds: int,
+    num_store_peers: int,
+    num_get_peers: int,
+    wait_after_iteration: float,
+    delay: float,
     wait_timeout: float,
     expiration: float,
+    latest: bool,
+    failure_rate: float,
 ):
     random.seed(random_seed)
 
@@ -40,88 +139,92 @@ def benchmark_dht(
         peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
         peers.append(peer)
 
-    store_peer, get_peer = peers[-2:]
-
-    expert_uids = list(
-        set(
-            f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
-            for _ in range(num_experts)
-        )
-    )
-    logger.info(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
-    random.shuffle(expert_uids)
-
-    logger.info(f"Storing experts to dht in batches of {expert_batch_size}...")
-    successful_stores = total_stores = total_store_time = 0
     benchmark_started = time.perf_counter()
-    endpoints = []
-
-    for start in trange(0, num_experts, expert_batch_size):
-        store_start = time.perf_counter()
-        endpoints.append(random_endpoint())
-        store_ok = declare_experts(
-            store_peer, expert_uids[start : start + expert_batch_size], endpoints[-1], expiration=expiration
+    logger.info("Creating store and get tasks...")
+    shutdown_peers = random.sample(peers, min(int(failure_rate * num_peers), num_peers))
+    assert len(shutdown_peers) != len(peers)
+    remaining_peers = list(set(peers) - set(shutdown_peers))
+    shutdown_timestamps = random.sample(
+        range(0, num_threads * total_num_rounds), min(len(shutdown_peers), num_threads * total_num_rounds)
+    )
+    shutdown_timestamps.sort()
+    node_killer = NodeKiller(shutdown_peers, shutdown_timestamps)
+    task_list = [
+        asyncio.create_task(
+            store_and_get_task(
+                remaining_peers,
+                total_num_rounds,
+                num_store_peers,
+                num_get_peers,
+                wait_after_iteration,
+                delay,
+                expiration,
+                latest,
+                node_killer,
+            )
         )
-        successes = store_ok.values()
-        total_store_time += time.perf_counter() - store_start
-
-        total_stores += len(successes)
-        successful_stores += sum(successes)
-        time.sleep(wait_after_request)
+        for _ in trange(num_threads)
+    ]
+
+    store_and_get_result = await asyncio.gather(*task_list)
+    benchmark_total_time = time.perf_counter() - benchmark_started
+    total_store_times = []
+    total_get_times = []
+    total_successful_stores = []
+    total_successful_gets = []
+    total_stores = total_gets = 0
+    for result in store_and_get_result:
+        store_times, get_times, successful_stores, successful_gets, stores, gets = result
+
+        total_store_times.extend(store_times)
+        total_get_times.extend(get_times)
+        total_successful_stores.extend(successful_stores)
+        total_successful_gets.extend(successful_gets)
+        total_stores += stores
+        total_gets += gets
 
+    alive_peers = [peer.is_alive() for peer in peers]
     logger.info(
-        f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})"
+        f"Store wall time (sec.): mean({np.mean(total_store_times):.3f}) "
+        + f"std({np.std(total_store_times, ddof=1):.3f}) max({np.max(total_store_times):.3f})"
     )
-    logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
-    time.sleep(wait_before_read)
-
-    if time.perf_counter() - benchmark_started > expiration:
-        logger.warning("All keys expired before benchmark started getting them. Consider increasing expiration_time")
-
-    successful_gets = total_get_time = 0
-
-    for start in trange(0, len(expert_uids), expert_batch_size):
-        get_start = time.perf_counter()
-        get_result = get_experts(get_peer, expert_uids[start : start + expert_batch_size])
-        total_get_time += time.perf_counter() - get_start
-
-        for i, expert in enumerate(get_result):
-            if (
-                expert is not None
-                and expert.uid == expert_uids[start + i]
-                and expert.endpoint == endpoints[start // expert_batch_size]
-            ):
-                successful_gets += 1
-
-    if time.perf_counter() - benchmark_started > expiration:
-        logger.warning(
-            "keys expired midway during get requests. If that isn't desired, increase expiration_time param"
-        )
-
     logger.info(
-        f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})"
+        f"Get wall time (sec.): mean({np.mean(total_get_times):.3f}) "
+        + f"std({np.std(total_get_times, ddof=1):.3f}) max({np.max(total_get_times):.3f})"
+    )
+    logger.info(f"Average store time per worker: {sum(total_store_times) / num_threads:.3f} sec.")
+    logger.info(f"Average get time per worker: {sum(total_get_times) / num_threads:.3f} sec.")
+    logger.info(f"Total benchmark time: {benchmark_total_time:.5f} sec.")
+    logger.info(
+        "Store success rate: "
+        + f"{sum(total_successful_stores) / total_stores * 100:.1f}% ({sum(total_successful_stores)}/{total_stores})"
+    )
+    logger.info(
+        "Get success rate: "
+        + f"{sum(total_successful_gets) / total_gets * 100:.1f}% ({sum(total_successful_gets)}/{total_gets})"
     )
-    logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
-
-    alive_peers = [peer.is_alive() for peer in peers]
     logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--num_peers", type=int, default=32, required=False)
-    parser.add_argument("--initial_peers", type=int, default=1, required=False)
-    parser.add_argument("--num_experts", type=int, default=256, required=False)
-    parser.add_argument("--expert_batch_size", type=int, default=32, required=False)
-    parser.add_argument("--expiration", type=float, default=300, required=False)
-    parser.add_argument("--wait_after_request", type=float, default=0, required=False)
-    parser.add_argument("--wait_before_read", type=float, default=0, required=False)
+    parser.add_argument("--num_peers", type=int, default=16, required=False)
+    parser.add_argument("--initial_peers", type=int, default=4, required=False)
+    parser.add_argument("--random_seed", type=int, default=30, required=False)
+    parser.add_argument("--num_threads", type=int, default=10, required=False)
+    parser.add_argument("--total_num_rounds", type=int, default=16, required=False)
+    parser.add_argument("--num_store_peers", type=int, default=8, required=False)
+    parser.add_argument("--num_get_peers", type=int, default=8, required=False)
+    parser.add_argument("--wait_after_iteration", type=float, default=0, required=False)
+    parser.add_argument("--delay", type=float, default=0, required=False)
     parser.add_argument("--wait_timeout", type=float, default=5, required=False)
-    parser.add_argument("--random_seed", type=int, default=random.randint(1, 1000))
+    parser.add_argument("--expiration", type=float, default=300, required=False)
+    parser.add_argument("--latest", type=bool, default=True, required=False)
+    parser.add_argument("--failure_rate", type=float, default=0.1, required=False)
     parser.add_argument("--increase_file_limit", action="store_true")
     args = vars(parser.parse_args())
 
     if args.pop("increase_file_limit", False):
         increase_file_limit()
 
-    benchmark_dht(**args)
+    asyncio.run(benchmark_dht(**args))

+ 162 - 0
benchmarks/benchmark_optimizer.py

@@ -0,0 +1,162 @@
+import multiprocessing as mp
+import random
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable
+
+import torch
+import torchvision
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.utils.data import Dataset
+
+import hivemind
+from hivemind.optim.optimizer import Optimizer
+from hivemind.utils.crypto import RSAPrivateKey
+
+
+@dataclass(frozen=True)
+class TrainingArguments:
+    seed: int = 42
+    run_id: str = "my_exp"
+
+    num_peers: int = 8
+    num_clients: int = 3
+    target_batch_size: int = 256
+    reuse_grad_buffers: bool = True
+    delay_grad_averaging: bool = True
+    delay_optimizer_step: bool = True
+    average_state_every: int = 1
+    use_amp: bool = False
+
+    lr_base: float = 0.1
+    lr_gamma: int = 0.1
+    lr_step_size: int = 10
+    max_epoch: int = 25
+
+    batch_size_min: int = 2
+    batch_size_max: int = 16
+    batch_time_min: float = 1.0
+    batch_time_max: float = 4.5
+    batch_time_std: float = 0.5
+
+    matchmaking_time: float = 5.0
+    max_refresh_period: float = 5.0
+    averaging_timeout: float = 15.0
+    winddown_time: float = 5.0
+    verbose: bool = True
+
+    device: str = "cpu"
+    make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
+    make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
+        nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
+    )
+
+
+def benchmark_optimizer(args: TrainingArguments):
+    random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    torch.set_num_threads(1)
+
+    dht = hivemind.DHT(start=True)
+
+    train_dataset = args.make_dataset()
+    num_features = train_dataset.data[0].numel()
+    num_classes = len(train_dataset.classes)
+    X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
+    X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
+    y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
+    del train_dataset
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
+        model = args.make_model(num_features, num_classes).to(args.device)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id=args.run_id,
+            target_batch_size=args.target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=args.lr_base),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
+            matchmaking_time=args.matchmaking_time,
+            averaging_timeout=args.averaging_timeout,
+            reuse_grad_buffers=args.reuse_grad_buffers,
+            delay_grad_averaging=args.delay_grad_averaging,
+            delay_optimizer_step=args.delay_optimizer_step,
+            average_state_every=args.average_state_every,
+            client_mode=client_mode,
+            verbose=verbose,
+        )
+
+        if args.use_amp and args.reuse_grad_buffers:
+            grad_scaler = hivemind.GradScaler()
+        else:
+            # check that hivemind.Optimizer supports regular PyTorch grad scaler as well
+            grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < args.max_epoch:
+            time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
+
+            batch = torch.randint(0, len(X_train), (batch_size,))
+
+            with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
+                loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
+                grad_scaler.scale(loss).backward()
+
+            grad_scaler.unscale_(optimizer)
+
+            if args.use_amp:
+                grad_scaler.step(optimizer)
+            else:
+                optimizer.step()
+
+            grad_scaler.update()
+
+            if not args.reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(args.winddown_time)
+        optimizer.shutdown()
+
+    peers = []
+
+    for index in range(args.num_peers):
+        batch_size = random.randint(args.batch_size_min, args.batch_size_max)
+        batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                daemon=False,
+                kwargs=dict(
+                    batch_size=batch_size,
+                    batch_time=batch_time,
+                    client_mode=(index >= args.num_peers - args.num_clients),
+                    verbose=args.verbose and (index == 0),
+                ),
+            )
+        )
+
+    try:
+        for peer in peers[1:]:
+            peer.start()
+        peers[0].run()
+        for peer in peers[1:]:
+            peer.join()
+    finally:
+        for peer in peers[1:]:
+            peer.kill()
+
+
+if __name__ == "__main__":
+    benchmark_optimizer(TrainingArguments())

+ 3 - 2
benchmarks/benchmark_tensor_compression.py

@@ -3,10 +3,11 @@ import time
 
 import torch
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 16 - 15
benchmarks/benchmark_throughput.py

@@ -6,12 +6,15 @@ import time
 
 import torch
 
-import hivemind
-from hivemind import find_open_port
-from hivemind.moe.server import layers
+from hivemind.moe.client import RemoteExpert
+from hivemind.moe.server import ExpertBackend, Server
+from hivemind.moe.server.layers import name_to_block
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import LOCALHOST, get_free_port
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
@@ -31,9 +34,7 @@ def print_device_info(device=None):
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
     torch.set_num_threads(1)
     can_start.wait()
-    experts = [
-        hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
-    ]
+    experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -65,8 +66,8 @@ def benchmark_throughput(
         or not torch.cuda.is_initialized()
         or torch.device(device) == torch.device("cpu")
     )
-    assert expert_cls in layers.name_to_block
-    port = port or find_open_port()
+    assert expert_cls in name_to_block
+    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
@@ -104,20 +105,20 @@ def benchmark_throughput(
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         for i in range(num_experts):
-            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = hivemind.ExpertBackend(
+            expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
+            experts[f"expert{i}"] = ExpertBackend(
                 name=f"expert{i}",
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
-                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
-                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
+                args_schema=(BatchTensorDescriptor(hid_dim),),
+                outputs_schema=BatchTensorDescriptor(hid_dim),
                 max_batch_size=max_batch_size,
             )
         timestamps["created_experts"] = time.perf_counter()
-        server = hivemind.moe.Server(
+        server = Server(
             None,
             experts,
-            listen_on=f"{hivemind.LOCALHOST}:{port}",
+            listen_on=f"{LOCALHOST}:{port}",
             num_connection_handlers=num_handlers,
             device=device,
         )

BIN
docs/_static/dht.odp


BIN
docs/_static/dht.png


+ 1 - 1
docs/conf.py

@@ -203,7 +203,7 @@ todo_include_todos = True
 
 
 def setup(app):
-    app.add_stylesheet("fix_rtd.css")
+    app.add_css_file("fix_rtd.css")
     app.add_config_value(
         "recommonmark_config",
         {

+ 3 - 3
docs/index.rst

@@ -9,9 +9,9 @@ of computers, whether you're running a very capable computer or a less reliable
 Learn how to create or join a Hivemind run in the `quickstart tutorial <./user/quickstart.html>`__ or browse the API
 documentation below.
 
-| Hivemind is currently in active development, so expect some adventures. If you encounter any issues, please let us know
-  `on github <https://github.com/learning-at-home/hivemind/issues>`__.
-
+| Hivemind is currently in active development, so expect some adventures. If you have any questions, feel free to ask them
+  in `our Discord chat <https://discord.com/invite/uGugx9zYvN>`_ or
+  file an `issue <https://github.com/learning-at-home/hivemind/issues>`__.
 
 **Table of contents:**
 ~~~~~~~~~~~~~~~~~~~~~~

+ 26 - 4
docs/modules/optim.rst

@@ -1,14 +1,36 @@
 **hivemind.optim**
 ==================
 
-.. automodule:: hivemind.optim
-.. currentmodule:: hivemind.optim
-
 .. raw:: html
 
-  This module contains decentralized optimizers that wrap regular pytorch optimizers to collaboratively train a shared model. Depending on the exact type, optimizer may average model parameters with peers, exchange gradients, or follow a more complicated distributed training strategy.
+  This module contains decentralized optimizers that wrap your regular PyTorch Optimizer to train with peers.
+  Depending on the exact configuration, Optimizer may perform large synchronous updates equivalent,
+  or perform asynchrnous local updates and average model parameters.
+
   <br><br>
 
+.. automodule:: hivemind.optim.optimizer
+.. currentmodule:: hivemind.optim.optimizer
+
+**hivemind.Optimizer**
+----------------------
+
+.. autoclass:: Optimizer
+   :members: step, local_epoch, zero_grad, load_state_from_peers, param_groups, shutdown
+   :member-order: bysource
+
+.. currentmodule:: hivemind.optim.grad_scaler
+.. autoclass:: GradScaler
+   :member-order: bysource
+
+
+**CollaborativeOptimizer**
+--------------------------
+
+
+.. automodule:: hivemind.optim.collaborative
+.. currentmodule:: hivemind.optim
+
 .. autoclass:: CollaborativeOptimizer
    :members: step
    :member-order: bysource

+ 8 - 6
docs/modules/server.rst

@@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo
 The hivemind.moe.server module is organized as follows:
 
 - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
-- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
 - ExpertBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
   that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
+- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
 - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
   and offers those batches to Runtime_ for processing.
 
@@ -25,16 +25,18 @@ The hivemind.moe.server module is organized as follows:
    :members:
    :member-order: bysource
 
-.. _Runtime:
-.. autoclass:: Runtime
-    :members:
-    :member-order: bysource
-
 .. _ExpertBackend:
 .. autoclass:: ExpertBackend
     :members: forward, backward, apply_gradients, get_info, get_pools
     :member-order: bysource
 
+.. currentmodule:: hivemind.moe.server.runtime
+
+.. _Runtime:
+.. autoclass:: Runtime
+    :members:
+    :member-order: bysource
+
 .. currentmodule:: hivemind.moe.server.task_pool
 
 .. _TaskPool:

+ 1 - 1
docs/user/dht.md

@@ -18,7 +18,7 @@ dht2 = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
 ```
 
 Note that `initial_peers` contains the address of the first DHT node.
-This implies that the resulting node will have shared key-value with the first node, __as well as any other
+This implies that the new node will share the key-value data with the first node, __as well as any other
 nodes connected to it.__ When the two nodes are connected, subsequent peers can use any one of them (or both)
 as `initial_peers` to connect to the shared "dictionary".
 

+ 35 - 32
docs/user/quickstart.md

@@ -11,7 +11,7 @@ You can also install the bleeding edge version from GitHub:
 ```
 git clone https://github.com/learning-at-home/hivemind
 cd hivemind
-pip install -e .
+pip install -e . --no-use-pep517
 ```
  
 ## Decentralized Training
@@ -47,26 +47,27 @@ model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(),
                       nn.Flatten(), nn.Linear(32 * 5 * 5, 10))
 opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
-
 # Create DHT: a decentralized key-value storage shared between peers
 dht = hivemind.DHT(start=True)
 print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])
 
 # Set up a decentralized optimizer that will average with peers in background
-opt = hivemind.optim.DecentralizedOptimizer(
-    opt,                      # wrap the SGD optimizer defined above
-    dht,                      # use a DHT that is connected with other peers
-    average_parameters=True,  # periodically average model weights in opt.step
-    average_gradients=False,  # do not average accumulated gradients
-    prefix='my_cifar_run',    # unique identifier of this collaborative run
-    target_group_size=16,     # maximum concurrent peers for this run
+opt = hivemind.Optimizer(
+    dht=dht,                  # use a DHT that is connected with other peers
+    run_id='my_cifar_run',    # unique identifier of this collaborative run
+    batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
+    target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch 
+    optimizer=opt,            # wrap the SGD optimizer defined above
+    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
+    matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
+    averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
 )
-# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
 
+# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
 with tqdm() as progressbar:
     while True:
-        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256):
+        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=32):
             opt.zero_grad()
             loss = F.cross_entropy(model(x_batch), y_batch)
             loss.backward()
@@ -78,7 +79,7 @@ with tqdm() as progressbar:
 
 
 As you can see, this code is regular PyTorch with one notable exception: it wraps your regular optimizer with a
-`DecentralizedOptimizer`. This optimizer uses `DHT` to find other peers and tries to exchange weights them. When you run
+`hivemind.Optimizer`. This optimizer uses `DHT` to find other peers and tries to exchange parameters them. When you run
 the code (please do so), you will see the following output:
 
 ```shell
@@ -86,7 +87,7 @@ To join the training, use initial_peers = ['/ip4/127.0.0.1/tcp/XXX/p2p/YYY']
 [...] Starting a new averaging round with current parameters.
 ```
 
-This is `DecentralizedOptimizer` telling you that it's looking for peers. Since there are no peers, we'll need to create 
+This is `hivemind.Optimizer` telling you that it's looking for peers. Since there are no peers, we'll need to create 
 them ourselves.
 
 Copy the entire script (or notebook) and modify this line:
@@ -123,26 +124,28 @@ model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(),
 opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
 # Create DHT: a decentralized key-value storage shared between peers
-dht = hivemind.DHT(initial_peers=[COPY_FROM_ANOTHER_PEER_OUTPUTS], start=True)
+dht = hivemind.DHT(initial_peers=[COPY_FROM_OTHER_PEERS_OUTPUTS], start=True)
 print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])
 
 # Set up a decentralized optimizer that will average with peers in background
-opt = hivemind.optim.DecentralizedOptimizer(
-    opt,                      # wrap the SGD optimizer defined above
-    dht,                      # use a DHT that is connected with other peers
-    average_parameters=True,  # periodically average model weights in opt.step
-    average_gradients=False,  # do not average accumulated gradients
-    prefix='my_cifar_run',    # unique identifier of this collaborative run
-    target_group_size=16,     # maximum concurrent peers for this run
+opt = hivemind.Optimizer(
+    dht=dht,                  # use a DHT that is connected with other peers
+    run_id='my_cifar_run',    # unique identifier of this collaborative run
+    batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
+    target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch
+    optimizer=opt,            # wrap the SGD optimizer defined above
+    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
+    matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
+    averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
 )
 
-opt.averager.load_state_from_peers()
+opt.load_state_from_peers()
 
-# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
+# Note: if you intend to use GPU, switch to it only after the optimizer is created
 with tqdm() as progressbar:
     while True:
-        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256):
+        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=32):
             opt.zero_grad()
             loss = F.cross_entropy(model(x_batch), y_batch)
             loss.backward()
@@ -166,22 +169,22 @@ This message means that the optimizer has averaged model parameters with another
 during one of the calls to `opt.step()`. You can start more peers by replicating the same code as the second peer,
 using either the first or second peer as `initial_peers`.
 
-The only issue with this code is that each new peer starts with a different untrained network blends its un-trained
-parameters with other peers, reseting their progress. You can see this effect as a spike increase in training loss
-immediately after new peer joins training. To avoid this problem, the second peer can download the
-current model/optimizer state from an existing peer right before it begins training on minibatches:
+Each new peer starts with an untrained network and must download the latest training state before it can contribute.
+By default, peer will automatically detect that it is out of sync and start ``Downloading parameters from peer <...>``.
+To avoid wasting the first optimizer step, one can manually download the latest model/optimizer state right before it begins training on minibatches:
 ```python
-opt.averager.load_state_from_peers()
+opt.load_state_from_peers()
 ```
 
 Congrats, you've just started a pocket-sized experiment with decentralized deep learning!
 
-However, this is just the bare minimum of what hivemind can do. In [this example](https://github.com/learning-at-home/hivemind/tree/master/examples/albert),
+However, this is only the basics of what hivemind can do. In [this example](https://github.com/learning-at-home/hivemind/tree/master/examples/albert),
 we show how to use a more advanced version of DecentralizedOptimizer to collaboratively train a large Transformer over the internet.
 
 If you want to learn more about each individual component,
 - Learn how to use `hivemind.DHT` using this basic [DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html),
-- Learn the underlying math behind DecentralizedOptimizer in
-  [(Li et al. 2020)](https://arxiv.org/abs/2005.00124) and [(Ryabinin et al. 2021)](https://arxiv.org/abs/2103.03239).
+- Read more on how to use `hivemind.Optimizer` in its [documentation page](https://learning-at-home.readthedocs.io/en/latest/modules/optim.html), 
+- Learn the underlying math behind hivemind.Optimizer in [Diskin et al., (2021)](https://arxiv.org/abs/2106.10207), 
+  [Li et al. (2020)](https://arxiv.org/abs/2005.00124) and [Ryabinin et al. (2021)](https://arxiv.org/abs/2103.03239).
 - Read about setting up Mixture-of-Experts training in [this guide](https://learning-at-home.readthedocs.io/en/latest/user/moe.html),
  

+ 32 - 32
examples/albert/README.md

@@ -9,7 +9,7 @@ using `hivemind.CollaborativeOptimizer` to exchange information between peers.
 
 * Install hivemind: `pip install git+https://github.com/learning-at-home/hivemind.git`
 * Dependencies: `pip install -r requirements.txt`
-* Preprocess data: `python tokenize_wikitext103.py`
+* Preprocess data: `./tokenize_wikitext103.py`
 * Upload the data to a publicly available location or ask volunteers to preprocess it locally
 
 ## Running an experiment
@@ -20,16 +20,16 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,
 
 - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics. If you're unfamiliar with Weights
   & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
-- Run `python run_training_monitor.py --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
-- `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.`
-  due to naming conventions.
-- `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the
-  same project name.
+- Run `./run_training_monitor.py --wandb_project YOUR_WANDB_PROJECT`
+
+  - `YOUR_WANDB_PROJECT` is a name of wandb project used to track training metrics. Multiple experiments can have the
+    same project name.
 
 ```
-$ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
-[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:54] Running a DHT peer. To connect other peers to this one over the Internet, 
+$ ./run_training_monitor.py --wandb_project Demo-run
+Oct 14 16:26:36.083 [INFO] Running a DHT peer. To connect other peers to this one over the Internet,
 use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
+Oct 14 16:26:36.083 [INFO] Full list of visible multiaddresses: ...
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32
 wandb: Syncing run dry-mountain-2
@@ -37,12 +37,12 @@ wandb:  View project at https://wandb.ai/XXX/Demo-run
 wandb:  View run at https://wandb.ai/XXX/Demo-run/runs/YYY
 wandb: Run data is saved locally in /path/to/run/data
 wandb: Run `wandb offline` to turn off syncing.
-[2021/04/19 02:26:41.064][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
-[2021/04/19 02:26:44.068][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
+Oct 14 16:26:41.064 [INFO] Found no active peers: None
+Oct 14 16:26:44.068 [INFO] Found no active peers: None
 ...
-[2021/04/19 02:37:37.246][INFO][__main__.<module>:194] Step #1  loss = 11.05164
-[2021/04/19 02:39:37.441][INFO][__main__.<module>:194] Step #2  loss = 11.03771
-[2021/04/19 02:40:37.541][INFO][__main__.<module>:194] Step #3  loss = 11.02886
+Oct 14 16:37:37.246 [INFO] Step #1  loss = 11.05164
+Oct 14 16:39:37.441 [INFO] Step #2  loss = 11.03771
+Oct 14 16:40:37.541 [INFO] Step #3  loss = 11.02886
 ```
 
 ### GPU trainers
@@ -55,9 +55,9 @@ To join the collaboration with a GPU trainer,
   (see [default paths](./arguments.py#L117-L134) for reference)
 - Run:
   ```bash
-  python run_trainer.py \
-  --experiment_prefix SAME_AS_IN_RUN_TRAINING_MONITOR --initial_peers ONE_OR_MORE_PEERS --seed 42 \
-  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
+  ./run_trainer.py \
+      --initial_peers ONE_OR_MORE_PEERS \
+      --logging_first_step --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
   ```
 
   Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
@@ -87,17 +87,18 @@ See the ["Tips and tricks"](#tips-and-tricks) section for more information on se
 As the peer begins training, it will periodically report training logs in the following form:
 
 ```
-[...][INFO][...] Collaboration accumulated 448 samples from 17 peers; ETA 18.88 seconds (refresh in 15.73s.)
-[...][INFO][...] Collaboration accumulated 4096 samples from 16 peers; ETA 0.00 seconds (refresh in 0.50s.)
-[...][INFO][optim.collaborative.step:195] Averaged tensors successfully with 17 peers
-[...][INFO][optim.collaborative.step:211] Optimizer step: done!
-06/17/2021 18:58:23 - INFO - __main__ -   Step 0
-06/17/2021 18:58:23 - INFO - __main__ -   Your current contribution: 892 samples
-06/17/2021 18:58:23 - INFO - __main__ -   Local loss: 11.023
-
+Dec 28 00:15:31.482 [INFO] albert accumulated 4056 samples for epoch #0 from 2 peers. ETA 0.75 sec (refresh in 0.50 sec)
+Dec 28 00:15:31.990 [INFO] albert accumulated 4072 samples for epoch #0 from 2 peers. ETA 0.24 sec (refresh in 0.50 sec)
+...
+Dec 28 00:15:32.857 [INFO] Step #1
+Dec 28 00:15:32.857 [INFO] Your current contribution: 2144 samples
+Dec 28 00:15:32.857 [INFO] Performance: 20.924 samples/sec
+Dec 28 00:15:32.857 [INFO] Local loss: 11.06709
+Dec 28 00:15:33.580 [INFO] Averaged gradients with 2 peers
+Dec 28 00:15:38.336 [INFO] Averaged parameters with 2 peers
 ```
 
-__Sanity check:__ a healthy peer will periodically report `Averaged tensors successfully with [N > 1]` peers.
+__Sanity check:__ a healthy peer will periodically report `Averaged gradients/parameters with [N > 1]` peers.
 
 For convenience, you can view (and share!) the learning curves of your collaborative experiments in wandb:
 
@@ -135,7 +136,7 @@ incoming connections (e.g. when in colab or behind a firewall), add `--client_mo
 below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or
 set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
 be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
-recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`. 
+recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`.
 
 The example trainer supports
 multiple GPUs via DataParallel. However, using advanced distributed training strategies (
@@ -155,7 +156,7 @@ collaborative experiment. Here's how to best use them:
 - Most free GPUs are running behind a firewall, which requires you to run trainer with `--client_mode` (see example
   below). Such peers can only exchange gradients if there is at least one non-client-mode peer (GPU server or desktop
   with public IP). We recommend using a few preemptible instances with the cheapest GPU you can find. For example, we
-  tested this code on preemptible 
+  tested this code on preemptible
   [`g4dn.xlarge`](https://aws.amazon.com/blogs/aws/now-available-ec2-instances-g4-with-nvidia-t4-tensor-core-gpus/)
   nodes for around $0.15/h apiece with 8 AWS nodes and up to 61 Colab/Kaggle participants.
 - You can create starter notebooks to make it more convenient for collaborators to join your training
@@ -168,11 +169,10 @@ Here's an example of a full trainer script for Google Colab:
 !pip install transformers datasets sentencepiece torch_optimizer==0.1.0
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !curl -L YOUR_HOSTED_DATA | tar xzf -
-!ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
- --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
- --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
- --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \
- --experiment_prefix EXPERIMENT_NAME_HERE --seed 42
+!ulimit -n 4096 && ./hivemind/examples/albert/run_trainer.py \
+    --initial_peers ONE_OR_MORE_PEERS \
+    --logging_dir ./logs --logging_first_step --output_dir ./outputs --overwrite_output_dir \
+    --client_mode --averaging_expiration 10 --batch_size_lead 300 --gradient_accumulation_steps 1
 ```
 
 ### Using IPFS

+ 23 - 19
examples/albert/arguments.py

@@ -7,13 +7,13 @@ from transformers import TrainingArguments
 @dataclass
 class BaseTrainingArguments:
     experiment_prefix: str = field(
-        metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
+        default="albert", metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
     )
     initial_peers: List[str] = field(
         default_factory=list,
         metadata={
             "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
-            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
+            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY"
         },
     )
     use_ipfs: bool = field(
@@ -24,27 +24,32 @@ class BaseTrainingArguments:
         },
     )
     host_maddrs: List[str] = field(
-        default_factory=lambda: ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
+        default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"],
         metadata={
             "help": "Multiaddrs to listen for external connections from other p2p instances. "
-            "Defaults to all IPv4 interfaces with TCP and QUIC (over UDP) protocols: "
-            "/ip4/0.0.0.0/tcp/0 /ip4/0.0.0.0/udp/0/quic"
+            "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0"
         },
     )
     announce_maddrs: List[str] = field(
         default_factory=list,
         metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
     )
+    identity_path: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
+            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+        },
+    )
 
 
 @dataclass
 class AveragerArguments:
-    averaging_expiration: float = field(
-        default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
-    )
-    averaging_timeout: float = field(
-        default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
-    )
+    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
+
+
+@dataclass
+class ProgressTrackerArguments:
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
     )
@@ -60,17 +65,13 @@ class AveragerArguments:
     expected_drift_rate: float = field(
         default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
     )
-    performance_ema_alpha: float = field(
-        default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
-    )
-    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
     metadata_expiration: float = field(
         default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
     )
 
 
 @dataclass
-class CollaborativeOptimizerArguments:
+class OptimizerArguments:
     target_batch_size: int = field(
         default=4096,
         metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
@@ -87,13 +88,16 @@ class CollaborativeOptimizerArguments:
         default=100.0,
         metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
     )
-    compression: str = field(
-        default="FLOAT16", metadata={"help": "Use this compression when averaging parameters/gradients"}
+    averaging_timeout: float = field(
+        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
+    )
+    matchmaking_time: float = field(
+        default=5.0, metadata={"help": "When looking for group, wait for requests for at least this many seconds"}
     )
 
 
 @dataclass
-class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArguments):
+class CollaborationArguments(OptimizerArguments, BaseTrainingArguments):
     statistics_expiration: float = field(
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )

+ 5 - 5
examples/albert/requirements.txt

@@ -1,7 +1,7 @@
-transformers>=4.6.0
-datasets>=1.5.0
-torch_optimizer>=0.1.0
-wandb>=0.10.26
+transformers==4.6.0
+datasets==1.5.0
+torch_optimizer==0.1.0
+wandb==0.10.26
 sentencepiece
 requests
-nltk>=3.6.2
+nltk==3.6.7

+ 108 - 96
examples/albert/run_trainer.py

@@ -1,8 +1,8 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
-import logging
 import os
 import pickle
+import sys
 from dataclasses import asdict
 from pathlib import Path
 
@@ -17,34 +17,29 @@ from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 
-import hivemind
-from hivemind.utils.compression import CompressionType
+from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
-from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
+from arguments import (
+    AlbertTrainingArguments,
+    AveragerArguments,
+    CollaborationArguments,
+    DatasetArguments,
+    ProgressTrackerArguments,
+)
 
-logger = logging.getLogger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
 
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 
-def setup_logging(training_args):
-    logging.basicConfig(
-        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
-        datefmt="%m/%d/%Y %H:%M:%S",
-        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
-    )
 
-    # Log on each process the small summary:
-    logger.warning(
-        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
-        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
-    )
-    # Set the verbosity to info of the Transformers logger (on main process only):
-    if is_main_process(training_args.local_rank):
+def setup_transformers_logging(process_rank: int):
+    if is_main_process(process_rank):
         transformers.utils.logging.set_verbosity_info()
-        transformers.utils.logging.enable_default_handler()
-        transformers.utils.logging.enable_explicit_format()
-    logger.info("Training/evaluation parameters %s", training_args)
+        transformers.utils.logging.disable_default_handler()
+        transformers.utils.logging.enable_propagation()
 
 
 def get_model(training_args, config, tokenizer):
@@ -64,36 +59,6 @@ def get_model(training_args, config, tokenizer):
     return model
 
 
-def get_optimizer_and_scheduler(training_args, model):
-    no_decay = ["bias", "LayerNorm.weight"]
-    optimizer_grouped_parameters = [
-        {
-            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
-            "weight_decay": training_args.weight_decay,
-        },
-        {
-            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
-            "weight_decay": 0.0,
-        },
-    ]
-
-    opt = Lamb(
-        optimizer_grouped_parameters,
-        lr=training_args.learning_rate,
-        betas=(training_args.adam_beta1, training_args.adam_beta2),
-        eps=training_args.adam_epsilon,
-        weight_decay=training_args.weight_decay,
-        clamp_value=training_args.clamp_value,
-        debias=True,
-    )
-
-    scheduler = get_linear_schedule_with_warmup(
-        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
-    )
-
-    return opt, scheduler
-
-
 class CollaborativeCallback(transformers.TrainerCallback):
     """
     This callback monitors and reports collaborative training progress.
@@ -102,8 +67,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
     def __init__(
         self,
-        dht: hivemind.DHT,
-        optimizer: hivemind.CollaborativeOptimizer,
+        dht: DHT,
+        optimizer: Optimizer,
         model: torch.nn.Module,
         local_public_key: bytes,
         statistics_expiration: float,
@@ -111,7 +76,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
     ):
         super().__init__()
         self.model = model
-        self.dht, self.collaborative_optimizer = dht, optimizer
+        self.dht, self.optimizer = dht, optimizer
         self.local_public_key = local_public_key
         self.statistics_expiration = statistics_expiration
         self.last_reported_collaboration_step = -1
@@ -126,7 +91,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
     ):
         logger.info("Loading state from peers")
-        self.collaborative_optimizer.load_state_from_peers()
+        self.optimizer.load_state_from_peers()
 
     def on_step_end(
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
@@ -136,40 +101,43 @@ class CollaborativeCallback(transformers.TrainerCallback):
             self.restore_from_backup(self.latest_backup)
             return control
 
+        local_progress = self.optimizer.local_progress
+
         if state.log_history:
             self.loss += state.log_history[-1]["loss"]
             self.steps += 1
-            if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
-                self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
+
+            if self.optimizer.local_epoch != self.last_reported_collaboration_step:
+                self.last_reported_collaboration_step = self.optimizer.local_epoch
                 self.total_samples_processed += self.samples
-                samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
+                samples_per_second = local_progress.samples_per_second
                 statistics = utils.LocalMetrics(
-                    step=self.collaborative_optimizer.local_step,
+                    step=self.optimizer.local_epoch,
                     samples_per_second=samples_per_second,
                     samples_accumulated=self.samples,
                     loss=self.loss,
                     mini_steps=self.steps,
                 )
-                logger.info(f"Step {self.collaborative_optimizer.local_step}")
+                logger.info(f"Step #{self.optimizer.local_epoch}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
-                logger.info(f"Performance: {samples_per_second} samples per second.")
+                logger.info(f"Performance: {samples_per_second:.3f} samples/sec")
                 if self.steps:
-                    logger.info(f"Local loss: {self.loss / self.steps}")
-                if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
+                    logger.info(f"Local loss: {self.loss / self.steps:.5f}")
+                if self.optimizer.local_epoch % self.backup_every_steps == 0:
                     self.latest_backup = self.backup_state()
 
                 self.loss = 0
                 self.steps = 0
-                if self.collaborative_optimizer.is_synchronized:
+                if self.optimizer.is_synchronized_with_peers():
                     self.dht.store(
-                        key=self.collaborative_optimizer.prefix + "_metrics",
+                        key=self.optimizer.run_id + "_metrics",
                         subkey=self.local_public_key,
                         value=statistics.dict(),
-                        expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                        expiration_time=get_dht_time() + self.statistics_expiration,
                         return_future=True,
                     )
 
-        self.samples = self.collaborative_optimizer.local_samples_accumulated
+        self.samples = local_progress.samples_accumulated
 
         return control
 
@@ -182,19 +150,17 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
     @torch.no_grad()
     def backup_state(self) -> bytes:
-        return pickle.dumps(
-            {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
-        )
+        return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})
 
     @torch.no_grad()
     def restore_from_backup(self, backup: bytes):
         state = pickle.loads(backup)
         self.model.load_state_dict(state["model"])
-        self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
+        self.optimizer.load_state_dict(state["optimizer"])
 
 
 class NoOpScheduler(LRSchedulerBase):
-    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
+    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""
 
     def get_lr(self):
         return [group["lr"] for group in self.optimizer.param_groups]
@@ -204,7 +170,6 @@ class NoOpScheduler(LRSchedulerBase):
             return self.optimizer.scheduler.print_lr(*args, **kwargs)
 
     def step(self):
-        logger.debug("Called NoOpScheduler.step")
         self._last_lr = self.get_lr()
 
     def state_dict(self):
@@ -215,20 +180,34 @@ class NoOpScheduler(LRSchedulerBase):
 
 
 def main():
-    parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
-    training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
-
+    parser = HfArgumentParser(
+        (
+            AlbertTrainingArguments,
+            DatasetArguments,
+            CollaborationArguments,
+            AveragerArguments,
+            ProgressTrackerArguments,
+        )
+    )
+    training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
     logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
-    if len(collaboration_args.initial_peers) == 0:
-        raise ValueError("Please specify at least one network endpoint in initial peers.")
 
-    setup_logging(training_args)
+    setup_transformers_logging(training_args.local_rank)
+    logger.info(f"Training/evaluation parameters:\n{training_args}")
 
     # Set seed before initializing model.
     set_seed(training_args.seed)
 
     config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
-    tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
+    try:
+        tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
+    except OSError:
+        logger.fatal(
+            f"No tokenizer data found in {dataset_args.tokenizer_path}, "
+            f"please run ./tokenize_wikitext103.py before running this"
+        )
+        sys.exit(1)
+
     model = get_model(training_args, config, tokenizer)
     model.to(training_args.device)
 
@@ -236,11 +215,9 @@ def main():
     # This data collator will take care of randomly masking the tokens.
     data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
 
-    opt, scheduler = get_optimizer_and_scheduler(training_args, model)
-
     validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
 
-    dht = hivemind.DHT(
+    dht = DHT(
         start=True,
         initial_peers=collaboration_args.initial_peers,
         client_mode=collaboration_args.client_mode,
@@ -248,6 +225,7 @@ def main():
         use_ipfs=collaboration_args.use_ipfs,
         host_maddrs=collaboration_args.host_maddrs,
         announce_maddrs=collaboration_args.announce_maddrs,
+        identity_path=collaboration_args.identity_path,
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
 
@@ -257,19 +235,53 @@ def main():
 
     adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
 
-    collaborative_optimizer = hivemind.CollaborativeOptimizer(
-        opt=opt,
+    # We need to make such a lambda function instead of just an optimizer instance
+    # to make hivemind.Optimizer(..., offload_optimizer=True) work
+    opt = lambda params: Lamb(
+        params,
+        lr=training_args.learning_rate,
+        betas=(training_args.adam_beta1, training_args.adam_beta2),
+        eps=training_args.adam_epsilon,
+        weight_decay=training_args.weight_decay,
+        clamp_value=training_args.clamp_value,
+        debias=True,
+    )
+
+    no_decay = ["bias", "LayerNorm.weight"]
+    params = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+            "weight_decay": training_args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+            "weight_decay": 0.0,
+        },
+    ]
+
+    scheduler = lambda opt: get_linear_schedule_with_warmup(
+        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
+    )
+
+    optimizer = Optimizer(
         dht=dht,
-        scheduler=scheduler,
-        prefix=collaboration_args.experiment_prefix,
-        compression_type=CompressionType.Value(collaboration_args.compression),
-        batch_size_per_step=total_batch_size_per_step,
-        bandwidth=collaboration_args.bandwidth,
+        run_id=collaboration_args.experiment_prefix,
         target_batch_size=adjusted_target_batch_size,
+        batch_size_per_step=total_batch_size_per_step,
+        optimizer=opt,
+        params=params,
+        scheduler=scheduler,
+        matchmaking_time=collaboration_args.matchmaking_time,
+        averaging_timeout=collaboration_args.averaging_timeout,
+        offload_optimizer=True,
+        delay_optimizer_step=True,
+        delay_grad_averaging=True,
         client_mode=collaboration_args.client_mode,
+        grad_compression=Float16Compression(),
+        state_averaging_compression=Float16Compression(),
+        averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
+        tracker_opts=asdict(tracker_args),
         verbose=True,
-        start=True,
-        **asdict(averager_args),
     )
 
     class TrainerWithIndependentShuffling(Trainer):
@@ -285,11 +297,11 @@ def main():
         data_collator=data_collator,
         train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
-        optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
+        optimizers=(optimizer, NoOpScheduler(optimizer)),
         callbacks=[
             CollaborativeCallback(
                 dht,
-                collaborative_optimizer,
+                optimizer,
                 model,
                 local_public_key,
                 collaboration_args.statistics_expiration,

+ 23 - 25
examples/albert/run_training_monitor.py

@@ -1,6 +1,5 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
-import logging
 import time
 from dataclasses import asdict, dataclass, field
 from ipaddress import ip_address
@@ -13,12 +12,14 @@ from torch_optimizer import Lamb
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 import hivemind
-from hivemind.utils.compression import CompressionType
+from hivemind.optim.state_averager import TrainingStateAverager
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
-from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
+from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments
 
-logger = logging.getLogger(__name__)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
 
 
 @dataclass
@@ -55,14 +56,14 @@ class TrainingMonitorArguments(BaseTrainingArguments):
     upload_interval: Optional[float] = field(
         default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
     )
-    store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
+    store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
 
 
 class CheckpointHandler:
     def __init__(
         self,
         monitor_args: TrainingMonitorArguments,
-        collab_optimizer_args: CollaborativeOptimizerArguments,
+        optimizer_args: OptimizerArguments,
         averager_args: AveragerArguments,
         dht: hivemind.DHT,
     ):
@@ -95,17 +96,13 @@ class CheckpointHandler:
             debias=True,
         )
 
-        adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
-
-        self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
-            opt=opt,
+        self.state_averager = TrainingStateAverager(
             dht=dht,
+            optimizer=opt,
             prefix=experiment_prefix,
-            compression_type=CompressionType.Value(collab_optimizer_args.compression),
-            bandwidth=collab_optimizer_args.bandwidth,
-            target_batch_size=adjusted_target_batch_size,
-            client_mode=collab_optimizer_args.client_mode,
-            verbose=True,
+            state_compression=hivemind.Float16Compression(),
+            bandwidth=optimizer_args.bandwidth,
+            client_mode=optimizer_args.client_mode,
             start=True,
             **asdict(averager_args),
         )
@@ -121,7 +118,7 @@ class CheckpointHandler:
 
     def save_state(self, cur_step):
         logger.info("Saving state from peers")
-        self.collaborative_optimizer.load_state_from_peers()
+        self.state_averager.load_state_from_peers()
         self.previous_step = cur_step
 
     def is_time_to_upload(self):
@@ -134,20 +131,20 @@ class CheckpointHandler:
 
     def upload_checkpoint(self, current_loss):
         logger.info("Saving optimizer")
-        torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
+        torch.save(self.state_averager.optimizer.state_dict(), f"{self.repo_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
         logger.info("Started uploading to Model Hub")
         self.model.push_to_hub(
             repo_name=self.repo_path,
             repo_url=self.repo_url,
-            commit_message=f"Step {current_step}, loss {current_loss:.3f}",
+            commit_message=f"Step #{current_step}, loss {current_loss:.3f}",
         )
         logger.info("Finished uploading to Model Hub")
 
 
 if __name__ == "__main__":
-    parser = HfArgumentParser((TrainingMonitorArguments, CollaborativeOptimizerArguments, AveragerArguments))
-    monitor_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
+    parser = HfArgumentParser((TrainingMonitorArguments, OptimizerArguments, AveragerArguments))
+    monitor_args, optimizer_args, averager_args = parser.parse_args_into_dataclasses()
 
     if monitor_args.use_google_dns:
         request = requests.get("https://api.ipify.org")
@@ -156,7 +153,7 @@ if __name__ == "__main__":
         address = request.text
         logger.info(f"Received public IP address of this machine: {address}")
         version = ip_address(address).version
-        monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0", f"/ip{version}/{address}/udp/0/quic"]
+        monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]
 
     experiment_prefix = monitor_args.experiment_prefix
     validators, local_public_key = utils.make_validators(experiment_prefix)
@@ -168,6 +165,7 @@ if __name__ == "__main__":
         use_ipfs=monitor_args.use_ipfs,
         host_maddrs=monitor_args.host_maddrs,
         announce_maddrs=monitor_args.announce_maddrs,
+        identity_path=monitor_args.identity_path,
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
 
@@ -175,8 +173,8 @@ if __name__ == "__main__":
         wandb.init(project=monitor_args.wandb_project)
 
     current_step = 0
-    if monitor_args.store_checkpoins:
-        checkpoint_handler = CheckpointHandler(monitor_args, collab_optimizer_args, averager_args, dht)
+    if monitor_args.store_checkpoints:
+        checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)
 
     while True:
         metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
@@ -218,7 +216,7 @@ if __name__ == "__main__":
                         }
                     )
 
-                if monitor_args.store_checkpoins:
+                if monitor_args.store_checkpoints:
                     if checkpoint_handler.is_time_to_save_state(current_step):
                         checkpoint_handler.save_state(current_step)
                         if checkpoint_handler.is_time_to_upload():

+ 1 - 1
examples/albert/tokenize_wikitext103.py

@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 """ This script builds a pre-tokenized compressed representation of WikiText-103 using huggingface/datasets """
 import random
 from functools import partial

+ 1 - 7
examples/albert/utils.py

@@ -7,7 +7,7 @@ from hivemind import choose_ip_address
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import TextStyle, get_logger
 
 logger = get_logger(__name__)
 
@@ -30,12 +30,6 @@ def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase],
     return validators, signature_validator.local_public_key
 
 
-class TextStyle:
-    BOLD = "\033[1m"
-    BLUE = "\033[34m"
-    RESET = "\033[0m"
-
-
 def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
     if only_p2p:
         unique_addrs = {addr["p2p"] for addr in visible_maddrs}

+ 6 - 2
hivemind/__init__.py

@@ -1,4 +1,5 @@
-from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.averaging import DecentralizedAverager
+from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.moe import (
     ExpertBackend,
@@ -15,8 +16,11 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
+    GradScaler,
+    Optimizer,
+    TrainingAverager,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 
-__version__ = "1.0.0.dev0"
+__version__ = "1.1.0dev0"

+ 0 - 1
hivemind/averaging/__init__.py

@@ -1,2 +1 @@
 from hivemind.averaging.averager import DecentralizedAverager
-from hivemind.averaging.training import TrainingAverager

+ 196 - 84
hivemind/averaging/allreduce.py

@@ -1,15 +1,22 @@
 import asyncio
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
 
 import torch
 
-from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
+from hivemind.averaging.partition import AllreduceException, BannedException, TensorPartContainer, TensorPartReducer
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils.asyncio import (
+    achain,
+    aiter_with_timeout,
+    amap_in_executor,
+    anext,
+    as_aiter,
+    attach_event_on_finished,
+)
 
 # flavour types
 GroupID = bytes
@@ -37,15 +44,21 @@ class AllReduceRunner(ServicerBase):
     :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
-    :param tensors: local tensors that should be averaged with groupmates
+    :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1)
     :param peer_id: your peer_id, must be included in ordered_peer_ids
     :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
-    :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 x sender_timeout
+    :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
+    :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
+      non-averaging peers receive results only after they finish sending, which helps them avoid
+      throughput issues in case of asymmetric high-latency connections (e.g. ACK compression).
     """
 
     def __init__(
@@ -56,16 +69,23 @@ class AllReduceRunner(ServicerBase):
         prefix: Optional[str],
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
+        weight: Optional[float] = None,
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
-        weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
         **kwargs,
     ):
         self._p2p = p2p
         self.peer_id = p2p.peer_id
         assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
+        if reducer_timeout is not None and (sender_timeout is None or reducer_timeout <= sender_timeout):
+            raise ValueError(
+                "If reducer_timeout is enabled, sender_timeout must be shorter than reducer_timeout. "
+                "Otherwise, there is a chance that reducers will be banned while they await senders."
+            )
 
         if not issubclass(servicer_type, ServicerBase):
             raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
@@ -73,31 +93,42 @@ class AllReduceRunner(ServicerBase):
         self._prefix = prefix
 
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
-        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
+        assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
-        for mode, frac, weight in zip(modes, peer_fractions, weights):
+        for mode, frac in zip(modes, peer_fractions):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
-            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
+        if weight is None:
+            weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
+        self.weight = weight
+
         self._future = asyncio.Future()
 
-        self.sender_peer_ids, self.sender_weights = [], []
-        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
+        self.sender_peer_ids = []
+        for peer_id, mode in zip(self.ordered_peer_ids, modes):
             if mode != AveragingMode.AUX:
                 self.sender_peer_ids.append(peer_id)
-                self.sender_weights.append(weight)
+
+        self.sender_timeout, self.reducer_timeout = sender_timeout, reducer_timeout
+        self.all_senders_started = asyncio.Event()
+        self.banned_senders: Set[PeerID] = set()  # peers that did not send data by next_chunk_timeout
+        self.banlock = asyncio.Lock()
+
+        self.active_senders: Set[PeerID] = set()  # peers that began sending data via rpc_aggregate_part
+        if self.peer_id in self.sender_peer_ids:
+            self.active_senders.add(self.peer_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
 
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
-        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
+        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, return_deltas=True, **kwargs)
         self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             len(self.sender_peer_ids),
-            self.sender_weights,
         )
 
     def __repr__(self):
@@ -116,9 +147,16 @@ class AllReduceRunner(ServicerBase):
     def _get_peer_stub(self, peer: PeerID) -> StubBase:
         return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
 
+    def should_delay_results(self, peer_id: PeerID) -> bool:
+        return self.peer_fractions[self.ordered_peer_ids.index(peer_id)] == 0
+
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
+
+        if self.tensor_part_container.num_parts_by_peer[self.ordered_peer_ids.index(self.peer_id)] != 0:
+            pending_tasks.add(asyncio.create_task(self._handle_missing_senders()))
+
         try:
             if len(self.sender_peer_ids) == 0:
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
@@ -131,6 +169,7 @@ class AllReduceRunner(ServicerBase):
 
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
+
                 self.finalize()
 
             else:  # auxiliary peer
@@ -143,31 +182,69 @@ class AllReduceRunner(ServicerBase):
                 task.cancel()
             raise
 
+        finally:
+            for task in pending_tasks:
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass
+                except Exception as inner_exc:
+                    logger.debug(f"Task {task} failed with {inner_exc}", exc_info=True)
+
+    async def _handle_missing_senders(self):
+        """Detect senders that should have sent tensors for averaging, but did not send anything within timeout"""
+        try:
+            await asyncio.wait_for(self.all_senders_started.wait(), self.sender_timeout)
+        except asyncio.TimeoutError:
+            for peer_id in self.sender_peer_ids:
+                if peer_id not in self.active_senders and peer_id not in self.banned_senders:
+                    await self._ban_sender(peer_id)
+
     async def _communicate_with_peer(self, peer_id: PeerID):
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
         peer_index = self.ordered_peer_ids.index(peer_id)
         if peer_id == self.peer_id:
             sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
-                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                averaged_part = await self.tensor_part_reducer.accumulate_part(
+                    sender_index, part_index, tensor_part, weight=self.weight
+                )
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
-            loop = asyncio.get_event_loop()
-            code = None
-            stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
-            async for part_index, msg in aenumerate(stream):
-                if code is None:
-                    code = msg.code
-                averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
-                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
-
-            if code != averaging_pb2.AVERAGED_PART:
-                raise AllreduceException(
-                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
-                    f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                    f", allreduce failed"
-                )
+            try:
+                done_sending = asyncio.Event()
+                inputs_aiter = attach_event_on_finished(self._generate_input_for_peer(peer_index), done_sending)
+                stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter)
+
+                if self.should_delay_results(self.peer_id):
+                    await done_sending.wait()
+
+                part_index = 0
+
+                def _try_deserialize(msg):
+                    if msg.code != averaging_pb2.AVERAGED_PART:
+                        raise AllreduceException(f"{peer_id} sent {averaging_pb2.MessageCode.Name(msg.code)}")
+                    return deserialize_torch_tensor(msg.tensor_part), msg
+
+                async for delta, msg in amap_in_executor(
+                    _try_deserialize,
+                    aiter_with_timeout(stream, self.reducer_timeout),
+                    max_prefetch=self.tensor_part_container.prefetch,
+                ):
+                    self.tensor_part_container.register_processed_part(peer_index, part_index, delta)
+                    part_index += 1
+
+                if part_index != self.tensor_part_container.num_parts_by_peer[peer_index]:
+                    raise AllreduceException(
+                        f"peer {peer_id} sent {part_index} parts, but we expected "
+                        f"{self.tensor_part_container.num_parts_by_peer[peer_index]}"
+                    )
+            except BaseException as e:
+                if isinstance(e, Exception):
+                    logger.debug(f"Caught {repr(e)} when communicating to {peer_id}", exc_info=True)
+                self.tensor_part_container.register_failed_reducer(peer_index)
+                raise
 
     async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
@@ -176,79 +253,116 @@ class AllReduceRunner(ServicerBase):
             code=averaging_pb2.PART_FOR_AVERAGING,
             group_id=self.group_id,
             tensor_part=first_part,
+            weight=self.weight,
         )
         async for part in parts_aiter:
-            yield averaging_pb2.AveragingData(tensor_part=part)
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
 
     async def rpc_aggregate_part(
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
-        request: averaging_pb2.AveragingData = await anext(stream)
-        reason_to_reject = self._check_reasons_to_reject(request)
-        if reason_to_reject:
-            yield reason_to_reject
-            return
+        sender_index = self.sender_peer_ids.index(context.remote_id)
+        self.active_senders.add(context.remote_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
 
-        elif request.code == averaging_pb2.PART_FOR_AVERAGING:
-            try:
-                sender_index = self.sender_peer_ids.index(context.remote_id)
-                async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
-                    yield msg
+        try:
+            request: averaging_pb2.AveragingData = await asyncio.wait_for(anext(stream), self.sender_timeout)
+            reason_to_reject = self._check_reasons_to_reject(request, context)
+            if reason_to_reject:
+                yield reason_to_reject
+                return
+
+            elif request.code == averaging_pb2.PART_FOR_AVERAGING:
+                stream = aiter_with_timeout(achain(as_aiter(request), stream), self.sender_timeout)
+                if not self.should_delay_results(context.remote_id):
+                    async for msg in self._accumulate_parts_streaming(stream, sender_index):
+                        yield msg
+
+                else:
+                    done_receiving = asyncio.Event()
+                    delayed_results = asyncio.Queue()
+
+                    async def _accumulate_parts():
+                        try:
+                            async for msg in self._accumulate_parts_streaming(
+                                attach_event_on_finished(stream, done_receiving), sender_index
+                            ):
+                                delayed_results.put_nowait(msg)
+                        finally:
+                            delayed_results.put_nowait(None)
+
+                    accumulate_task = asyncio.create_task(_accumulate_parts())
+
+                    await done_receiving.wait()
+
+                    while True:
+                        next_result = await delayed_results.get()
+                        if next_result is None:
+                            break
+                        yield next_result
+                    await accumulate_task
 
-            except Exception as e:
-                self.finalize(exception=e)
+            else:
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
-        else:
-            error_code = averaging_pb2.MessageCode.Name(request.code)
-            logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
-            self.finalize(exception=AllreduceException(f"peer {context.remote_id} sent {error_code}."))
-            yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                raise AllreduceException(f"{context.remote_id} sent {averaging_pb2.MessageCode.Name(request.code)}")
 
-    def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
+        except BaseException as e:
+            await self._ban_sender(context.remote_id)
+            if isinstance(e, Exception):
+                logger.debug(f"Caught {repr(e)} when communicating with {context.remote_id}", exc_info=True)
+                yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+            else:
+                raise  # CancelledError, StopIteration and similar
+
+    async def _ban_sender(self, peer_id: PeerID):
+        async with self.banlock:
+            if peer_id not in self.banned_senders:
+                self.banned_senders.add(peer_id)
+                self.tensor_part_reducer.on_sender_failed(self.sender_peer_ids.index(peer_id))
+
+    def _check_reasons_to_reject(
+        self, request: averaging_pb2.AveragingData, context: P2PContext
+    ) -> Optional[averaging_pb2.AveragingData]:
         if request.group_id != self.group_id:
             return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
         elif self._future.cancelled():
             return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
         elif self._future.done():
             return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+        elif context.remote_id not in self.sender_peer_ids:
+            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
 
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
-        loop = asyncio.get_event_loop()
-        async for part_index, (tensor_part, part_compression) in aenumerate(
-            amap_in_executor(
-                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
+        part_index = 0
+        try:
+            loop = asyncio.get_event_loop()
+            async for tensor_part, weight, part_compression in amap_in_executor(
+                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
-            )
-        ):
-            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
-
-            serialized_delta = await loop.run_in_executor(
-                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
-            )
-            yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
-
-    async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        # In case of reporting the error, we expect the response stream to contain exactly one item
-        await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
+            ):
+                try:
+                    averaged_part = await self.tensor_part_reducer.accumulate_part(
+                        sender_index, part_index, tensor_part, weight=weight
+                    )
+                    part_index += 1
+                except BannedException:
+                    logger.debug(f"Sender {sender_index} is already banned")
+                    break  # sender was banned, we no longer need to aggregate it
+
+                serialized_delta = await loop.run_in_executor(
+                    None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
+                )
+                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
+        finally:
+            if part_index != self.tensor_part_reducer.num_parts:
+                await self._ban_sender(self.sender_peer_ids[sender_index])
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
-        pending_tasks = set()
-        if cancel or exception:
-            # propagate error to peers
-            if cancel or isinstance(exception, asyncio.CancelledError):
-                code = averaging_pb2.CANCELLED
-            else:
-                code = averaging_pb2.INTERNAL_ERROR
-            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
-                if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
-                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
-
         if not self._future.done():
             if cancel:
                 logger.debug(f"{self} - cancelled")
@@ -261,7 +375,5 @@ class AllReduceRunner(ServicerBase):
                 self._future.set_result(None)
             self.tensor_part_container.finalize()
             self.tensor_part_reducer.finalize()
-            return pending_tasks
         else:
-            logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
-            return pending_tasks
+            logger.debug(f"{self} - attempted to finalize allreduce that is already finished: {self._future}")

+ 277 - 139
hivemind/averaging/averager.py

@@ -7,9 +7,9 @@ import contextlib
 import ctypes
 import multiprocessing as mp
 import os
+import random
 import threading
 import weakref
-from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
 from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 
@@ -17,16 +17,33 @@ import numpy as np
 import torch
 
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
+from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
+from hivemind.compression import (
+    CompressionBase,
+    CompressionInfo,
+    NoCompression,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.proto import averaging_pb2, runtime_pb2
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
+from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils.asyncio import (
+    achain,
+    afirst,
+    aiter_with_timeout,
+    anext,
+    as_aiter,
+    azip,
+    enter_asynchronously,
+    switch_to_uvloop,
+)
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -49,14 +66,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param prefix: a shared prefix for all group keys
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
-    :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
-      note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
-    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
-    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+    :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
+    :param compression: optionally compress tensors with this compression algorithm before running all-reduce
+    :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
+    :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
-    :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
+    :note: request_timeout must be smaller than min_matchmaking_time to avoid potential deadlocks.
     :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
@@ -68,6 +85,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
       with averager.allow_state_sharing = True / False
+    :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
+    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+    :param next_chunk_timeout: during all-reduce and load_state_from_peers, if peer does not send next data chunk in
+      this number of seconds, consider it failed and proceed with remaining peers. default: no timeout
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 * sender_timeout
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
 
     Example:
@@ -85,6 +110,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
+    _state_updated: asyncio.Event
+    _p2p: P2P
     serializer = MSGPackSerializer
 
     def __init__(
@@ -94,19 +121,26 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         *,
         start: bool,
         prefix: str,
-        target_group_size: int,
+        target_group_size: Optional[int] = None,
         min_group_size: int = 2,
-        initial_group_bits: Optional[str] = None,
-        averaging_expiration: float = 15,
-        request_timeout: float = 3,
+        initial_group_bits: str = "",
+        averaging_expiration: Optional[float] = None,
+        min_matchmaking_time: float = 5.0,
+        request_timeout: float = 3.0,
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
-        compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+        next_chunk_timeout: Optional[float] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
+        compression: CompressionBase = NoCompression(),
+        state_compression: CompressionBase = NoCompression(),
+        tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         bandwidth: Optional[float] = None,
         min_vector_size: int = 0,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
+        declare_state_period: float = 30,
         client_mode: Optional[bool] = None,
         daemon: bool = True,
         shutdown_timeout: float = 5,
@@ -115,17 +149,25 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert bandwidth is None or (
             bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
         ), "bandwidth must be a non-negative float32"
-        if not is_power_of_two(target_group_size):
-            logger.warning("It is recommended to set target_group_size to a power of 2.")
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
+        assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
+        if averaging_expiration is not None:
+            logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
+            assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
+            min_matchmaking_time = averaging_expiration
+
         super().__init__()
         self.dht = dht
         self.prefix = prefix
 
         if client_mode is None:
             client_mode = dht.client_mode
+        if sender_timeout is None:
+            sender_timeout = next_chunk_timeout
+        if reducer_timeout is None:
+            reducer_timeout = 2 * sender_timeout if sender_timeout is not None else None
+
         self.client_mode = client_mode
 
         self._parent_pid = os.getpid()
@@ -139,13 +181,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         self._averaged_tensors = tuple(averaged_tensors)
         self.lock_averaged_tensors = mp.Lock()
-        self.last_updated: DHTExpiration = -float("inf")
         for tensor in self._averaged_tensors:
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             tensor.share_memory_()
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.shutdown_timeout = shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
         self.bandwidth = bandwidth
 
         self.matchmaking_kwargs = dict(
@@ -154,11 +196,15 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             initial_group_bits=initial_group_bits,
             target_group_size=target_group_size,
             min_group_size=min_group_size,
-            averaging_expiration=averaging_expiration,
             request_timeout=request_timeout,
+            min_matchmaking_time=min_matchmaking_time,
         )
         self.allreduce_kwargs = dict(
-            compression_type=compression_type, part_size_bytes=part_size_bytes, min_vector_size=min_vector_size
+            compression=compression,
+            part_size_bytes=part_size_bytes,
+            min_vector_size=min_vector_size,
+            sender_timeout=sender_timeout,
+            reducer_timeout=reducer_timeout,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -166,11 +212,16 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
+        self._state_sharing_priority = mp.Value(ctypes.c_double, 0)
+
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
+        self.declare_state_period = declare_state_period
+        self.state_compression = state_compression
+        self.tensor_infos = tensor_infos
 
-        self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
+        self._ready = MPFuture()
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
             daemon=True,
@@ -189,14 +240,38 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     @allow_state_sharing.setter
     def allow_state_sharing(self, value: bool):
         if value and self.client_mode:
-            raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
+            raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state")
+        else:
+            old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
+            if value != old_value:
+                self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
+
+    @property
+    def state_sharing_priority(self) -> float:
+        """Others will preferentially downloading state from peers with highest priority."""
+        return float(self._state_sharing_priority.value)
+
+    @state_sharing_priority.setter
+    def state_sharing_priority(self, value: float):
+        if value and self.client_mode:
+            raise ValueError("State sharing priority is unused: averager in client mode cannot share its state")
         else:
-            self._allow_state_sharing.value = value
+            old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
+            if self.allow_state_sharing and value != old_value:
+                self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
+
+    async def _trigger_declare_load_state(self):
+        # note: previously tried to set mp.Event instead of this. Awaiting it in executor caused degradation in py39
+        self._state_updated.set()
 
     @property
     def peer_id(self) -> PeerID:
         return self.dht.peer_id
 
+    @property
+    def request_timeout(self):
+        return self._matchmaking.request_timeout
+
     def run(self):
         """
         Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
@@ -211,14 +286,17 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
-            async def _run():
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
                 self._p2p = await self.dht.replicate_p2p()
                 if not self.client_mode:
                     await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
                 else:
-                    logger.debug(f"The averager is running in client mode.")
+                    logger.debug("The averager is running in client mode")
 
                 self._matchmaking = Matchmaking(
                     self._p2p,
@@ -230,45 +308,65 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
 
+                self._state_updated = asyncio.Event()
                 self._pending_group_assembled = asyncio.Event()
                 self._pending_group_assembled.set()
-                self.ready.set()
-
-                while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
-
-            loop.run_until_complete(_run())
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
 
-    def run_in_background(self, await_ready=True, timeout=None):
+            while True:
+                try:
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self.request_timeout)
+                except asyncio.TimeoutError:
+                    pass
+                if not self._inner_pipe.poll():
+                    continue
+                try:
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self.request_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
+
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
         Starts averager in a background process. if await_ready, this method will wait until background dht
         is ready to process incoming requests or for :timeout: seconds max.
         """
         self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
 
     def shutdown(self) -> None:
         """Shut down the averager process"""
         if self.is_alive():
-            self._outer_pipe.send(("_shutdown", [None], {}))  # shut down the daemon process
+            self._outer_pipe.send(("_shutdown", [self.shutdown_timeout], {}))  # shut down the daemon process
             self._inner_pipe.send(("_SHUTDOWN", None))  # shut down background thread in master
             self.join(self.shutdown_timeout)
             if self.is_alive():
-                logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
+                logger.warning("Averager did not shut down within the grace period; terminating it the hard way")
                 self.terminate()
         else:
             logger.exception("Averager shutdown has no effect: the process is already not alive")
 
-    async def _shutdown(self, timeout: Optional[float] = None) -> None:
+    async def _shutdown(self, timeout: Optional[float]) -> None:
         remaining_tasks = set()
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
-        await asyncio.gather(*remaining_tasks)
+        await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout)
 
     def __del__(self):
         if self._parent_pid == os.getpid() and self.is_alive():
@@ -277,67 +375,97 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def step(
         self,
         gather: Optional[GatheredData] = None,
+        scheduled_time: Optional[DHTExpiration] = None,
         weight: Optional[float] = None,
         timeout: Optional[float] = None,
         allow_retries: bool = True,
+        require_trigger: bool = False,
         wait: bool = True,
-    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
         :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
           (this operation is known as all-gather). The gathered data will be available as the output of this function.
+        :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment.
+          By default, schedule all-reduce current time plus min_matchmaking_time seconds
         :param weight: averaging weight for this peer, int or float, must be strictly positive
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
           within the specified timeout
-        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
-        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
+        :param require_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
+        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failed
+        :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         if self.mode == AveragingMode.AUX and weight is not None:
-            logger.warning("Averager is running in auxiliary mode, weight is unused.")
+            logger.warning("Averager is running in auxiliary mode, weight is unused")
+        if scheduled_time is None:
+            scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
+        deadline = get_dht_time() + timeout if timeout is not None else float("inf")
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
-
-        future = MPFuture()
-        gather_binary = self.serializer.dumps(
-            gather
-        )  # serialize here to avoid loading modules in the averager process
-        self._outer_pipe.send(
-            (
-                "_step",
-                [],
-                dict(
-                    future=future,
-                    gather_binary=gather_binary,
-                    weight=weight,
-                    allow_retries=allow_retries,
-                    timeout=timeout,
-                ),
-            )
+        assert not (wait and require_trigger), "Non-asynchronous step cannot wait for trigger (use wait=False)"
+        assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
+
+        user_data_for_gather = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
+        data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, user_data_for_gather])
+        step = StepControl(
+            scheduled_time=scheduled_time,
+            deadline=deadline,
+            allow_retries=allow_retries,
+            weight=weight,
+            data_for_gather=data_for_gather,
         )
-        return future.result() if wait else future
 
-    async def _step(
-        self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
-    ):
-        start_time = get_dht_time()
+        future_for_init = MPFuture()
+        self._outer_pipe.send(("_step", [], dict(step=step, future_for_init=future_for_init)))
+        step.attach(*future_for_init.result())
+
+        if not require_trigger:
+            step.allow_allreduce()
+        return step.result() if wait else step
 
+    async def _step(self, *, step: StepControl, future_for_init: MPFuture):
         try:
-            while not future.done():
+            trigger, cancel = MPFuture(), MPFuture()
+            step.attach(trigger, cancel)
+            future_for_init.set_result((trigger, cancel))
+
+            async def find_peers_or_notify_cancel():
+                group_info = await self._matchmaking.look_for_group(step)
+                if not step.triggered:
+                    step.stage = AveragingStage.AWAITING_TRIGGER
+                    await step.wait_for_trigger()
+                return group_info
+
+            while not step.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(
-                        timeout=timeout, data_for_gather=data_for_gather
-                    )
+                    step.stage = AveragingStage.LOOKING_FOR_GROUP
+                    matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
+                    check_cancel_task = asyncio.create_task(step.wait_for_cancel())
+
+                    await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
+                    if step.cancelled():
+                        matchmaking_task.cancel()
+                        raise asyncio.CancelledError()
+                    else:
+                        check_cancel_task.cancel()
+
+                    group_info = await matchmaking_task
+
                     if group_info is None:
-                        raise AllreduceException("Averaging step failed: could not find a group.")
+                        raise AllreduceException("Averaging step failed: could not find a group")
 
-                    future.set_result(
+                    step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                    step.set_result(
                         await asyncio.wait_for(
-                            self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout
+                            self._run_allreduce(
+                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
+                            ),
+                            timeout=self._allreduce_timeout,
                         )
                     )
                     # averaging is finished, loop will now exit
@@ -350,21 +478,25 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
                     P2PHandlerError,
+                    DispatchFailure,
+                    ControlFailure,
                 ) as e:
-                    time_elapsed = get_dht_time() - start_time
-                    if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                        logger.exception(f"Averager caught {repr(e)}")
-                        future.set_exception(e)
+                    if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
+                        if not step.cancelled():
+                            logger.exception(e)
+                        if not step.done():
+                            step.set_exception(e)
                     else:
-                        logger.warning(f"Averager caught {repr(e)}, retrying")
+                        logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
 
         except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
+            if not step.done():
+                step.set_exception(e)
             raise
         finally:
-            if not future.done():
-                future.set_exception(
+            step.stage = AveragingStage.FINISHED
+            if not step.done():
+                step.set_exception(
                     RuntimeError(
                         "Internal sanity check failed: averager.step left future pending."
                         " Please report this to hivemind issues."
@@ -374,8 +506,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
-            weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
             modes = tuple(map(AveragingMode, mode_ids))
 
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
@@ -386,7 +518,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
             )
 
-            async with self.get_tensors_async() as local_tensors:
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
                 allreduce = AllReduceRunner(
                     p2p=self._p2p,
                     servicer_type=type(self),
@@ -395,26 +527,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
-                    weights=weights,
                     gathered=user_gathered,
                     modes=modes,
                     **kwargs,
                 )
 
                 with self.register_allreduce_group(group_info.group_id, allreduce):
-
-                    # actually run all-reduce
-                    averaging_outputs = [output async for output in allreduce]
-
                     if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        assert len(local_tensors) == len(self._averaged_tensors)
-                        for tensor, update in zip(local_tensors, averaging_outputs):
+                        iter_results = allreduce.run()
+                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
+                            # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
-                        self.last_updated = get_dht_time()
+                        self._state_updated.set()
+
+                    else:
+                        async for _ in allreduce:  # trigger all-reduce by iterating
+                            raise ValueError("aux peers should not receive averaged tensors")
 
                 return allreduce.gathered
         except BaseException as e:
-            logger.exception(e)
+            if isinstance(e, Exception):
+                logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
     @contextlib.contextmanager
@@ -437,16 +570,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
-        self.last_updated = get_dht_time()
-
-    @contextlib.asynccontextmanager
-    async def get_tensors_async(self) -> Sequence[torch.Tensor]:
-        """Like get_tensors, but uses an asynchronous contextmanager"""
-        try:
-            await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
-            yield self._averaged_tensors
-        finally:
-            self.lock_averaged_tensors.release()
 
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
@@ -470,26 +593,36 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
 
-        async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
+        async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
 
     async def _declare_for_download_periodically(self):
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
+        sharing_was_allowed = self.allow_state_sharing
         while True:
-            if self.allow_state_sharing:
+            expiration_time = get_dht_time() + self.declare_state_period
+            if self.allow_state_sharing or sharing_was_allowed:
+                # notify either if sharing is allowed or if it was just switched off (to overwrite previous message)
                 asyncio.create_task(
                     asyncio.wait_for(
                         self.dht.store(
                             download_key,
                             subkey=self.peer_id.to_bytes(),
-                            value=self.last_updated,
-                            expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
+                            value=self.state_sharing_priority if self.allow_state_sharing else None,
+                            expiration_time=expiration_time,
                             return_future=True,
                         ),
-                        timeout=self._matchmaking.averaging_expiration,
+                        timeout=expiration_time - get_dht_time(),
                     )
                 )
-            await asyncio.sleep(self._matchmaking.averaging_expiration)
+                sharing_was_allowed = self.allow_state_sharing
+
+            # report again either in state_declare_period or after the field was changed by the user
+            self._state_updated.clear()
+            try:
+                await asyncio.wait_for(self._state_updated.wait(), timeout=max(0.0, expiration_time - get_dht_time()))
+            except asyncio.TimeoutError:
+                pass
 
     async def rpc_download_state(
         self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
@@ -503,24 +636,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         if not self.allow_state_sharing:
             return  # deny request and direct peer to the next prospective averager
-        metadata, tensors = await self._get_current_state_from_host_process()
+        metadata, tensors, infos = await self._get_current_state_from_host_process()
+        if infos is None:
+            infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)]
+        assert len(tensors) == len(infos)
 
-        for tensor in tensors:
-            for part in split_for_streaming(serialize_torch_tensor(tensor)):
+        for tensor, info in zip(tensors, infos):
+            for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)):
                 if metadata is not None:
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     metadata = None
                 else:
                     yield averaging_pb2.DownloadData(tensor_part=part)
 
-    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
+    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]:
         """
         Get current state and send it to a peer. executed in the host process. Meant to be overriden.
         :returns: a tuple of (small metadata, sequence of torch tensors)
         :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
         """
         with self.get_tensors() as tensors:
-            return dict(group_key=self.get_group_bits()), tensors
+            return dict(group_key=self.get_group_bits()), tensors, self.tensor_infos
 
     async def _get_current_state_from_host_process(self):
         """Executed in the averager process inside rpc_download_state"""
@@ -528,7 +664,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
         return await future
 
-    def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
+    def load_state_from_peers(
+        self, wait: bool = True, timeout: Optional[float] = None
+    ) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
         """
         Try to download the latest optimizer state one of the existing peer.
         :returns: on success, return a 2-tuple with (metadata, tensors), where
@@ -539,21 +677,23 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
         future = MPFuture()
-        self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
-        return future.result() if wait else future
+        self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
+        return future.result(timeout=timeout) if wait else future
 
-    async def _load_state_from_peers(self, future: MPFuture):
+    async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
+        if timeout is not None:
+            timeout = self.next_chunk_timeout if self.next_chunk_timeout is not None else self.request_timeout
         try:
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority = {
-                PeerID(peer_id): float(info.value)
+                PeerID(peer_id): (float(info.value), random.random())  # using randomness as a tie breaker
                 for peer_id, info in peer_priority.items()
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
             }
 
             if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
-                logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
+                logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}")
                 future.set_result(None)
                 return
 
@@ -563,9 +703,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     logger.info(f"Downloading parameters from peer {peer}")
                     try:
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
-                        stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
+                        stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
-                        async for message in stream:
+
+                        async for message in aiter_with_timeout(stream, timeout=timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:
@@ -577,19 +718,17 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
 
                         if not metadata:
-                            logger.debug(f"Peer {peer} did not send its state.")
+                            logger.debug(f"Peer {peer} did not send its state")
                             continue
 
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
-                        self.last_updated = get_dht_time()
                         return
-                    except BaseException as e:
+                    except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
 
         finally:
             if not future.done():
-                logger.warning("Averager could not load state from peers: all requests have failed.")
                 future.set_result(None)
 
     def get_group_bits(self, wait: bool = True):
@@ -623,11 +762,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 future.set_exception(e)
 
 
-def is_power_of_two(n):
-    """Check whether n is a power of 2"""
-    return (n != 0) and (n & (n - 1) == 0)
-
-
 def _background_thread_fetch_current_state(
     serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
 ):
@@ -652,7 +786,11 @@ def _background_thread_fetch_current_state(
             get_current_state = get_current_state_ref()
             if get_current_state is None:
                 break
-            state_metadata, state_tensors = get_current_state()
+            state = get_current_state()
+            assert 0 < len(state) <= 3
+            if len(state) != 3:
+                state = tuple(state + (None,) * (3 - len(state)))
+            state_metadata, state_tensors, tensor_infos = state
             del get_current_state
 
             state_metadata = serializer.dumps(state_metadata)
@@ -660,7 +798,7 @@ def _background_thread_fetch_current_state(
                 tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
             )
             # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
-            future.set_result((state_metadata, state_tensors))
+            future.set_result((state_metadata, state_tensors, tensor_infos))
         except BaseException as e:
             future.set_exception(e)
             logger.warning(e)

+ 165 - 0
hivemind/averaging/control.py

@@ -0,0 +1,165 @@
+import os
+import struct
+from enum import Enum
+from typing import Optional
+
+import numpy as np
+import torch
+
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class AveragingStage(Enum):
+    IDLE = 0  # still initializing
+    LOOKING_FOR_GROUP = 1  # running decentralized matchmaking, can't run allreduce yet
+    AWAITING_TRIGGER = 2  # waiting for user to set the trigger that allows running allreduce
+    RUNNING_ALLREDUCE = 3  # exchanging tensors with groupmates
+    FINISHED = 4  # either done or failed with exception
+
+
+class StepControl(MPFuture):
+    """
+    An auxiliary data structure that allows user to control stages and track progress in a single averaging step
+
+    :param scheduled_time: estimated time when averaging should begin. Will be used for scheduling
+    :param deadline: if averaging is still in progress at this time, it should be stopped due to TimeoutError
+    :param allow_retries: if True, allow running matchmaking and all-reduce again if previous attempt fails
+    :param weight: averaging weight, can be changed afterwards
+    :param data_for_gather: send this data to all peers in the next group and gather it from groupmates
+    """
+
+    # indices for the shared buffer
+    _SCHEDULED_TIME, _WEIGHT, _STAGE, _BEGAN_ALLREDUCE = slice(0, 8), slice(8, 16), 16, 17
+
+    def __init__(
+        self,
+        scheduled_time: DHTExpiration,
+        deadline: float,
+        allow_retries: bool,
+        weight: float,
+        data_for_gather: bytes,
+    ):
+        super().__init__()
+        self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
+        self._trigger: Optional[MPFuture] = None
+        self._cancel: Optional[MPFuture] = None
+
+        # Buffer contents:
+        # scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
+        self._shared_buffer = torch.zeros([18], dtype=torch.uint8).share_memory_()
+        self.stage = AveragingStage.IDLE
+        self.scheduled_time = scheduled_time
+        self.weight = weight
+        self.began_allreduce = False
+
+    def attach(self, trigger: MPFuture, cancel: MPFuture):
+        assert self._trigger is None and self._cancel is None, "Futures are already attached"
+        self._trigger, self._cancel = trigger, cancel
+
+    def allow_allreduce(self):
+        """Allow averager to begin all-reduce when it finds a group. Meant to be triggered by user."""
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        if self._trigger.done():
+            logger.warning("Trigger is already set")
+        else:
+            self._trigger.set_result(None)
+
+    async def wait_for_trigger(self):
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        await self._trigger
+
+    @property
+    def triggered(self) -> bool:
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        return self._trigger.done()
+
+    @property
+    def scheduled_time(self) -> DHTExpiration:
+        return struct.unpack("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data)[0]
+
+    @scheduled_time.setter
+    def scheduled_time(self, scheduled_time):
+        if self.began_allreduce:
+            logger.warning("Changing scheduled time has no effect after all-reduce has already started")
+        if scheduled_time >= self.deadline:
+            logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout")
+        struct.pack_into("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data, 0, float(scheduled_time))
+
+    @property
+    def weight(self) -> float:
+        return struct.unpack("d", self._shared_buffer[StepControl._WEIGHT].numpy().data)[0]
+
+    @weight.setter
+    def weight(self, weight: float):
+        assert weight >= 0 and np.isfinite(weight)
+        if self.began_allreduce:
+            logger.warning("Changing weights has no effect after all-reduce has already started")
+        struct.pack_into("d", self._shared_buffer[StepControl._WEIGHT].numpy().data, 0, float(weight))
+
+    @property
+    def stage(self) -> AveragingStage:
+        return AveragingStage(self._shared_buffer[StepControl._STAGE].item())
+
+    @stage.setter
+    def stage(self, stage: AveragingStage):
+        if stage == AveragingStage.RUNNING_ALLREDUCE:
+            self.began_allreduce = True
+        self._shared_buffer[StepControl._STAGE] = stage.value
+
+    @property
+    def began_allreduce(self) -> bool:
+        return bool(self._shared_buffer[StepControl._BEGAN_ALLREDUCE].item())
+
+    @began_allreduce.setter
+    def began_allreduce(self, value: bool):
+        self._shared_buffer[StepControl._BEGAN_ALLREDUCE] = int(value)
+
+    @property
+    def data_for_gather(self) -> bytes:
+        return self._data_for_gather
+
+    @property
+    def deadline(self) -> DHTExpiration:
+        return self._deadline
+
+    def get_timeout(self) -> Optional[DHTExpiration]:
+        return max(0.0, self.deadline - get_dht_time())
+
+    @property
+    def allow_retries(self) -> bool:
+        return self._allow_retries
+
+    def __getstate__(self):
+        return dict(
+            super().__getstate__(),
+            _trigger=self._trigger,
+            _cancel=self._cancel,
+            _shared_buffer=self._shared_buffer,
+            immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
+        )
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
+        self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
+
+    def __del__(self):
+        if os.getpid() == self._origin_pid and not self.triggered:
+            logger.warning(
+                "Deleted an averaging StepControl, but the step was not triggered. This may cause other "
+                "peers to fail an averaging round via TimeoutError."
+            )
+        super().__del__()
+
+    def cancel(self) -> bool:
+        if self._trigger is not None:
+            self._trigger.cancel()
+        if self._cancel is not None:
+            self._cancel.set_result(None)
+        return super().cancel()
+
+    async def wait_for_cancel(self):
+        """Await for step to be cancelled by the user. Should be called from insider the averager."""
+        await self._cancel

+ 33 - 97
hivemind/averaging/key_manager.py

@@ -1,4 +1,3 @@
-import asyncio
 import random
 import re
 from typing import List, Optional, Tuple
@@ -12,6 +11,7 @@ from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get
 
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
+DEFAULT_NUM_BUCKETS = 256
 logger = get_logger(__name__)
 
 
@@ -25,31 +25,20 @@ class GroupKeyManager:
     Utility class that declares and fetches averaging-related keys using a DHT
     """
 
-    RESERVED_KEY_FOR_NBITS = "::NBITS"
-
     def __init__(
         self,
         dht: DHT,
         prefix: str,
-        initial_group_bits: Optional[str],
-        target_group_size: int,
-        insufficient_size: Optional[int] = None,
-        excessive_size: Optional[int] = None,
-        nbits_expiration: float = 60,
-        nbits_rewrite_grace_period: float = 15,
+        initial_group_bits: str,
+        target_group_size: Optional[int],
     ):
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
-        if initial_group_bits is None:
-            search_result = dht.get(f"{prefix}.0b", latest=True)
-            initial_group_nbits = self.get_suggested_nbits(search_result) or 0
-            initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
+        assert all(bit in "01" for bit in initial_group_bits)
+        if target_group_size is not None and not is_power_of_two(target_group_size):
+            logger.warning("It is recommended to set target_group_size to a power of 2")
+
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
-        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
-        self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
-        self.excessive_size = excessive_size or target_group_size * 3
-        self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
-        self.suggested_nbits: Optional[int] = None
+        self.peer_id = dht.peer_id
 
     @property
     def current_key(self) -> GroupKey:
@@ -91,94 +80,41 @@ class GroupKeyManager:
         assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
         result = await self.dht.get(group_key, latest=True, return_future=True)
         if result is None or not isinstance(result.value, dict):
-            logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
+            logger.debug(f"Allreduce group not found: {group_key}, creating new group")
             return []
-        averagers = [
-            (PeerID(key), looking_for_group.expiration_time)
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
-        ]
-        num_active_averagers = sum(
-            1
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and looking_for_group.value
-        )
-
-        suggested_nbits = self.get_suggested_nbits(result)
-        if (
-            suggested_nbits is not None
-            and suggested_nbits != len(self.group_bits)
-            and suggested_nbits != self.suggested_nbits
-        ):
-            self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
-        elif num_active_averagers >= self.excessive_size:
-            self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+        averagers = []
+        for key, looking_for_group in result.value.items():
+            try:
+                if only_active and not looking_for_group.value:
+                    continue
+                averagers.append((PeerID(key), looking_for_group.expiration_time))
+            except Exception as e:
+                logger.warning(f"Could not parse group key {key} ({looking_for_group}, exc={e})")
         return averagers
 
-    async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
-        """notify other peers that they can run averaging at this depth"""
-        return await self.dht.store(
-            key=group_key,
-            subkey=self.RESERVED_KEY_FOR_NBITS,
-            value=nbits,
-            expiration_time=expiration_time,
-            return_future=True,
-        )
-
-    @classmethod
-    def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
-        if (
-            isinstance(search_result, ValueWithExpiration)
-            and cls.RESERVED_KEY_FOR_NBITS in search_result.value
-            and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int)
-        ):
-            return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
-        else:
-            return None
-
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
         index = group_info.peer_ids.index(self.peer_id)
-        generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
-        nbits = int(np.ceil(np.log2(self.target_group_size)))
+        num_buckets = self.target_group_size
+        if num_buckets is None:
+            num_buckets = next_power_of_two(group_info.group_size)
+        generalized_index = rng.sample(range(num_buckets), group_info.group_size)[index]
+        nbits = int(np.ceil(np.log2(num_buckets)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
-        if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
-            asyncio.create_task(self.notify_stragglers())
-        if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
-            num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
-            self.group_bits = "".join((random.choice("01") for _ in range(num_extra_bits))) + self.group_bits
-            self.group_bits = self.group_bits[-self.suggested_nbits :]
-        self.suggested_nbits = None
-
     async def update_key_on_not_enough_peers(self):
         """this function is triggered whenever averager fails to assemble group within timeout"""
-        new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
-        if self.group_bits != prev_nbits:
-            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
-        self.suggested_nbits = None
-
-    async def notify_stragglers(self):
-        """Find averagers that have fewer nbits and redirect them to your current nbits"""
-        for nbits in reversed(range(1, len(self.group_bits) - 1)):
-            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
-            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
-
-            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
-                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
-                break
-
-        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
-        if (
-            isinstance(root_data, dict)
-            and root_data.get(self.RESERVED_KEY_FOR_NBITS, (None, -float("inf")))[1]
-            > get_dht_time() + self.nbits_grace_period
-        ):
-            return
-        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
+        pass  # to be implemented in subclasses
+
+
+def is_power_of_two(n):
+    """Check whether n is a power of 2"""
+    return (n != 0) and (n & (n - 1) == 0)
+
+
+def next_power_of_two(n):
+    """Round n up to the nearest power of 2"""
+    return 1 if n == 0 else 2 ** (n - 1).bit_length()

+ 1 - 1
hivemind/averaging/load_balancing.py

@@ -80,7 +80,7 @@ def optimize_parts_lp(vector_size: int, bandwidths: np.ndarray, min_size: int =
             peer_scores[peer_scores < min_size / float(vector_size)] = 0.0
         peer_scores = np.round(peer_scores, LOAD_BALANCING_LP_DECIMALS)
     else:
-        logger.error(f"Failed to solve load-balancing for bandwidths {bandwidths}.")
+        logger.error(f"Failed to solve load-balancing for bandwidths {bandwidths}")
         peer_scores = np.ones(group_size, c.dtype)
 
     return peer_scores[np.argsort(permutation)]

+ 103 - 95
hivemind/averaging/matchmaking.py

@@ -9,13 +9,15 @@ import random
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
+from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
-from hivemind.utils.asyncio import anext
+from hivemind.utils.asyncio import anext, cancel_and_wait
 
 logger = get_logger(__name__)
 
@@ -41,18 +43,18 @@ class Matchmaking:
         *,
         servicer_type: Type[ServicerBase],
         prefix: str,
-        target_group_size: int,
+        target_group_size: Optional[int],
         min_group_size: int,
+        min_matchmaking_time: float,
         request_timeout: float,
         client_mode: bool,
-        initial_group_bits: Optional[str] = None,
-        averaging_expiration: float = 15,
+        initial_group_bits: str = "",
     ):
         assert "." not in prefix, "group prefix must be a string without ."
-        if request_timeout is None or request_timeout >= averaging_expiration:
+        if request_timeout is None or request_timeout >= min_matchmaking_time:
             logger.warning(
-                "It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
-                "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
+                "It is recommended to use request_timeout smaller than min_matchmaking_time. Otherwise,"
+                " matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
             )
 
         super().__init__()
@@ -67,7 +69,7 @@ class Matchmaking:
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
-        self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.min_matchmaking_time, self.request_timeout = min_matchmaking_time, request_timeout
         self.client_mode = client_mode
 
         self.lock_looking_for_group = asyncio.Lock()
@@ -78,8 +80,18 @@ class Matchmaking:
 
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
-        self.data_for_gather: Optional[bytes] = None
+        self.potential_leaders = PotentialLeaders(self.peer_id, min_matchmaking_time, target_group_size)
+        self.step_control: Optional[StepControl] = None
+
+    @contextlib.asynccontextmanager
+    async def looking_for_group(self, step_control: StepControl):
+        async with self.lock_looking_for_group:
+            assert self.step_control is None
+            try:
+                self.step_control = step_control
+                yield
+            finally:
+                self.step_control = None
 
     @property
     def is_looking_for_group(self):
@@ -98,10 +110,9 @@ class Matchmaking:
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
 
-    async def look_for_group(self, *, data_for_gather: bytes, timeout: Optional[float] = None) -> Optional[GroupInfo]:
+    async def look_for_group(self, step: StepControl) -> Optional[GroupInfo]:
         """
-        :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates
-        :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
+        :param step: step parameters and user control structure for the current step
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         """
@@ -110,11 +121,10 @@ class Matchmaking:
                 "Another look_for_group is already in progress. The current run will be scheduled after"
                 " the existing group is either assembled or disbanded."
             )
-        async with self.lock_looking_for_group:
-            self.data_for_gather = data_for_gather
-            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
+        async with self.looking_for_group(step):
+            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
             try:
-                return await asyncio.wait_for(self.assembled_group, timeout=timeout)
+                return await asyncio.wait_for(self.assembled_group, timeout=step.get_timeout())
             except asyncio.TimeoutError:
                 return None
 
@@ -127,26 +137,25 @@ class Matchmaking:
                 raise
 
             finally:
-                if not request_leaders_task.done():
-                    request_leaders_task.cancel()
-                if not self.assembled_group.done():
-                    self.assembled_group.cancel()
+                await cancel_and_wait(request_leaders_task)
+                self.assembled_group.cancel()
+
                 while len(self.current_followers) > 0:
                     await self.follower_was_discarded.wait()
                     self.follower_was_discarded.clear()
                 # note: the code above ensures that we send all followers away before creating new future
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
-                self.data_for_gather = None
 
-    async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
+    async def _request_join_potential_leaders(self, step: StepControl) -> GroupInfo:
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
-        async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
+        assert self.is_looking_for_group
+        async with self.potential_leaders.begin_search(step, self.group_key_manager, declare=not self.client_mode):
             while True:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
 
-                    group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
+                    group = await self._request_join_group(next_leader)
                     if group is not None:
                         return group
 
@@ -167,33 +176,32 @@ class Matchmaking:
                         self.assembled_group.set_exception(e)
                     raise e
 
-    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def _request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
         """
         :param leader: request this peer to be your leader for allreduce
-        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
         :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
         :note: this function does not guarantee that your group leader is the same as :leader: parameter
           The originally specified leader can disband group and redirect us to a different leader
         """
         assert self.is_looking_for_group and self.current_leader is None
-        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
+        stream: Optional[AsyncIterator[averaging_pb2.MessageFromLeader]] = None
         try:
             async with self.lock_request_join_group:
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
-
-                stream = leader_stub.rpc_join_group(
+                request_expiration_time = self.get_request_expiration_time()
+                stream = await leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                         schema_hash=self.schema_hash,
-                        expiration=expiration_time,
+                        expiration=request_expiration_time,
                         client_mode=self.client_mode,
-                        gather=self.data_for_gather,
+                        gather=self.step_control.data_for_gather,
                         group_key=self.group_key_manager.current_key,
                     )
-                ).__aiter__()
+                )
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
-                    logger.debug(f"{self.peer_id} - joining the group of {leader}; waiting for peers")
+                    logger.debug(f"{self.peer_id} - joining the group of {leader}, waiting for peers")
                     self.current_leader = leader
                     self.was_accepted_to_group.set()
                     if len(self.current_followers) > 0:
@@ -205,7 +213,7 @@ class Matchmaking:
                 return None
 
             async with self.potential_leaders.pause_search():
-                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
+                time_to_expiration = max(0.0, request_expiration_time - get_dht_time())
                 message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
@@ -218,8 +226,11 @@ class Matchmaking:
                     if suggested_leader != self.peer_id:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
-                        await stream.aclose()
-                        return await self.request_join_group(suggested_leader, expiration_time)
+                        try:
+                            await stream.aclose()
+                        except RuntimeError as e:
+                            logger.debug(e, exc_info=True)
+                        return await self._request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 return None
 
@@ -228,15 +239,26 @@ class Matchmaking:
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
-        except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+        except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
+            logger.debug(f"{self} - failed to request potential leader {leader}:", exc_info=True)
             return None
 
         finally:
             self.was_accepted_to_group.clear()
             self.current_leader = None
             if stream is not None:
-                await stream.aclose()
+                try:
+                    await stream.aclose()
+                except RuntimeError as e:
+                    logger.debug(e, exc_info=True)
+
+    def get_request_expiration_time(self) -> float:
+        """Returns the averager's current expiration time, which is used to send join requests to leaders"""
+        if isfinite(self.potential_leaders.declared_expiration_time):
+            return self.potential_leaders.declared_expiration_time
+        else:
+            scheduled_time = max(self.step_control.scheduled_time, get_dht_time() + self.min_matchmaking_time)
+            return min(scheduled_time, self.potential_leaders.search_end_time)
 
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
@@ -252,7 +274,11 @@ class Matchmaking:
                 self.current_followers[context.remote_id] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
-                if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
+                if (
+                    self.target_group_size is not None
+                    and len(self.current_followers) + 1 >= self.target_group_size
+                    and not self.assembled_group.done()
+                ):
                     # outcome 1: we have assembled a full group and are ready for allreduce
                     await self.leader_assemble_group()
 
@@ -338,7 +364,7 @@ class Matchmaking:
             )
         elif context.remote_id == self.peer_id or context.remote_id in self.current_followers:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_PEER_ID)
-        elif len(self.current_followers) + 1 >= self.target_group_size:
+        elif self.target_group_size is not None and len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
         else:
             return None
@@ -353,11 +379,11 @@ class Matchmaking:
         random.shuffle(ordered_peer_ids)
 
         gathered = tuple(
-            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            self.step_control.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
             for peer_id in ordered_peer_ids
         )
 
-        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers.")
+        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers")
         group_info = GroupInfo(group_id, tuple(ordered_peer_ids), gathered)
         await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
         self.assembled_group.set_result(group_info)
@@ -374,7 +400,7 @@ class Matchmaking:
         assert self.peer_id in ordered_peer_ids, "Leader sent us group_peer_ids that does not contain us!"
         assert len(ordered_peer_ids) == len(msg.gathered)
 
-        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}.")
+        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}")
         group_info = GroupInfo(group_id, tuple(ordered_peer_ids), tuple(msg.gathered))
         await self.group_key_manager.update_key_on_group_assembled(group_info)
         self.assembled_group.set_result(group_info)
@@ -389,8 +415,8 @@ class Matchmaking:
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
 
-    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
+    def __init__(self, peer_id: PeerID, min_matchmaking_time: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.min_matchmaking_time = peer_id, min_matchmaking_time
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
@@ -402,21 +428,20 @@ class PotentialLeaders:
         self.search_end_time = float("inf")
 
     @contextlib.asynccontextmanager
-    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
+    async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
         async with self.lock_search:
             self.running.set()
-            self.search_end_time = get_dht_time() + timeout if timeout is not None else float("inf")
+            self.search_end_time = step.deadline if step.deadline is not None else float("inf")
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             if declare:
-                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
 
             try:
                 yield self
             finally:
-                if not update_queue_task.done():
-                    update_queue_task.cancel()
-                if declare and not declare_averager_task.done():
-                    declare_averager_task.cancel()
+                await cancel_and_wait(update_queue_task)
+                if declare:
+                    await cancel_and_wait(declare_averager_task)
 
                 for field in (
                     self.past_attempts,
@@ -469,51 +494,38 @@ class PotentialLeaders:
             self.past_attempts.add((maybe_next_leader, entry.expiration_time))
             return maybe_next_leader
 
-    @property
-    def request_expiration_time(self) -> float:
-        """this averager's current expiration time - used to send join requests to leaders"""
-        if isfinite(self.declared_expiration_time):
-            return self.declared_expiration_time
-        else:
-            return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-
-    async def _update_queue_periodically(self, key_manager: GroupKeyManager):
-        try:
-            DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
-            while get_dht_time() < self.search_end_time:
-                new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
-                self.max_assured_time = max(
-                    self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
-                )
+    async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
+        DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+        while get_dht_time() < self.search_end_time:
+            new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
+            self.max_assured_time = max(
+                self.max_assured_time, get_dht_time() + self.min_matchmaking_time - DISCREPANCY
+            )
 
-                self.leader_queue.clear()
-                for peer, peer_expiration_time in new_peers:
-                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
-                        continue
-                    self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
-                    self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+            self.leader_queue.clear()
+            for peer, peer_expiration_time in new_peers:
+                if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
+                    continue
+                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
-                self.update_finished.set()
+            self.update_finished.set()
 
-                await asyncio.wait(
-                    {self.running.wait(), self.update_triggered.wait()},
-                    return_when=asyncio.ALL_COMPLETED,
-                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
-                )
-                self.update_triggered.clear()
-        except (concurrent.futures.CancelledError, asyncio.CancelledError):
-            return  # note: this is a compatibility layer for python3.7
-        except Exception as e:
-            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
-            raise
+            await asyncio.wait(
+                {self.running.wait(), self.update_triggered.wait()},
+                return_when=asyncio.ALL_COMPLETED,
+                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
+            )
+            self.update_triggered.clear()
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
+    async def _declare_averager_periodically(self, step: StepControl, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
             try:
                 while True:
                     await self.running.wait()
-
-                    new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                    new_expiration_time = float(
+                        min(max(step.scheduled_time, get_dht_time() + self.min_matchmaking_time), self.search_end_time)
+                    )
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()
@@ -521,10 +533,6 @@ class PotentialLeaders:
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
-            except (concurrent.futures.CancelledError, asyncio.CancelledError):
-                pass  # note: this is a compatibility layer for python3.7
-            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time

+ 92 - 43
hivemind/averaging/partition.py

@@ -3,27 +3,31 @@ Auxiliary data structures for AllReduceRunner
 """
 import asyncio
 from collections import deque
-from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar, Union
+from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar
 
 import numpy as np
 import torch
 
-from hivemind.proto.runtime_pb2 import CompressionType, Tensor
-from hivemind.utils.asyncio import amap_in_executor
-from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
+from hivemind.proto import runtime_pb2
+from hivemind.utils import amap_in_executor, as_aiter, get_logger
 
 T = TypeVar("T")
 DEFAULT_PART_SIZE_BYTES = 2 ** 19
+logger = get_logger(__name__)
 
 
 class TensorPartContainer:
     """
     Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
+
     :param tensors: local tensors to be split and aggregated
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
-    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param compression: optionally compress tensors with this compression algorithm before sending them to peers
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
+    :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
+    :param return_deltas: if True, output tensors are differences (aggregated tensor - local tensor)
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     """
 
@@ -31,20 +35,24 @@ class TensorPartContainer:
         self,
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
-        compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE,
+        compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
+        tensor_infos: Optional[Sequence[CompressionInfo]] = None,
+        return_deltas: bool = True,
         prefetch: int = 1,
     ):
-        if not isinstance(compression_type, Sequence):
-            compression_type = [compression_type] * len(tensors)
-        assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
+        if tensor_infos is None:
+            tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
+        assert len(tensor_infos) == len(tensors), "compression types do not match the number of tensors"
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
-        self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
+        self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
         self.total_size = sum(tensor.numel() for tensor in tensors)
-        self._input_parts_by_peer: List[Deque[Tuple[torch.Tensor, Type[CompressionType]]]] = [
-            deque() for _ in range(self.group_size)
-        ]
-        self._output_parts_by_peer: List[Deque[torch.Tensor]] = [deque() for _ in range(self.group_size)]
+        self.failed_size = 0
+        self.return_deltas = return_deltas
+        self.prefetch = prefetch
+
+        self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
         self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
         self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
@@ -58,11 +66,13 @@ class TensorPartContainer:
         pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
         pivots[-1] = self.total_size
 
-        for tensor, tensor_compression in zip(self.local_tensors, compression_type):
-            part_size_values = int(part_size_bytes / get_nbytes_per_value(tensor.dtype, tensor_compression))
+        for tensor, info in zip(self.local_tensors, self.tensor_infos):
+            bytes_per_value = tensor.element_size() * compression.estimate_compression_ratio(info)
+            part_size_values = int(part_size_bytes / bytes_per_value)
             tensor_parts = tensor.detach().view(-1).split(part_size_values)
             self.num_parts_by_tensor.append(len(tensor_parts))
-            for part in tensor_parts:
+            for part_index, part in enumerate(tensor_parts):
+                part_info = info.get_part(part_index, part_size_values)
                 if current_length + len(part) > pivots[current_peer_index]:
                     # switch to next peer; if a part lands between parts of two or
                     # more peers, assign that part to the peer with highest intersection
@@ -73,9 +83,9 @@ class TensorPartContainer:
                         current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
                         peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
                     assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
-                    self._input_parts_by_peer[assigned_peer_index].append((part, tensor_compression))
+                    self._input_parts_by_peer[assigned_peer_index].append((part, part_info))
                 else:
-                    self._input_parts_by_peer[current_peer_index].append((part, tensor_compression))
+                    self._input_parts_by_peer[current_peer_index].append((part, part_info))
                 current_length += len(part)
 
         assert current_length == self.total_size
@@ -87,21 +97,16 @@ class TensorPartContainer:
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
         input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
-        self._input_parts_by_peer[peer_index].clear()
         return input_parts
 
     @torch.no_grad()
-    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
+    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[runtime_pb2.Tensor]:
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
-
-        async def _aiterate_parts():
-            for _ in range(self.num_parts_by_peer[peer_index]):
-                yield self._input_parts_by_peer[peer_index].popleft()
-
+        parts_aiter = as_aiter(*self._input_parts_by_peer[peer_index])
         async for serialized_part in amap_in_executor(
-            lambda x_and_compr: serialize_torch_tensor(*x_and_compr), _aiterate_parts(), max_prefetch=self.prefetch
+            lambda x_and_info: self.compression.compress(*x_and_info), parts_aiter, max_prefetch=self.prefetch
         ):
             yield serialized_part
 
@@ -119,6 +124,16 @@ class TensorPartContainer:
         self._outputs_registered_by_peer[peer_index] += 1
         self._output_part_available[peer_index].set()
 
+    def register_failed_reducer(self, peer_index: int):
+        """
+        a given peer failed to aggregate a certain part, use our local part instead, keep track of failed parts
+        """
+        for part_index in range(self._outputs_registered_by_peer[peer_index], self.num_parts_by_peer[peer_index]):
+            part_and_info = self._input_parts_by_peer[peer_index][part_index]
+            part_result_or_delta = torch.zeros_like(part_and_info[0]) if self.return_deltas else part_and_info[0]
+            self.register_processed_part(peer_index, part_index, part_result_or_delta)
+            self.failed_size += part_result_or_delta.numel()
+
     async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
         """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
         assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
@@ -135,7 +150,7 @@ class TensorPartContainer:
                     self._output_part_available[peer_index].clear()
                     await self._output_part_available[peer_index].wait()
                     if self.finished.is_set():
-                        raise AllreduceException("All-reduce was terminated during iteration.")
+                        raise AllreduceException("All-reduce was terminated during iteration")
 
                 tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
                 num_parts_processed += 1
@@ -151,9 +166,11 @@ class TensorPartContainer:
         if not self.finished.is_set():
             for peer_index in range(self.group_size):
                 self._inputs_consumed_by_peer[peer_index] = True
+                self._output_part_available[peer_index].set()
                 self._input_parts_by_peer[peer_index].clear()
                 self._output_parts_by_peer[peer_index].clear()
-                self._output_part_available[peer_index].set()
+            if self.failed_size != 0:
+                logger.warning(f"Averaging: received {(1. - self.failed_size / self.total_size) * 100:.1f}% results")
             self._outputs_consumed = True
             self.finished.set()
 
@@ -163,26 +180,27 @@ class TensorPartReducer:
     Auxiliary data structure responsible for running asynchronous all-reduce
     :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
     :param num_senders: total number of peers in a given all-reduce group that will send gradients
-    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
     :note: even if local peer is not sending data, local parts will be used for shape information
     """
 
-    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int):
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
-        self.weights = tuple(weights or (1 for _ in range(num_senders)))
-        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
-        assert all(isinstance(weight, (int, float)) for weight in self.weights)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.accumulator: Optional[torch.Tensor] = None  # contains the sum of current tensor part from group peers
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
         self.finished = asyncio.Event()
+
+        self.num_parts_received = [0 for _ in range(self.num_senders)]
+        self.sender_failed_after = [float("inf") for _ in range(self.num_senders)]
+        self.num_current_senders = self.num_senders
+
         self.reset_accumulators()
 
     def reset_accumulators(self):
         """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
-        assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
+        assert self.current_part_accumulated_from == self.num_current_senders or self.current_part_index == -1
         if self.current_part_index >= self.num_parts - 1:
             self.finalize()
             return
@@ -190,32 +208,53 @@ class TensorPartReducer:
         self.current_part_index += 1
         self.current_part_accumulated_from = 0
         self.current_part_future = asyncio.Future()
+        self.num_current_senders = sum(
+            self.current_part_index < failed_index for failed_index in self.sender_failed_after
+        )
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.denominator = 0.0
 
-    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+    async def accumulate_part(
+        self, sender_index: int, part_index: int, tensor_part: torch.Tensor, weight: float = 1.0
+    ) -> torch.Tensor:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
+        self.num_parts_received[sender_index] += 1
 
         while part_index > self.current_part_index:
             # wait for previous parts to finish processing ...
             await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
             if self.finished.is_set():
                 raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
+
+        if self.sender_failed_after[sender_index] != float("inf"):
+            raise BannedException(f"sender {sender_index} was banned in background")
         assert part_index == self.current_part_index
 
         current_part_future = self.current_part_future
 
-        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
-        self.denominator += self.weights[sender_index]
-        self.current_part_accumulated_from += 1
+        if part_index < self.sender_failed_after[sender_index]:
+            self.accumulator.add_(tensor_part, alpha=weight)
+            self.current_part_accumulated_from += 1
+            self.denominator += weight
+            self.check_current_part_finished()
+        return await current_part_future
 
-        assert self.current_part_accumulated_from <= self.num_senders
-        if self.current_part_accumulated_from == self.num_senders:
-            current_part_future.set_result(self.accumulator.div_(self.denominator))
+    def on_sender_failed(self, sender_index: int):
+        """Exclude that sender's data for averaging any parts that it did not submit yet."""
+        self.sender_failed_after[sender_index] = self.num_parts_received[sender_index]
+        if self.finished.is_set():
+            return
+        if self.current_part_index == self.num_parts_received[sender_index]:
+            self.num_current_senders -= 1
+            self.check_current_part_finished()
+
+    def check_current_part_finished(self):
+        assert self.current_part_accumulated_from <= self.num_current_senders
+        if self.current_part_accumulated_from == self.num_current_senders:
+            self.current_part_future.set_result(self.accumulator.div_(self.denominator))
             self.reset_accumulators()
-        return await current_part_future
 
     def finalize(self):
         if not self.finished.is_set():
@@ -224,9 +263,19 @@ class TensorPartReducer:
                 del self.accumulator
             self.finished.set()
 
+            if self.num_parts != 0 and self.num_senders != 0:
+                parts_expected = self.num_parts * self.num_senders
+                parts_received = sum(self.num_parts_received)
+                if parts_expected != parts_received:
+                    logger.warning(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
+
     def __del__(self):
         self.finalize()
 
 
 class AllreduceException(Exception):
     """A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error)"""
+
+
+class BannedException(AllreduceException):
+    """An exception that indicates that a given sender was banned and will no longer be aggregated"""

+ 52 - 0
hivemind/compression/__init__.py

@@ -0,0 +1,52 @@
+"""
+Compression strategies that reduce the network communication in .averaging, .optim and .moe
+"""
+
+import warnings
+from typing import Dict, Optional
+
+import torch
+
+from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
+from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
+from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
+from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.proto import runtime_pb2
+
+warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
+
+
+BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+    NONE=NoCompression(),
+    FLOAT16=Float16Compression(),
+    MEANSTD_16BIT=ScaledFloat16Compression(),
+    QUANTILE_8BIT=Quantile8BitQuantization(),
+    UNIFORM_8BIT=Uniform8BitQuantization(),
+)
+
+for key in runtime_pb2.CompressionType.keys():
+    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer."
+    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+    assert (
+        runtime_pb2.CompressionType.Name(actual_compression_type) == key
+    ), f"Compression strategy for {key} has inconsistent type"
+
+
+def serialize_torch_tensor(
+    tensor: torch.Tensor,
+    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+    info: Optional[CompressionInfo] = None,
+    allow_inplace: bool = False,
+    **kwargs,
+) -> runtime_pb2.Tensor:
+    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
+    assert tensor.device == torch.device("cpu")
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
+    return compression.compress(tensor, info, allow_inplace)
+
+
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    """Restore a pytorch tensor from a protobuf message"""
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)

+ 67 - 0
hivemind/compression/adaptive.py

@@ -0,0 +1,67 @@
+from abc import ABC, abstractmethod
+from typing import Mapping, Sequence, Union
+
+import torch
+
+import hivemind
+from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole
+from hivemind.proto import runtime_pb2
+
+
+class AdaptiveCompressionBase(CompressionBase, ABC):
+    @abstractmethod
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        ...
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return self.choose_compression(info).estimate_compression_ratio(info)
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace)
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        return hivemind.compression.deserialize_torch_tensor(serialized_tensor)
+
+
+class SizeAdaptiveCompression(AdaptiveCompressionBase):
+    """Apply compression strategy 1 if tensor has more than :threshold: elements and strategy 2 otherwise"""
+
+    def __init__(self, threshold: int, less: CompressionBase, greater_equal: CompressionBase):
+        self.threshold, self.less, self.greater_equal = threshold, less, greater_equal
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.greater_equal if info.descriptor.numel() >= self.threshold else self.less
+
+
+class RoleAdaptiveCompression(AdaptiveCompressionBase):
+    """Compress a tensor based on its role in training. Any non-specified compressions will use the "default" option"""
+
+    def __init__(
+        self,
+        *,
+        activation: CompressionBase = None,
+        parameter: CompressionBase = None,
+        gradient: CompressionBase = None,
+        optimizer: CompressionBase = None,
+        default: CompressionBase = NoCompression()
+    ):
+        self.role_compressions = {
+            TensorRole.ACTIVATION: activation or default,
+            TensorRole.PARAMETER: parameter or default,
+            TensorRole.GRADIENT: gradient or default,
+            TensorRole.OPTIMIZER: optimizer or default,
+            TensorRole.UNSPECIFIED: default,
+        }
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.role_compressions[info.role]
+
+
+class PerTensorCompression(AdaptiveCompressionBase):
+    """Manually specify the compression strategy depending on tensor key"""
+
+    def __init__(self, tensor_compressions: Union[Sequence[CompressionBase], Mapping[Key, CompressionBase]]):
+        self.tensor_compressions = tensor_compressions
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.tensor_compressions[info.key]

+ 92 - 0
hivemind/compression/base.py

@@ -0,0 +1,92 @@
+import dataclasses
+from abc import ABC, abstractmethod
+from enum import Enum, auto
+from typing import Any, Optional
+
+import numpy as np
+import torch
+
+from hivemind.proto import runtime_pb2
+from hivemind.utils.tensor_descr import TensorDescriptor
+
+Key = Any
+
+
+class TensorRole(Enum):
+    ACTIVATION = auto()
+    PARAMETER = auto()
+    GRADIENT = auto()
+    OPTIMIZER = auto()
+    UNSPECIFIED = auto()
+
+
+@dataclasses.dataclass(frozen=True)
+class CompressionInfo:
+    """Auxiliary data structure that contains information about the tensor that determines how it is compressed"""
+
+    key: Key  # name or index of the tensor from named parameters, optimizer state dict or i/o structure
+    descriptor: TensorDescriptor  # data structure that defines shape, dtype, layout and device information
+    role: TensorRole = TensorRole.UNSPECIFIED  # which role does the tensor play with respect to the model
+    part_index: int = 0  # if tensor is sliced into parts, this represents the index within one tensor
+    part_size: Optional[int] = None  # if tensor is sliced into parts, this is the _maximum_ number of values per part
+
+    @classmethod
+    def from_tensor(cls, tensor: torch.Tensor, key: Key = None, descriptor: TensorDescriptor = None, **kwargs):
+        return cls(key, descriptor or TensorDescriptor.from_tensor(tensor), **kwargs)
+
+    def get_part(self, part_index: int, part_size: Optional[int]):
+        return CompressionInfo(self.key, self.descriptor, self.role, part_index=part_index, part_size=part_size)
+
+
+class CompressionBase(ABC):
+    """A base class that applies compression algorithm to a pytorch tensor"""
+
+    compression_type: runtime_pb2.CompressionType
+
+    @abstractmethod
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        """
+        Applies compression algorithm to a tensor based on their meta-parameters
+
+        :param tensor: a pytorch tensor to compress; depending on the applicaiton, it is a full tensor or a part
+        :param info: meta-information about the tensor; if partitioning is used, this still describes the full tensor
+        :param allow_inplace: if True, compression can (but doesn't have to) to modify tensor in-place for efficiency
+        :returns: a protobuf message that encodes the tensor
+        """
+        ...
+
+    @abstractmethod
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        """Create a pytorch tensor from the serialized outputs of .compress"""
+        ...
+
+    @abstractmethod
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        """Estimate the compression ratio without doing the actual compression; lower ratio = better compression"""
+        ...
+
+    def __repr__(self):
+        return f"hivemind.{self.__class__.__name__}()"
+
+
+class NoCompression(CompressionBase):
+    """A dummy compression strategy that preserves the original tensor as is."""
+
+    compression_type = runtime_pb2.CompressionType.NONE
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        array = tensor.numpy()
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=array.tobytes(),
+            size=array.shape,
+            dtype=array.dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
+        return torch.as_tensor(array).reshape(tuple(serialized_tensor.size))
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return 1.0

+ 92 - 0
hivemind/compression/floating.py

@@ -0,0 +1,92 @@
+import math
+
+import numpy as np
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo
+from hivemind.proto import runtime_pb2
+
+
+class Float16Compression(CompressionBase):
+    compression_type = runtime_pb2.CompressionType.FLOAT16
+    FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        dtype_name = tensor.numpy().dtype.name
+        tensor = tensor.detach().cpu().float()
+        tensor = tensor if allow_inplace else tensor.clone()
+        tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=tensor.numpy().tobytes(),
+            size=tensor.shape,
+            dtype=dtype_name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        original_dtype = np.dtype(serialized_tensor.dtype)
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
+        return torch.as_tensor(np.asarray(array, dtype=original_dtype)).reshape(tuple(serialized_tensor.size))
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return 16.0 / get_num_bits(info.descriptor.dtype)
+
+
+class ScaledFloat16Compression(Float16Compression):
+    """A compression strategy that applies mean-std scaling over last axis before casting to float16"""
+
+    compression_type = runtime_pb2.CompressionType.MEANSTD_16BIT
+    FP32_BYTES = torch.finfo(torch.float32).bits // 8
+    FP32_EPS = torch.finfo(torch.float32).eps
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        dtype_name = tensor.numpy().dtype.name
+        tensor = tensor.detach().cpu().float()
+        tensor = tensor if allow_inplace else tensor.clone()
+        means = torch.mean(tensor, dim=-1, keepdim=True)
+        tensor.sub_(means)
+        stds = tensor.norm(dim=-1, keepdim=True) / math.sqrt(tensor.shape[-1])
+        stds.clamp_min_(self.FP32_EPS)
+        tensor.div_(stds)
+        tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
+
+        data = b"".join((tensor.numpy().tobytes(), means.float().numpy().tobytes(), stds.float().numpy().tobytes()))
+
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype=dtype_name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        stats_shape = list(serialized_tensor.size)
+        stats_shape[-1] = 1
+        stats_count = np.prod(stats_shape)
+        means_offset = len(serialized_tensor.buffer) - 2 * stats_count * self.FP32_BYTES
+        stds_offset = len(serialized_tensor.buffer) - stats_count * self.FP32_BYTES
+
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16, count=np.prod(serialized_tensor.size))
+        means = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=means_offset, count=stats_count)
+        stds = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=stds_offset, count=stats_count)
+
+        means = torch.as_tensor(means).reshape(stats_shape)
+        stds = torch.as_tensor(stds).reshape(stats_shape)
+        tensor = torch.as_tensor(np.asarray(array, dtype=serialized_tensor.dtype)).reshape(
+            list(serialized_tensor.size)
+        )
+        return tensor.mul_(stds).add_(means)
+
+
+def get_num_bits(dtype: torch.dtype) -> int:
+    if dtype == torch.bool:
+        return 8  # see https://github.com/pytorch/pytorch/issues/41571
+    elif dtype.is_floating_point:
+        return torch.finfo(dtype).bits
+    else:
+        try:
+            return torch.iinfo(dtype).bits
+        except TypeError:
+            raise TypeError(f"Could not infer size for tensor type {dtype}")

+ 114 - 0
hivemind/compression/quantization.py

@@ -0,0 +1,114 @@
+import math
+import os
+from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Tuple
+
+import numpy as np
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo
+from hivemind.proto import runtime_pb2
+
+EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
+
+
+class Quantization(CompressionBase, ABC):
+    codebook_dtype, indices_dtype = np.float32, np.uint8
+
+    @abstractmethod
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        """Convert tensor into a pair of (indices, codebook)"""
+        ...
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
+            size=tensor.shape,
+            dtype=tensor.numpy().dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        codebook_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
+        codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
+        quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
+        quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
+        codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
+        return codebook[quantized]
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return self.n_bits / torch.finfo(info.descriptor.dtype).bits
+
+    @property
+    def n_bits(self):
+        return self.indices_dtype(1).itemsize * 8
+
+    @property
+    def n_bins(self):
+        return 2 ** self.n_bits
+
+
+class Uniform8BitQuantization(Quantization):
+    RANGE_IN_SIGMAS: int = 6
+    compression_type = runtime_pb2.UNIFORM_8BIT
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        offset = self.n_bins // 2
+        shift = tensor.mean()
+        centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
+        std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
+        scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
+        quantized = torch.quantize_per_tensor(centered_tensor, scale, offset, torch.quint8).int_repr()
+        lookup = average_buckets(tensor, quantized, self.n_bins)
+        return np.asarray(quantized, dtype=self.indices_dtype), np.asarray(lookup, dtype=self.codebook_dtype)
+
+
+class Quantile8BitQuantization(Quantization):
+    compression_type = runtime_pb2.QUANTILE_8BIT
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        tensor = tensor.detach().float()
+        borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), self.n_bins + 1)[1:-1])
+        quantized = torch.clamp_(torch.bucketize(tensor, borders), 0, self.n_bins - 1)
+        codebook = average_buckets(tensor, quantized, self.n_bins)
+        return quantized.numpy().astype(np.uint8), codebook.numpy()
+
+
+def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
+    """Return the average value in each bucket"""
+    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
+    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
+    lookup = bin_sums / bin_counts
+    return lookup
+
+
+def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
+    """Adjust chunk_size to minimize imbalance between chunk sizes"""
+    if min_chunk_size >= num_elements:
+        return min_chunk_size
+    leftover_elements = num_elements % min_chunk_size
+    num_chunks = num_elements // min_chunk_size
+    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
+
+
+def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
+    """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
+    if not array.data.c_contiguous and array.data.f_contiguous:
+        array = array.T
+    array = np.ascontiguousarray(array.reshape(-1))
+    quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
+    chunk_size = get_chunk_size(len(array), min_chunk_size)
+    num_chunks = (len(array) - 1) // chunk_size + 1
+    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
+
+    jobs = []
+    for i in range(num_chunks):
+        chunk = slice(chunk_size * i, chunk_size * (i + 1))
+        jobs.append(EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
+
+    for job in jobs:
+        job.result()
+    return np.quantile(partition_quantiles, quantiles)

+ 4 - 302
hivemind/dht/__init__.py

@@ -4,7 +4,7 @@ Hivemind DHT is based on Kademlia [1] with added support for improved bulk store
 
 The code is organized as follows:
 
- * **class DHT (__init__.py)** - high-level class for model training. Runs DHTNode in a background process.
+ * **class DHT (dht.py)** - high-level class for model training. Runs DHTNode in a background process.
  * **class DHTNode (node.py)** - an asyncio implementation of dht server, stores AND gets keys.
  * **class DHTProtocol (protocol.py)** - an RPC protocol to request data from dht nodes.
  * **async def traverse_dht (traverse.py)** - a search algorithm that crawls DHT peers.
@@ -12,306 +12,8 @@ The code is organized as follows:
 - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
 - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 """
-from __future__ import annotations
 
-import asyncio
-import multiprocessing as mp
-import os
-from concurrent.futures import ThreadPoolExecutor
-from functools import partial
-from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
-
-from multiaddr import Multiaddr
-
-from hivemind.dht.node import DHTNode
-from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
+from hivemind.dht.dht import DHT
+from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
-from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
-
-logger = get_logger(__name__)
-
-ReturnType = TypeVar("ReturnType")
-
-
-class DHT(mp.Process):
-    """
-    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
-    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
-    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
-
-    :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
-    :param start: if True, automatically starts the background process on creation. Otherwise await manual start
-    :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
-    :param max_workers: declare_experts and get_experts will use up to this many parallel workers
-      (but no more than one per key)
-    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
-    :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
-      The validators will be combined using the CompositeValidator class. It merges them when possible
-      (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
-    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
-    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
-    :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
-    """
-
-    _node: DHTNode
-
-    def __init__(
-        self,
-        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        *,
-        start: bool,
-        daemon: bool = True,
-        max_workers: Optional[int] = None,
-        record_validators: Iterable[RecordValidatorBase] = (),
-        shutdown_timeout: float = 3,
-        await_ready: bool = True,
-        **kwargs,
-    ):
-        self._parent_pid = os.getpid()
-        super().__init__()
-
-        if not (
-            initial_peers is None
-            or (
-                isinstance(initial_peers, Sequence)
-                and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
-            )
-        ):
-            raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
-        self.initial_peers = initial_peers
-        self.kwargs = kwargs
-        self.max_workers = max_workers
-
-        self._record_validator = CompositeValidator(record_validators)
-        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
-        self.shutdown_timeout = shutdown_timeout
-        self.ready = mp.Event()
-        self.daemon = daemon
-
-        # These values will be fetched from the child process when requested
-        self._peer_id = None
-        self._client_mode = None
-        self._p2p_replica = None
-
-        if start:
-            self.run_in_background(await_ready=await_ready)
-
-    def run(self) -> None:
-        """Serve DHT forever. This function will not return until DHT node is shut down"""
-        loop = switch_to_uvloop()
-
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
-
-            async def _run():
-                self._node = await DHTNode.create(
-                    initial_peers=self.initial_peers,
-                    num_workers=self.max_workers or 1,
-                    record_validator=self._record_validator,
-                    **self.kwargs,
-                )
-                self.ready.set()
-
-                while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
-
-            coro = _run()
-            loop.run_until_complete(coro)
-
-    def run_in_background(self, await_ready=True, timeout=None):
-        """
-        Starts DHT in a background process. if await_ready, this method will wait until background dht
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")
-
-    def shutdown(self) -> None:
-        """Shut down a running dht process"""
-        if self.is_alive():
-            self._outer_pipe.send(("_shutdown", [], {}))
-            self.join(self.shutdown_timeout)
-            if self.is_alive():
-                logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
-                self.terminate()
-
-    async def _shutdown(self):
-        await self._node.shutdown()
-
-    def get(
-        self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
-    ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
-        """
-        Search for a key across DHT and return either first or latest entry (if found).
-        :param key: same key as in node.store(...)
-        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
-        :returns: (value, expiration time); if value was not found, returns None
-        """
-        future = MPFuture()
-        self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
-        return future if return_future else future.result()
-
-    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
-        try:
-            result = await self._node.get(key, latest=latest, **kwargs)
-            if not future.done():
-                future.set_result(result)
-        except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
-            raise
-
-    def store(
-        self,
-        key: DHTKey,
-        value: DHTValue,
-        expiration_time: DHTExpiration,
-        subkey: Optional[Subkey] = None,
-        return_future: bool = False,
-        **kwargs,
-    ) -> Union[bool, MPFuture]:
-        """
-        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
-
-        :param key: msgpack-serializable key to be associated with value until expiration.
-        :param value: msgpack-serializable value to be stored under a given key until expiration.
-        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
-        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: True if store succeeds, False if it fails (due to no response or newer value)
-        """
-        future = MPFuture()
-        self._outer_pipe.send(
-            (
-                "_store",
-                [],
-                dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
-            )
-        )
-        return future if return_future else future.result()
-
-    async def _store(
-        self,
-        key: DHTKey,
-        value: DHTValue,
-        expiration_time: DHTExpiration,
-        subkey: Optional[Subkey],
-        future: MPFuture,
-        **kwargs,
-    ):
-        try:
-            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
-            if not future.done():
-                future.set_result(result)
-        except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
-            raise
-
-    def run_coroutine(
-        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
-    ) -> Union[ReturnType, MPFuture[ReturnType]]:
-        """
-        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
-         for running custom functions DHT for special cases (e.g. declare experts, beam search)
-
-        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: coroutine outputs or MPFuture for these outputs
-        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
-          DHT fields made by this coroutine will not be accessible from the host process.
-        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
-          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
-        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
-        """
-        future = MPFuture()
-        self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
-        return future if return_future else future.result()
-
-    async def _run_coroutine(
-        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
-    ):
-        main_task = asyncio.create_task(coro(self, self._node))
-        cancel_task = asyncio.create_task(await_cancelled(future))
-        try:
-            await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
-            if future.cancelled():
-                main_task.cancel()
-            else:
-                future.set_result(await main_task)
-        except BaseException as e:
-            logger.exception(f"Caught an exception when running a coroutine: {e}")
-            if not future.done():
-                future.set_exception(e)
-
-    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
-        if not self.ready.is_set():
-            raise RuntimeError(
-                "Can't append new validators before the DHT process has started. "
-                "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
-            )
-
-        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
-
-    @staticmethod
-    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
-        node.protocol.record_validator.extend(record_validators)
-
-    @property
-    def peer_id(self) -> PeerID:
-        if self._peer_id is None:
-            self._peer_id = self.run_coroutine(DHT._get_peer_id)
-        return self._peer_id
-
-    @staticmethod
-    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
-        return node.peer_id
-
-    @property
-    def client_mode(self) -> bool:
-        if self._client_mode is None:
-            self._client_mode = self.run_coroutine(DHT._get_client_mode)
-        return self._client_mode
-
-    @staticmethod
-    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
-        return node.protocol.client_mode
-
-    def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
-        """
-        Get multiaddrs of the current DHT node that should be accessible by other peers.
-
-        :param latest: ask the P2P daemon to refresh the visible multiaddrs
-        """
-
-        return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
-
-    @staticmethod
-    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
-        return await node.get_visible_maddrs(latest=latest)
-
-    async def replicate_p2p(self) -> P2P:
-        """
-        Get a replica of a P2P instance used in the DHT process internally.
-        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
-        """
-
-        if self._p2p_replica is None:
-            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
-            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
-        return self._p2p_replica
-
-    @staticmethod
-    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
-        return node.p2p.daemon_listen_maddr
-
-    def __del__(self):
-        if self._parent_pid == os.getpid() and self.is_alive():
-            self.shutdown()

+ 324 - 0
hivemind/dht/dht.py

@@ -0,0 +1,324 @@
+from __future__ import annotations
+
+import asyncio
+import multiprocessing as mp
+import os
+from functools import partial
+from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
+
+from multiaddr import Multiaddr
+
+from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
+from hivemind.dht.routing import DHTKey, DHTValue, Subkey
+from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
+from hivemind.p2p import P2P, PeerID
+from hivemind.utils import MPFuture, get_logger, switch_to_uvloop
+from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration
+
+logger = get_logger(__name__)
+ReturnType = TypeVar("ReturnType")
+
+
+class DHT(mp.Process):
+    """
+    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
+    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
+    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
+
+    :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
+    :param start: if True, automatically starts the background process on creation. Otherwise await manual start
+    :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
+    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
+      (but no more than one per key)
+    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
+    :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
+      The validators will be combined using the CompositeValidator class. It merges them when possible
+      (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
+    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
+    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
+    :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
+    """
+
+    _node: DHTNode
+
+    def __init__(
+        self,
+        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
+        *,
+        start: bool,
+        p2p: Optional[P2P] = None,
+        daemon: bool = True,
+        num_workers: int = DEFAULT_NUM_WORKERS,
+        record_validators: Iterable[RecordValidatorBase] = (),
+        shutdown_timeout: float = 3,
+        await_ready: bool = True,
+        **kwargs,
+    ):
+        self._parent_pid = os.getpid()
+        super().__init__()
+
+        if not (
+            initial_peers is None
+            or (
+                isinstance(initial_peers, Sequence)
+                and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
+            )
+        ):
+            raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
+        self.initial_peers = initial_peers
+        self.kwargs = kwargs
+        self.num_workers = num_workers
+
+        self._record_validator = CompositeValidator(record_validators)
+        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
+        self.shutdown_timeout = shutdown_timeout
+        self._ready = MPFuture()
+        self.daemon = daemon
+
+        # These values will be fetched from the child process when requested
+        self._peer_id = None
+        self._client_mode = None
+        self._p2p_replica = None
+
+        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
+
+        if start:
+            self.run_in_background(await_ready=await_ready)
+
+    def run(self) -> None:
+        """Serve DHT forever. This function will not return until DHT node is shut down"""
+
+        loop = switch_to_uvloop()
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                if self._daemon_listen_maddr is not None:
+                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                else:
+                    replicated_p2p = None
+
+                self._node = await DHTNode.create(
+                    initial_peers=self.initial_peers,
+                    num_workers=self.num_workers,
+                    record_validator=self._record_validator,
+                    p2p=replicated_p2p,
+                    **self.kwargs,
+                )
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
+                try:
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
+                except asyncio.TimeoutError:
+                    pass
+                if not self._inner_pipe.poll():
+                    continue
+                try:
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._node.protocol.wait_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
+
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
+        """
+        Starts DHT in a background process. if await_ready, this method will wait until background dht
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
+
+    def shutdown(self) -> None:
+        """Shut down a running dht process"""
+        if self.is_alive():
+            self._outer_pipe.send(("_shutdown", [], {}))
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning("DHT did not shut down within the grace period; terminating it the hard way")
+                self.terminate()
+
+    async def _shutdown(self):
+        await self._node.shutdown()
+
+    def get(
+        self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
+    ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
+        """
+        Search for a key across DHT and return either first or latest entry (if found).
+        :param key: same key as in node.store(...)
+        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
+        :returns: (value, expiration time); if value was not found, returns None
+        """
+        future = MPFuture()
+        self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
+        try:
+            result = await self._node.get(key, latest=latest, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey] = None,
+        return_future: bool = False,
+        **kwargs,
+    ) -> Union[bool, MPFuture]:
+        """
+        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
+
+        :param key: msgpack-serializable key to be associated with value until expiration.
+        :param value: msgpack-serializable value to be stored under a given key until expiration.
+        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
+        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: True if store succeeds, False if it fails (due to no response or newer value)
+        """
+        future = MPFuture()
+        self._outer_pipe.send(
+            (
+                "_store",
+                [],
+                dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
+            )
+        )
+        return future if return_future else future.result()
+
+    async def _store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey],
+        future: MPFuture,
+        **kwargs,
+    ):
+        try:
+            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
+    ) -> Union[ReturnType, MPFuture[ReturnType]]:
+        """
+        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
+         for running custom functions DHT for special cases (e.g. declare experts, beam search)
+
+        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: coroutine outputs or MPFuture for these outputs
+        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
+          DHT fields made by this coroutine will not be accessible from the host process.
+        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
+          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
+        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
+        """
+        future = MPFuture()
+        self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
+        return future if return_future else future.result()
+
+    async def _run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
+    ):
+        try:
+            future.set_result(await coro(self, self._node))
+        except BaseException as e:
+            logger.exception("Caught an exception when running a coroutine:")
+            future.set_exception(e)
+
+    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
+        if not self._ready.done():
+            raise RuntimeError(
+                "Can't append new validators before the DHT process has started. "
+                "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
+            )
+
+        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
+
+    @staticmethod
+    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
+        node.protocol.record_validator.extend(record_validators)
+
+    @property
+    def peer_id(self) -> PeerID:
+        if self._peer_id is None:
+            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+        return self._peer_id
+
+    @staticmethod
+    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
+        return node.peer_id
+
+    @property
+    def client_mode(self) -> bool:
+        if self._client_mode is None:
+            self._client_mode = self.run_coroutine(DHT._get_client_mode)
+        return self._client_mode
+
+    @staticmethod
+    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
+        return node.protocol.client_mode
+
+    def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
+        """
+        Get multiaddrs of the current DHT node that should be accessible by other peers.
+
+        :param latest: ask the P2P daemon to refresh the visible multiaddrs
+        """
+
+        return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
+
+    @staticmethod
+    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
+        return await node.get_visible_maddrs(latest=latest)
+
+    async def replicate_p2p(self) -> P2P:
+        """
+        Get a replica of a P2P instance used in the DHT process internally.
+        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
+        """
+
+        if self._p2p_replica is None:
+            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
+            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
+        return self._p2p_replica
+
+    @staticmethod
+    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
+        return node.p2p.daemon_listen_maddr
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 18 - 7
hivemind/dht/node.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import asyncio
 import dataclasses
+import os
 import random
 from collections import Counter, defaultdict
 from dataclasses import dataclass, field
@@ -38,6 +39,9 @@ from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithE
 logger = get_logger(__name__)
 
 
+DEFAULT_NUM_WORKERS = int(os.getenv("HIVEMIND_DHT_NUM_WORKERS", 4))
+
+
 class DHTNode:
     """
     Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
@@ -110,14 +114,14 @@ class DHTNode:
         cache_refresh_before_expiry: float = 5,
         cache_on_store: bool = True,
         reuse_get_requests: bool = True,
-        num_workers: int = 1,
+        num_workers: int = DEFAULT_NUM_WORKERS,
         chunk_size: int = 16,
         blacklist_time: float = 5.0,
         backoff_rate: float = 2.0,
         client_mode: bool = False,
         record_validator: Optional[RecordValidatorBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
-        validate: bool = True,
+        ensure_bootstrap_success: bool = True,
         strict: bool = True,
         **kwargs,
     ) -> DHTNode:
@@ -152,9 +156,10 @@ class DHTNode:
         :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
         :param blacklist_time: excludes non-responsive peers from search for this many seconds (set 0 to disable)
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
-        :param validate: if True, use initial peers to validate that this node is accessible and synchronized
+        :param ensure_bootstrap_success: raise an error if node could not connect to initial peers (or vice versa)
+           If False, print a warning instead. It is recommended to keep this flag unless you know what you're doing.
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
-        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citzen"
+        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citizen"
           if True, this node will refuse any incoming requests, effectively being only a client
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response
@@ -182,6 +187,8 @@ class DHTNode:
         if p2p is None:
             if not kwargs.get("use_ipfs"):
                 kwargs["initial_peers"] = initial_peers
+            if client_mode:
+                kwargs.setdefault("dht_mode", "client")
             p2p = await P2P.create(**kwargs)
             self._should_shutdown_p2p = True
         else:
@@ -216,7 +223,7 @@ class DHTNode:
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             start_time = get_dht_time()
             ping_tasks = set(
-                asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
+                asyncio.create_task(self.protocol.call_ping(peer, validate=ensure_bootstrap_success, strict=strict))
                 for peer in initial_peers
             )
             finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
@@ -231,7 +238,11 @@ class DHTNode:
                 finished_pings |= finished_in_time
 
             if not finished_pings or all(ping.result() is None for ping in finished_pings):
-                logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
+                message = "DHTNode bootstrap failed: none of the initial_peers responded to a ping."
+                if ensure_bootstrap_success:
+                    raise RuntimeError(f"{message} (set ensure_bootstrap_success=False to ignore)")
+                else:
+                    logger.warning(message)
 
             if strict:
                 for task in asyncio.as_completed(finished_pings):
@@ -706,7 +717,7 @@ class DHTNode:
         """Add key to a refresh queue, refresh at :refresh_time: or later"""
         if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
             self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
-            logger.debug("Spawned cache refresh task.")
+            logger.debug("Spawned cache refresh task")
         earliest_key, earliest_item = self.cache_refresh_queue.top()
         if earliest_item is None or refresh_time < earliest_item.expiration_time:
             self.cache_refresh_evt.set()  # if we new element is now earliest, notify the cache queue

+ 1 - 1
hivemind/dht/protocol.py

@@ -81,7 +81,7 @@ class DHTProtocol(ServicerBase):
 
     def __init__(self, *, _initialized_with_create=False):
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
-        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
+        assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
         super().__init__()
 
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:

+ 2 - 2
hivemind/dht/routing.py

@@ -10,7 +10,7 @@ from itertools import chain
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.p2p import PeerID
-from hivemind.utils import MSGPackSerializer, get_dht_time
+from hivemind.utils import DHTExpiration, MSGPackSerializer, get_dht_time
 
 DHTKey = Subkey = DHTValue = Any
 BinaryDHTID = BinaryDHTValue = bytes
@@ -217,7 +217,7 @@ class KBucket:
 
     def __delitem__(self, node_id: DHTID):
         if not (node_id in self.nodes_to_peer_id or node_id in self.replacement_nodes):
-            raise KeyError(f"KBucket does not contain node id={node_id}.")
+            raise KeyError(f"KBucket does not contain node id={node_id}")
 
         if node_id in self.replacement_nodes:
             del self.replacement_nodes[node_id]

+ 1 - 1
hivemind/dht/schema.py

@@ -18,7 +18,7 @@ class SchemaValidator(RecordValidatorBase):
     This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
     """
 
-    def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool = True, prefix: Optional[str] = None):
+    def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None):
         """
         :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
 

+ 9 - 6
hivemind/hivemind_cli/run_server.py

@@ -4,12 +4,13 @@ from pathlib import Path
 import configargparse
 import torch
 
-from hivemind.moe.server import Server
+from hivemind.moe import Server
 from hivemind.moe.server.layers import schedule_name_to_scheduler
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
@@ -21,13 +22,14 @@ def main():
                         help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
-                        help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will sample random expert uids'
-                             ' between myexpert.0.0 and myexpert.255.1023 . Use either num_experts and this or expert_uids')
+                        help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will'
+                             ' sample random expert uids between myexpert.0.0 and myexpert.255.1023 . Use either'
+                             ' num_experts and this or expert_uids')
     parser.add_argument('--expert_uids', type=str, nargs="*", default=None, required=False,
                         help="specify the exact list of expert uids to create. Use either this or num_experts"
                              " and expert_pattern, not both")
     parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
-                        help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
+                        help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'")
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
 
     parser.add_argument('--num_handlers', type=int, default=None, required=False,
@@ -42,7 +44,8 @@ def main():
     parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
     parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',
                         help='LR scheduler type to use')
-    parser.add_argument('--num_warmup_steps', type=int, required=False, help='The number of warmup steps for LR schedule')
+    parser.add_argument('--num_warmup_steps', type=int, required=False,
+                        help='The number of warmup steps for LR schedule')
     parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 

+ 8 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,9 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class
+from hivemind.moe.server import (
+    ExpertBackend,
+    Server,
+    background_server,
+    declare_experts,
+    get_experts,
+    register_expert_class,
+)

+ 3 - 3
hivemind/moe/client/beam_search.py

@@ -125,7 +125,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
     ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
-        num_workers = num_workers or dht.max_workers or beam_size
+        num_workers = num_workers or dht.num_workers or beam_size
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         unattempted_indices: List[Coordinate] = sorted(
             range(len(scores)), key=scores.__getitem__
@@ -206,7 +206,7 @@ class MoEBeamSearcher:
         num_workers: Optional[int] = None,
     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
         grid_size = grid_size or float("inf")
-        num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
+        num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
         for prefix, found in dht_responses.items():
@@ -270,7 +270,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
     ) -> List[RemoteExpert]:
-        num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
+        num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(

+ 3 - 4
hivemind/moe/client/expert.py

@@ -1,13 +1,12 @@
-import pickle
 from typing import Any, Dict, Optional, Tuple
 
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils import Endpoint, MSGPackSerializer, nested_compare, nested_flatten, nested_pack
 from hivemind.utils.grpc import ChannelCache
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
@@ -60,7 +59,7 @@ class RemoteExpert(nn.Module):
     def info(self):
         if self._info is None:
             outputs = self.stub.info(runtime_pb2.ExpertUID(uid=self.uid))
-            self._info = pickle.loads(outputs.serialized_info)
+            self._info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._info
 
     def extra_repr(self):

+ 7 - 7
hivemind/moe/client/moe.py

@@ -9,13 +9,13 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-import hivemind
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_flatten, nested_map, nested_pack
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -48,7 +48,7 @@ class RemoteMixtureOfExperts(nn.Module):
         *,
         in_features,
         grid_size: Tuple[int, ...],
-        dht: hivemind.DHT,
+        dht: DHT,
         uid_prefix: str,
         k_best: int,
         k_min: int = 1,
@@ -238,14 +238,14 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies
         )
         if len(responded_inds) < k_min:
-            raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout.")
+            raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout")
 
         if not isinstance(info["outputs_schema"], tuple):
             outputs_schema = (info["outputs_schema"],)
         else:
             outputs_schema = info["outputs_schema"]
         outputs = nested_map(
-            lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
+            lambda descriptor: descriptor.make_zeros(num_samples, max_experts, device=flat_inputs[0].device),
             outputs_schema,
         )
 
@@ -330,7 +330,7 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies
         )
         if len(survivor_inds) < backward_k_min:
-            raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout.")
+            raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout")
 
         # assemble responses
         batch_inds, expert_inds = map(
@@ -341,7 +341,7 @@ class _RemoteCallMany(torch.autograd.Function):
         # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
 
         grad_inputs = nested_map(
-            lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
+            lambda descr: descr.make_zeros(num_samples, device=flat_grad_outputs[0].device),
             list(nested_flatten(info["forward_schema"])),
         )
 

+ 3 - 355
hivemind/moe/server/__init__.py

@@ -1,356 +1,4 @@
-from __future__ import annotations
-
-import multiprocessing as mp
-import multiprocessing.synchronize
-import threading
-from contextlib import contextmanager
-from functools import partial
-from pathlib import Path
-from typing import Dict, List, Optional, Tuple
-
-import torch
-from multiaddr import Multiaddr
-
-import hivemind
-from hivemind.dht import DHT
-from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
-from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
+from hivemind.moe.server.dht_handler import declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
-from hivemind.moe.server.layers import (
-    add_custom_models_from_file,
-    name_to_block,
-    name_to_input,
-    register_expert_class,
-    schedule_name_to_scheduler,
-)
-from hivemind.moe.server.runtime import Runtime
-from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import BatchTensorDescriptor, Endpoint, find_open_port, get_logger, get_port, replace_port
-
-logger = get_logger(__name__)
-
-
-class Server(threading.Thread):
-    """
-    Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
-    After creation, a server should be started: see Server.run or Server.run_in_background.
-
-    A working server does 3 things:
-     - processes incoming forward/backward requests via Runtime (created by the server)
-     - publishes updates to expert status every :update_period: seconds
-     - follows orders from HivemindController - if it exists
-
-    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
-    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
-    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
-    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
-        if too small for normal functioning, we recommend 4 handlers per expert backend.
-    :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
-        if dht is None, this parameter is ignored.
-    :param start: if True, the server will immediately start as a background thread and returns control after server
-        is ready (see .ready below)
-    """
-
-    def __init__(
-        self,
-        dht: Optional[DHT],
-        expert_backends: Dict[str, ExpertBackend],
-        listen_on: Endpoint = "0.0.0.0:*",
-        num_connection_handlers: int = 1,
-        update_period: int = 30,
-        start=False,
-        checkpoint_dir=None,
-        **kwargs,
-    ):
-        super().__init__()
-        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=find_open_port())
-        self.listen_on, self.port = listen_on, get_port(listen_on)
-
-        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
-        if checkpoint_dir is not None:
-            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
-        else:
-            self.checkpoint_saver = None
-        self.runtime = Runtime(self.experts, **kwargs)
-
-        if self.dht and self.experts:
-            self.dht_handler_thread = DHTHandlerThread(
-                experts=self.experts,
-                dht=self.dht,
-                endpoint=self.listen_on,
-                update_period=self.update_period,
-                daemon=True,
-            )
-
-        if start:
-            self.run_in_background(await_ready=True)
-
-    @classmethod
-    def create(
-        cls,
-        listen_on="0.0.0.0:*",
-        num_experts: int = None,
-        expert_uids: str = None,
-        expert_pattern: str = None,
-        expert_cls="ffn",
-        hidden_dim=1024,
-        optim_cls=torch.optim.Adam,
-        scheduler: str = "none",
-        num_warmup_steps=None,
-        num_total_steps=None,
-        clip_grad_norm=None,
-        num_handlers=None,
-        min_batch_size=1,
-        max_batch_size=4096,
-        device=None,
-        no_dht=False,
-        initial_peers=(),
-        checkpoint_dir: Optional[Path] = None,
-        compression=CompressionType.NONE,
-        stats_report_interval: Optional[int] = None,
-        custom_module_path=None,
-        *,
-        start: bool,
-    ) -> Server:
-        """
-        Instantiate a server with several identical experts. See argparse comments below for details
-        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
-        :param num_experts: run this many identical experts
-        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-           means "sample random experts between myprefix.0.0 and myprefix.255.255;
-        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
-        :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
-        :param hidden_dim: main dimension for expert_cls
-        :param num_handlers: server will use this many parallel processes to handle incoming requests
-        :param min_batch_size: total num examples in the same batch will be greater than this value
-        :param max_batch_size: total num examples in the same batch will not exceed this value
-        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
-
-        :param optim_cls: uses this optimizer to train all experts
-        :param scheduler: if not `none`, the name of the expert LR scheduler
-        :param num_warmup_steps: the number of warmup steps for LR schedule
-        :param num_total_steps: the total number of steps for LR schedule
-        :param clip_grad_norm: maximum gradient norm used for clipping
-
-        :param no_dht: if specified, the server will not be attached to a dht
-        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
-
-        :param checkpoint_dir: directory to save and load expert checkpoints
-
-        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
-            hosted on this server. For a more fine-grained compression, start server in python and specify compression
-            for each BatchTensorProto in ExpertBackend for the respective experts.
-
-        :param start: if True, starts server right away and returns when server is ready for requests
-        :param stats_report_interval: interval between two reports of batch processing performance statistics
-        """
-        if custom_module_path is not None:
-            add_custom_models_from_file(custom_module_path)
-        assert expert_cls in name_to_block
-
-        if no_dht:
-            dht = None
-        else:
-            dht = hivemind.DHT(initial_peers=initial_peers, start=True)
-            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
-
-        assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
-            num_experts is not None and expert_uids is None
-        ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
-
-        if expert_uids is None:
-            if checkpoint_dir is not None:
-                assert is_directory(checkpoint_dir)
-                expert_uids = [
-                    child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
-                ]
-                total_experts_in_checkpoint = len(expert_uids)
-                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
-
-                if total_experts_in_checkpoint > num_experts:
-                    raise ValueError(
-                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
-                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
-                    )
-            else:
-                expert_uids = []
-
-            uids_to_generate = num_experts - len(expert_uids)
-            if uids_to_generate > 0:
-                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
-                expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
-
-        num_experts = len(expert_uids)
-        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
-        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-
-        sample_input = name_to_input[expert_cls](3, hidden_dim)
-        if isinstance(sample_input, tuple):
-            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
-        else:
-            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
-
-        scheduler = schedule_name_to_scheduler[scheduler]
-
-        # initialize experts
-        experts = {}
-        for expert_uid in expert_uids:
-            expert = name_to_block[expert_cls](hidden_dim)
-            experts[expert_uid] = hivemind.ExpertBackend(
-                name=expert_uid,
-                expert=expert,
-                args_schema=args_schema,
-                optimizer=optim_cls(expert.parameters()),
-                scheduler=scheduler,
-                num_warmup_steps=num_warmup_steps,
-                num_total_steps=num_total_steps,
-                clip_grad_norm=clip_grad_norm,
-                min_batch_size=min_batch_size,
-                max_batch_size=max_batch_size,
-            )
-
-        if checkpoint_dir is not None:
-            load_experts(experts, checkpoint_dir)
-
-        return cls(
-            dht,
-            experts,
-            listen_on=listen_on,
-            num_connection_handlers=num_handlers,
-            device=device,
-            checkpoint_dir=checkpoint_dir,
-            stats_report_interval=stats_report_interval,
-            start=start,
-        )
-
-    def run(self):
-        """
-        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
-        runs Runtime (self.runtime) to process incoming requests.
-        """
-        logger.info(f"Server started at {self.listen_on}")
-        logger.info(f"Got {len(self.experts)} experts:")
-        for expert_name, backend in self.experts.items():
-            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
-
-        if self.dht:
-            if not self.dht.is_alive():
-                self.dht.run_in_background(await_ready=True)
-
-            if self.experts:
-                self.dht_handler_thread.start()
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.start()
-
-        for process in self.conn_handlers:
-            if not process.is_alive():
-                process.start()
-            process.ready.wait()
-
-        try:
-            self.runtime.run()
-        finally:
-            self.shutdown()
-
-    def run_in_background(self, await_ready=True, timeout=None):
-        """
-        Starts Server in a background thread. if await_ready, this method will wait until background server
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
-
-    @property
-    def ready(self) -> mp.synchronize.Event:
-        """
-        An event (multiprocessing.Event) that is set when the server is ready to process requests.
-
-        Example
-        =======
-        >>> server.start()
-        >>> server.ready.wait(timeout=10)
-        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
-        """
-        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
-
-    def shutdown(self):
-        """
-        Gracefully terminate the server, process-safe.
-        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
-        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
-        """
-        self.ready.clear()
-
-        for process in self.conn_handlers:
-            process.terminate()
-            process.join()
-        logger.debug("Connection handlers terminated")
-
-        if self.dht and self.experts:
-            self.dht_handler_thread.stop.set()
-            self.dht_handler_thread.join()
-
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
-
-        if self.dht is not None:
-            self.dht.shutdown()
-            self.dht.join()
-
-        logger.debug(f"Shutting down runtime")
-
-        self.runtime.shutdown()
-        logger.info("Server shutdown succesfully")
-
-
-@contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, List[Multiaddr]]:
-    """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
-    pipe, runners_pipe = mp.Pipe(duplex=True)
-    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
-    try:
-        runner.start()
-        # once the server is ready, runner will send us
-        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
-        start_ok, data = pipe.recv()
-        if start_ok:
-            yield data
-            pipe.send("SHUTDOWN")  # on exit from context, send shutdown signal
-        else:
-            raise RuntimeError(f"Server failed to start: {data}")
-    finally:
-        runner.join(timeout=shutdown_timeout)
-        if runner.is_alive():
-            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
-            runner.kill()
-            logger.info("Server terminated.")
-
-
-def _server_runner(pipe, *args, **kwargs):
-    try:
-        server = Server.create(*args, start=True, **kwargs)
-    except Exception as e:
-        logger.exception(f"Encountered an exception when starting a server: {e}")
-        pipe.send((False, f"{type(e).__name__} {e}"))
-        return
-
-    try:
-        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
-        pipe.send((True, (server.listen_on, dht_maddrs)))
-        pipe.recv()  # wait for shutdown signal
-
-    finally:
-        logger.info("Shutting down server...")
-        server.shutdown()
-        server.join()
-        logger.info("Server shut down.")
+from hivemind.moe.server.layers import register_expert_class
+from hivemind.moe.server.server import Server, background_server

+ 3 - 4
hivemind/moe/server/connection_handler.py

@@ -1,16 +1,15 @@
 import multiprocessing as mp
 import os
-import pickle
 from typing import Dict
 
 import grpc
 import torch
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import Endpoint, get_logger, nested_flatten
+from hivemind.utils import Endpoint, MSGPackSerializer, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 logger = get_logger(__name__)
@@ -61,7 +60,7 @@ class ConnectionHandler(mp.context.ForkProcess):
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
-        return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
+        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
 
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]

+ 2 - 2
hivemind/moe/server/dht_handler.py

@@ -56,7 +56,7 @@ def declare_experts(
 async def _declare_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
@@ -89,7 +89,7 @@ async def _get_experts(
 ) -> List[Optional[RemoteExpert]]:
     if expiration_time is None:
         expiration_time = get_dht_time()
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)

+ 4 - 3
hivemind/moe/server/expert_backend.py

@@ -74,8 +74,8 @@ class ExpertBackend:
 
         if outputs_schema is None:
             # run expert once to get outputs schema
-            dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
-            dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
+            dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
+            dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 
@@ -187,7 +187,8 @@ class ExpertBackend:
 
     def get_stats(self) -> Dict:
         """
-        Return current expert training statistics (number of updates, number of processed examples after last optimizer step)
+        Return current expert training statistics (number of updates, number of processed examples after
+        last optimizer step)
         """
         return {"updates": self.update_count, "examples_processed": self.examples_processed}
 

+ 2 - 73
hivemind/moe/server/expert_uid.py

@@ -1,12 +1,7 @@
-import random
 import re
-from typing import List, NamedTuple, Optional, Tuple, Union
+from typing import NamedTuple, Tuple, Union
 
-import hivemind
-from hivemind.dht import DHT
-from hivemind.utils import Endpoint, get_logger
-
-logger = get_logger(__name__)
+from hivemind.utils import Endpoint
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
@@ -32,69 +27,3 @@ def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPref
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
     pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
     return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
-
-
-def generate_uids_from_pattern(
-    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
-) -> List[str]:
-    """
-    Sample experts from a given pattern, remove duplicates.
-    :param num_experts: sample this many unique expert uids
-    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-     means "sample random experts between myprefix.0.0 and myprefix.255.255;
-    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
-    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
-    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
-     a small chance of sampling duplicate expert uids.
-    """
-    remaining_attempts = attempts_per_expert * num_experts
-    found_uids, attempted_uids = list(), set()
-
-    def _generate_uid():
-        if expert_pattern is None:
-            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
-
-        uid = []
-        for block in expert_pattern.split(UID_DELIMITER):
-            try:
-                if "[" not in block and "]" not in block:
-                    uid.append(block)
-                elif block.startswith("[") and block.endswith("]") and ":" in block:
-                    slice_start, slice_end = map(int, block[1:-1].split(":"))
-                    uid.append(str(random.randint(slice_start, slice_end - 1)))
-                else:
-                    raise ValueError("Block must be either fixed or a range [from:to]")
-            except KeyboardInterrupt:
-                raise
-            except Exception as e:
-                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
-        return UID_DELIMITER.join(uid)
-
-    while remaining_attempts > 0 and len(found_uids) < num_experts:
-
-        # 1. sample new expert uids at random
-        new_uids = []
-        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
-            new_uid = _generate_uid()
-            remaining_attempts -= 1
-            if new_uid not in attempted_uids:
-                attempted_uids.add(new_uid)
-                new_uids.append(new_uid)
-
-        # 2. look into DHT (if given) and remove duplicates
-        if dht:
-            existing_expert_uids = {
-                found_expert.uid
-                for found_expert in hivemind.moe.server.get_experts(dht, new_uids)
-                if found_expert is not None
-            }
-            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
-
-        found_uids += new_uids
-
-    if len(found_uids) != num_experts:
-        logger.warning(
-            f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
-            f"{attempts_per_expert * num_experts} attempts"
-        )
-    return found_uids

+ 419 - 0
hivemind/moe/server/server.py

@@ -0,0 +1,419 @@
+from __future__ import annotations
+
+import multiprocessing as mp
+import random
+import threading
+from contextlib import contextmanager
+from functools import partial
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from multiaddr import Multiaddr
+
+from hivemind.dht import DHT
+from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.moe.server.layers import (
+    add_custom_models_from_file,
+    name_to_block,
+    name_to_input,
+    schedule_name_to_scheduler,
+)
+from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.logging import get_logger
+from hivemind.utils.networking import Endpoint, get_free_port, get_port, replace_port
+from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
+
+logger = get_logger(__name__)
+
+
+class Server(threading.Thread):
+    """
+    Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts.
+    After creation, a server should be started: see Server.run or Server.run_in_background.
+
+    A working server does two things:
+     - processes incoming forward/backward requests via Runtime (created by the server)
+     - publishes updates to expert status every :update_period: seconds
+
+    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
+     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
+    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
+    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
+    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
+        if too small for normal functioning, we recommend 4 handlers per expert backend.
+    :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
+        if dht is None, this parameter is ignored.
+    :param start: if True, the server will immediately start as a background thread and returns control after server
+        is ready (see .ready below)
+    """
+
+    def __init__(
+        self,
+        dht: Optional[DHT],
+        expert_backends: Dict[str, ExpertBackend],
+        listen_on: Endpoint = "0.0.0.0:*",
+        num_connection_handlers: int = 1,
+        update_period: int = 30,
+        start=False,
+        checkpoint_dir=None,
+        **kwargs,
+    ):
+        super().__init__()
+        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
+        if get_port(listen_on) is None:
+            listen_on = replace_port(listen_on, new_port=get_free_port())
+        self.listen_on, self.port = listen_on, get_port(listen_on)
+
+        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
+        if checkpoint_dir is not None:
+            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
+        else:
+            self.checkpoint_saver = None
+        self.runtime = Runtime(self.experts, **kwargs)
+
+        if self.dht and self.experts:
+            self.dht_handler_thread = DHTHandlerThread(
+                experts=self.experts,
+                dht=self.dht,
+                endpoint=self.listen_on,
+                update_period=self.update_period,
+                daemon=True,
+            )
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    @classmethod
+    def create(
+        cls,
+        listen_on="0.0.0.0:*",
+        num_experts: int = None,
+        expert_uids: str = None,
+        expert_pattern: str = None,
+        expert_cls="ffn",
+        hidden_dim=1024,
+        optim_cls=torch.optim.Adam,
+        scheduler: str = "none",
+        num_warmup_steps=None,
+        num_total_steps=None,
+        clip_grad_norm=None,
+        num_handlers=None,
+        min_batch_size=1,
+        max_batch_size=4096,
+        device=None,
+        no_dht=False,
+        initial_peers=(),
+        checkpoint_dir: Optional[Path] = None,
+        compression=CompressionType.NONE,
+        stats_report_interval: Optional[int] = None,
+        custom_module_path=None,
+        *,
+        start: bool,
+    ) -> Server:
+        """
+        Instantiate a server with several identical experts. See argparse comments below for details
+        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
+        :param num_experts: run this many identical experts
+        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+           means "sample random experts between myprefix.0.0 and myprefix.255.255;
+        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
+        :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
+        :param hidden_dim: main dimension for expert_cls
+        :param num_handlers: server will use this many parallel processes to handle incoming requests
+        :param min_batch_size: total num examples in the same batch will be greater than this value
+        :param max_batch_size: total num examples in the same batch will not exceed this value
+        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
+
+        :param optim_cls: uses this optimizer to train all experts
+        :param scheduler: if not `none`, the name of the expert LR scheduler
+        :param num_warmup_steps: the number of warmup steps for LR schedule
+        :param num_total_steps: the total number of steps for LR schedule
+        :param clip_grad_norm: maximum gradient norm used for clipping
+
+        :param no_dht: if specified, the server will not be attached to a dht
+        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
+
+        :param checkpoint_dir: directory to save and load expert checkpoints
+
+        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
+            hosted on this server. For a more fine-grained compression, start server in python and specify compression
+            for each BatchTensorProto in ExpertBackend for the respective experts.
+
+        :param start: if True, starts server right away and returns when server is ready for requests
+        :param stats_report_interval: interval between two reports of batch processing performance statistics
+        """
+        if custom_module_path is not None:
+            add_custom_models_from_file(custom_module_path)
+        assert expert_cls in name_to_block
+
+        if no_dht:
+            dht = None
+        else:
+            dht = DHT(initial_peers=initial_peers, start=True)
+            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+
+        assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
+            num_experts is not None and expert_uids is None
+        ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
+
+        if expert_uids is None:
+            if checkpoint_dir is not None:
+                assert is_directory(checkpoint_dir)
+                expert_uids = [
+                    child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
+                ]
+                total_experts_in_checkpoint = len(expert_uids)
+                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
+
+                if total_experts_in_checkpoint > num_experts:
+                    raise ValueError(
+                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
+                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
+                    )
+            else:
+                expert_uids = []
+
+            uids_to_generate = num_experts - len(expert_uids)
+            if uids_to_generate > 0:
+                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
+                expert_uids.extend(_generate_uids(uids_to_generate, expert_pattern, dht))
+
+        num_experts = len(expert_uids)
+        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
+        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+
+        sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim)
+        if isinstance(sample_input, tuple):
+            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
+        else:
+            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
+
+        scheduler = schedule_name_to_scheduler[scheduler]
+
+        # initialize experts
+        experts = {}
+        for expert_uid in expert_uids:
+            expert = name_to_block[expert_cls](hidden_dim)
+            experts[expert_uid] = ExpertBackend(
+                name=expert_uid,
+                expert=expert,
+                args_schema=args_schema,
+                optimizer=optim_cls(expert.parameters()),
+                scheduler=scheduler,
+                num_warmup_steps=num_warmup_steps,
+                num_total_steps=num_total_steps,
+                clip_grad_norm=clip_grad_norm,
+                min_batch_size=min_batch_size,
+                max_batch_size=max_batch_size,
+            )
+
+        if checkpoint_dir is not None:
+            load_experts(experts, checkpoint_dir)
+
+        return cls(
+            dht,
+            experts,
+            listen_on=listen_on,
+            num_connection_handlers=num_handlers,
+            device=device,
+            checkpoint_dir=checkpoint_dir,
+            stats_report_interval=stats_report_interval,
+            start=start,
+        )
+
+    def run(self):
+        """
+        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
+        runs Runtime (self.runtime) to process incoming requests.
+        """
+        logger.info(f"Server started at {self.listen_on}")
+        logger.info(f"Got {len(self.experts)} experts:")
+        for expert_name, backend in self.experts.items():
+            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
+
+        if self.dht:
+            if not self.dht.is_alive():
+                self.dht.run_in_background(await_ready=True)
+
+            if self.experts:
+                self.dht_handler_thread.start()
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.start()
+
+        for process in self.conn_handlers:
+            if not process.is_alive():
+                process.start()
+            process.ready.wait()
+
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
+
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts Server in a background thread. if await_ready, this method will wait until background server
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
+
+    @property
+    def ready(self) -> mp.synchronize.Event:
+        """
+        An event (multiprocessing.Event) that is set when the server is ready to process requests.
+
+        Example
+        =======
+        >>> server.start()
+        >>> server.ready.wait(timeout=10)
+        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
+        """
+        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
+
+    def shutdown(self):
+        """
+        Gracefully terminate the server, process-safe.
+        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
+        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
+        """
+        self.ready.clear()
+
+        for process in self.conn_handlers:
+            process.terminate()
+            process.join()
+        logger.debug("Connection handlers terminated")
+
+        if self.dht and self.experts:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
+
+        if self.dht is not None:
+            self.dht.shutdown()
+            self.dht.join()
+
+        logger.debug(f"Shutting down runtime")
+
+        self.runtime.shutdown()
+        logger.info("Server shutdown succesfully")
+
+
+@contextmanager
+def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, List[Multiaddr]]:
+    """A context manager that creates server in a background process, awaits .ready on entry and shuts down on exit"""
+    pipe, runners_pipe = mp.Pipe(duplex=True)
+    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
+    try:
+        runner.start()
+        # once the server is ready, runner will send us
+        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
+        start_ok, data = pipe.recv()
+        if start_ok:
+            yield data
+            pipe.send("SHUTDOWN")  # on exit from context, send shutdown signal
+        else:
+            raise RuntimeError(f"Server failed to start: {data}")
+    finally:
+        runner.join(timeout=shutdown_timeout)
+        if runner.is_alive():
+            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
+            runner.kill()
+            logger.info("Server terminated")
+
+
+def _server_runner(pipe, *args, **kwargs):
+    try:
+        server = Server.create(*args, start=True, **kwargs)
+    except Exception as e:
+        logger.exception(f"Encountered an exception when starting a server: {e}")
+        pipe.send((False, f"{type(e).__name__} {e}"))
+        return
+
+    try:
+        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
+        pipe.send((True, (server.listen_on, dht_maddrs)))
+        pipe.recv()  # wait for shutdown signal
+
+    finally:
+        logger.info("Shutting down server...")
+        server.shutdown()
+        server.join()
+        logger.info("Server shut down")
+
+
+def _generate_uids(
+    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
+) -> List[str]:
+    """
+    Sample experts from a given pattern, remove duplicates.
+    :param num_experts: sample this many unique expert uids
+    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+     means "sample random experts between myprefix.0.0 and myprefix.255.255;
+    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
+    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
+    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
+     a small chance of sampling duplicate expert uids.
+    """
+    remaining_attempts = attempts_per_expert * num_experts
+    found_uids, attempted_uids = list(), set()
+
+    def _generate_uid():
+        if expert_pattern is None:
+            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
+
+        uid = []
+        for block in expert_pattern.split(UID_DELIMITER):
+            try:
+                if "[" not in block and "]" not in block:
+                    uid.append(block)
+                elif block.startswith("[") and block.endswith("]") and ":" in block:
+                    slice_start, slice_end = map(int, block[1:-1].split(":"))
+                    uid.append(str(random.randint(slice_start, slice_end - 1)))
+                else:
+                    raise ValueError("Block must be either fixed or a range [from:to]")
+            except KeyboardInterrupt:
+                raise
+            except Exception as e:
+                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
+        return UID_DELIMITER.join(uid)
+
+    while remaining_attempts > 0 and len(found_uids) < num_experts:
+
+        # 1. sample new expert uids at random
+        new_uids = []
+        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
+            new_uid = _generate_uid()
+            remaining_attempts -= 1
+            if new_uid not in attempted_uids:
+                attempted_uids.add(new_uid)
+                new_uids.append(new_uid)
+
+        # 2. look into DHT (if given) and remove duplicates
+        if dht is not None:
+            existing_expert_uids = {
+                found_expert.uid for found_expert in get_experts(dht, new_uids) if found_expert is not None
+            }
+            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
+
+        found_uids += new_uids
+
+    if len(found_uids) != num_experts:
+        logger.warning(
+            f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
+            f"{attempts_per_expert * num_experts} attempts"
+        )
+    return found_uids

+ 3 - 0
hivemind/optim/__init__.py

@@ -1,4 +1,7 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
+from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
+from hivemind.optim.training_averager import TrainingAverager

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 import torch.optim
 
-from hivemind import TrainingAverager
 from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.training_averager import TrainingAverager
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 8 - 0
hivemind/optim/base.py

@@ -1,3 +1,5 @@
+from warnings import warn
+
 import torch
 
 from hivemind.dht import DHT
@@ -8,6 +10,12 @@ class DecentralizedOptimizerBase(torch.optim.Optimizer):
 
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
         self.opt, self.dht = opt, dht
+        warn(
+            "DecentralizedOptimizerBase and its subclasses have been deprecated and will be removed "
+            "in hivemind 1.1.0. Use hivemind.Optimizer instead",
+            FutureWarning,
+            stacklevel=2,
+        )
 
     @property
     def state(self):

+ 135 - 61
hivemind/optim/collaborative.py

@@ -9,13 +9,14 @@ import numpy as np
 import torch
 from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 
-from hivemind.averaging.training import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.utils import Endpoint, get_dht_time, get_logger
+from hivemind.optim.grad_scaler import HivemindGradScaler
+from hivemind.optim.training_averager import TrainingAverager
+from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils.performance_ema import PerformanceEMA
 
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
@@ -56,11 +57,15 @@ class TrainingProgressSchema(BaseModel):
 
 class CollaborativeOptimizer(DecentralizedOptimizerBase):
     """
-    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
+    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers.
 
     These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
     Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
 
+    :note: **For new projects, please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
+      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and many advanced ones.
+      CollaborativeOptimizer will still be supported for a while, but it will be deprecated in v1.1.0.
+
     :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
 
       * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
@@ -85,6 +90,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
     :param scheduler: if specified, use this scheduler to update optimizer learning rate
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
@@ -114,6 +120,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         performance_ema_alpha: float = 0.1,
         metadata_expiration: float = 60.0,
         averaging_timeout: Optional[float] = None,
+        load_state_timeout: float = 600.0,
         step_tolerance: int = 1,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
@@ -137,19 +144,23 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             default_refresh_period,
         )
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
-        self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
+        self.averaging_timeout = averaging_timeout
+        self.load_state_timeout = load_state_timeout
+        self.metadata_expiration = metadata_expiration
         self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
         self.client_mode, self.step_tolerance = client_mode, step_tolerance
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
 
+        self._step_supports_amp_scaling = self.reuse_grad_buffers  # enable custom execution with torch GradScaler
+
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
-        self.local_steps_accumulated = 0  # a number of calls to step() since last optimizer update
+        self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
-        self.collaboration_state = self.fetch_collaboration_state()
+        self.collaboration_state = self._fetch_state()
         self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
         self.lock_local_progress, self.should_report_progress = Lock(), Event()
         self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
@@ -177,6 +188,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     @property
     def is_synchronized(self) -> bool:
+        return self.local_step >= self.collaboration_state.optimizer_step
+
+    @property
+    def is_within_tolerance(self) -> bool:
         return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
 
     def is_alive(self) -> bool:
@@ -185,18 +200,40 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
         with self.lock_collaboration_state:
-            self.averager.load_state_from_peers(**kwargs)
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            while True:
+                try:
+                    self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    break
+                except KeyboardInterrupt:
+                    raise
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
+                    continue
+
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
-    def step(self, batch_size: Optional[int] = None, **kwargs):
+    def state_dict(self) -> dict:
+        state_dict = super().state_dict()
+        state_dict["state"]["collaborative_step"] = self.local_step
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "collaborative_step" in state_dict["state"]:
+            self.averager.local_step = state_dict["state"].pop("collaborative_step")
+        return super().load_state_dict(state_dict)
+
+    def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
 
         :param batch_size: optional override for batch_size_per_step from init
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
+        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
+            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler)")
         if self.batch_size_per_step is None:
             if batch_size is None:
                 raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
@@ -204,10 +241,20 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.batch_size_per_step = batch_size
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
-        if not self.is_synchronized:
-            logger.log(self.status_loglevel, "Peer is out of sync.")
+        if not self.is_synchronized and not self.is_within_tolerance:
+            logger.log(self.status_loglevel, "Peer is out of sync")
             self.load_state_from_peers()
             return
+        elif not self.is_synchronized and self.is_within_tolerance:
+            self.averager.local_step = self.collaboration_state.optimizer_step
+            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}")
+
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.reset_accumulated_grads_()
+            self.should_report_progress.set()
+            return
 
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
             logger.warning(
@@ -219,50 +266,72 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
-            self.local_steps_accumulated += 1
-            self.performance_ema.update(num_processed=batch_size)
+            self.local_updates_accumulated += 1
+            self.performance_ema.update(task_size=batch_size)
             self.should_report_progress.set()
 
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self.fetch_collaboration_state()
-        self.collaboration_state_updated.set()
-
-        if not self.is_synchronized:
-            self.load_state_from_peers()
-            return
-
+        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
         with self.performance_ema.pause(), self.lock_collaboration_state:
+            self.collaboration_state = self._fetch_state()
+            self.collaboration_state_updated.set()
+
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
+            self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.unscale_(self)
+
             current_step, group_info = self.averager.local_step, None
 
             if self.collaboration_state.num_peers > 1:
                 mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
                 weight = self.local_samples_accumulated / mean_samples_per_worker
                 try:
-                    group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
+                    group_info = self.averager.step(
+                        weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
+                    )
                     if group_info:
                         logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                        # update our current step if we averaged with another peer that was at a more recent step
+                        for peer, peer_step in group_info.items():
+                            if isinstance(peer_step, int):
+                                current_step = max(current_step, peer_step)
+                            else:
+                                logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
+
                 except BaseException as e:
-                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
+                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
 
             else:
                 logger.log(
                     self.status_loglevel,
-                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
+                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s)",
                 )
 
-            self.opt.step()
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.step(self)
+            else:
+                self.opt.step()
+
             self.reset_accumulated_grads_()
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
             self.update_scheduler()
 
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.update()
+
+            if not self.averager.client_mode:
+                self.averager.state_sharing_priority = self.local_step
+
         logger.log(self.status_loglevel, f"Optimizer step: done!")
 
         return group_info
@@ -277,19 +346,26 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self.fetch_collaboration_state()
+        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
+        self.collaboration_state = self._fetch_state()
         self.collaboration_state_updated.set()
 
         with self.lock_collaboration_state:
-            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             current_step, group_info = self.averager.local_step, None
+
             try:
-                group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
+                group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
                 if group_info:
                     logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
+
+                    # update our current step if we averaged with another peer that was at a more recent step
+                    for peer, peer_step in group_info.items():
+                        if isinstance(peer_step, int):
+                            current_step = max(current_step, peer_step)
+                        else:
+                            logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
             except BaseException as e:
-                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
+                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
 
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
@@ -313,38 +389,36 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         """local gradient accumulators"""
         if self.reuse_grad_buffers:
             yield from self._grad_buffers()
-        elif self._grads is None:
-            with torch.no_grad():
-                self._grads = [
-                    torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
-                ]
+            return
+
+        if self._grads is None:
+            self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
         yield from self._grads
 
     @torch.no_grad()
     def accumulate_grads_(self, batch_size: int):
         """add current gradients to grad accumulators (if any)"""
         if self.reuse_grad_buffers:
-            return  # user is responsible for accumulating gradients in .grad buffers
-        alpha = float(batch_size) / self.batch_size_per_step
-        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-            grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+            # user is responsible for accumulating gradients in .grad buffers
+            assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
+        else:
+            alpha = float(batch_size) / self.batch_size_per_step
+            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
 
     @torch.no_grad()
     def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
-        if self.reuse_grad_buffers:
-            return
-        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-            grad_buf[...] = grad_acc.to(grad_buf.device)
-            if scale_by is not None:
+        if not self.reuse_grad_buffers:
+            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+                grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
+        if scale_by is not None:
+            for grad_buf in self._grad_buffers():
                 grad_buf.mul_(scale_by)
 
     @torch.no_grad()
     def reset_accumulated_grads_(self):
-        if self.reuse_grad_buffers:
-            self.opt.zero_grad()
-        else:
-            for grad_buf in self.accumulated_grads():
-                grad_buf.zero_()
+        for grad_buf in self.accumulated_grads():
+            grad_buf.zero_()
 
     def report_training_progress(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next step"""
@@ -381,18 +455,18 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 continue  # if state was updated externally, reset timer
 
             with self.lock_collaboration_state:
-                self.collaboration_state = self.fetch_collaboration_state()
+                self.collaboration_state = self._fetch_state()
 
-    def fetch_collaboration_state(self) -> CollaborationState:
+    def _fetch_state(self) -> CollaborationState:
         """Read performance statistics reported by peers, estimate progress towards next batch"""
         response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         current_time = get_dht_time()
 
         if not isinstance(response, dict) or len(response) == 0:
             logger.log(self.status_loglevel, f"Found no active peers: {response}")
-            local_eta_next_step = (
-                max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
-            )
+            samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
+            local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
+
             return CollaborationState(
                 self.local_step,
                 self.local_samples_accumulated,
@@ -441,9 +515,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         )
         logger.log(
             self.status_loglevel,
-            f"Collaboration accumulated {total_samples_accumulated} samples from "
-            f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
-            f"(refresh in {time_to_next_fetch:.2f}s.)",
+            f"{self.prefix} accumulated {total_samples_accumulated} samples from "
+            f"{num_peers} peers for step #{global_optimizer_step}. "
+            f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
         )
         return CollaborationState(
             global_optimizer_step,
@@ -478,7 +552,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             value=None,
             expiration_time=get_dht_time() + self.metadata_expiration,
         )
-        logger.debug(f"{self.__class__.__name__} is shut down.")
+        logger.debug(f"{self.__class__.__name__} is shut down")
 
     def __del__(self):
         self.shutdown()

+ 226 - 0
hivemind/optim/grad_averager.py

@@ -0,0 +1,226 @@
+import contextlib
+from typing import Iterable, Iterator, Optional
+
+import torch
+
+import hivemind
+from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
+from hivemind.utils import DHTExpiration, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class GradientAverager(DecentralizedAverager):
+    """
+    An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
+    GradientAverager is meant to be used within hivemind.Optimizer, but it can be used standalone (see example below).
+
+    GradientAverager manages three sets of buffers:
+    (1) model gradients - the gradients associated with local model parameters by PyTorch (param.grad).
+        These tensors are typically stored on device and updated by torch autograd
+    (2) gradient accumulators - an [optional] set of buffers where local gradients are accumulated.
+      - note: if reuse_grad_buffers is True, the averager will use gradients from parameters as local accumulators,
+        which reduces RAM usage but requires the user to avoid calling zero_grad / clip_grad manually
+    (3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
+
+
+    Example:
+
+    >>> model = SuchModelMuchLayers()
+    >>> opt = torch.optim.Adam(model.parameters())
+    >>> grad_averager = GradientAverager(model.parameters(), dht=hivemind.DHT(...))
+    >>> next_step_time = hivemind.get_dht_time() + 60   # runs global steps every 60 seconds
+    >>> next_step_control = None
+    >>> while True:
+    >>>    # accumulate as many gradients as you can before next_step_time
+    >>>    loss = compute_loss(model, batch_size=32)
+    >>>    loss.backward()
+    >>>    grad_averager.accumulate_grads_(batch_size=32)
+    >>>    # [optional] next step in 5 seconds, start looking for peers in advance
+    >>>    if next_step_time - hivemind.get_dht_time() <= 5
+    >>>        next_step_control = grad_averager.schedule_step(scheduled_time=next_step_time)
+    >>>    # aggregate gradients and perform optimizer step
+    >>>    if hivemind.get_dht_time() >= next_step_time:
+    >>>        grad_averager.step(control=next_step_control)
+    >>>        with grad_averager.use_averaged_gradients():  # this will fill param.grads with aggregated gradients
+    >>>            opt.step()  # update model parameters using averaged gradients
+    >>>        grad_averager.reset_accumulated_grads_()  # prepare for next step
+    >>>        next_step_time = hivemind.get_dht_time() + 60
+    >>>        next_step_control = None
+
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        *,
+        dht: hivemind.DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        **kwargs,
+    ):
+        if reuse_grad_buffers and accumulate_grads_on is not None:
+            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
+        client_mode = client_mode if client_mode is not None else dht.client_mode
+        self.parameters = tuple(parameters)
+        self.reuse_grad_buffers = reuse_grad_buffers
+        self.warn = warn
+        self.local_samples_accumulated = 0
+        self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        self._local_accumulators = None
+        if not reuse_grad_buffers:
+            self._local_accumulators = tuple(
+                torch.zeros_like(grad, device=accumulate_grads_on) for grad in self._grads_from_parameters()
+            )
+        self._accumulators_used_in_step = False
+        self._new_averaged_grads = False
+
+        with torch.no_grad():
+            averaged_grads = tuple(
+                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+            )
+        super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
+
+    def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
+        """gradient buffers associated with parameters"""
+        for param in self.parameters:
+            if param.grad is None:
+                param.grad = torch.zeros_like(param)
+            yield param.grad
+
+    @torch.no_grad()
+    def _grad_accumulators(self) -> Iterator[torch.Tensor]:
+        """averager-based gradient accumulators"""
+        assert (self._local_accumulators is None) == self.reuse_grad_buffers
+        yield from self._grads_from_parameters() if self.reuse_grad_buffers else self._local_accumulators
+
+    @torch.no_grad()
+    def accumulate_grads_(self, batch_size: int):
+        """add current gradients to local grad accumulators (if used)"""
+        if self._accumulators_used_in_step and self.warn:
+            logger.warning(
+                "[warn=True] Gradient accumulators were not reset since the last averaging round. Please "
+                "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)"
+            )
+            self._accumulators_used_in_step = False  # warn once per round
+        if self._anchor_batch_size is None:
+            # remember the first batch size to correctly re-scale gradients if subsequent batches have a different size
+            self._anchor_batch_size = batch_size
+        self.local_samples_accumulated += batch_size
+        self.local_times_accumulated += 1
+        if self.reuse_grad_buffers:
+            pass  # user is responsible for accumulating gradients in .grad buffers
+        else:
+            alpha = float(batch_size) / self._anchor_batch_size
+            for grad_buf, grad_acc in zip(self._grads_from_parameters(), self._grad_accumulators()):
+                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
+    def step(
+        self,
+        weight: Optional[float] = None,
+        reset_accumulators: bool = True,
+        control: Optional[StepControl] = None,
+        timeout: Optional[float] = None,
+        wait: bool = True,
+        **kwargs,
+    ):
+        """
+        Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators
+
+        :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
+        :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
+        :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
+        :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True)
+        :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
+        """
+        if control is None:
+            control = self.schedule_step(timeout=timeout, **kwargs)
+        elif len(kwargs) > 0:
+            raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect")
+        assert not control.triggered, f"This {type(control)} instance was already used"
+        if self._new_averaged_grads and self.warn:
+            logger.warning(
+                "[warn=True] Starting new averaging round, but previous round results were not used. "
+                "This may be a sign of incorrect optimizer behavior"
+            )
+
+        self.load_accumulators_into_averager_()
+        self._accumulators_used_in_step = True
+        self._new_averaged_grads = True
+
+        control.weight = self.local_samples_accumulated if weight is None else weight
+        if reset_accumulators:
+            self.reset_accumulated_grads_()
+        control.allow_allreduce()
+
+        return control.result(timeout) if wait else control
+
+    @torch.no_grad()
+    def load_accumulators_into_averager_(self):
+        """load locally accumulated gradients into the averager for aggregation"""
+        # divide locally accumulated gradients by the number of times they were accumulated
+        grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
+        with self.get_tensors() as averaged_grads:
+            for grad_acc, averaged_grad in zip(self._grad_accumulators(), averaged_grads):
+                averaged_grad.copy_(grad_acc, non_blocking=True).mul_(grad_scale)
+
+    @torch.no_grad()
+    def reset_accumulated_grads_(self):
+        """reset averager-internal gradient accumulators and the denominator"""
+        self._accumulators_used_in_step = False
+        self.local_samples_accumulated = self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        for grad_buf in self._grad_accumulators():
+            grad_buf.zero_()
+
+    @contextlib.contextmanager
+    @torch.no_grad()
+    def use_averaged_gradients(self):
+        """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
+        self._new_averaged_grads = False
+        with self.get_tensors() as averaged_grads:
+            assert len(averaged_grads) == len(self.parameters)
+            try:
+                old_grads = [param.grad for param in self.parameters]
+                for param, new_grad in zip(self.parameters, averaged_grads):
+                    param.grad = new_grad
+                yield averaged_grads
+            finally:
+                for param, old_grad in zip(self.parameters, old_grads):
+                    param.grad = old_grad
+
+    def notify_used_averaged_gradients(self):
+        """Notify averager that the results of a previous averaging round are accounted for"""
+        self._new_averaged_grads = False

+ 125 - 0
hivemind/optim/grad_scaler.py

@@ -0,0 +1,125 @@
+import contextlib
+import threading
+from copy import deepcopy
+from typing import Dict, Optional
+
+import torch
+from torch.cuda.amp import GradScaler as TorchGradScaler
+from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
+from torch.optim import Optimizer as TorchOptimizer
+
+import hivemind
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class GradScaler(TorchGradScaler):
+    """
+    A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.
+
+    :note: if not using reuse_grad_buffers=True, one can and *should* train normally without this class, e.g. using
+      standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.
+
+    hivemind.GradScaler makes 3 modifications to the regular PyTorch AMP:
+
+    - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
+    - limit increasing gradient scale to only immediately after global optimizer steps
+    - allow training with some or master parameters in float16
+
+    :note: The above modiffications will be enabled automatically. One can (and should) use hivemind.GradScaler exactly
+      as regular ``torch.amp.GradScaler``.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._is_running_global_step = False
+        self._is_ready_to_update = False
+        self._inner_optimizer_states = {}
+        self._optimizer_states_to_reset = set()
+        self._lock = threading.RLock()
+
+    @contextlib.contextmanager
+    def running_global_step(self):
+        with self._lock:
+            was_running, self._is_running_global_step = self._is_running_global_step, True
+            try:
+                yield
+            finally:
+                self._is_running_global_step = was_running
+
+    def unscale_(self, optimizer: TorchOptimizer) -> bool:
+        with self._lock:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            if self._is_running_global_step:
+                super().unscale_(optimizer)
+                self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                # note: we store unscaled optimizer state in a separate dict and not in _per_optimizer_states in order
+                # to avoid an edge case where full DPU peer encounters overflow in local gradients while averaging
+                # offloaded gradients (i.e. after global unscale but before global step). Due to overflow, next call to
+                # .update on user side would reset *all* optimizer states and cause .step to unscale gradients twice.
+                # Offloaded optimizer is not affected by overflow in on-device gradients and should not be reset.
+                return True
+            else:
+                self._check_inf_per_device(optimizer)
+                self._optimizer_states_to_reset.add(id(optimizer))
+                return False
+
+    def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
+        if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
+            # ^-- invoked privately within hivemind optimizer
+            inner_optimizer = optimizer
+            with self._lock:
+                if self._is_ready_to_update:
+                    logger.warning("Please call grad_scaler.update() after each step")
+
+                inner_optimizer_state = self._inner_optimizer_states.pop(id(inner_optimizer), None)
+                if inner_optimizer_state is not None:
+                    self._per_optimizer_states[id(inner_optimizer)] = inner_optimizer_state
+                assert (
+                    self._per_optimizer_states[id(inner_optimizer)]["stage"] == OptState.UNSCALED
+                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step"
+                if self.are_grads_finite(inner_optimizer, use_cached=True):
+                    super().step(inner_optimizer, *args, **kwargs)
+                else:
+                    logger.warning("Skipping global step due to gradient over/underflow")
+                self._is_ready_to_update = True
+                return True
+        else:
+            super().step(optimizer)
+            self._optimizer_states_to_reset.add(id(optimizer))
+            return False
+
+    def update(self, new_scale: Optional[float] = None) -> bool:
+        with self._lock:
+            total_infs = 0
+            for optimizer_state in self._per_optimizer_states.values():
+                total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
+
+            if self._is_ready_to_update or total_infs != 0:
+                # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
+                super().update(new_scale)
+                self._is_ready_to_update = False
+                return True
+            else:
+                for opt_id in self._optimizer_states_to_reset:
+                    self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
+                self._optimizer_states_to_reset.clear()
+                return False
+
+    def _unscale_grads_(
+        self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
+    ) -> Dict[torch.device, torch.Tensor]:
+        # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
+        # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
+        return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
+
+    def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
+        opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
+        return not sum(v.item() for v in opt_dict.values())
+
+
+class HivemindGradScaler(GradScaler):
+    def __init__(self, *args, **kwargs):
+        logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
+        super().__init__(*args, **kwargs)

+ 779 - 0
hivemind/optim/optimizer.py

@@ -0,0 +1,779 @@
+from __future__ import annotations
+
+import logging
+import os
+import time
+from functools import partial
+from typing import Callable, Optional, Sequence, Union
+
+import torch
+
+from hivemind.averaging.control import AveragingStage, StepControl
+from hivemind.compression import CompressionBase, NoCompression
+from hivemind.dht import DHT
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
+from hivemind.optim.state_averager import (
+    LRSchedulerBase,
+    OptimizerFactory,
+    Parameters,
+    ParamGroups,
+    SchedulerFactory,
+    TorchOptimizer,
+    TrainingStateAverager,
+)
+from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class Optimizer(torch.optim.Optimizer):
+    """
+    hivemind.Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
+
+    By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
+    There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
+    or even fully asynchronous (use_local_updates=True).
+
+    :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
+
+    >>> model = transformers.AutoModel("albert-xxlarge-v2")
+    >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
+    >>> opt = hivemind.Optimizer(dht=dht, run_id="run_42", batch_size_per_step=4, target_batch_size=4096,
+    >>>                          params=model.parameters(), optimizer=lambda params: torch.optim.Adam(params))
+    >>> while True:
+    >>>     loss = compute_loss_on_batch(model, batch_size=4)
+    >>>     opt.zero_grad()
+    >>>     loss.backward()
+    >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
+
+    By default, peers will perform the following steps:
+
+     * accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
+     * after peers collectively accumulate target_batch_size, average gradients with peers and perform optimizer step;
+     * if your peer lags behind the rest of the swarm, it will download parameters and optimizer state from others;
+
+    Unlike regular training, your device may join midway through training, when other peers already made some progress.
+    For this reason, any learning rate schedulers, curriculum and other **time-dependent features should be based on**
+    ``optimizer.local_epoch`` (and not the number ot calls to opt.step). Otherwise, peers that joined training late
+    may end up having different learning rates. To do so automatically, specify ``scheduler=...`` parameter below.
+
+    :What is an epoch?: Optimizer uses the term ``epoch`` to describe intervals between synchronizations. One epoch
+      coresponds to processing certain number of training samples (``target_batch_size``) in total across all peers.
+      Like in PyTorch LR Scheduler, **epoch does not necessarily correspond to a full pass over the training data.**
+      At the end of epoch, peers perform synchronous actions such as averaging gradients for a global optimizer update,
+      updating the learning rate scheduler or simply averaging parameters (if using local updates).
+      The purpose of this is to ensure that changing the number of peers does not require changing hyperparameters.
+      For instance, if the number of peers doubles, they will run all-reduce more frequently to adjust for faster training.
+
+    :Configuration guide: This guide will help you set up your first collaborative training run. It covers the most
+      important basic options, but ignores features that require significant changes to the training code.
+
+    >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=IF_BEHIND_FIREWALL_OR_VERY_UNRELIABLE, start=True)
+    >>> opt = hivemind.Optimizer(
+    >>>    dht=dht, run_id="a_unique_name_that_every_participant_will_see_when_training",
+    >>>    batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER, target_batch_size=LARGE_GLOBAL_BATCH,
+    >>>    # ^--- Each global optimzier step will use gradients from 1x-1.1x of target_batch_size (due to latency);
+    >>>    # It is recommended to train with very large batch sizes to reduce the % of time spent on communication.
+    >>>
+    >>>    params=params, optimizer=lambda params: AnyPyTorchOptimizer(params, **hyperparams_for_target_batch_size),
+    >>>    # tune learning rate for your target_batch_size. Here's a good reference: https://arxiv.org/abs/1904.00962
+    >>>    scheduler=lambda opt: AnyPyTorchScheduler(opt, **hyperparams_for_target_batch_size),
+    >>>    # scheduler.step will be called automatically each time when peers collectively accumulate target_batch_size
+    >>>
+    >>>    offload_optimizer=True,  # saves GPU memory, but increases RAM usage; Generally a good practice to use this.
+    >>>    delay_grad_averaging=OPTIONAL, delay_optimizer_step=OPTIONAL, # train faster, but with 1 round of staleness;
+    >>>    # setting both to True is equivalent to Delayed Parameter Updates (see https://arxiv.org/abs/2101.06840)
+    >>>
+    >>>    grad_compression=hivemind.Float16Compression(),  state_averaging_compression=hivemind.Float16Compression(),
+    >>>    # ^-- it is usually fine to use pure 16-bit or even lower precision during communication with no precaution;
+    >>>    # See hivemind/examples/albert for an working example of mixed 8/16-bit compression.
+    >>>
+    >>>    matchmaking_time=15.0, # 3-5s for small local runs, 10-15s for training over the internet or with many peers
+    >>>    averaging_timeout=60.0,  # around of 2x the actual time it takes to run all-reduce
+    >>>    verbose=True  # periodically report the training progress to the console (e.g. "Averaged with N peers")
+    >>> )  # and you're done!
+
+
+    :param dht: a running hivemind.DHT instance connected to other peers.
+    :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
+      **Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
+      Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
+      ``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
+      individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
+
+    :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch.
+      The actual batch may be *slightly* larger due asynchrony (e.g. peers submit more gradients in the last second).
+    :param batch_size_per_step: you should accumulate gradients over this many samples between calls to optimizer.step.
+
+    :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params).
+    :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer.
+      **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging require
+      and require the callable and will not work if hivemind.optimizer is created with a pre-existing PyTorch Optimizer.
+    :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
+      The learning rate scheduler will adjust learning rate based on global epoch, not the number of
+      local calls to optimizer.step; this is required to keep different peers synchronized.
+
+    :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds.
+      Increase if you see "averaged gradients with N peers" where N is below 0.9x the real siee on >=25% of epochs.
+      When training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes.
+    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
+      Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
+      Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
+    :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
+    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
+      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
+
+    :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
+    :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
+    :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
+
+    :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
+      if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
+
+    :param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
+      This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
+      The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
+      hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
+      gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
+
+    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
+      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
+      Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
+
+    :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
+    :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
+
+    :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
+    :param load_state_compression: compression strategy for loading state from peers, default = no compression
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+
+    :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
+    :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
+    :param performance_ema_alpha: moving average alpha in ProgressTracker, TrainingStateAverager and Optimizer
+    :param verbose: if True, report internal events such as accumilating gradients and running background tasks
+
+    :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
+      is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.
+
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: DHT,
+        run_id: str,
+        target_batch_size: int,
+        batch_size_per_step: Optional[int] = None,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        params: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        matchmaking_time: Optional[float] = 15.0,
+        averaging_timeout: Optional[float] = 60.0,
+        allreduce_timeout: Optional[float] = None,
+        next_chunk_timeout: Optional[float] = None,
+        load_state_timeout: float = 600.0,
+        reuse_grad_buffers: bool = False,
+        offload_optimizer: Optional[bool] = None,
+        delay_optimizer_step: Optional[bool] = None,
+        delay_grad_averaging: bool = False,
+        delay_state_averaging: bool = True,
+        average_state_every: int = 1,
+        use_local_updates: bool = False,
+        client_mode: bool = None,
+        auxiliary: bool = False,
+        grad_compression: CompressionBase = NoCompression(),
+        state_averaging_compression: CompressionBase = NoCompression(),
+        load_state_compression: CompressionBase = NoCompression(),
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        averager_opts: Optional[dict] = None,
+        tracker_opts: Optional[dict] = None,
+        performance_ema_alpha: float = 0.1,
+        shutdown_timeout: float = 5,
+        verbose: bool = False,
+    ):
+        self._parent_pid = os.getpid()
+
+        client_mode = client_mode if client_mode is None else dht.client_mode
+        delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
+        offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
+        next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
+        assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
+        assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
+        assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
+        if callable(optimizer) and params is not None:
+            if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
+                raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
+        elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
+            if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
+                raise ValueError(
+                    "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
+                    "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
+                )
+        else:
+            raise ValueError(
+                "Please initialize the optimizer in one of the following two ways:\n"
+                "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
+                "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
+            )
+        if use_local_updates:
+            assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
+            assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+
+        self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
+        self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
+        self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
+        self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
+        self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
+
+        self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
+        self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
+
+        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
+        self.scheduled_grads: Optional[StepControl] = None
+        self.scheduled_state: Optional[StepControl] = None
+
+        self.tracker = self._make_progress_tracker(
+            target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
+        )
+        self.state_averager = self._make_state_averager(
+            optimizer=optimizer,
+            params=params,
+            scheduler=scheduler,
+            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
+            compression=state_averaging_compression,
+            state_compression=load_state_compression,
+            average_opt_statistics=average_opt_statistics,
+            performance_ema_alpha=performance_ema_alpha,
+            extra_tensors=extra_tensors,
+            **averager_opts or {},
+        )
+        if not use_local_updates:
+            self.grad_averager = self._make_gradient_averager(
+                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+            )
+        else:
+            self.grad_averager = None
+
+        self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
+        self._schema_hash = self._compute_schema_hash()
+
+        self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
+        # used for pre-scheduling the averaging round in state_averager
+
+        self._step_supports_amp_scaling = reuse_grad_buffers
+        # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
+        # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
+
+    def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
+        return TrainingStateAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_state_averager",
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.allreduce_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            offload_optimizer=self.offload_optimizer,
+            custom_gradients=self.offload_optimizer,
+            status_loglevel=self.status_loglevel,
+            next_chunk_timeout=self.next_chunk_timeout,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+
+    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+        assert hasattr(self, "state_averager"), "must initialize state averager first"
+        grad_averager = GradientAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_grad_averager",
+            parameters=self.state_averager.main_parameters,
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.allreduce_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            next_chunk_timeout=self.next_chunk_timeout,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+        if self.offload_optimizer:
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad = averaged_grad
+        return grad_averager
+
+    def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
+        return ProgressTracker(
+            dht=self.dht,
+            prefix=self.run_id,
+            target_batch_size=target_batch_size,
+            client_mode=self.client_mode,
+            status_loglevel=self.status_loglevel,
+            start=True,
+            **kwargs,
+        )
+
+    def _compute_schema_hash(self) -> int:
+        optimized_param_groups = self.state_averager.optimizer.param_groups
+        optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+        param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
+
+        # offloaded optimizer requires that gradient tensors are reused between iterations
+        grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
+        return hash((grad_ids, param_shapes))
+
+    def is_alive(self) -> bool:
+        return self.state_averager.is_alive()
+
+    @property
+    def local_epoch(self) -> int:
+        """
+        This worker's current epoch, kept synchronized with peers. If peer's local_epoch lags behind others, it will
+        automatically re-synchronize by downloading state from another peer.
+        An epoch corresponds to accumulating target_batch_size across all active devices.
+        """
+        return self.state_averager.local_epoch
+
+    @property
+    def local_progress(self) -> LocalTrainingProgress:
+        return self.tracker.local_progress
+
+    @property
+    def use_local_updates(self) -> bool:
+        return self.grad_averager is None
+
+    @property
+    def use_gradient_averaging(self) -> bool:
+        return self.grad_averager is not None
+
+    def step(
+        self,
+        closure: Optional[Callable[[], torch.Tensor]] = None,
+        batch_size: Optional[int] = None,
+        grad_scaler: Optional[GradScaler] = None,
+    ):
+        """
+        Update training progress after accumulating another local batch size. Depending on the configuration, this will
+        report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
+
+        :param closure: A closure that reevaluates the model and returns the loss.
+        :param batch_size: optional override for batch_size_per_step from init.
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler.
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+        if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
+            raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
+        if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
+            raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
+        if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
+            raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
+        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
+
+        # if delayed updates finished before step, apply these updates; otherwise do nothing
+        self.state_averager.step(apply_delayed_updates=True)
+
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        if not self.auxiliary and self._should_load_state_from_peers():
+            logger.log(self.status_loglevel, "Peer is out of sync")
+            self.load_state_from_peers()
+            return loss  # local gradients were computed with out-of-sync parameters, must start over
+
+        if self.use_gradient_averaging:
+            # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
+            if not self.auxiliary:
+                grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
+                if not grads_are_valid:
+                    return loss  # local gradients were reset due to overflow, must start over
+
+            self._maybe_schedule_gradient_averaging()
+            self._maybe_schedule_state_averaging()
+
+        else:
+            # use_local_updates=True: update parameters on every step independently of other peers
+            if not self.auxiliary:
+                if grad_scaler is not None:
+                    with grad_scaler.running_global_step():
+                        assert grad_scaler.unscale_(self)
+
+                new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
+                self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
+                self._maybe_schedule_state_averaging()
+
+                self.state_averager.step(
+                    increment_epoch=False,
+                    optimizer_step=True,
+                    delay_optimizer_step=self.delay_optimizer_step,
+                    grad_scaler=grad_scaler,
+                )
+
+        if self.tracker.ready_to_update_epoch:
+            self._update_global_epoch(grad_scaler)
+
+        return loss
+
+    def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
+        """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
+        assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
+        _epoch_start_time = time.perf_counter()
+
+        with self.tracker.pause_updates():
+            wait_for_trigger = None
+
+            if self.use_gradient_averaging:
+                logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
+                if self.delay_optimizer_step:
+                    self.state_averager.step(wait_for_delayed_updates=True)
+
+                began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
+                if not began_averaging_gradients:
+                    # failed to start gradient averaging due to an internal error
+                    self.grad_averager.load_accumulators_into_averager_()
+                elif self.delay_grad_averaging:
+                    # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
+                    wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
+                else:
+                    # delay_grad_averaging=False, average gradients immediately
+                    self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
+
+            next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+            swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
+            should_average_state = (
+                swarm_not_empty
+                and next_epoch % self.average_state_every == 0
+                and not self.state_averager.averaging_in_progress
+            )
+
+            if should_average_state and self.scheduled_state is not None:
+                if self.scheduled_state.triggered or self.scheduled_state.done():
+                    logger.log(
+                        self.status_loglevel,
+                        f"Not using pre-scheduled group for state averaging because it"
+                        f"was already used elsewhere: {self.scheduled_state}",
+                    )
+                    self.scheduled_state = None
+                self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
+
+            self.state_averager.step(
+                increment_epoch=True,
+                wait_for_trigger=wait_for_trigger,
+                optimizer_step=should_perform_optimizer_step,
+                delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
+                grad_scaler=grad_scaler,
+                averaging_round=should_average_state,
+                delay_averaging=self.delay_state_averaging and not self.auxiliary,
+                averaging_control=self.scheduled_state if should_average_state else None,
+                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
+            )
+
+            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
+                self.scheduled_state.cancel()
+            self.scheduled_state = None
+
+            self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+            self._should_check_synchronization_on_update = True
+            # the above line ensures that peers check for *strict* synchronization once per epoch
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging and not self.auxiliary:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+            logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}")
+
+    def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
+        """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
+        if grad_scaler is not None:
+            with grad_scaler.running_global_step():
+                assert grad_scaler.unscale_(self)
+
+        began_averaging_gradients = False
+        if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
+            logger.log(
+                self.status_loglevel,
+                f"Not using pre-scheduled group for state averaging because it"
+                f"was already used elsewhere: {self.scheduled_state}",
+            )
+            self.scheduled_grads = None
+
+        elif self.tracker.global_progress.num_peers > 1:
+            try:
+                self.scheduled_grads = self.grad_averager.step(
+                    control=self.scheduled_grads, reset_accumulators=True, wait=False
+                )
+                began_averaging_gradients = True
+            except BaseException as e:
+                logger.exception(e)
+
+        if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
+            if self.tracker.global_progress.num_peers > 1:
+                logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
+                self._tag_along_with_zero_weight(self.scheduled_grads)
+            else:
+                logger.log(self.status_loglevel, f"Skipping pre-scheduled averaging round: there are no other peers")
+                self._load_local_gradients_into_optimizer()
+                self.scheduled_grads.cancel()
+            self.scheduled_grads = None
+        return began_averaging_gradients
+
+    def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
+        """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
+        assert not self.use_local_updates and not self.auxiliary
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
+            return False
+
+        self.grad_averager.accumulate_grads_(batch_size)
+        self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
+        return True
+
+    def _maybe_schedule_gradient_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
+        assert self.use_gradient_averaging
+        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
+            if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
+                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
+                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec")
+                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
+
+    def _maybe_schedule_state_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+        if next_epoch % self.average_state_every != 0:
+            return  # averaging is not performed at this epoch
+        if self.state_averager.averaging_in_progress:
+            return  # previous run is still in progress
+        if self.delay_before_state_averaging.num_updates == 0:
+            return  # not enough data to accurately pre-schedule
+
+        estimated_time = self.tracker.estimated_next_update_time
+        estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
+        estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
+        eta_seconds_to_averaging = estimated_time - get_dht_time()
+
+        if eta_seconds_to_averaging <= self.matchmaking_time:
+            if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
+                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
+                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f} sec")
+                self.scheduled_state = self.state_averager.schedule_step(
+                    gather=next_epoch, timeout=self.averaging_timeout
+                )
+
+    def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
+        """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
+        assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
+        averaged_gradients = False
+
+        try:
+            if maybe_step_control is not None:
+                group_info = maybe_step_control.result(self.averaging_timeout)
+                logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                self._load_averaged_gradients_into_optimizer_()
+                averaged_gradients = True
+            else:
+                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
+        except BaseException as e:
+            logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
+
+        if not averaged_gradients:
+            self._load_local_gradients_into_optimizer()
+
+    def _load_averaged_gradients_into_optimizer_(self):
+        """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
+        assert self.use_gradient_averaging
+
+        if self.offload_optimizer:
+            pass  # averaged gradients are already baked into optimizer, see _make_gradient_averager
+        else:
+            # copy averaged gradients into optimizer .grad buffers
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad.copy_(averaged_grad, non_blocking=True)
+
+        self.grad_averager.notify_used_averaged_gradients()
+
+    def _load_local_gradients_into_optimizer(self):
+        """Fallback to using local gradients in the optimizer (instead of averaged gradients)"""
+        logger.log(self.status_loglevel, f"Proceeding with local gradients")
+        self.grad_averager.load_accumulators_into_averager_()
+        # note: we load gradients into grad_averager even though there is only one peer because of two reasons:
+        # - if offload_optimizer, then we must load gradients onto the CPU gradient buffers used by the optimizer
+        # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
+        self._load_averaged_gradients_into_optimizer_()
+
+    def zero_grad(self, set_to_none: bool = False):
+        """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
+        if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
+            raise ValueError(
+                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
+                f"call zero_grad manually. Gradients will be refreshed internally"
+            )
+        for param_group in self.param_groups:
+            for param in param_group["params"]:
+                if param.grad is None:
+                    pass
+                elif set_to_none:
+                    param.grad = None
+                else:
+                    param.grad.zero_()
+
+    def _should_load_state_from_peers(self) -> bool:
+        """
+        If true, peer will discard local progress and attempt to download state from peers.
+        This method allows peer to continue training in two cases:
+         - peer is on the same epoch as other collaborators - keep training normally
+         - peer was on the same epoch and accumulated some grads, but some collaborators
+             have just transitioned to the next epoch - this peer should also transition.
+
+        :note: The latter case occurs due to the lack of network synchrony: the first peer that
+        detects enough samples will transition to the next step and start counting samples anew.
+        Some other peers may take time before they check with DHT and observe that
+          - the global epoch is technically one epoch ahead of the current one and
+          - the remaining (non-transitioned) peers no longer have target_batch_size between them
+        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
+        """
+        if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
+            self._should_check_synchronization_on_update = False
+            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
+        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
+
+    def is_synchronized_with_peers(self) -> bool:
+        """Checks whether the current peer is up-to-date with others in terms of the epoch (step) number."""
+        return self.local_epoch >= self.tracker.global_epoch - 1
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to load the newest collaboration state from other peers within the same run_id.
+
+        If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
+        """
+        # note: we tag along for the next all-reduce because the run may have already started and cancelling it
+        # will cause peers to restart matchmaking and may  stall the entire collaboration for a few seconds.
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self._tag_along_with_zero_weight(self.scheduled_grads)
+            self.scheduled_grads = None
+        self.state_averager.step(wait_for_delayed_updates=True)
+
+        with self.tracker.pause_updates():
+            while True:
+                try:
+                    self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    break
+                except KeyboardInterrupt:
+                    raise
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
+                    continue
+
+            if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
+                logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}")
+                self.state_averager.local_epoch = self.tracker.global_epoch
+
+            self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+    def state_dict(self) -> dict:
+        state_dict = self.state_averager.optimizer.state_dict()
+        state_dict["state"]["local_epoch"] = self.local_epoch
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "local_epoch" in state_dict["state"]:
+            self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
+        return self.state_averager.optimizer.load_state_dict(state_dict)
+
+    @property
+    def state(self):
+        return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
+
+    @property
+    def opt(self) -> TorchOptimizer:
+        return self.state_averager.optimizer
+
+    @property
+    def param_groups(self) -> ParamGroups:
+        next_index = 0
+        param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
+        for param_group in param_groups:
+            num_params = len(param_group["params"])
+            main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
+            param_group["params"] = main_params_for_group
+            next_index += num_params
+        assert next_index == len(self.state_averager.main_parameters)
+        return param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise ValueError(
+            f"{self.__class__.__name__} does not support calling add_param_group after creation. "
+            f"Please provide all parameter groups at init"
+        )
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
+
+    def _tag_along_with_zero_weight(self, control: StepControl):
+        """Wait for a running averaging round to finish with zero weight."""
+        if not control.triggered:
+            control.weight = 0
+            control.allow_allreduce()
+        if not control.done():
+            try:
+                control.result(self.averaging_timeout)
+            except BaseException as e:
+                logger.exception(e)
+                if not control.done():
+                    control.cancel()
+
+    def shutdown(self):
+        logger.log(self.status_loglevel, "Sending goodbye to peers...")
+        self.tracker.shutdown(self.shutdown_timeout)
+        self.state_averager.step(wait_for_delayed_updates=True)
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
+                else:
+                    self._tag_along_with_zero_weight(scheduled_round)
+
+        logger.log(self.status_loglevel, "Shutting down averagers...")
+        self.state_averager.shutdown()
+        if self.use_gradient_averaging:
+            self.grad_averager.shutdown()
+        logger.log(self.status_loglevel, f"{self.__class__.__name__} is shut down")
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 0 - 41
hivemind/optim/performance_ema.py

@@ -1,41 +0,0 @@
-from contextlib import contextmanager
-
-from hivemind.utils import get_dht_time
-
-
-class PerformanceEMA:
-    """
-    A running estimate of performance (operations/sec) using adjusted exponential moving average
-    :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
-    """
-
-    def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
-        self.alpha, self.eps, self.num_updates = alpha, eps, 0
-        self.ema_seconds_per_sample, self.samples_per_second = 0.0, eps
-        self.timestamp = get_dht_time()
-        self.paused = False
-
-    def update(self, num_processed: int) -> float:
-        """
-        :param num_processed: how many items were processed since last call
-        :returns: current estimate of performance (samples per second), but at most
-        """
-        assert not self.paused, "PerformanceEMA is currently paused"
-        assert num_processed > 0, f"Can't register processing {num_processed} samples"
-        self.timestamp, old_timestamp = get_dht_time(), self.timestamp
-        seconds_per_sample = max(0, self.timestamp - old_timestamp) / num_processed
-        self.ema_seconds_per_sample = self.alpha * seconds_per_sample + (1 - self.alpha) * self.ema_seconds_per_sample
-        self.num_updates += 1
-        adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
-        self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
-        return self.samples_per_second
-
-    @contextmanager
-    def pause(self):
-        """While inside this context, EMA will not count the time passed towards the performance estimate"""
-        self.paused, was_paused = True, self.paused
-        try:
-            yield
-        finally:
-            self.timestamp = get_dht_time()
-            self.paused = was_paused

+ 363 - 0
hivemind/optim/progress_tracker.py

@@ -0,0 +1,363 @@
+import asyncio
+import contextlib
+import logging
+import threading
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+import numpy as np
+from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
+
+from hivemind.dht import DHT
+from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
+from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
+from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.performance_ema import PerformanceEMA
+
+logger = get_logger(__name__)
+
+
+@dataclass(frozen=False)
+class GlobalTrainingProgress:
+    epoch: int
+    samples_accumulated: int
+    target_batch_size: int
+    num_peers: int
+    num_clients: int
+    eta_next_epoch: float
+    next_fetch_time: float
+
+
+class LocalTrainingProgress(BaseModel):
+    peer_id: bytes
+    epoch: conint(ge=0, strict=True)
+    samples_accumulated: conint(ge=0, strict=True)
+    samples_per_second: confloat(ge=0.0, strict=True)
+    time: StrictFloat
+    client_mode: StrictBool
+
+
+class TrainingProgressSchema(BaseModel):
+    progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]
+
+
+class ProgressTracker(threading.Thread):
+    """
+    Auxiliary class that keeps track of local & global training progress, measured in epochs.
+    An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
+    Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.
+
+    :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
+    :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
+    :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
+    :param expected_drift_peers: assume that this many new peers can join between epochs
+    :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between epochs
+    :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
+      refresh the collaboration-wide statistics (to avoid missing the moment when peers transition to the next epoch)
+    :param performance_ema_alpha: smoothing value used to estimate this peer's performance (samples per second)
+    :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
+
+    Example:
+
+    >>> tracker = ProgressTracker(hivemind.DHT(...), prefix="my_experiment_with_several_peers", target_batch_size=100)
+    >>> local_epoch, local_samples = 0, 0
+    >>> while True:
+    >>>     accumulate_gradients(batch_size=32)
+    >>>     local_samples += 32
+    >>>     tracker.report_local_progress(local_epoch, local_samples)
+    >>>     if local_epoch < tracker.global_progress.epoch:
+    >>>         download_state_from_peers()  # if peer is out of sync, synchronize it with the swarm
+    >>>     if tracker.accumulated_enough_samples:
+    >>>         with tracker.pause_updates():
+    >>>             aggregate_gradients_with_peers()
+    >>>             update_model_parameters()
+    >>>             local_epoch = tracker.update_epoch(local_epoch + 1)
+    >>>             local_samples = 0
+    """
+
+    def __init__(
+        self,
+        dht: DHT,
+        prefix: str,
+        target_batch_size: int,
+        *,
+        client_mode: Optional[bool] = None,
+        min_refresh_period: float = 0.5,
+        max_refresh_period: float = 10,
+        default_refresh_period: float = 3,
+        expected_drift_peers: float = 3,
+        expected_drift_rate: float = 0.2,
+        performance_ema_alpha: float = 0.1,
+        metadata_expiration: float = 60.0,
+        status_loglevel: int = logging.DEBUG,
+        private_key: Optional[RSAPrivateKey] = None,
+        daemon: bool = True,
+        start: bool,
+    ):
+        client_mode = client_mode if client_mode is not None else dht.client_mode
+        self.dht, self.prefix, self.client_mode = dht, prefix, client_mode
+        self.training_progress_key = f"{self.prefix}_progress"
+        self.target_batch_size = target_batch_size
+        self.min_refresh_period, self.max_refresh_period = min_refresh_period, max_refresh_period
+        self.default_refresh_period = default_refresh_period
+        self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
+        self.status_loglevel = status_loglevel
+        self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
+        self.metadata_expiration = metadata_expiration
+
+        signature_validator = RSASignatureValidator(private_key)
+        self._local_public_key = signature_validator.local_public_key
+        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
+
+        # report the collaboration progress periodically or in background
+        self.local_progress = self._get_local_progress(local_epoch=0, samples_accumulated=0)
+        metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
+        self.global_progress = self._parse_swarm_progress_data(metadata)
+        self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
+        self.should_report_progress, self.fetched_global_progress_this_epoch = threading.Event(), threading.Event()
+        self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
+        super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
+        if start:
+            self.start()
+
+    @property
+    def global_epoch(self) -> int:
+        return self.global_progress.epoch
+
+    @property
+    def ready_to_update_epoch(self) -> bool:
+        """Whether or not this peer can increment epoch right away."""
+        return (
+            self.global_epoch > self.local_progress.epoch
+            or self.global_progress.samples_accumulated >= self.target_batch_size
+            or get_dht_time() >= self.global_progress.eta_next_epoch
+        )
+
+    @property
+    def estimated_next_update_time(self) -> DHTExpiration:
+        """Estimate (absolute) time when this peer should increment epoch"""
+        if self.ready_to_update_epoch:
+            return get_dht_time()
+        return self.global_progress.eta_next_epoch
+
+    def _get_local_progress(self, local_epoch: int, samples_accumulated: int):
+        return LocalTrainingProgress(
+            peer_id=self.dht.peer_id.to_bytes(),
+            epoch=local_epoch,
+            samples_accumulated=samples_accumulated,
+            samples_per_second=self.performance_ema.samples_per_second,
+            time=get_dht_time(),
+            client_mode=self.client_mode,
+        )
+
+    def report_local_progress(self, local_epoch: int, samples_accumulated: int, update_global_samples: bool = True):
+        """Update the number of locally accumulated samples and notify to other peers about this."""
+        extra_samples = samples_accumulated - self.local_progress.samples_accumulated
+        if update_global_samples and local_epoch == self.local_progress.epoch == self.global_progress.epoch:
+            self.global_progress.samples_accumulated += extra_samples
+            # note: the above line can decrease the number of samples, e.g. if forced to reset due to overflow
+
+        if extra_samples > 0:
+            self.performance_ema.update(task_size=extra_samples)
+            logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
+        else:
+            logger.debug("Resetting performance timestamp to current time (progress was reset)")
+            self.performance_ema.reset_timer()
+
+        self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
+        self.should_report_progress.set()
+
+    @contextlib.contextmanager
+    def pause_updates(self):
+        """Temporarily stop progress tracker from updating global training state"""
+        with self.lock_global_progress, self.performance_ema.pause():
+            yield
+
+    def update_epoch(self, new_epoch: Optional[int] = None) -> int:
+        """Update the local epoch, reset the number of sample accumulated, reset local progress, return new epoch"""
+        assert self.lock_global_progress.locked(), "ProgressTracker must be paused when incrementing epoch"
+        if new_epoch is None:
+            new_epoch = self.local_progress.epoch + 1
+        if new_epoch > self.global_progress.epoch:
+            self.global_progress.epoch = new_epoch
+            self.global_progress.samples_accumulated = 0
+            self.global_progress.eta_next_epoch = float("inf")
+        self.report_local_progress(new_epoch, samples_accumulated=0)
+        self.fetched_global_progress_this_epoch.clear()
+        return new_epoch
+
+    def run(self):
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+        loop.run_until_complete(asyncio.gather(self._progress_reporter(), self._progress_fetcher()))
+        self.shutdown_complete.set()
+
+    async def _progress_reporter(self):
+        """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
+        last_report_time = -float("inf")
+        last_report_epoch = -float("inf")
+        store_task = None
+        try:
+            while not self.shutdown_triggered.is_set():
+                wait_timeout = max(0.0, last_report_time - get_dht_time() + self.metadata_expiration / 2)
+                logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
+                await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
+                if self.should_report_progress.is_set():
+                    logger.debug(f"Progress update triggered by report_local_progress")
+                    self.should_report_progress.clear()
+                else:
+                    logger.debug(f"Progress update triggered by metadata_expiration")
+
+                local_progress = self.local_progress
+                last_report_time = get_dht_time()
+                if local_progress.samples_accumulated > 0:
+                    last_report_epoch = self.global_epoch
+
+                if last_report_epoch >= self.global_epoch - 1:
+                    # report progress if peer is synchronized and actively reporting samples. Do not report aux peers.
+                    store_task = asyncio.create_task(
+                        asyncio.wait_for(
+                            self.dht.store(
+                                key=self.training_progress_key,
+                                subkey=self._local_public_key,
+                                value=local_progress.dict(),
+                                expiration_time=last_report_time + self.metadata_expiration,
+                                return_future=True,
+                            ),
+                            timeout=self.metadata_expiration,
+                        )
+                    )
+        finally:
+            logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}")
+            if store_task is not None:
+                store_task.cancel()
+
+    async def _progress_fetcher(self):
+        """
+        Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
+        """
+        loop = asyncio.get_event_loop()
+        shutdown_checker = asyncio.create_task(
+            asyncio.wait_for(loop.run_in_executor(None, self.shutdown_triggered.wait), None)
+        )
+
+        async def _fetch_progress_unless_shutdown_triggered():
+            """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
+            getter = asyncio.create_task(
+                asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
+            )
+            await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
+            if self.shutdown_triggered.is_set():
+                return
+            return await getter
+
+        try:
+            while not self.shutdown_triggered.is_set():
+                time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
+                state_updated_externally = await loop.run_in_executor(
+                    None, self.global_state_updated.wait, time_to_next_update
+                )
+                if state_updated_externally:
+                    self.global_state_updated.clear()
+                    continue
+
+                async with enter_asynchronously(self.lock_global_progress):
+                    maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
+                    if self.shutdown_triggered.is_set():
+                        break
+                    metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
+                    self.global_progress = self._parse_swarm_progress_data(metadata)
+                    self.fetched_global_progress_this_epoch.set()
+
+        finally:
+            logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}")
+
+    def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
+        """Read performance statistics reported by peers, estimate progress towards next batch"""
+        current_time = get_dht_time()
+
+        if not isinstance(metadata, dict) or len(metadata) == 0:
+            logger.log(self.status_loglevel, f"Found no active peers: {metadata}")
+            samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated)
+            local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second
+
+            return GlobalTrainingProgress(
+                self.local_progress.epoch,
+                self.local_progress.samples_accumulated,
+                self.target_batch_size,
+                num_peers=0,
+                num_clients=0,
+                eta_next_epoch=current_time + local_eta_next_epoch,
+                next_fetch_time=current_time + self.default_refresh_period,
+            )
+
+        valid_peer_entries = [
+            LocalTrainingProgress.parse_obj(peer_state.value)
+            for peer_state in metadata.values()
+            if peer_state.value is not None
+        ]
+
+        num_peers = len(valid_peer_entries)
+        num_clients = sum(peer.client_mode for peer in valid_peer_entries)
+
+        global_epoch = self.local_progress.epoch
+        for peer in valid_peer_entries:
+            if not peer.client_mode:
+                global_epoch = max(global_epoch, peer.epoch)
+
+        total_samples_accumulated = estimated_current_samples = 0
+        total_samples_per_second = self.performance_ema.eps
+
+        for peer in valid_peer_entries:
+            total_samples_per_second += peer.samples_per_second
+            if peer.epoch == global_epoch:
+                total_samples_accumulated += peer.samples_accumulated
+                estimated_current_samples += (
+                    peer.samples_accumulated + max(0.0, current_time - peer.time) * peer.samples_per_second
+                )
+            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
+            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
+
+        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
+        estimated_time_to_next_epoch = max(0, estimated_samples_remaining) / total_samples_per_second
+
+        expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
+        time_to_next_fetch = float(
+            np.clip(
+                a=estimated_time_to_next_epoch * num_peers / expected_max_peers,
+                a_min=self.min_refresh_period,
+                a_max=self.max_refresh_period,
+            )
+        )
+        logger.log(
+            self.status_loglevel,
+            f"{self.prefix} accumulated {total_samples_accumulated} samples for epoch #{global_epoch} from "
+            f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
+        )
+        return GlobalTrainingProgress(
+            global_epoch,
+            total_samples_accumulated,
+            target_batch_size=self.target_batch_size,
+            num_peers=num_peers,
+            num_clients=num_clients,
+            eta_next_epoch=current_time + estimated_time_to_next_epoch,
+            next_fetch_time=current_time + time_to_next_fetch,
+        )
+
+    def shutdown(self, timeout: Optional[float] = None):
+        """Permanently disable all tracking activity"""
+        self.shutdown_triggered.set()
+        self.should_report_progress.set()
+        self.global_state_updated.set()
+        self.shutdown_complete.wait(timeout)
+        self.dht.store(
+            self.training_progress_key,
+            subkey=self._local_public_key,
+            value=None,
+            expiration_time=get_dht_time() + self.metadata_expiration,
+            return_future=True,
+        )
+
+    def __del__(self):
+        if self.is_alive():
+            self.shutdown()

+ 9 - 5
hivemind/optim/simple.py

@@ -4,9 +4,9 @@ from typing import Optional, Sequence, Tuple
 
 import torch
 
-from hivemind.averaging import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
 
 logger = get_logger(__name__)
@@ -86,6 +86,10 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             if self.local_step % self.averaging_step_period == 0:
                 self.update_event.set()
             self.averager.pending_updates_done.wait()
+
+            if not self.averager.client_mode:
+                self.averager.state_sharing_priority = get_dht_time()
+
             return loss
         finally:
             self.lock_parameters.acquire()
@@ -127,16 +131,16 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                 time.sleep(time_to_nearest_interval)
 
             if verbose:
-                logger.info(f"Starting a new averaging round with current parameters.")
+                logger.info(f"Starting a new averaging round with current parameters")
             try:
                 group_info = averager.step(lock_parameters, **kwargs)
                 if verbose:
                     if group_info is not None:
-                        logger.info(f"Finished averaging round in with {len(group_info)} peers.")
+                        logger.info(f"Finished averaging round in with {len(group_info)} peers")
                     else:
-                        logger.warning(f"Averaging round failed: could not find group.")
+                        logger.warning(f"Averaging round failed: could not find group")
             except Exception as e:
-                logger.error(f"Averaging round failed: caught {e}.")
+                logger.error(f"Averaging round failed: caught {e}")
 
 
 class DecentralizedSGD(DecentralizedOptimizer):

+ 723 - 0
hivemind/optim/state_averager.py

@@ -0,0 +1,723 @@
+""" An extension of averager that supports common optimization use cases. """
+import logging
+import threading
+import time
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import nullcontext
+from itertools import chain
+from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
+
+import torch
+
+import hivemind
+from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
+
+logger = get_logger(__name__)
+
+
+Parameters = Iterable[torch.Tensor]
+ParamGroups = Iterable[Dict[str, Any]]
+TorchOptimizer = torch.optim.Optimizer
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
+SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
+
+
+class TrainingStateAverager(DecentralizedAverager):
+    """
+    An auxiliary class that holds peer's training state, including model parameters, optimizer statistics, scheduler
+    and any other variables that define the local training state (e.g. batchnorm moving averages).
+    TrainingStateAveraager is intended to keep these parameters weakly synchronized across the swarm.
+
+    The intended use is to call .step(optimizer_step=..., averaging_round=...) periodically, e.g. after every batch.
+    If peer gets out of sync with the swarm, one should call state_averager.load_state_from_peers() to re-synchronize.
+
+    Example:
+
+    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, params=model.parameters(), ...)
+    >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
+    >>> avgr.load_state_from_peers()
+    >>> for i, batch in enumerate(training_dataloader):
+    >>>     loss = compute_loss(model, batch)
+    >>>     loss.backward()
+    >>>     avgr.step(optimizer_step=i % 10 == 0, averaging_round=is_it_time_for_averaging(), delay_averaging=True)
+
+    :note: when using delay_averaging or delay_optimizer_step, calling optimizer directly is not recommended because
+      it may overlap with delayed updates from a background thread with unpredictable results. Instead, please call
+      TrainingStateAverager.step(..., optimizer_step=True)
+
+    :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
+    :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
+    :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
+    :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
+    :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
+      state tensors. If False, user must make sure that all tensors are pre-initialized at init.
+      By default, initialize optimizer unless it already has some state tensors to begin with.
+    :param offload_optimizer: if True, create optimizer on top of averaged parameters which may save device memory.
+    :param custom_gradients: if True, do *not* automatically load local gradients into the offloaded optimizer.
+      This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
+    :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
+      For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+      Defaults to True if offload_optimizer, False otherwise.
+    :param delta_rule_averaging: if True, averaging will use delta rule to allow running local optimizer steps
+      while averaging. Delta rule: `state_tensor := state_tensor + averaging_result - state_tensor_before_averaging`
+    :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
+    :param parameter_names: optionally provide parameter names in the same order as in params
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+    :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
+    :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: hivemind.DHT,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        params: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        initialize_optimizer: Optional[bool] = None,
+        offload_optimizer: bool = False,
+        custom_gradients: bool = False,
+        reuse_tensors: Optional[bool] = None,
+        delta_rule_averaging: bool = False,
+        performance_ema_alpha: float = 0.1,
+        sync_epoch_when_averaging: bool = False,
+        parameter_names: Optional[Sequence[str]] = None,
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        status_loglevel: int = logging.DEBUG,
+        **kwargs,
+    ):
+        average_opt_statistics = tuple(average_opt_statistics)
+        assert all(isinstance(key, str) for key in average_opt_statistics)
+        if reuse_tensors is None:
+            reuse_tensors = offload_optimizer and not delta_rule_averaging
+        if custom_gradients and not offload_optimizer:
+            logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+        if reuse_tensors and delta_rule_averaging:
+            raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive")
+
+        param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
+
+        self.status_loglevel = status_loglevel
+        self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients
+        self.reuse_tensors, self.delta_rule_averaging = reuse_tensors, delta_rule_averaging
+        self._old_tensors: Optional[Sequence[torch.Tensor]] = None  # for delta rule
+
+        self.main_parameters, self.parameter_names = main_parameters, parameter_names
+        self._averaged_parameters = self._make_averaged_parameters(main_parameters)
+        self.optimizer, self.scheduler = self._init_components(
+            param_groups, optimizer, scheduler, initialize_optimizer
+        )
+        self.opt_keys_for_averaging, self.extra_tensors = average_opt_statistics, extra_tensors
+        self.sync_epoch_when_averaging = sync_epoch_when_averaging
+        self.local_epoch = 0
+
+        self.delay_before_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        self.step_executor = ThreadPoolExecutor(max_workers=2 if self.delta_rule_averaging else 1)
+        self.finished_optimizer_step = threading.Event()
+        self.finished_averaging_round = threading.Event()
+        self.lock_optimizer = threading.Lock()
+        self.lock_averaging = threading.Lock()
+        self.pending_updates = set()
+
+        super().__init__(
+            dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
+        )
+
+    @staticmethod
+    def _check_params(
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        param_groups: Optional[Union[Parameters, ParamGroups]],
+        parameter_names: Optional[Sequence[str]],
+    ) -> Tuple[ParamGroups, Sequence[torch.Tensor], Sequence[str]]:
+        """Get and verify parameters, groups and names"""
+        if param_groups is None:
+            assert hasattr(optimizer, "param_groups"), "Must provide param_groups or an optimizer with .param_groups"
+            param_groups = optimizer.param_groups
+        param_groups = tuple(param_groups)
+        if all(isinstance(p, torch.Tensor) for p in param_groups):
+            param_groups = (dict(params=param_groups),)
+        for group in param_groups:
+            assert isinstance(group, dict) and group.get("params") is not None
+            assert all(isinstance(p, torch.Tensor) for p in group["params"])
+        parameters = tuple(chain(*(group["params"] for group in param_groups)))
+        if parameter_names is None:
+            parameter_names = tuple(i for i in range(len(parameters)))
+        parameter_names = tuple(nested_flatten(parameter_names))
+        assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
+        assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
+        params_with_grad = sum(p.numel() for p in parameters if p.requires_grad)
+        params_no_grad = sum(p.numel() for p in parameters if not p.requires_grad)
+        if params_no_grad >= params_with_grad:
+            logger.warning(
+                "The majority of parameters have requires_grad=False, but they are still synchronized"
+                " with peers. If these parameters are frozen (not updated), please do not feed them into "
+                "the optimizer at all in order to avoid communication overhead. Proceeding anyway."
+            )
+
+        return param_groups, parameters, parameter_names
+
+    def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
+        """Initialize averaged parameters based on the optimizer and averaging mode"""
+        return tuple(self._make_host_tensor(param, force_copy=self.offload_optimizer) for param in main_parameters)
+
+    def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
+        """Create a new tensor for averaging or reuse the existing one"""
+        if self.reuse_tensors and not force_copy:
+            if source_tensor.device != torch.device("cpu"):
+                raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU")
+            if not source_tensor.is_shared():
+                source_tensor.share_memory_()
+            return source_tensor
+        else:
+            averaged_tensor = source_tensor.detach().to(device="cpu", dtype=torch.float32, copy=True)
+            return averaged_tensor.share_memory_().requires_grad_(source_tensor.requires_grad)
+
+    def _init_components(
+        self,
+        param_groups: ParamGroups,
+        optimizer_or_factory: Union[TorchOptimizer, OptimizerFactory],
+        scheduler_or_factory: Optional[Union[LRSchedulerBase, SchedulerFactory]],
+        initialize_optimizer: Optional[bool],
+    ) -> Tuple[TorchOptimizer, Optional[LRSchedulerBase]]:
+        """Get optimizer and scheduler by either instantiating user-provided factory or using pre-instantiated ones"""
+        assert hasattr(self, "_averaged_parameters"), "Internal error: must initialize averaged parameters first"
+        optimizer_is_factory = callable(optimizer_or_factory) and not isinstance(optimizer_or_factory, TorchOptimizer)
+        scheduler_is_factory = callable(scheduler_or_factory) and not isinstance(scheduler_or_factory, LRSchedulerBase)
+        if optimizer_is_factory and not scheduler_is_factory and scheduler_or_factory is not None:
+            raise ValueError("If optimizer is created internally, scheduler must also be initialized internally")
+        if self.offload_optimizer and not optimizer_is_factory:
+            raise ValueError("Using offload_optimizer requires creating optimizer inside hivemind")
+
+        # create optimizer
+        if optimizer_is_factory:
+            if self.offload_optimizer:
+                if self.reuse_tensors:
+                    parameters_for_optimizer = self._averaged_parameters
+                else:
+                    parameters_for_optimizer = tuple(
+                        tensor.detach().clone().requires_grad_(tensor.requires_grad)
+                        for tensor in self._averaged_parameters
+                    )
+
+                next_index = 0
+                param_groups_for_optimizer = []
+                for param_group in param_groups:
+                    num_params = len(param_group["params"])
+                    averaged_params_for_group = parameters_for_optimizer[next_index : next_index + num_params]
+                    param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
+                    next_index += num_params
+                assert next_index == len(parameters_for_optimizer)
+
+                for param in parameters_for_optimizer:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
+            else:
+                param_groups_for_optimizer = param_groups
+            optimizer = optimizer_or_factory(param_groups_for_optimizer)
+        else:
+            optimizer = optimizer_or_factory
+
+        # optionally initialize optimizer state dict
+        if initialize_optimizer is None:
+            initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
+            logger.log(
+                self.status_loglevel,
+                "Initializing optimizer manually since it has no tensors in state dict. "
+                "To override this, provide initialize_optimizer=False",
+            )
+
+        if initialize_optimizer:
+            initialize_optimizer_state_(optimizer)  # note: this will run one optimizer step!
+
+        # create LR scheduler
+        if scheduler_is_factory:
+            assert callable(scheduler_or_factory)
+            scheduler = scheduler_or_factory(optimizer)
+        else:
+            scheduler = scheduler_or_factory
+
+        # verify optimizer and scheduler
+        assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
+        if self.reuse_tensors:
+            for param_group in optimizer.param_groups:
+                for param in param_group["params"]:
+                    assert param.is_shared()
+        assert isinstance(scheduler, (LRSchedulerBase, type(None)))
+        if scheduler is not None:
+            assert scheduler.optimizer == optimizer
+        return optimizer, scheduler
+
+    def _local_tensors(self) -> Iterator[torch.Tensor]:
+        """Iterate local trainer's tensors that should be averaged with peers"""
+        for param_group in self.optimizer.param_groups:
+            yield from param_group["params"]
+        for stats in self.opt_keys_for_averaging:
+            for param_group in self.optimizer.param_groups:
+                for param in param_group["params"]:
+                    yield self.optimizer.state[param][stats]
+        yield from self.extra_tensors
+
+    @torch.no_grad()
+    def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
+        """Create or reuse a tuple of all averaged tensors, including parameters, optimizer statistics and extras"""
+        assert hasattr(self, "optimizer"), "Optimizer should already be initialized by this point"
+        assert hasattr(self, "_averaged_parameters"), "Should initialize _averaged_parameters first"
+        assert not hasattr(self, "_averaged_tensors"), "Averager is already initialized"
+        assert all(isinstance(key, str) for key in self.opt_keys_for_averaging)
+
+        local_tensors = tuple(self._local_tensors())
+        local_non_parameters = local_tensors[len(self._averaged_parameters) :]
+        averaged_tensors = tuple(map(torch.Tensor.detach, self._averaged_parameters))
+        averaged_non_parameters = tuple(map(self._make_host_tensor, local_non_parameters))
+        averaged_tensors = tuple(chain(averaged_tensors, averaged_non_parameters))
+
+        assert len(averaged_tensors) == len(local_tensors)
+        for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+            assert local_tensor.shape == averaged_tensor.shape
+            if averaged_tensor.grad is not None:
+                logger.log(self.status_loglevel, "setting gradients for averaged tensor to None")
+
+        return averaged_tensors
+
+    def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
+        """Get CompressionInfo for each state tensor, accounting for its role and specification"""
+        tensor_infos = []
+        for param, param_name in zip(self.main_parameters, self.parameter_names):
+            tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
+        for stats_name in self.opt_keys_for_averaging:
+            opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+            assert len(opt_parameters) == len(self.parameter_names)
+            for param, param_name in zip(opt_parameters, self.parameter_names):
+                tensor_infos.append(
+                    CompressionInfo.from_tensor(
+                        self.optimizer.state[param][stats_name],
+                        key=(param_name, stats_name),
+                        role=TensorRole.OPTIMIZER,
+                    )
+                )
+        for i, extra_tensor in enumerate(self.extra_tensors):
+            tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
+        return tuple(tensor_infos)
+
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into TrainingStateAverager.step to use pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
+    def step(
+        self,
+        wait_for_delayed_updates: bool = None,
+        apply_delayed_updates: bool = True,
+        increment_epoch: bool = False,
+        optimizer_step: bool = False,
+        zero_grad: bool = False,
+        delay_optimizer_step: bool = False,
+        averaging_round: bool = False,
+        delay_averaging: Optional[bool] = None,
+        averaging_control: Optional[StepControl] = None,
+        wait_for_trigger: Optional[Callable[[], Any]] = None,
+        grad_scaler: Optional[GradScaler] = None,
+        averaging_opts: Optional[Dict[str, Any]] = None,
+    ):
+        """
+        Perform one or several possible actions, depending on the specified keyword args.
+        The actions will be performed in the same order as specified below:
+
+        :param wait_for_delayed_updates: if there are background averaging rounds, wait for them to finish
+          by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
+        :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
+        :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :note: if specified, it is guaranteed that epoch is incremented immediately regardless of other options
+        :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
+        :param zero_grad: if True, reset local gradients after performing optimizer step
+        :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
+        :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
+        :param delay_averaging: if True, perform averaging in background and apply results in a future step
+          by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param averaging_control: if specified, use this as a pre-scheduled averaging round. Should require_trigger.
+        :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
+        :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
+        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
+        :param averaging_opts: a dict of keyword arguments forwarded into averaging round
+        """
+        if delay_averaging is None:
+            delay_averaging = delay_optimizer_step
+        should_wait = averaging_round or optimizer_step or zero_grad if self.delta_rule_averaging else averaging_round
+        if wait_for_delayed_updates is None:
+            wait_for_delayed_updates = should_wait
+        if should_wait and not (wait_for_delayed_updates and apply_delayed_updates):
+            raise ValueError("Should wait for background operation to finish before scheduling new one")
+        assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
+        if delay_optimizer_step:
+            assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
+            assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
+        if averaging_opts and not averaging_round:
+            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
+        if averaging_control is not None:
+            assert averaging_round, "averaging_control is unused if averaging_round is not performed"
+        if wait_for_trigger is not None:
+            assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
+            if not (self.reuse_tensors or self.custom_gradients):
+                # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
+                # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
+                # the possibly different gradients when wait_for_trigger has finished).
+                raise ValueError(
+                    "wait_for_trigger is a low-level option that requires manual gradient manipulation. "
+                    "If you know what you're doing, please refer to the comments in the source code for details"
+                )
+        output = None
+
+        if wait_for_delayed_updates:
+            for pending_update in self.pending_updates:
+                try:
+                    timeout = (averaging_opts or {}).get("averaging_timeout", self._allreduce_timeout)
+                    logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                    output = pending_update.result(timeout)
+                except BaseException:
+                    # exception will be reported below
+                    if not pending_update.done():
+                        pending_update.cancel()
+
+        # remove finished updates, log any exceptions
+        finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
+        self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
+        for finished_update in finished_updates:
+            if finished_update.cancelled() or finished_update.exception():
+                logger.log(self.status_loglevel, f"Background update failed: {finished_update}")
+
+        if apply_delayed_updates:
+            if self.finished_averaging_round.is_set():
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                if self.offload_optimizer and not self.finished_optimizer_step.is_set():
+                    self._apply_optimizer_parameters_()
+                logger.log(self.status_loglevel, "Received parameters from background averaging round")
+                self.finished_averaging_round.clear()
+
+            if self.finished_optimizer_step.is_set():
+                if self.offload_optimizer:
+                    self._apply_optimizer_parameters_()
+                logger.debug("Received parameters from background optimizer step")
+                self.finished_optimizer_step.clear()
+
+        if increment_epoch:
+            self.local_epoch += 1
+
+        if optimizer_step or zero_grad or averaging_round:
+            if self.offload_optimizer and not self.custom_gradients:
+                self._load_local_grads_into_optimizer_()
+
+            pending_update = self.step_executor.submit(
+                self._do,
+                wait_for_trigger,
+                optimizer_step,
+                zero_grad,
+                averaging_round,
+                averaging_control,
+                grad_scaler,
+                **averaging_opts or {},
+            )
+            self.pending_updates.add(pending_update)
+
+            should_await_optimizer = (optimizer_step or zero_grad) and not delay_optimizer_step
+            should_await_averaging = averaging_round and not delay_averaging
+
+            if should_await_optimizer:
+                self.finished_optimizer_step.wait()
+                self.finished_optimizer_step.clear()
+                if self.offload_optimizer and not should_await_averaging:
+                    self._apply_optimizer_parameters_()
+                logger.debug("Finished optimizer step")
+
+            if should_await_averaging:
+                self.finished_averaging_round.wait()
+                self.finished_averaging_round.clear()
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                if self.offload_optimizer:
+                    self._apply_optimizer_parameters_()
+                logger.log(self.status_loglevel, "Finished averaging round")
+
+            async_averaging = averaging_round and delay_averaging
+            async_optimizer = (optimizer_step or zero_grad) and delay_optimizer_step
+
+            if not (async_averaging or async_optimizer):
+                try:
+                    output = pending_update.result()
+                finally:
+                    self.pending_updates.remove(pending_update)
+
+        return output
+
+    def _do(
+        self,
+        wait_for_trigger: Optional[Callable[[], Any]],
+        optimizer_step: bool,
+        zero_grad: bool,
+        averaging_round: bool,
+        averaging_control: Optional[StepControl],
+        grad_scaler: Optional[GradScaler],
+        timeout: Optional[float] = None,
+        **kwargs,
+    ):
+        """
+        Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
+        This method is meant to be called in the background executor.
+        """
+        if averaging_control is not None and (averaging_control.triggered or averaging_control.done()):
+            logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {averaging_control}")
+            averaging_control = None
+
+        start_time = time.perf_counter()
+        began_running = False
+
+        try:
+            if averaging_round and averaging_control is None:
+                averaging_control = super().step(
+                    gather=self.local_epoch,
+                    require_trigger=True,
+                    timeout=timeout,
+                    wait=False,
+                    **kwargs,
+                )
+
+            if wait_for_trigger is not None:
+                wait_for_trigger()
+            began_running = True
+
+            with self.lock_optimizer:
+                if optimizer_step:
+                    with self.lock_averaged_tensors if self.reuse_tensors else nullcontext():
+                        logger.debug(f"Running optimizer step")
+                        if grad_scaler is None:
+                            self.optimizer.step()
+                        else:
+                            with grad_scaler.running_global_step():
+                                assert grad_scaler.step(self.optimizer)
+
+                if zero_grad:
+                    logger.debug(f"Running zero grad")
+                    self.optimizer.zero_grad()
+                    if self.offload_optimizer:
+                        for parameter in self.main_parameters:
+                            if parameter.grad is not None:
+                                parameter.grad.zero_()
+
+                self._update_scheduler()
+                self.finished_optimizer_step.set()
+
+            if averaging_round:
+                with self.lock_averaging:
+                    if not self.reuse_tensors:
+                        self._load_local_tensors_into_averager_()
+                    if self.delta_rule_averaging:
+                        # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
+                        with torch.no_grad(), self.get_tensors() as averaged_tensors:
+                            self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
+
+                    self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
+                    try:
+                        averaging_control.allow_allreduce()
+                        gathered = averaging_control.result(timeout=timeout)
+                        logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                    except BaseException as e:
+                        logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                        gathered = {}
+
+                    self.finished_averaging_round.set()
+
+                if self.sync_epoch_when_averaging:
+                    old_epoch = self.local_epoch
+                    for peer_epoch in gathered.values():
+                        self.local_epoch = max(self.local_epoch, peer_epoch)
+                    if self.local_epoch != old_epoch:
+                        logger.log(self.status_loglevel, f"Found peer with newer epoch ({self.local_epoch})")
+                        self._update_scheduler()
+
+        except Exception as e:
+            if not began_running:
+                logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
+            logger.exception(e)
+            if averaging_control is not None and not averaging_control.done():
+                logger.error(f"Cancelled scheduled state averaging round")
+                averaging_control.cancel()
+            self.finished_optimizer_step.set()
+            self.finished_averaging_round.set()
+
+    @torch.no_grad()
+    def _load_local_grads_into_optimizer_(self):
+        """Copy local gradients into the gradient buffers of the offloaded optimizer"""
+        assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
+        opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        for main_param, opt_param in zip(self.main_parameters, opt_parameters):
+            if main_param.grad is not None:
+                opt_param.grad.copy_(main_param.grad, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_optimizer_parameters_(self):
+        """Copy parameters from offloaded optimizer to the main model"""
+        assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
+        offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training"
+        for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
+            main_param.copy_(offloaded_param, non_blocking=True)
+
+    @torch.no_grad()
+    def _load_local_tensors_into_averager_(self):
+        """Copy local tensors into the averaging buffers"""
+        assert not self.reuse_tensors, "No need to load tensors into averager: both tensors share the same memory"
+        with self.get_tensors() as averaged_tensors:
+            for local_tensor, averaged_tensor in zip(self._local_tensors(), averaged_tensors):
+                averaged_tensor.copy_(local_tensor, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_averaging_results_(self):
+        """Copy averaged tensors into their respective local tensors"""
+        assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        if self.delta_rule_averaging and self._old_tensors is None:
+            logger.warning("Using delta_rule_averaging, but old tensors were not found. Averaging may have failed")
+        with self.get_tensors() as averaged_tensors:
+            local_tensors = list(self._local_tensors())
+            assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
+            if not self.delta_rule_averaging or self._old_tensors is None:
+                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                    local_tensor.copy_(averaged_tensor, non_blocking=True)
+            else:
+                assert len(self._old_tensors) == len(local_tensors)
+                for local_tensor, new_tensor, old_tensor in zip(local_tensors, averaged_tensors, self._old_tensors):
+                    delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
+                    local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
+
+    @property
+    def averaging_in_progress(self) -> bool:
+        return self.lock_averaging.locked()
+
+    def get_current_state(self):
+        """
+        Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
+        :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+        """
+        with torch.no_grad(), self.lock_averaged_tensors:
+            optimized_parameters = tuple(
+                param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
+            )
+            parameter_infos = [
+                CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+                for param, key in zip(optimized_parameters, self.parameter_names)
+            ]
+            extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            extra_infos = [
+                CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+                for i, extra_tensor in enumerate(extra_tensors)
+            ]
+            optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.optimizer)
+            optimizer_infos = [
+                CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
+                for i, opt_tensor in enumerate(optimizer_tensors)
+            ]
+
+        metadata = dict(
+            epoch=self.local_epoch, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata
+        )
+        all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
+        return metadata, all_tensors, all_tensor_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        opt_parameters = tuple(param for param_group in self.optimizer.param_groups for param in param_group["params"])
+        main_parameters_and_extras = tuple(chain(opt_parameters, self.extra_tensors))
+        num_parameters_and_extras = len(main_parameters_and_extras)
+
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        if (not isinstance(metadata.get("epoch"), int)) or metadata["epoch"] < self.local_epoch:
+            logger.warning("Cowardly refusing to load state from peer: peer's epoch is behind our local epoch")
+            return
+
+        loaded_parameters_and_extras = flat_tensors[:num_parameters_and_extras]
+        loaded_opt_tensors = flat_tensors[num_parameters_and_extras:]
+        if num_parameters_and_extras != len(loaded_parameters_and_extras):
+            logger.error("Failed to load state from peer, received parameters, extras or metadata")
+            return
+
+        with torch.no_grad(), self.lock_averaged_tensors:
+            try:
+                load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
+            except StopIteration:
+                logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
+                return
+
+            for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
+                local_param.copy_(loaded_param, non_blocking=True)
+
+        if self.offload_optimizer:
+            self._apply_optimizer_parameters_()
+        if not self.reuse_tensors:
+            self._load_local_tensors_into_averager_()
+
+        self.local_epoch = metadata["epoch"]
+        self._update_scheduler()
+
+    def _update_scheduler(self):
+        """Increase the scheduler state until it becomes synchronized with local epoch"""
+        if self.scheduler:
+            while self.scheduler._step_count <= self.local_epoch:
+                self.scheduler.step()
+
+
+def initialize_optimizer_state_(opt: torch.optim.Optimizer):
+    """Initialize optimizer statistics by running a virtual optimizer step with zero gradients"""
+    flat_params = tuple(param for group in opt.param_groups for param in group["params"])
+    old_grads = []
+    for param in flat_params:
+        old_grads.append(param.grad)
+        param.grad = torch.zeros_like(param)
+    opt.step()
+    for param, old_grad in zip(flat_params, old_grads):
+        param.grad = old_grad
+
+
+def dump_optimizer_state(opt: torch.optim.Optimizer):
+    """Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers"""
+    with torch.no_grad():
+        flat_metadata, flat_tensors = [], []
+        for elem in nested_flatten(opt.state_dict()):
+            if isinstance(elem, torch.Tensor):
+                flat_metadata.append(dict(type="tensor", index=len(flat_tensors)))
+                flat_tensors.append(elem.cpu())
+            else:
+                flat_metadata.append(dict(type="value", value=elem))
+        return flat_metadata, flat_tensors
+
+
+def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
+    """Load a state obtained by dump_optimizer_state back into the optimizer"""
+    flat_optimizer_state = []
+    for elem in flat_metadata:
+        if elem.get("type") == "tensor" and isinstance(elem.get("index"), int):
+            flat_optimizer_state.append(flat_tensors[elem["index"]])
+        elif elem.get("type") == "value" and "value" in elem:
+            flat_optimizer_state.append(elem["value"])
+    return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))

+ 48 - 13
hivemind/averaging/training.py → hivemind/optim/training_averager.py

@@ -8,6 +8,7 @@ from typing import Dict, Iterator, Optional, Sequence
 import torch
 
 from hivemind.averaging import DecentralizedAverager
+from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
@@ -41,23 +42,28 @@ class TrainingAverager(DecentralizedAverager):
         average_gradients: bool,
         average_opt_statistics: Sequence[str] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
+        parameter_names: Optional[Sequence[str]] = None,
         initialize_optimizer: bool = True,
         **kwargs
     ):
+        if initialize_optimizer:
+            initialize_optimizer_state(opt)  # note: this will run one optimizer step!
+        if parameter_names is None:
+            parameter_names = tuple(i for group in opt.param_groups for i in range(len(group["params"])))
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
+        self.parameter_names = parameter_names
         self.step_executor = ThreadPoolExecutor(max_workers=1)
         self.lock_averager_step = Lock()
         self.pending_updates_done = Event()
         self.pending_updates_done.set()
-        if initialize_optimizer:
-            initialize_optimizer_state(opt)  # note: this will run one optimizer step!
 
         with torch.no_grad():
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
-        super().__init__(averaged_tensors=averaged_tensors, **kwargs)
+
+        super().__init__(averaged_tensors=averaged_tensors, tensor_infos=list(self.tensor_infos()), **kwargs)
 
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
         """
@@ -95,7 +101,7 @@ class TrainingAverager(DecentralizedAverager):
                 self.pending_updates_done.clear()
                 with data_lock, self.get_tensors() as averaged_tensors:
                     if len(averaged_tensors) != len(local_tensors):
-                        raise RuntimeError("The number of optimized parameters should not change.")
+                        raise RuntimeError("The number of optimized parameters should not change")
 
                     if use_old_local_tensors:
                         # since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents
@@ -119,13 +125,8 @@ class TrainingAverager(DecentralizedAverager):
             self.local_step += 1
             return gathered
 
-    def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:
-        """
-        Iterate local trainer's tensors that should be averaged with peers
-
-        :param replace_none: if True and average_gradients is True, None grads will be replaced with a zero tensors
-          Otherwise, such gradients will be skipped. (this may cause inconsistencies with averaged_tensors)
-        """
+    def local_tensors(self) -> Iterator[torch.Tensor]:
+        """Iterate local trainer's tensors that should be averaged with peers"""
         if self.average_parameters:
             for param_group in self.opt.param_groups:
                 yield from param_group["params"]
@@ -134,7 +135,7 @@ class TrainingAverager(DecentralizedAverager):
                 for param in param_group["params"]:
                     if param.grad is not None:
                         yield param.grad
-                    elif replace_none:
+                    else:
                         yield torch.zeros_like(param)
         for stats in self.opt_statistics:
             for param_group in self.opt.param_groups:
@@ -142,6 +143,26 @@ class TrainingAverager(DecentralizedAverager):
                     yield self.opt.state[param][stats]
         yield from iter(self.extra_tensors)
 
+    def tensor_infos(self):
+        """Get CompressionInfo for each tensor, accounting for its role and specification"""
+        params = tuple(param for param_group in self.opt.param_groups for param in param_group["params"])
+        assert len(params) == len(self.parameter_names)
+        if self.average_parameters:
+            for param, key in zip(params, self.parameter_names):
+                yield CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+        if self.average_gradients:
+            for param, key in zip(params, self.parameter_names):
+                if param.grad is not None:
+                    grad = param.grad if param.grad is not None else torch.zeros_like(param)
+                    yield CompressionInfo.from_tensor(grad, key=key, role=TensorRole.GRADIENT)
+        for stats in self.opt_statistics:
+            for param, key in zip(params, self.parameter_names):
+                yield CompressionInfo.from_tensor(
+                    self.opt.state[param][stats], key=(key, stats), role=TensorRole.OPTIMIZER
+                )
+        for i, extra_tensor in enumerate(self.extra_tensors):
+            yield CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+
     def get_current_state(self):
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
@@ -151,11 +172,25 @@ class TrainingAverager(DecentralizedAverager):
             optimized_parameters = tuple(
                 param.detach().cpu() for param_group in self.opt.param_groups for param in param_group["params"]
             )
+            parameter_infos = [
+                CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+                for param, key in zip(optimized_parameters, self.parameter_names)
+            ]
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            extra_infos = [
+                CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+                for i, extra_tensor in enumerate(extra_tensors)
+            ]
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
+            optimizer_infos = [
+                CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
+                for i, opt_tensor in enumerate(optimizer_tensors)
+            ]
 
         metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
-        return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
+        return metadata, all_tensors, all_tensor_infos
 
     def load_state_from_peers(self, **kwargs):
         """

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1,3 +1,3 @@
-from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
+from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
 from hivemind.p2p.servicer import ServicerBase, StubBase

+ 227 - 112
hivemind/p2p/p2p_daemon.py

@@ -1,21 +1,25 @@
 import asyncio
+import json
+import logging
 import os
 import secrets
 from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
+from datetime import datetime
 from importlib.resources import path
-from subprocess import Popen
-from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
+from google.protobuf.message import Message
 from multiaddr import Multiaddr
 
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
-from hivemind.utils.asyncio import aiter, asingle
-from hivemind.utils.logging import get_logger
+from hivemind.utils.asyncio import as_aiter, asingle
+from hivemind.utils.logging import get_logger, golog_level_to_python, loglevel, python_level_to_golog
 
 logger = get_logger(__name__)
 
@@ -28,7 +32,6 @@ class P2PContext(object):
     handle_name: str
     local_id: PeerID
     remote_id: PeerID = None
-    remote_maddr: Multiaddr = None
 
 
 class P2P:
@@ -54,9 +57,9 @@ class P2P:
     END_OF_STREAM = RPCError()
 
     DHT_MODE_MAPPING = {
-        "dht": {"dht": 1},
-        "dht_server": {"dhtServer": 1},
-        "dht_client": {"dhtClient": 1},
+        "auto": {"dht": 1},
+        "server": {"dhtServer": 1},
+        "client": {"dhtClient": 1},
     }
     FORCE_REACHABILITY_MAPPING = {
         "public": {"forceReachabilityPublic": 1},
@@ -66,57 +69,63 @@ class P2P:
 
     def __init__(self):
         self.peer_id = None
+        self._client = None
         self._child = None
         self._alive = False
+        self._reader_task = None
         self._listen_task = None
-        self._server_stopped = asyncio.Event()
 
     @classmethod
     async def create(
         cls,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        use_ipfs: bool = False,
-        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        *,
         announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        quic: bool = True,
-        tls: bool = True,
+        auto_nat: bool = True,
         conn_manager: bool = True,
-        dht_mode: str = "dht_server",
+        dht_mode: str = "server",
         force_reachability: Optional[str] = None,
+        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        identity_path: Optional[str] = None,
+        idle_timeout: float = 30,
         nat_port_map: bool = True,
-        auto_nat: bool = True,
+        quic: bool = False,
+        relay_hop_limit: int = 0,
+        startup_timeout: float = 15,
+        tls: bool = True,
+        use_auto_relay: bool = False,
+        use_ipfs: bool = False,
         use_relay: bool = True,
         use_relay_hop: bool = False,
         use_relay_discovery: bool = False,
-        use_auto_relay: bool = False,
-        relay_hop_limit: int = 0,
-        quiet: bool = True,
-        ping_n_attempts: int = 5,
-        ping_delay: float = 0.4,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
         :param initial_peers: List of bootstrap peers
-        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
-        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param auto_nat: Enables the AutoNAT service
         :param announce_maddrs: Visible multiaddrs that the peer will announce
-          for external connections from other p2p instances
-        :param quic: Enables the QUIC transport
-        :param tls: Enables TLS1.3 channel security protocol
+                                for external connections from other p2p instances
         :param conn_manager: Enables the Connection Manager
-        :param dht_mode: DHT mode (dht_client/dht_server/dht)
+        :param dht_mode: libp2p DHT mode (auto/client/server).
+                         Defaults to "server" to make collaborations work in local networks.
+                         Details: https://pkg.go.dev/github.com/libp2p/go-libp2p-kad-dht#ModeOpt
         :param force_reachability: Force reachability mode (public/private)
+        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
+                              May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``.
+        :param idle_timeout: kill daemon if client has been idle for a given number of
+                             seconds before opening persistent streams
         :param nat_port_map: Enables NAT port mapping
-        :param auto_nat: Enables the AutoNAT service
+        :param quic: Enables the QUIC transport
+        :param relay_hop_limit: sets the hop limit for hop relays
+        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
+        :param tls: Enables TLS1.3 channel security protocol
+        :param use_auto_relay: enables autorelay
+        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
         :param use_relay: enables circuit relay
         :param use_relay_hop: enables hop for relay
         :param use_relay_discovery: enables passive discovery for relay
-        :param use_auto_relay: enables autorelay
-        :param relay_hop_limit: sets the hop limit for hop relays
-        :param quiet: make the daemon process quiet
-        :param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
-        :param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
-          (in particular, wait for ``ping_delay`` seconds before the first attempt)
         :return: a wrapper for the p2p daemon
         """
 
@@ -131,6 +140,11 @@ class P2P:
         socket_uid = secrets.token_urlsafe(8)
         self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pd-{socket_uid}.sock")
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
+        if announce_maddrs is not None:
+            for addr in announce_maddrs:
+                addr = Multiaddr(addr)
+                if ("tcp" in addr and addr["tcp"] == "0") or ("udp" in addr and addr["udp"] == "0"):
+                    raise ValueError("Please specify an explicit port in announce_maddrs: port 0 is not supported")
 
         need_bootstrap = bool(initial_peers) or use_ipfs
         process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})
@@ -142,52 +156,55 @@ class P2P:
         ]:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
+        if identity_path is not None:
+            process_kwargs["id"] = identity_path
 
         proc_args = self._make_process_args(
             str(p2pd_path),
-            listen=self._daemon_listen_maddr,
-            quic=quic,
-            tls=tls,
+            autoRelay=use_auto_relay,
+            autonat=auto_nat,
+            b=need_bootstrap,
             connManager=conn_manager,
+            idleTimeout=f"{idle_timeout}s",
+            listen=self._daemon_listen_maddr,
             natPortMap=nat_port_map,
-            autonat=auto_nat,
+            quic=quic,
             relay=use_relay,
-            relayHop=use_relay_hop,
             relayDiscovery=use_relay_discovery,
-            autoRelay=use_auto_relay,
+            relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
-            b=need_bootstrap,
-            q=quiet,
+            tls=tls,
+            persistentConnMaxMsgSize=persistent_conn_max_msg_size,
             **process_kwargs,
         )
 
-        self._child = Popen(args=proc_args, encoding="utf8")
+        env = os.environ.copy()
+        env.setdefault("GOLOG_LOG_LEVEL", python_level_to_golog(loglevel))
+        env["GOLOG_LOG_FMT"] = "json"
+
+        logger.debug(f"Launching {proc_args}")
+        self._child = await asyncio.subprocess.create_subprocess_exec(
+            *proc_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, env=env
+        )
         self._alive = True
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
 
-        await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)
+        ready = asyncio.Future()
+        self._reader_task = asyncio.create_task(self._read_outputs(ready))
+        try:
+            await asyncio.wait_for(ready, startup_timeout)
+        except asyncio.TimeoutError:
+            await self.shutdown()
+            raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
+
+        self._client = await p2pclient.Client.create(
+            control_maddr=self._daemon_listen_maddr,
+            listen_maddr=self._client_listen_maddr,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
 
+        await self._ping_daemon()
         return self
 
-    async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
-        for try_number in range(ping_n_attempts):
-            await asyncio.sleep(ping_delay * (2 ** try_number))
-
-            if self._child.poll() is not None:  # Process died
-                break
-
-            try:
-                await self._ping_daemon()
-                break
-            except Exception as e:
-                if try_number == ping_n_attempts - 1:
-                    logger.exception("Failed to ping p2pd that has just started")
-                    await self.shutdown()
-                    raise
-
-        if self._child.returncode is not None:
-            raise RuntimeError(f"The p2p daemon has died with return code {self._child.returncode}")
-
     @classmethod
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """
@@ -206,7 +223,7 @@ class P2P:
         self._daemon_listen_maddr = daemon_listen_maddr
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
 
         await self._ping_daemon()
         return self
@@ -275,7 +292,7 @@ class P2P:
 
     @staticmethod
     async def receive_protobuf(
-        input_protobuf_type: type, reader: asyncio.StreamReader
+        input_protobuf_type: Type[Message], reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
@@ -296,7 +313,7 @@ class P2P:
         self,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
     ) -> None:
         """
@@ -314,7 +331,6 @@ class P2P:
                 handle_name=name,
                 local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
-                remote_maddr=stream_info.addr,
             )
             requests = asyncio.Queue(max_prefetch)
 
@@ -328,10 +344,18 @@ class P2P:
             async def _process_stream() -> None:
                 try:
                     async for response in handler(_read_stream(), context):
-                        await P2P.send_protobuf(response, writer)
+                        try:
+                            await P2P.send_protobuf(response, writer)
+                        except Exception:
+                            # The connection is unexpectedly closed by the caller or broken.
+                            # The loglevel is DEBUG since the actual error will be reported on the caller
+                            logger.debug("Exception while sending response:", exc_info=True)
+                            break
                 except Exception as e:
-                    logger.warning("Exception while processing stream and sending responses:", exc_info=True)
-                    await P2P.send_protobuf(RPCError(message=str(e)), writer)
+                    logger.warning("Handler failed with the exception:", exc_info=True)
+                    with suppress(Exception):
+                        # Sometimes `e` is a connection error, so it is okay if we fail to report `e` to the caller
+                        await P2P.send_protobuf(RPCError(message=str(e)), writer)
 
             with closing(writer):
                 processing_task = asyncio.create_task(_process_stream())
@@ -358,7 +382,7 @@ class P2P:
         await self.add_binary_stream_handler(name, _handle_stream)
 
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
     ) -> TOutputStream:
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
@@ -367,22 +391,25 @@ class P2P:
                 await P2P.send_protobuf(request, writer)
             await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
 
-        with closing(writer):
-            writing_task = asyncio.create_task(_write_to_stream())
-            try:
-                while True:
-                    try:
-                        response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
-                    except asyncio.IncompleteReadError:  # Connection is closed
-                        break
+        async def _read_from_stream() -> AsyncIterator[Message]:
+            with closing(writer):
+                try:
+                    while True:
+                        try:
+                            response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
+                        except asyncio.IncompleteReadError:  # Connection is closed
+                            break
+
+                        if err is not None:
+                            raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
+                        yield response
 
-                    if err is not None:
-                        raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
-                    yield response
+                    await writing_task
+                finally:
+                    writing_task.cancel()
 
-                await writing_task
-            finally:
-                writing_task.cancel()
+        writing_task = asyncio.create_task(_write_to_stream())
+        return _read_from_stream()
 
     async def add_protobuf_handler(
         self,
@@ -390,15 +417,22 @@ class P2P:
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         *,
         stream_input: bool = False,
+        stream_output: bool = False,
     ) -> None:
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
                              (not just ``TInputProtobuf``) as input.
+        :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
+                              (not ``Awaitable[TOutputProtobuf]``).
         """
 
+        if not stream_input and not stream_output:
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            return
+
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
             input = requests if stream_input else await asingle(requests)
             output = handler(input, context)
@@ -411,44 +445,76 @@ class P2P:
 
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
 
+    async def _add_protobuf_unary_handler(
+        self,
+        handle_name: str,
+        handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
+        input_protobuf_type: Type[Message],
+    ) -> None:
+        """
+        Register a request-response (unary) handler. Unary requests and responses
+        are sent through persistent multiplexed connections to the daemon for the
+        sake of reducing the number of open files.
+        :param handle_name: name of the handler (protocol id)
+        :param handler: function handling the unary requests
+        :param input_protobuf_type: protobuf type of the request
+        """
+
+        async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
+            input_serialized = input_protobuf_type.FromString(request)
+            context = P2PContext(
+                handle_name=handle_name,
+                local_id=self.peer_id,
+                remote_id=remote_id,
+            )
+
+            response = await handler(input_serialized, context)
+            return response.SerializeToString()
+
+        await self._client.add_unary_handler(handle_name, _unary_handler)
+
     async def call_protobuf_handler(
         self,
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+
+        if not isinstance(input, AsyncIterableABC):
+            return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
+
+        responses = await self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
 
-    def iterate_protobuf_handler(
+    async def _call_unary_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        handle_name: str,
+        input: TInputProtobuf,
+        output_protobuf_type: Type[Message],
+    ) -> Awaitable[TOutputProtobuf]:
+        serialized_input = input.SerializeToString()
+        response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
+        return output_protobuf_type.FromString(response)
+
+    async def iterate_protobuf_handler(
         self,
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> TOutputStream:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
-        return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+        requests = input if isinstance(input, AsyncIterableABC) else as_aiter(input)
+        return await self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
 
     def _start_listening(self) -> None:
         async def listen() -> None:
             async with self._client.listen():
-                await self._server_stopped.wait()
+                await asyncio.Future()  # Wait until this task will be cancelled in _terminate()
 
         self._listen_task = asyncio.create_task(listen())
 
-    async def _stop_listening(self) -> None:
-        if self._listen_task is not None:
-            self._server_stopped.set()
-            self._listen_task.cancel()
-            try:
-                await self._listen_task
-            except asyncio.CancelledError:
-                self._listen_task = None
-                self._server_stopped.clear()
-
     async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
         if self._listen_task is None:
             self._start_listening()
@@ -467,14 +533,21 @@ class P2P:
         return self._alive
 
     async def shutdown(self) -> None:
-        await self._stop_listening()
-        await asyncio.get_event_loop().run_in_executor(None, self._terminate)
+        self._terminate()
+        if self._child is not None:
+            await self._child.wait()
 
     def _terminate(self) -> None:
+        if self._client is not None:
+            self._client.close()
+        if self._listen_task is not None:
+            self._listen_task.cancel()
+        if self._reader_task is not None:
+            self._reader_task.cancel()
+
         self._alive = False
-        if self._child is not None and self._child.poll() is None:
+        if self._child is not None and self._child.returncode is None:
             self._child.terminate()
-            self._child.wait()
             logger.debug(f"Terminated p2pd with id = {self.peer_id}")
 
             with suppress(FileNotFoundError):
@@ -502,10 +575,52 @@ class P2P:
     def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
         return ",".join(str(addr) for addr in maddrs)
 
+    async def _read_outputs(self, ready: asyncio.Future) -> None:
+        last_line = None
+        while True:
+            line = await self._child.stdout.readline()
+            if not line:  # Stream closed
+                break
+            last_line = line.rstrip().decode(errors="ignore")
 
-class P2PInterruptedError(Exception):
-    pass
+            self._log_p2pd_message(last_line)
+            if last_line.startswith("Peer ID:"):
+                ready.set_result(None)
 
+        if not ready.done():
+            ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
 
-class P2PHandlerError(Exception):
-    pass
+    @staticmethod
+    def _log_p2pd_message(line: str) -> None:
+        if '"logger"' not in line:  # User-friendly info from p2pd stdout
+            logger.debug(line, extra={"caller": "p2pd"})
+            return
+
+        try:
+            record = json.loads(line)
+            caller = record["caller"]
+
+            level = golog_level_to_python(record["level"])
+            if level <= logging.WARNING:
+                # Many Go loggers are excessively verbose (e.g. show warnings for unreachable peers),
+                # so we downgrade INFO and WARNING messages to DEBUG.
+                # The Go verbosity can still be controlled via the GOLOG_LOG_LEVEL env variable.
+                # Details: https://github.com/ipfs/go-log#golog_log_level
+                level = logging.DEBUG
+
+            message = record["msg"]
+            if "error" in record:
+                message += f": {record['error']}"
+
+            logger.log(
+                level,
+                message,
+                extra={
+                    "origin_created": datetime.strptime(record["ts"], "%Y-%m-%dT%H:%M:%S.%f%z").timestamp(),
+                    "caller": caller,
+                },
+            )
+        except Exception:
+            # Parsing errors are unlikely, but we don't want to lose these messages anyway
+            logger.warning(line, extra={"caller": "p2pd"})
+            logger.exception("Failed to parse go-log message:")

+ 209 - 3
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -5,8 +5,9 @@ Author: Kevin Mai-Husan Chia
 """
 
 import asyncio
-from contextlib import asynccontextmanager
-from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
+from contextlib import asynccontextmanager, closing
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
+from uuid import UUID, uuid4
 
 from multiaddr import Multiaddr, protocols
 
@@ -25,6 +26,8 @@ SUPPORT_CONN_PROTOCOLS = (
 SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
 logger = get_logger(__name__)
 
+DEFAULT_MAX_MSG_SIZE = 4 * 1024 ** 2
+
 
 def parse_conn_protocol(maddr: Multiaddr) -> int:
     proto_codes = set(proto.code for proto in maddr.protocols())
@@ -54,17 +57,84 @@ class DaemonConnector:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
+    async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        """
+        Open connection to daemon and upgrade it to a persistent one
+        """
+        reader, writer = await self.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
+        await write_pbmsg(writer, req)
+
+        response = p2pd_pb.Response()
+        await read_pbmsg_safe(reader, response)
+
+        if response.type == "ERROR":
+            raise P2PDaemonError(response.error.msg)
+
+        return reader, writer
+
+
+TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
+CallID = UUID
+
 
 class ControlClient:
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
 
     def __init__(
-        self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+        self,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        *,
+        _initialized_with_create: bool = False,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
     ) -> None:
+        assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
+
+        self.persistent_conn_max_msg_size = persistent_conn_max_msg_size
+
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
 
+        self.unary_handlers: Dict[str, TUnaryHandler] = {}
+
+        self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
+        self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self._handler_tasks: Dict[CallID, asyncio.Task] = {}
+
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
+
+    @classmethod
+    async def create(
+        cls,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        use_persistent_conn: bool = True,
+        persistent_conn_max_msg_size=2 << 22,
+    ) -> "ControlClient":
+        control = cls(
+            daemon_connector,
+            listen_maddr,
+            _initialized_with_create=True,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
+
+        if use_persistent_conn:
+            await control._ensure_persistent_conn()
+
+        return control
+
+    def close(self) -> None:
+        if self._read_task is not None:
+            self._read_task.cancel()
+        if self._write_task is not None:
+            self._write_task.cancel()
+
+    def __del__(self):
+        self.close()
+
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
@@ -93,6 +163,130 @@ class ControlClient:
         async with server:
             yield self
 
+    async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
+        while True:
+            resp = p2pd_pb.PersistentConnectionResponse()
+            try:
+                await read_pbmsg_safe(reader, resp)
+            except asyncio.IncompleteReadError:
+                break
+
+            call_id = UUID(bytes=resp.callId)
+
+            if resp.HasField("callUnaryResponse"):
+                if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
+                    self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
+                elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
+                    remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
+                    self._pending_calls[call_id].set_exception(remote_exc)
+                else:
+                    logger.debug(f"Received unexpected unary call: {resp}")
+
+            elif resp.HasField("requestHandling"):
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self._handler_tasks[call_id] = handler_task
+
+            elif call_id in self._handler_tasks and resp.HasField("cancel"):
+                self._handler_tasks[call_id].cancel()
+
+            elif call_id in self._pending_calls and resp.HasField("daemonError"):
+                daemon_exc = P2PDaemonError(resp.daemonError.message)
+                self._pending_calls[call_id].set_exception(daemon_exc)
+
+            elif call_id in self._pending_calls:
+                self._pending_calls[call_id].set_result(None)
+
+            else:
+                logger.debug(f"Received unexpected response from daemon: {resp}")
+
+    async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        with closing(writer):
+            while True:
+                msg = await self._pending_messages.get()
+                await write_pbmsg(writer, msg)
+
+    async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest):
+        if request.proto not in self.unary_handlers:
+            logger.warning(f"Protocol {request.proto} not supported")
+            return
+
+        try:
+            remote_id = PeerID(request.peer)
+            response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
+            response = p2pd_pb.CallUnaryResponse(response=response_payload)
+
+        except Exception as e:
+            response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
+
+        payload = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, unaryResponse=response)
+        if payload.ByteSize() <= self.persistent_conn_max_msg_size:
+            await self._pending_messages.put(payload)
+        else:
+            error_msg = p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                callUnaryResponse=p2pd_pb.CallUnaryResponse(
+                    error=b"response size exceeds message size limit",
+                ),
+            )
+            await self._pending_messages.put(error_msg)
+
+        self._handler_tasks.pop(call_id)
+
+    async def _cancel_unary_call(self, call_id: UUID):
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                cancel=p2pd_pb.Cancel(),
+            ),
+        )
+
+    async def _ensure_persistent_conn(self):
+        reader, writer = await self.daemon_connector.open_persistent_connection()
+
+        self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+        self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        call_id = uuid4()
+
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+
+        if self.unary_handlers.get(proto):
+            raise P2PDaemonError(f"Handler for protocol {proto} already registered")
+        self.unary_handlers[proto] = handler
+
+        self._pending_calls[call_id] = asyncio.Future()
+        await self._pending_messages.put(req)
+        await self._pending_calls[call_id]
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        call_id = uuid4()
+        call_unary_req = p2pd_pb.CallUnaryRequest(
+            peer=peer_id.to_bytes(),
+            proto=proto,
+            data=data,
+        )
+        req = p2pd_pb.PersistentConnectionRequest(
+            callId=call_id.bytes,
+            callUnary=call_unary_req,
+        )
+
+        if req.ByteSize() > self.persistent_conn_max_msg_size:
+            raise P2PDaemonError(f"Message size exceeds set limit {self.persistent_conn_max_msg_size}")
+
+        try:
+            self._pending_calls[call_id] = asyncio.Future()
+            await self._pending_messages.put(req)
+            return await self._pending_calls[call_id]
+
+        except asyncio.CancelledError:
+            await self._cancel_unary_call(call_id)
+            raise
+
+        finally:
+            self._pending_calls.pop(call_id, None)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
@@ -179,3 +373,15 @@ class ControlClient:
 
         # if success, add the handler to the dict
         self.handlers[proto] = handler_cb
+
+
+class P2PHandlerError(Exception):
+    """
+    Raised if remote handled a request with an exception
+    """
+
+
+class P2PDaemonError(Exception):
+    """
+    Raised if daemon failed to handle request
+    """

+ 9 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -74,6 +74,12 @@ class PeerID:
         else:
             return False
 
+    def __lt__(self, other: object) -> bool:
+        if not isinstance(other, PeerID):
+            raise TypeError(f"'<' not supported between instances of 'PeerID' and '{type(other)}'")
+
+        return self.to_base58() < other.to_base58()
+
     def __hash__(self) -> int:
         return hash(self._bytes)
 
@@ -125,6 +131,9 @@ class PeerInfo:
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
+    def __repr__(self):
+        return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})"
+
 
 class InvalidAddrError(ValueError):
     pass

+ 40 - 3
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,16 +10,47 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 from multiaddr import Multiaddr
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler
+from hivemind.p2p.p2p_daemon_bindings.control import (
+    DEFAULT_MAX_MSG_SIZE,
+    ControlClient,
+    DaemonConnector,
+    StreamHandler,
+    TUnaryHandler,
+)
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
 class Client:
     control: ControlClient
 
-    def __init__(self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> None:
+    def __init__(self, *, _initialized_with_create=False) -> None:
+        assert _initialized_with_create, "Please use Client.create coroutine to spawn new client instances"
+        self.control = None
+
+    @classmethod
+    async def create(
+        cls,
+        control_maddr: Multiaddr = None,
+        listen_maddr: Multiaddr = None,
+        *,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
+    ) -> "Client":
+        client = cls(_initialized_with_create=True)
+
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        self.control = ControlClient(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(
+            daemon_connector=daemon_connector,
+            listen_maddr=listen_maddr,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
+
+        return client
+
+    def close(self) -> None:
+        self.control.close()
+
+    def __del__(self):
+        self.close()
 
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["Client"]:
@@ -30,6 +61,12 @@ class Client:
         async with self.control.listen():
             yield self
 
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        await self.control.add_unary_handler(proto, handler)
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self.control.call_unary_handler(peer_id, proto, data)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         """
         Get current node peer id and list of addresses

+ 22 - 33
hivemind/p2p/servicer.py

@@ -86,38 +86,21 @@ class ServicerBase:
     @classmethod
     def _make_rpc_caller(cls, handler: RPCHandler):
         input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
+        output_type = AsyncIterator[handler.response_type] if handler.stream_output else handler.response_type
 
         # This method will be added to a new Stub type (a subclass of StubBase)
-        if handler.stream_output:
-
-            def caller(
-                self: StubBase, input: input_type, timeout: None = None
-            ) -> AsyncIterator[handler.response_type]:
-                if timeout is not None:
-                    raise ValueError("Timeouts for handlers returning streams are not supported")
-
-                return self._p2p.iterate_protobuf_handler(
-                    self._peer,
-                    cls._get_handle_name(self._namespace, handler.method_name),
-                    input,
-                    handler.response_type,
-                )
-
-        else:
-
-            async def caller(
-                self: StubBase, input: input_type, timeout: Optional[float] = None
-            ) -> handler.response_type:
+        async def caller(self: StubBase, input: input_type, timeout: Optional[float] = None) -> output_type:
+            handle_name = cls._get_handle_name(self._namespace, handler.method_name)
+            if not handler.stream_output:
                 return await asyncio.wait_for(
-                    self._p2p.call_protobuf_handler(
-                        self._peer,
-                        cls._get_handle_name(self._namespace, handler.method_name),
-                        input,
-                        handler.response_type,
-                    ),
+                    self._p2p.call_protobuf_handler(self._peer, handle_name, input, handler.response_type),
                     timeout=timeout,
                 )
 
+            if timeout is not None:
+                raise ValueError("Timeouts for handlers returning streams are not supported")
+            return await self._p2p.iterate_protobuf_handler(self._peer, handle_name, input, handler.response_type)
+
         caller.__name__ = handler.method_name
         return caller
 
@@ -125,13 +108,19 @@ class ServicerBase:
         self._collect_rpc_handlers()
 
         servicer = self if wrapper is None else wrapper
-        for handler in self._rpc_handlers:
-            await p2p.add_protobuf_handler(
-                self._get_handle_name(namespace, handler.method_name),
-                getattr(servicer, handler.method_name),
-                handler.request_type,
-                stream_input=handler.stream_input,
-            )
+
+        await asyncio.gather(
+            *[
+                p2p.add_protobuf_handler(
+                    self._get_handle_name(namespace, handler.method_name),
+                    getattr(servicer, handler.method_name),
+                    handler.request_type,
+                    stream_input=handler.stream_input,
+                    stream_output=handler.stream_output,
+                )
+                for handler in self._rpc_handlers
+            ]
+        )
 
     @classmethod
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:

+ 1 - 1
hivemind/proto/averaging.proto

@@ -45,7 +45,7 @@ message AveragingData {
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
-  bytes metadata = 5;       // reserved user-extendable metadata
+  double weight = 5;        // peers will be averaged in proportion to these weights
 }
 
 message DownloadRequest {}

+ 59 - 10
hivemind/proto/p2pd.proto

@@ -8,15 +8,17 @@ package p2pclient.p2pd.pb;
 
 message Request {
   enum Type {
-    IDENTIFY       = 0;
-    CONNECT        = 1;
-    STREAM_OPEN    = 2;
-    STREAM_HANDLER = 3;
-    DHT            = 4;
-    LIST_PEERS     = 5;
-    CONNMANAGER    = 6;
-    DISCONNECT     = 7;
-    PUBSUB         = 8;
+    IDENTIFY                 = 0;
+    CONNECT                  = 1;
+    STREAM_OPEN              = 2;
+    STREAM_HANDLER           = 3;
+    DHT                      = 4;
+    LIST_PEERS               = 5;
+    CONNMANAGER              = 6;
+    DISCONNECT               = 7;      
+    PUBSUB                   = 8;
+
+    PERSISTENT_CONN_UPGRADE  = 9;
   }
 
   required Type type = 1;
@@ -45,6 +47,29 @@ message Response {
   optional PSResponse pubsub = 7;
 }
 
+message PersistentConnectionRequest {
+  required bytes callId = 1;
+
+  oneof message {
+    AddUnaryHandlerRequest addUnaryHandler = 2;
+    CallUnaryRequest  callUnary = 3;
+    CallUnaryResponse unaryResponse = 4;
+    Cancel cancel = 5;
+  }
+}
+
+message PersistentConnectionResponse {
+  required bytes callId = 1;
+
+  oneof message {
+    CallUnaryResponse callUnaryResponse = 2;
+    CallUnaryRequest requestHandling = 3;
+    DaemonError daemonError = 4;
+    Cancel cancel = 5;
+  }
+}
+
+
 message IdentifyResponse {
   required bytes id = 1;
   repeated bytes addrs = 2;
@@ -148,7 +173,7 @@ message PSRequest {
 }
 
 message PSMessage {
-  optional bytes from_id = 1;
+  optional bytes from = 1;
   optional bytes data = 2;
   optional bytes seqno = 3;
   repeated string topicIDs = 4;
@@ -161,6 +186,30 @@ message PSResponse {
   repeated bytes peerIDs = 2;
 }
 
+message CallUnaryRequest {
+  required bytes peer = 1;
+  required string proto = 2;
+  required bytes data = 3;
+}
+
+message CallUnaryResponse {
+  oneof result {
+    bytes response = 1;
+    bytes error = 2;
+  }
+}
+
+message AddUnaryHandlerRequest {
+  required string proto = 1;
+}
+
+message DaemonError {
+  optional string message = 1;
+}
+
+message Cancel {
+}
+
 message RPCError {
   optional string message = 1;
 }

+ 2 - 2
hivemind/utils/__init__.py

@@ -1,11 +1,11 @@
 from hivemind.utils.asyncio import *
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
+from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 90 - 11
hivemind/utils/asyncio.py

@@ -1,6 +1,8 @@
 import asyncio
+import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
+from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
 
 import uvloop
 
@@ -27,7 +29,7 @@ async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
     return await aiter.__anext__()
 
 
-async def aiter(*args: T) -> AsyncIterator[T]:
+async def as_aiter(*args: T) -> AsyncIterator[T]:
     """create an asynchronous iterator from a sequence of values"""
     for arg in args:
         yield arg
@@ -59,7 +61,7 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]
 
 
 async def asingle(aiter: AsyncIterable[T]) -> T:
-    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`."""
+    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises ``ValueError``."""
     count = 0
     async for item in aiter:
         count += 1
@@ -70,20 +72,41 @@ async def asingle(aiter: AsyncIterable[T]) -> T:
     return item
 
 
+async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
+    """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
+    async for item in aiter:
+        return item
+    return default
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable
         return False
-    except asyncio.CancelledError:
+    except (asyncio.CancelledError, concurrent.futures.CancelledError):
+        # In Python 3.7, awaiting a cancelled asyncio.Future raises concurrent.futures.CancelledError
+        # instead of asyncio.CancelledError
         return True
     except BaseException:
+        logger.exception(f"Exception in {awaitable}:")
         return False
 
 
+async def cancel_and_wait(awaitable: Awaitable) -> bool:
+    """
+    Cancels ``awaitable`` and waits for its cancellation.
+    In case of ``asyncio.Task``, helps to avoid ``Task was destroyed but it is pending!`` errors.
+    In case of ``asyncio.Future``, equal to ``future.cancel()``.
+    """
+
+    awaitable.cancel()
+    return await await_cancelled(awaitable)
+
+
 async def amap_in_executor(
     func: Callable[..., T],
     *iterables: AsyncIterable,
-    max_prefetch: Optional[int] = None,
+    max_prefetch: int = 1,
     executor: Optional[ThreadPoolExecutor] = None,
 ) -> AsyncIterator[T]:
     """iterate from an async iterable in a background thread, yield results to async iterable"""
@@ -91,9 +114,14 @@ async def amap_in_executor(
     queue: asyncio.Queue[Optional[Awaitable[T]]] = asyncio.Queue(max_prefetch)
 
     async def _put_items():
-        async for args in azip(*iterables):
-            await queue.put(loop.run_in_executor(executor, func, *args))
-        await queue.put(None)
+        try:
+            async for args in azip(*iterables):
+                await queue.put(loop.run_in_executor(executor, func, *args))
+            await queue.put(None)
+        except Exception as e:
+            future = asyncio.Future()
+            future.set_exception(e)
+            await queue.put(future)
 
     task = asyncio.create_task(_put_items())
     try:
@@ -101,7 +129,58 @@ async def amap_in_executor(
         while future is not None:
             yield await future
             future = await queue.get()
-        await task
     finally:
-        if not task.done():
-            task.cancel()
+        awaitables = [task]
+        while not queue.empty():
+            future = queue.get_nowait()
+            if future is not None:
+                awaitables.append(future)
+        for coro in awaitables:
+            coro.cancel()
+            try:
+                await coro
+            except BaseException as e:
+                if isinstance(e, Exception):
+                    logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
+                # note: we do not reraise here because it is already in the finally clause
+
+
+async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: Optional[float]) -> AsyncIterator[T]:
+    """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
+    # based on https://stackoverflow.com/a/50245879
+    iterator = iterable.__aiter__()
+    while True:
+        try:
+            yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
+        except StopAsyncIteration:
+            break
+
+
+async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Event()) -> AsyncIterator[T]:
+    """Iterate over an async iterable and set an event when the iteration has stopped, failed or terminated"""
+    try:
+        async for item in iterable:
+            yield item
+    finally:
+        event.set()
+
+
+class _AsyncContextWrapper(AbstractAsyncContextManager):
+    """Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
+
+    def __init__(self, context: AbstractContextManager):
+        self._context = context
+
+    async def __aenter__(self):
+        loop = asyncio.get_event_loop()
+        return await loop.run_in_executor(None, self._context.__enter__)
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        return self._context.__exit__(exc_type, exc_value, traceback)
+
+
+@asynccontextmanager
+async def enter_asynchronously(context: AbstractContextManager):
+    """Wrap a non-async context so that it can be entered asynchronously"""
+    async with _AsyncContextWrapper(context) as ret_value:
+        yield ret_value

+ 0 - 209
hivemind/utils/compression.py

@@ -1,209 +0,0 @@
-import os
-import warnings
-from concurrent.futures import ThreadPoolExecutor
-from typing import Optional, Sequence, Tuple
-
-import numpy as np
-import torch
-
-from hivemind.proto import runtime_pb2
-from hivemind.proto.runtime_pb2 import CompressionType
-
-FP32_EPS = 1e-06
-NUM_BYTES_FLOAT32 = 4
-NUM_BYTES_FLOAT16 = 2
-NUM_BITS_QUANTILE_COMPRESSION = 8
-NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
-UNIFORM_BUCKETS_STD_RANGE = 6
-FP16_MAX = 65_504
-UINT8_RANGE = 256
-
-COMPRESSION_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTILE_COMPRESSION_THREADS", 128)))
-
-warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
-
-
-def _quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
-    n_bins = 2 ** n_bits
-    borders = torch.as_tensor(_quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
-    quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
-    lookup = average_buckets(tensor, quant_weight, n_bins)
-    return quant_weight, lookup
-
-
-def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
-    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
-    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
-    lookup = bin_sums / bin_counts
-    return lookup
-
-
-def _quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
-    """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
-    if not array.data.c_contiguous and array.data.f_contiguous:
-        array = array.T
-    array = np.ascontiguousarray(array.reshape(-1))
-    quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
-    chunk_size = _get_chunk_size(len(array), min_chunk_size)
-    num_chunks = (len(array) - 1) // chunk_size + 1
-    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
-
-    jobs = []
-    for i in range(num_chunks):
-        chunk = slice(chunk_size * i, chunk_size * (i + 1))
-        jobs.append(COMPRESSION_EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
-
-    for job in jobs:
-        job.result()
-    return np.quantile(partition_quantiles, quantiles)
-
-
-def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
-    """Adjust chunk_size to minimize imbalance between chunk sizes"""
-    if min_chunk_size >= num_elements:
-        return min_chunk_size
-    leftover_elements = num_elements % min_chunk_size
-    num_chunks = num_elements // min_chunk_size
-    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
-
-
-def _uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
-    offset = UINT8_RANGE // 2
-    shift = tensor.mean()
-    scale = range_in_sigmas * tensor.std().item() / UINT8_RANGE
-
-    quant_weight = torch.quantize_per_tensor(tensor - shift, scale, offset, torch.quint8).int_repr()
-    lookup = average_buckets(tensor, quant_weight, UINT8_RANGE)
-    return quant_weight, lookup
-
-
-def serialize_torch_tensor(
-    tensor: torch.Tensor, compression_type=CompressionType.NONE, allow_inplace=False
-) -> runtime_pb2.Tensor:
-    assert tensor.device == torch.device("cpu")
-    if compression_type == CompressionType.MEANSTD_16BIT:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        means = torch.mean(tensor, dim=-1, keepdim=True)
-        tensor.sub_(means)
-
-        stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
-        stds.clamp_min_(FP32_EPS)
-        tensor.div_(stds)
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = b"".join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="compressed_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type == CompressionType.FLOAT16:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = tensor.numpy().tobytes()
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="clamped_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type == CompressionType.NONE:
-        array = tensor.numpy()
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=array.tobytes(),
-            size=array.shape,
-            dtype=array.dtype.name,
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        assert tensor.dtype == torch.float32
-
-        if compression_type == CompressionType.QUANTILE_8BIT:
-            quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
-        elif compression_type == CompressionType.UNIFORM_8BIT:
-            quantized, lookup = _uint8_uniform_buckets_encode(tensor.detach(), UNIFORM_BUCKETS_STD_RANGE)
-        data = b"".join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="compressed_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    else:
-        raise ValueError(f"Unknown compression type: {compression_type}")
-
-    return proto
-
-
-def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype] = None):
-    """Helper conversion function that handles edge case with scalar deserialization"""
-    if size:
-        return torch.as_tensor(array, dtype=dtype).view(*size)
-    else:
-        return torch.as_tensor(array, dtype=dtype)
-
-
-def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-    if serialized_tensor.compression == CompressionType.NONE:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
-        tensor = construct_torch_tensor(array, serialized_tensor.size)
-
-    elif serialized_tensor.compression == CompressionType.MEANSTD_16BIT:
-        stats_size = list(serialized_tensor.size)
-        stats_size[-1] = 1
-        stats_count = np.prod(stats_size)
-
-        means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count : -NUM_BYTES_FLOAT32 * stats_count]
-        stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count :]
-        means = construct_torch_tensor(np.frombuffer(means, dtype=np.float32), stats_size)
-        stds = construct_torch_tensor(np.frombuffer(stds, dtype=np.float32), stats_size)
-
-        array = np.frombuffer(serialized_tensor.buffer[: -8 * stats_count], dtype=np.float16)
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
-
-    elif serialized_tensor.compression == CompressionType.FLOAT16:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
-
-    elif serialized_tensor.compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        if serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
-            lookup_size = NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32
-        else:
-            lookup_size = UINT8_RANGE * NUM_BYTES_FLOAT32
-        lookup = serialized_tensor.buffer[:lookup_size]
-        quantized = serialized_tensor.buffer[lookup_size:]
-        lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32))
-        quantized = np.frombuffer(quantized, dtype=np.uint8)
-        quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
-        tensor = lookup[quantized]
-
-    else:
-        raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
-
-    tensor.requires_grad_(serialized_tensor.requires_grad)
-    return tensor
-
-
-def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
-    """returns the number of bytes per value for a given tensor (excluding metadata)"""
-    if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        return 1
-    elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):
-        return 2
-    elif compression == CompressionType.NONE:
-        return torch.finfo(dtype).bits // 8
-    else:
-        raise NotImplementedError(f"Unknown compression type: {CompressionType.Name(compression)}")

+ 185 - 17
hivemind/utils/logging.py

@@ -1,22 +1,190 @@
 import logging
 import os
+import sys
+import threading
+from enum import Enum
+from typing import Optional, Union
 
+logging.addLevelName(logging.WARNING, "WARN")
 
-def get_logger(module_name: str) -> logging.Logger:
-    # trim package name
-    name_without_prefix = ".".join(module_name.split(".")[1:])
-    loglevel = os.getenv("LOGLEVEL", "INFO")
-
-    logging.addLevelName(logging.WARNING, "WARN")
-    formatter = logging.Formatter(
-        fmt="[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}",
-        style="{",
-        datefmt="%Y/%m/%d %H:%M:%S",
-    )
-    handler = logging.StreamHandler()
-    handler.setFormatter(formatter)
-    logger = logging.getLogger(name_without_prefix)
-    logger.setLevel(loglevel)
-    logger.addHandler(handler)
+loglevel = os.getenv("HIVEMIND_LOGLEVEL", "INFO")
+
+_env_colors = os.getenv("HIVEMIND_COLORS")
+if _env_colors is not None:
+    use_colors = _env_colors.lower() == "true"
+else:
+    use_colors = sys.stderr.isatty()
+
+_env_log_caller = os.getenv("HIVEMIND_ALWAYS_LOG_CALLER")
+always_log_caller = _env_log_caller is not None and _env_log_caller.lower() == "true"
+
+
+class HandlerMode(Enum):
+    NOWHERE = 0
+    IN_HIVEMIND = 1
+    IN_ROOT_LOGGER = 2
+
+
+_init_lock = threading.RLock()
+_current_mode = HandlerMode.IN_HIVEMIND
+_default_handler = None
+
+
+class TextStyle:
+    """
+    ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
+    """
+
+    RESET = "\033[0m"
+    BOLD = "\033[1m"
+    RED = "\033[31m"
+    BLUE = "\033[34m"
+    PURPLE = "\033[35m"
+    ORANGE = "\033[38;5;208m"  # From 8-bit palette
+
+    if not use_colors:
+        # Set the constants above to empty strings
+        _codes = locals()
+        _codes.update({_name: "" for _name in list(_codes) if _name.isupper()})
+
+
+class CustomFormatter(logging.Formatter):
+    """
+    A formatter that allows a log time and caller info to be overridden via
+    ``logger.log(level, message, extra={"origin_created": ..., "caller": ...})``.
+    """
+
+    # Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
+    _LEVEL_TO_COLOR = {
+        logging.DEBUG: TextStyle.PURPLE,
+        logging.INFO: TextStyle.BLUE,
+        logging.WARNING: TextStyle.ORANGE,
+        logging.ERROR: TextStyle.RED,
+        logging.CRITICAL: TextStyle.RED,
+    }
+
+    def format(self, record: logging.LogRecord) -> str:
+        if hasattr(record, "origin_created"):
+            record.created = record.origin_created
+            record.msecs = (record.created - int(record.created)) * 1000
+
+        if record.levelno != logging.INFO or always_log_caller:
+            if not hasattr(record, "caller"):
+                record.caller = f"{record.name}.{record.funcName}:{record.lineno}"
+            record.caller_block = f" [{TextStyle.BOLD}{record.caller}{TextStyle.RESET}]"
+        else:
+            record.caller_block = ""
+
+        # Aliases for the format argument
+        record.levelcolor = self._LEVEL_TO_COLOR[record.levelno]
+        record.bold = TextStyle.BOLD
+        record.reset = TextStyle.RESET
+
+        return super().format(record)
+
+
+def _initialize_if_necessary():
+    global _current_mode, _default_handler
+
+    with _init_lock:
+        if _default_handler is not None:
+            return
+
+        formatter = CustomFormatter(
+            fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}]{caller_block} {message}",
+            style="{",
+            datefmt="%b %d %H:%M:%S",
+        )
+        _default_handler = logging.StreamHandler()
+        _default_handler.setFormatter(formatter)
+
+        _enable_default_handler("hivemind")
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+    """
+    Same as ``logging.getLogger()`` but ensures that the default hivemind log handler is initialized.
+
+    :note: By default, the hivemind log handler (that reads the ``HIVEMIND_LOGLEVEL`` env variable and uses
+           the colored log formatter) is only applied to messages logged inside the hivemind package.
+           If you want to extend this handler to other loggers in your application, call
+           ``use_hivemind_log_handler("in_root_logger")``.
+    """
+
+    _initialize_if_necessary()
+    return logging.getLogger(name)
+
+
+def _enable_default_handler(name: str) -> None:
+    logger = get_logger(name)
+    logger.addHandler(_default_handler)
     logger.propagate = False
-    return logger
+    logger.setLevel(loglevel)
+
+
+def _disable_default_handler(name: str) -> None:
+    logger = get_logger(name)
+    logger.removeHandler(_default_handler)
+    logger.propagate = True
+    logger.setLevel(logging.NOTSET)
+
+
+def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
+    """
+    Choose loggers where the default hivemind log handler is applied. Options for the ``where`` argument are:
+
+    * "in_hivemind" (default): Use the hivemind log handler in the loggers of the ``hivemind`` package.
+                               Don't propagate their messages to the root logger.
+    * "nowhere": Don't use the hivemind log handler anywhere.
+                 Propagate the ``hivemind`` messages to the root logger.
+    * "in_root_logger": Use the hivemind log handler in the root logger
+                        (that is, in all application loggers until they disable propagation to the root logger).
+                        Propagate the ``hivemind`` messages to the root logger.
+
+    The options may be defined as strings (case-insensitive) or values from the HandlerMode enum.
+    """
+
+    global _current_mode
+
+    if isinstance(where, str):
+        # We allow `where` to be a string, so a developer does not have to import the enum for one usage
+        where = HandlerMode[where.upper()]
+
+    _initialize_if_necessary()
+
+    if where == _current_mode:
+        return
+
+    if _current_mode == HandlerMode.IN_HIVEMIND:
+        _disable_default_handler("hivemind")
+    elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
+        _disable_default_handler(None)
+
+    _current_mode = where
+
+    if _current_mode == HandlerMode.IN_HIVEMIND:
+        _enable_default_handler("hivemind")
+    elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
+        _enable_default_handler(None)
+
+
+def golog_level_to_python(level: str) -> int:
+    level = level.upper()
+    if level in ["DPANIC", "PANIC", "FATAL"]:
+        return logging.CRITICAL
+
+    level = logging.getLevelName(level)
+    if not isinstance(level, int):
+        raise ValueError(f"Unknown go-log level: {level}")
+    return level
+
+
+def python_level_to_golog(level: str) -> str:
+    if not isinstance(level, str):
+        raise ValueError("`level` is expected to be a Python log level in the string form")
+
+    if level == "CRITICAL":
+        return "FATAL"
+    if level == "WARNING":
+        return "WARN"
+    return level

+ 12 - 6
hivemind/utils/mpfuture.py

@@ -18,6 +18,8 @@ from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
 
+torch.multiprocessing.set_sharing_strategy(os.environ.get("HIVEMIND_MEMORY_SHARING_STRATEGY", "file_system"))
+
 # flavour types
 ResultType = TypeVar("ResultType")
 PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection
@@ -53,7 +55,7 @@ class SharedBytes:
         """Create another shared byte value, represented as a scalar uint8 tensor"""
         with cls._lock:
             if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
-                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16))
                 cls._pid = os.getpid()
                 cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
                 cls._index = 0
@@ -138,7 +140,9 @@ class MPFuture(base.Future, Generic[ResultType]):
         async def _event_setter():
             self._aio_event.set()
 
-        if self._loop.is_running() and running_loop == self._loop:
+        if self._loop.is_closed():
+            return  # do nothing, the loop is already closed
+        elif self._loop.is_running() and running_loop == self._loop:
             asyncio.create_task(_event_setter())
         elif self._loop.is_running() and running_loop != self._loop:
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
@@ -180,8 +184,10 @@ class MPFuture(base.Future, Generic[ResultType]):
                     future = future_ref()
 
                 if future is None:
-                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
-                elif update_type == UpdateType.RESULT:
+                    # The MPFuture instance is already destroyed in this process
+                    # (the caller is not interested in the result)
+                    continue
+                if update_type == UpdateType.RESULT:
                     future.set_result(payload)
                 elif update_type == UpdateType.EXCEPTION:
                     future.set_exception(payload)
@@ -199,8 +205,8 @@ class MPFuture(base.Future, Generic[ResultType]):
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
                 self._sender_pipe.send((self._uid, update_type, payload))
-        except (ConnectionError, BrokenPipeError, EOFError) as e:
-            logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
+        except (ConnectionError, BrokenPipeError, EOFError, OSError) as e:
+            logger.debug(f"No updates were sent: pipe to origin process was broken ({e})", exc_info=True)
 
     def set_result(self, result: ResultType):
         if os.getpid() == self._origin_pid:

+ 7 - 2
hivemind/utils/networking.py

@@ -30,8 +30,13 @@ def strip_port(endpoint: Endpoint) -> Hostname:
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
 
 
-def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
+def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+    """
+    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
+
+    :note: Using this function is discouraged since it often leads to a race condition
+           with the "Address is already in use" error if the code is run in parallel.
+    """
     try:
         with closing(socket.socket(*params)) as sock:
             sock.bind(("", 0))

+ 70 - 0
hivemind/utils/performance_ema.py

@@ -0,0 +1,70 @@
+import time
+from contextlib import contextmanager
+from threading import Lock
+from typing import Optional
+
+
+class PerformanceEMA:
+    """
+    A running estimate of performance (operations/sec) using adjusted exponential moving average
+    :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
+    """
+
+    def __init__(self, alpha: float = 0.1, eps: float = 1e-20, paused: bool = False):
+        self.alpha, self.eps, self.num_updates = alpha, eps, 0
+        self.ema_seconds_per_sample, self.samples_per_second = 0, eps
+        self.timestamp = time.perf_counter()
+        self.paused = paused
+        self.lock = Lock()
+
+    def update(self, task_size: float, interval: Optional[float] = None) -> float:
+        """
+        :param task_size: how many items were processed since last call
+        :param interval: optionally provide the time delta it took to process this task
+        :returns: current estimate of performance (samples per second), but at most
+        """
+        assert task_size > 0, f"Can't register processing {task_size} samples"
+        if not self.paused:
+            self.timestamp, old_timestamp = time.perf_counter(), self.timestamp
+            interval = interval if interval is not None else self.timestamp - old_timestamp
+        else:
+            assert interval is not None, "If PerformanceEMA is paused, please specify the time interval"
+        self.ema_seconds_per_sample = (
+            self.alpha * interval / task_size + (1 - self.alpha) * self.ema_seconds_per_sample
+        )
+        self.num_updates += 1
+        adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
+        self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
+        return self.samples_per_second
+
+    def reset_timer(self):
+        """Reset the time since the last update so that the next task performance is counted from current time"""
+        self.timestamp = time.perf_counter()
+
+    @contextmanager
+    def pause(self):
+        """While inside this context, EMA will not count the time passed towards the performance estimate"""
+        self.paused, was_paused = True, self.paused
+        try:
+            yield
+        finally:
+            self.paused = was_paused
+            self.reset_timer()
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
+
+    @contextmanager
+    def update_threadsafe(self, task_size: float):
+        """
+        Update the EMA throughput of a code that runs inside the context manager, supports multiple concurrent threads.
+
+        :param task_size: how many items were processed since last call
+        """
+        start_timestamp = time.perf_counter()
+        yield
+        with self.lock:
+            self.update(task_size, interval=time.perf_counter() - max(start_timestamp, self.timestamp))
+            # note: we define interval as such to support two distinct scenarios:
+            # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
+            # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls

+ 2 - 2
hivemind/utils/serializer.py

@@ -35,7 +35,7 @@ class MSGPackSerializer(SerializerBase):
                 getattr(wrapped_type, "unpackb", None)
             ), f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
             if type_code in cls._ext_type_codes:
-                logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting.")
+                logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting")
             cls._ext_type_codes[type_code], cls._ext_types[wrapped_type] = wrapped_type, type_code
             return wrapped_type
 
@@ -60,7 +60,7 @@ class MSGPackSerializer(SerializerBase):
         elif type_code == cls._TUPLE_EXT_TYPE_CODE:
             return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False))
 
-        logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
+        logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is")
         return data
 
     @classmethod

+ 64 - 8
hivemind/utils/tensor_descr.py

@@ -1,9 +1,14 @@
+from __future__ import annotations
+
 import warnings
 from dataclasses import asdict, dataclass
+from typing import Tuple
 
+import numpy as np
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.serializer import MSGPackSerializer
 
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
@@ -29,22 +34,37 @@ class TensorDescriptor(DescriptorBase):
     compression: CompressionType = CompressionType.NONE
 
     @property
-    def shape(self):
+    def shape(self) -> Tuple[int, ...]:
         return self.size
 
+    def numel(self) -> int:
+        return int(np.prod(self.size))
+
     @classmethod
-    def from_tensor(cls, tensor: torch.Tensor):
+    def from_tensor(cls, tensor: torch.Tensor) -> TensorDescriptor:
         return cls(
             tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
         )
 
-    def make_empty(self, **kwargs):
+    def make_zeros(self, **kwargs):
         properties = asdict(self)
         properties.update(kwargs)
         properties.pop("compression")
-        return torch.empty(**properties)
+        return torch.zeros(**properties)
+
+
+def _str_to_torch_type(name: str, torch_type: type):
+    try:
+        value = getattr(torch, name.split(".")[-1])
+    except AttributeError:
+        raise ValueError(f"Invalid dtype: torch has no attribute {name}")
+    if not isinstance(value, torch_type):
+        raise ValueError(f"Invalid dtype: expected {torch_type}, got: {type(value)}")
 
+    return value
 
+
+@MSGPackSerializer.ext_serializable(0x51)
 @dataclass(repr=True, frozen=True)
 class BatchTensorDescriptor(TensorDescriptor):
     """torch.Tensor with a variable 0-th dimension, used to describe batched data"""
@@ -55,7 +75,7 @@ class BatchTensorDescriptor(TensorDescriptor):
         super().__init__((None, *instance_size), **kwargs)
 
     @classmethod
-    def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
+    def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE) -> BatchTensorDescriptor:
         return cls(
             *tensor.shape[1:],
             dtype=tensor.dtype,
@@ -63,12 +83,48 @@ class BatchTensorDescriptor(TensorDescriptor):
             device=tensor.device,
             requires_grad=tensor.requires_grad,
             pin_memory=_safe_check_pinned(tensor),
-            compression=compression if tensor.is_floating_point() else CompressionType.NONE
+            compression=compression if tensor.is_floating_point() else CompressionType.NONE,
         )
 
-    def make_empty(self, *batch_size, **kwargs):
+    def make_zeros(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
-        return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
+        return super().make_zeros(size=(*batch_size, *self.shape[1:]), **kwargs)
+
+    def packb(self) -> bytes:
+        obj_dict = asdict(self)
+
+        obj_dict["dtype"] = str(self.dtype) if self.dtype is not None else None
+        obj_dict["layout"] = str(self.layout) if self.layout is not None else None
+
+        device = obj_dict.pop("device")
+        device_type, device_index = (device.type, device.index) if device is not None else (None, None)
+        obj_dict.update(
+            device_type=device_type,
+            device_index=device_index,
+        )
+
+        return MSGPackSerializer.dumps(obj_dict)
+
+    @classmethod
+    def unpackb(cls, raw: bytes) -> BatchTensorDescriptor:
+        obj_dict = MSGPackSerializer.loads(raw)
+
+        if obj_dict["dtype"] is not None:
+            obj_dict["dtype"] = _str_to_torch_type(obj_dict["dtype"], torch.dtype)
+
+        if obj_dict["layout"] is not None:
+            obj_dict["layout"] = _str_to_torch_type(obj_dict["layout"], torch.layout)
+
+        if obj_dict["device_type"] is not None:
+            obj_dict["device"] = torch.device(obj_dict["device_type"], obj_dict["device_index"])
+        else:
+            obj_dict["device"] = None
+
+        del obj_dict["device_type"], obj_dict["device_index"]
+
+        size = obj_dict.pop("size")[1:]
+
+        return BatchTensorDescriptor(*size, **obj_dict)
 
 
 def _safe_check_pinned(tensor: torch.Tensor) -> bool:

+ 2 - 1
requirements-dev.txt

@@ -1,9 +1,10 @@
 pytest
 pytest-forked
-pytest-asyncio
+pytest-asyncio==0.16.0
 pytest-cov
 tqdm
 scikit-learn
+torchvision
 black==21.6b0
 isort
 psutil

+ 4 - 2
requirements-docs.txt

@@ -1,2 +1,4 @@
-recommonmark
-sphinx_rtd_theme
+recommonmark==0.5.0
+sphinx_rtd_theme==0.4.3
+docutils==0.16
+sphinx==4.2.0

+ 5 - 5
setup.py

@@ -14,9 +14,10 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.1"
-P2PD_CHECKSUM = "15292b880c6b31f5b3c36084b3acc17f"
+P2PD_VERSION = "v0.3.6"
+P2PD_CHECKSUM = "627d0c3b475a29331fdfd1667e828f6d"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
+P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
 
 here = os.path.abspath(os.path.dirname(__file__))
 
@@ -85,11 +86,10 @@ def download_p2p_daemon():
     binary_path = os.path.join(install_path, "p2pd")
     if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
         print("Downloading Peer to Peer Daemon")
-        url = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
-        urllib.request.urlretrieve(url, binary_path)
+        urllib.request.urlretrieve(P2PD_BINARY_URL, binary_path)
         os.chmod(binary_path, 0o777)
         if md5(binary_path) != P2PD_CHECKSUM:
-            raise RuntimeError(f"Downloaded p2pd binary from {url} does not match with md5 checksum")
+            raise RuntimeError(f"Downloaded p2pd binary from {P2PD_BINARY_URL} does not match with md5 checksum")
 
 
 class BuildPy(build_py):

+ 24 - 3
tests/conftest.py

@@ -1,20 +1,41 @@
+import asyncio
 import gc
-import multiprocessing as mp
 from contextlib import suppress
 
 import psutil
 import pytest
 
-from hivemind.utils.logging import get_logger
-from hivemind.utils.mpfuture import MPFuture, SharedBytes
+from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.mpfuture import MPFuture
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
+@pytest.fixture
+def event_loop():
+    """
+    This overrides the ``event_loop`` fixture from pytest-asyncio
+    (e.g. to make it compatible with ``asyncio.subprocess``).
+
+    This fixture is identical to the original one but does not call ``loop.close()`` in the end.
+    Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops).
+    However, finalizers of objects created in the current test may reference the current loop and fail if it is closed.
+    For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer
+    fails if the loop is closed, but works if the loop is only stopped).
+    """
+
+    yield asyncio.get_event_loop()
+
+
 @pytest.fixture(autouse=True, scope="session")
 def cleanup_children():
     yield
 
+    with RSAPrivateKey._process_wide_key_lock:
+        RSAPrivateKey._process_wide_key = None
+
     gc.collect()  # Call .__del__() for removed objects
 
     children = psutil.Process().children(recursive=True)

+ 5 - 5
tests/test_allreduce.py

@@ -6,12 +6,12 @@ from typing import Sequence
 import pytest
 import torch
 
-from hivemind import aenumerate
+from hivemind import Quantile8BitQuantization, aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
+from hivemind.compression import deserialize_torch_tensor
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import deserialize_torch_tensor
 
 
 @pytest.mark.forked
@@ -83,7 +83,7 @@ async def test_partitioning_asynchronous():
     tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096), torch.randn(4096, 1024), torch.randn(30_000, 1024)]
     peer_fractions = [0.4, 0.3, 0.2, 0.1]
 
-    partition = TensorPartContainer(tensors, peer_fractions, compression_type=CompressionType.QUANTILE_8BIT)
+    partition = TensorPartContainer(tensors, peer_fractions, compression=Quantile8BitQuantization())
     read_started, read_finished = asyncio.Event(), asyncio.Event()
 
     async def write_tensors():
@@ -187,7 +187,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
     allreduce_protocols = []
-    for p2p in p2ps:
+    for i, p2p in enumerate(p2ps):
         allreduce_protocol = AllReduceRunner(
             p2p=p2p,
             servicer_type=AllReduceRunner,
@@ -197,7 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             modes=peer_modes,
-            weights=averaging_weights,
+            weight=averaging_weights[i],
             part_size_bytes=part_size_bytes,
         )
         await allreduce_protocol.add_p2p_handlers(p2p)

+ 213 - 0
tests/test_allreduce_fault_tolerance.py

@@ -0,0 +1,213 @@
+from __future__ import annotations
+
+import asyncio
+from enum import Enum, auto
+from typing import AsyncIterator
+
+import pytest
+import torch
+
+import hivemind
+from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.averaging.averager import *
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import MatchmakingException
+from hivemind.proto import averaging_pb2
+from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class Fault(Enum):
+    NONE = auto()
+    FAIL_BEFORE = auto()
+    FAIL_SENDING = auto()
+    SLOW_SENDING = auto()
+    FAIL_REDUCING = auto()
+    SLOW_REDUCING = auto()
+    CANCEL = auto()
+
+
+class FaultyAverager(hivemind.DecentralizedAverager):
+    def __init__(self, *args, fault: Fault = Fault.NONE, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            if self.fault == Fault.FAIL_BEFORE:
+                raise Exception("Oops, I failed!")
+
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
+                allreduce = FaultyAllReduceRunner(
+                    p2p=self._p2p,
+                    servicer_type=type(self),
+                    prefix=self.prefix,
+                    group_id=group_info.group_id,
+                    tensors=local_tensors,
+                    ordered_peer_ids=group_info.peer_ids,
+                    peer_fractions=peer_fractions,
+                    gathered=user_gathered,
+                    modes=modes,
+                    fault=self.fault,
+                    **kwargs,
+                )
+
+                with self.register_allreduce_group(group_info.group_id, allreduce):
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                            # all-reduce is performed asynchronously while iterating
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self._state_updated.set()
+
+                    else:
+                        async for _ in allreduce:  # trigger all-reduce by iterating
+                            raise ValueError("aux peers should not receive averaged tensors")
+
+                return allreduce.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+
+class FaultyAllReduceRunner(AllReduceRunner):
+    def __init__(self, *args, fault: Fault, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def rpc_aggregate_part(self, stream, context) -> AsyncIterator[averaging_pb2.AveragingData]:
+        if self.fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING):
+            async for i, message in aenumerate(super().rpc_aggregate_part(stream, context)):
+                yield message
+                if i == 2:
+                    if self.fault == Fault.FAIL_SENDING:
+                        yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                        break
+                    else:
+                        await asyncio.sleep(10)
+
+        elif self.fault == Fault.CANCEL:
+            yield averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        else:
+            async for message in super().rpc_aggregate_part(stream, context):
+                yield message
+
+    async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
+        parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
+
+        first_part = await anext(parts_aiter)
+        yield averaging_pb2.AveragingData(
+            code=averaging_pb2.PART_FOR_AVERAGING,
+            group_id=self.group_id,
+            tensor_part=first_part,
+            weight=self.weight,
+        )
+        if self.fault in (Fault.FAIL_SENDING, Fault.SLOW_SENDING):
+            last_reducer_index = self.group_size - 1 - (self.tensor_part_container.num_parts_by_peer[-1] == 0)
+            if peer_index == last_reducer_index:
+                if self.fault == Fault.FAIL_SENDING:
+                    raise Exception("Oops, I failed!")
+                else:
+                    await asyncio.sleep(10)
+        async for part in parts_aiter:
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "fault0, fault1",
+    [
+        (Fault.NONE, Fault.FAIL_BEFORE),
+        (Fault.FAIL_BEFORE, Fault.FAIL_BEFORE),
+        (Fault.SLOW_SENDING, Fault.FAIL_SENDING),
+        (Fault.FAIL_SENDING, Fault.FAIL_BEFORE),
+        (Fault.SLOW_REDUCING, Fault.FAIL_SENDING),
+        (Fault.FAIL_REDUCING, Fault.FAIL_REDUCING),
+        (Fault.NONE, Fault.CANCEL),
+    ],
+)
+def test_fault_tolerance(fault0: Fault, fault1: Fault):
+    def _make_tensors():
+        return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]
+
+    dht = hivemind.DHT(start=True)
+
+    averagers = []
+    for i in range(5):
+        averager = FaultyAverager(
+            _make_tensors(),
+            hivemind.DHT(initial_peers=dht.get_visible_maddrs(), start=True),
+            prefix="test",
+            request_timeout=0.3,
+            min_matchmaking_time=1.0,
+            next_chunk_timeout=0.5,
+            allreduce_timeout=5,
+            part_size_bytes=2 ** 16,
+            client_mode=(i == 1),
+            start=True,
+            fault=fault0 if i == 0 else fault1 if i == 1 else Fault.NONE,
+        )
+        averagers.append(averager)
+
+    ref_numerators = [0, 0, 0, 0]
+    ref_denominator = 0
+
+    for averager in averagers:
+        if averager.fault not in (Fault.FAIL_BEFORE, Fault.CANCEL):
+            with averager.get_tensors() as tensors:
+                for i, tensor in enumerate(tensors):
+                    ref_numerators[i] = ref_numerators[i] + tensor.clone()
+                ref_denominator += 1
+
+    ref_tensors = [ref_numerator / ref_denominator for ref_numerator in ref_numerators]
+    flat_ref = torch.cat(list(map(torch.flatten, ref_tensors)))
+
+    flat_local_tensors = []
+    for averager in averagers:
+        with averager.get_tensors() as tensors:
+            flat_local_tensors.append(torch.cat(list(map(torch.flatten, tensors))))
+
+    futures = [averager.step(timeout=5, wait=False, allow_retries=False) for averager in averagers]
+    for i, averager in enumerate(averagers):
+        if averager.fault == Fault.CANCEL:
+            futures[i].cancel()
+
+    for future in futures[2:]:
+        assert future.result()
+
+    for averager, prev_local_tensors in zip(averagers[2:], flat_local_tensors[2:]):
+        with averager.get_tensors() as tensors:
+            flat_tensors = torch.cat(list(map(torch.flatten, tensors)))
+
+        diff_with_reference = abs(flat_ref - flat_tensors)
+
+        if all(fault == (Fault.FAIL_SENDING, Fault.SLOW_SENDING) for fault in (fault0, fault1)):
+            assert fault0 != Fault.FAIL_REDUCING and fault1 != Fault.FAIL_REDUCING
+            assert diff_with_reference[: len(diff_with_reference) // 2].max() < 1e-5
+        elif all(fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING) for fault in (fault0, fault1)):
+            diff_to_reference = abs(flat_ref - flat_tensors)
+            diff_to_local = abs(prev_local_tensors - flat_tensors)
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+            assert torch.all(torch.minimum(diff_to_reference, diff_to_local) < 1e-5).item()
+        elif any(fault == Fault.CANCEL for fault in (fault0, fault1)):
+            pass  # late cancel may result in an arbitrary mix of averaging results with and without the cancelled peer
+        elif fault0 == Fault.NONE:  # only peer1 in client mode may have failed
+            assert diff_with_reference.max() < 1e-5
+        else:
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+
+    for averager in averagers:
+        averager.shutdown()

+ 129 - 61
tests/test_averaging.py

@@ -8,10 +8,11 @@ import torch
 import hivemind
 import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.partition import AllreduceException
 from hivemind.p2p import PeerID
-from hivemind.proto.runtime_pb2 import CompressionType
 
 from test_utils.dht_swarms import launch_dht_instances
 
@@ -168,61 +169,6 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
         process.shutdown()
 
 
-@pytest.mark.forked
-def test_allreduce_compression():
-    """this test ensures that compression works correctly when multiple tensors have different compression types"""
-
-    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
-    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
-    results = {}
-
-    FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
-
-    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
-        dht_instances = launch_dht_instances(2)
-        averager1 = hivemind.averaging.DecentralizedAverager(
-            [x.clone() for x in tensors1],
-            dht=dht_instances[0],
-            compression_type=compression_type_pair,
-            client_mode=True,
-            target_group_size=2,
-            prefix="mygroup",
-            start=True,
-        )
-        averager2 = hivemind.averaging.DecentralizedAverager(
-            [x.clone() for x in tensors2],
-            dht=dht_instances[1],
-            compression_type=compression_type_pair,
-            target_group_size=2,
-            prefix="mygroup",
-            start=True,
-        )
-
-        for future in averager1.step(wait=False), averager2.step(wait=False):
-            future.result()
-
-        with averager1.get_tensors() as averaged_tensors:
-            results[compression_type_pair] = averaged_tensors
-
-        for instance in [averager1, averager2] + dht_instances:
-            instance.shutdown()
-
-    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
-    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
-    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
-    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
-
-    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
-    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
-    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
-    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
-
-    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
-    for i in range(2):
-        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
-        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
-
-
 def compute_mean_std(averagers, unbiased=True):
     results = []
     for averager in averagers:
@@ -363,9 +309,11 @@ def test_too_few_peers():
         )
         for i, dht in enumerate(dht_instances)
     ]
-    step_futures = [averager.step(wait=False) for averager in averagers]
+    step_futures = [averager.step(wait=False, timeout=2) for averager in averagers]
+
     for future in step_futures:
-        assert len(future.result()) == 2
+        with pytest.raises(AllreduceException):
+            future.result()
 
     for process in averagers + dht_instances:
         process.shutdown()
@@ -424,7 +372,6 @@ def test_load_state_from_peers():
         target_group_size=2,
     )
 
-    dht_instances[1].get("demo-run.all_averagers")
     averager2 = TestAverager(
         [torch.randn(3), torch.rand(5)],
         dht=dht_instances[1],
@@ -433,6 +380,8 @@ def test_load_state_from_peers():
         target_group_size=2,
     )
 
+    time.sleep(0.5)
+
     assert num_calls == 0
     got_metadata, got_tensors = averager2.load_state_from_peers()
     assert num_calls == 1
@@ -451,7 +400,9 @@ def test_load_state_from_peers():
 
     averager1.allow_state_sharing = False
     assert averager2.load_state_from_peers() is None
+
     averager1.allow_state_sharing = True
+    time.sleep(0.5)
     got_metadata, got_tensors = averager2.load_state_from_peers()
     assert num_calls == 3
     assert got_metadata == super_metadata
@@ -460,6 +411,47 @@ def test_load_state_from_peers():
         instance.shutdown()
 
 
+@pytest.mark.forked
+def test_load_state_priority():
+    dht_instances = launch_dht_instances(4)
+
+    averagers = []
+    for i in range(4):
+        averager = hivemind.DecentralizedAverager(
+            [torch.randn(3), torch.rand(5), torch.tensor([i], dtype=torch.float32)],
+            dht=dht_instances[i],
+            start=True,
+            prefix="demo-run",
+            target_group_size=2,
+            allow_state_sharing=i != 1,
+        )
+        averager.state_sharing_priority = 5 - abs(2 - i)
+        averagers.append(averager)
+
+    time.sleep(0.5)
+    metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 2
+
+    metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 3
+
+    averagers[0].state_sharing_priority = 10
+    time.sleep(0.2)
+
+    metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 0
+
+    averagers[1].allow_state_sharing = False
+    averagers[2].allow_state_sharing = False
+    metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 3
+
+    for averager in averagers:
+        averager.shutdown()
+    for dht in dht_instances:
+        dht.shutdown()
+
+
 @pytest.mark.forked
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
@@ -474,6 +466,82 @@ def test_getset_bits():
     assert averager.get_group_bits() == "00101011101010"
 
 
+@pytest.mark.forked
+def test_averaging_trigger():
+    averagers = tuple(
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            min_matchmaking_time=0.5,
+            request_timeout=0.3,
+            prefix="mygroup",
+            initial_group_bits="",
+            start=True,
+        )
+        for dht in launch_dht_instances(4)
+    )
+
+    controls = []
+    for i, averager in enumerate(averagers):
+        controls.append(
+            averager.step(
+                wait=False,
+                scheduled_time=hivemind.get_dht_time() + 0.5,
+                weight=1.0,
+                require_trigger=i in (1, 2),
+            )
+        )
+
+    time.sleep(0.6)
+
+    c0, c1, c2, c3 = controls
+    assert not any(c.done() for c in controls)
+    assert c0.stage == AveragingStage.RUNNING_ALLREDUCE
+    assert c1.stage == AveragingStage.AWAITING_TRIGGER
+    assert c2.stage == AveragingStage.AWAITING_TRIGGER
+    assert c3.stage == AveragingStage.RUNNING_ALLREDUCE
+
+    c1.allow_allreduce()
+    c2.allow_allreduce()
+    time.sleep(0.5)
+    assert all(c.stage == AveragingStage.FINISHED for c in controls)
+    assert all(c.done() for c in controls)
+
+    # check that setting trigger twice does not raise error
+    c0.allow_allreduce()
+
+
+@pytest.mark.forked
+def test_averaging_cancel():
+    averagers = tuple(
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            min_matchmaking_time=0.5,
+            request_timeout=0.3,
+            client_mode=(i % 2 == 0),
+            prefix="mygroup",
+            start=True,
+        )
+        for i, dht in enumerate(launch_dht_instances(4))
+    )
+
+    step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]
+
+    time.sleep(0.1)
+    step_controls[0].cancel()
+    step_controls[1].cancel()
+
+    for i, control in enumerate(step_controls):
+        if i in (0, 1):
+            assert control.cancelled()
+        else:
+            assert control.result() is not None and len(control.result()) == 2
+
+    for averager in averagers:
+        averager.shutdown()
+
+
 @pytest.mark.forked
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)
@@ -487,7 +555,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.averaging.TrainingAverager(
+    averager1 = hivemind.TrainingAverager(
         opt1,
         average_gradients=True,
         average_parameters=True,
@@ -498,7 +566,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.averaging.TrainingAverager(
+    averager2 = hivemind.TrainingAverager(
         opt2,
         average_gradients=True,
         average_parameters=True,

+ 213 - 0
tests/test_compression.py

@@ -0,0 +1,213 @@
+import multiprocessing as mp
+from ctypes import c_int32
+
+import pytest
+import torch
+import torch.nn as nn
+
+import hivemind
+from hivemind.compression import (
+    CompressionBase,
+    CompressionInfo,
+    Float16Compression,
+    NoCompression,
+    PerTensorCompression,
+    RoleAdaptiveCompression,
+    SizeAdaptiveCompression,
+    Uniform8BitQuantization,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)
+from hivemind.compression.adaptive import AdaptiveCompressionBase
+from hivemind.proto.runtime_pb2 import CompressionType
+
+from test_utils.dht_swarms import launch_dht_instances
+
+
+@pytest.mark.forked
+def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
+    torch.manual_seed(0)
+    X = torch.randn(*size)
+    assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
+    assert error.square().mean() < alpha
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
+    assert error.square().mean() < alpha
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
+    assert error.square().mean() < beta
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
+    assert error.square().mean() < beta
+
+    zeros = torch.zeros(5, 5)
+    for compression_type in CompressionType.values():
+        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
+
+
+@pytest.mark.forked
+def test_serialize_tensor():
+    def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
+        serialized_tensor = serialize_torch_tensor(tensor, compression)
+        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+        restored = hivemind.combine_from_streaming(chunks)
+        assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
+
+    tensor = torch.randn(512, 12288)
+    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
+        _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
+
+    _check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
+    _check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
+    _check(torch.tensor(1.0), CompressionType.NONE)
+    _check(torch.tensor(1.0), CompressionType.FLOAT16)
+
+
+@pytest.mark.forked
+def test_allreduce_compression():
+    """this test ensures that compression works correctly when multiple tensors have different compression types"""
+
+    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
+    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
+    results = {}
+
+    FLOAT16, UINT8 = Float16Compression(), Uniform8BitQuantization()
+
+    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        dht_instances = launch_dht_instances(2)
+        averager1 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors1],
+            dht=dht_instances[0],
+            compression=PerTensorCompression(compression_type_pair),
+            client_mode=True,
+            target_group_size=2,
+            prefix="mygroup",
+            start=True,
+        )
+        averager2 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors2],
+            dht=dht_instances[1],
+            compression=PerTensorCompression(compression_type_pair),
+            target_group_size=2,
+            prefix="mygroup",
+            start=True,
+        )
+
+        for future in averager1.step(wait=False), averager2.step(wait=False):
+            future.result()
+
+        with averager1.get_tensors() as averaged_tensors:
+            results[compression_type_pair] = averaged_tensors
+
+        for instance in [averager1, averager2] + dht_instances:
+            instance.shutdown()
+
+    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
+    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
+    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
+    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
+
+    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
+    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
+    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
+    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
+
+    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
+    for i in range(2):
+        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
+        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
+
+
+class TrackedCompression(AdaptiveCompressionBase):
+    def __init__(self, compression: CompressionBase):
+        self.compression = compression
+        self.mp_counter, self.mp_part_size = mp.Value(c_int32, 0), mp.Value(c_int32, 0)
+        super().__init__()
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.compression
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False):
+        self.mp_counter.value += 1
+        if info.part_size is not None:
+            self.mp_part_size.value = max(self.mp_part_size.value, info.part_size)
+        return self.compression.compress(tensor, info=info, allow_inplace=allow_inplace)
+
+
+def make_params():
+    return [
+        nn.Parameter(x)
+        for x in (
+            torch.randn([]),
+            torch.randn(1),
+            torch.randn(100),
+            torch.randn(1_000),
+            torch.randn(5_000),
+            torch.randn(10_000),
+        )
+    ]
+
+
+@pytest.mark.forked
+def test_adaptive_compression():
+    UINT8 = TrackedCompression(Uniform8BitQuantization())
+    FLOAT16 = TrackedCompression(Float16Compression())
+    FLOAT32 = TrackedCompression(NoCompression())
+    STATE_FP16 = TrackedCompression(Float16Compression())
+    STATE_FP32 = TrackedCompression(NoCompression())
+
+    averaging_compression_adaptive = RoleAdaptiveCompression(
+        parameter=FLOAT16,
+        gradient=SizeAdaptiveCompression(threshold=1_000, less=FLOAT16, greater_equal=UINT8),
+        optimizer=FLOAT32,
+        default=FLOAT32,
+    )
+
+    state_compression_adaptive = SizeAdaptiveCompression(
+        threshold=500,
+        less=STATE_FP32,
+        greater_equal=STATE_FP16,
+    )
+
+    averager1 = hivemind.TrainingAverager(
+        opt=torch.optim.Adam(make_params()),
+        average_parameters=True,
+        average_gradients=True,
+        average_opt_statistics=("exp_avg",),
+        compression=averaging_compression_adaptive,
+        state_compression=state_compression_adaptive,
+        prefix="test_avgr",
+        target_group_size=2,
+        part_size_bytes=5_000,
+        start=True,
+        dht=hivemind.DHT(start=True),
+    )
+
+    averager2 = hivemind.TrainingAverager(
+        opt=torch.optim.Adam(make_params()),
+        average_parameters=True,
+        average_gradients=True,
+        average_opt_statistics=("exp_avg",),
+        compression=averaging_compression_adaptive,
+        state_compression=state_compression_adaptive,
+        prefix="test_avgr",
+        target_group_size=2,
+        part_size_bytes=5_000,
+        start=True,
+        dht=hivemind.DHT(initial_peers=averager1.dht.get_visible_maddrs(), start=True),
+    )
+
+    futures = [averager1.step(wait=False), averager2.step(wait=False)]
+
+    for future in futures:
+        future.result()
+
+    assert UINT8.mp_counter.value == 4  # half gradients: 3 tensors, 1 is split
+    assert UINT8.mp_part_size.value == 5_000  # single byte tensors
+    assert FLOAT16.mp_counter.value == 13  # parameters and half gradients
+    assert FLOAT16.mp_part_size.value == 2_500  # two-byte tensors
+    assert FLOAT32.mp_counter.value == 16  # statistics
+    assert FLOAT32.mp_part_size.value == 1250  # four-byte tensors
+
+    averager1.load_state_from_peers()
+    assert STATE_FP16.mp_counter.value == STATE_FP32.mp_counter.value == 9
+    assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0  # not partitioned

+ 17 - 1
tests/test_dht.py

@@ -1,4 +1,5 @@
 import asyncio
+import concurrent.futures
 import random
 import time
 
@@ -6,10 +7,25 @@ import pytest
 from multiaddr import Multiaddr
 
 import hivemind
+from hivemind.utils.networking import get_free_port
 
 from test_utils.dht_swarms import launch_dht_instances
 
 
+@pytest.mark.asyncio
+async def test_startup_error():
+    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
+        hivemind.DHT(
+            initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
+            start=True,
+        )
+
+    dht = hivemind.DHT(start=True, await_ready=False)
+    with pytest.raises(concurrent.futures.TimeoutError):
+        dht.wait_until_ready(timeout=0.01)
+    dht.shutdown()
+
+
 @pytest.mark.forked
 def test_get_store(n_peers=10):
     peers = launch_dht_instances(n_peers)
@@ -102,7 +118,7 @@ async def test_dht_get_visible_maddrs():
 
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
-    dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
+    dht = hivemind.DHT(start=True, p2p=p2p)
 
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()

+ 34 - 215
tests/test_dht_node.py

@@ -1,200 +1,27 @@
 import asyncio
 import heapq
-import multiprocessing as mp
 import random
-import signal
 from itertools import product
-from typing import List, Sequence, Tuple
 
 import numpy as np
 import pytest
-from multiaddr import Multiaddr
 
 import hivemind
 from hivemind import get_dht_time
 from hivemind.dht.node import DHTID, DHTNode
-from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.storage import DictionaryDHTValue
-from hivemind.p2p import P2P, PeerID
 from hivemind.utils.logging import get_logger
 
 from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
 
 logger = get_logger(__name__)
 
-
-def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
-    return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
-
-
-def run_protocol_listener(
-    dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
-) -> None:
-    loop = asyncio.get_event_loop()
-
-    p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
-    visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
-
-    protocol = loop.run_until_complete(
-        DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
-    )
-
-    logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
-
-    for peer_id in maddrs_to_peer_ids(initial_peers):
-        loop.run_until_complete(protocol.call_ping(peer_id))
-
-    maddr_conn.send((p2p.peer_id, visible_maddrs))
-
-    async def shutdown():
-        await p2p.shutdown()
-        logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
-        loop.stop()
-
-    loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
-    loop.run_forever()
-
-
-def launch_protocol_listener(
-    initial_peers: Sequence[Multiaddr] = (),
-) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
-    remote_conn, local_conn = mp.Pipe()
-    dht_id = DHTID.generate()
-    process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
-    process.start()
-    peer_id, visible_maddrs = local_conn.recv()
-
-    return dht_id, process, peer_id, visible_maddrs
-
-
 # note: we run network-related tests in a separate process to re-initialize all global states from scratch
 # this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
 
 
 @pytest.mark.forked
-def test_dht_protocol():
-    peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
-    peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
-
-    loop = asyncio.get_event_loop()
-    for client_mode in [True, False]:  # note: order matters, this test assumes that first run uses client mode
-        peer_id = DHTID.generate()
-        p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
-        protocol = loop.run_until_complete(
-            DHTProtocol.create(
-                p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
-            )
-        )
-        logger.info(f"Self id={protocol.node_id}")
-
-        assert loop.run_until_complete(protocol.call_ping(peer1_id)) == peer1_node_id
-
-        key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
-        store_ok = loop.run_until_complete(
-            protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )
-        assert all(store_ok), "DHT rejected a trivial store"
-
-        # peer 1 must know about peer 2
-        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [key])
-        )[key]
-        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-        (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
-        assert (
-            recv_id == peer2_node_id and recv_peer_id == peer2_id
-        ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
-
-        assert recv_value == value and recv_expiration == expiration, (
-            f"call_find_value expected {value} (expires by {expiration}) "
-            f"but got {recv_value} (expires by {recv_expiration})"
-        )
-
-        # peer 2 must know about peer 1, but not have a *random* nonexistent value
-        dummy_key = DHTID.generate()
-        empty_item, nodes_found_2 = loop.run_until_complete(protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
-        assert empty_item is None, "Non-existent keys shouldn't have values"
-        (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
-        assert (
-            recv_id == peer1_node_id and recv_peer_id == peer1_id
-        ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
-
-        # cause a non-response by querying a nonexistent peer
-        assert loop.run_until_complete(protocol.call_find(PeerID.from_base58("fakeid"), [key])) is None
-
-        # store/get a dictionary with sub-keys
-        nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
-        value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
-        assert loop.run_until_complete(
-            protocol.call_store(
-                peer1_id,
-                keys=[nested_key],
-                values=[hivemind.MSGPackSerializer.dumps(value1)],
-                expiration_time=[expiration],
-                subkeys=[subkey1],
-            )
-        )
-        assert loop.run_until_complete(
-            protocol.call_store(
-                peer1_id,
-                keys=[nested_key],
-                values=[hivemind.MSGPackSerializer.dumps(value2)],
-                expiration_time=[expiration + 5],
-                subkeys=[subkey2],
-            )
-        )
-        (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [nested_key])
-        )[nested_key]
-        assert isinstance(recv_dict, DictionaryDHTValue)
-        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
-        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
-        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
-
-        if not client_mode:
-            loop.run_until_complete(p2p.shutdown())
-
-    peer1_proc.terminate()
-    peer2_proc.terminate()
-
-
-@pytest.mark.forked
-def test_empty_table():
-    """Test RPC methods with empty routing table"""
-    peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
-
-    loop = asyncio.get_event_loop()
-    p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
-    protocol = loop.run_until_complete(
-        DHTProtocol.create(
-            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
-        )
-    )
-
-    key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
-
-    empty_item, nodes_found = loop.run_until_complete(protocol.call_find(peer_peer_id, [key]))[key]
-    assert empty_item is None and len(nodes_found) == 0
-    assert all(
-        loop.run_until_complete(
-            protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )
-    ), "peer rejected store"
-
-    (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-        protocol.call_find(peer_peer_id, [key])
-    )[key]
-    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-    assert len(nodes_found) == 0
-    assert recv_value == value and recv_expiration == expiration
-
-    assert loop.run_until_complete(protocol.call_ping(peer_peer_id)) == peer_id
-    assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58("fakeid"))) is None
-    peer_proc.terminate()
-
-
-@pytest.mark.forked
-def test_dht_node(
+@pytest.mark.asyncio
+async def test_dht_node(
     n_peers: int = 20, n_sequential_peers: int = 5, parallel_rpc: int = 10, bucket_size: int = 5, num_replicas: int = 3
 ):
     # step A: create a swarm of 50 dht nodes in separate processes
@@ -205,26 +32,23 @@ def test_dht_node(
     )
 
     # step B: run 51-st node in this process
-    loop = asyncio.get_event_loop()
     initial_peers = random.choice(swarm_maddrs)
-    me = loop.run_until_complete(
-        DHTNode.create(
-            initial_peers=initial_peers,
-            parallel_rpc=parallel_rpc,
-            bucket_size=bucket_size,
-            num_replicas=num_replicas,
-            cache_refresh_before_expiry=False,
-        )
+    me = await DHTNode.create(
+        initial_peers=initial_peers,
+        parallel_rpc=parallel_rpc,
+        bucket_size=bucket_size,
+        num_replicas=num_replicas,
+        cache_refresh_before_expiry=False,
     )
 
     # test 1: find self
-    nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
+    nearest = (await me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
     assert len(nearest) == 1 and nearest[me.node_id] == me.peer_id
 
     # test 2: find others
     for _ in range(10):
         ref_peer_id, query_id = random.choice(list(dht.items()))
-        nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
+        nearest = (await me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
         assert len(nearest) == 1
         found_node_id, found_peer_id = next(iter(nearest.items()))
         assert found_node_id == query_id and found_peer_id == ref_peer_id
@@ -238,10 +62,8 @@ def test_dht_node(
         query_id = DHTID.generate()
         k_nearest = random.randint(1, 10)
         exclude_self = random.random() > 0.5
-        nearest = loop.run_until_complete(
-            me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
-        )[query_id]
-        nearest_nodes = list(nearest)  # keys from ordered dict
+        find_result = await me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
+        nearest_nodes = list(find_result[query_id])  # keys from ordered dict
 
         assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
         assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
@@ -268,66 +90,63 @@ def test_dht_node(
 
     # test 4: find all nodes
     dummy = DHTID.generate()
-    nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
+    nearest = (await me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
     assert len(nearest) == len(dht) + 1
     assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
 
     # test 5: node without peers
-    detached_node = loop.run_until_complete(DHTNode.create())
-    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
+    detached_node = await DHTNode.create()
+    nearest = (await detached_node.find_nearest_nodes([dummy]))[dummy]
     assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.peer_id
-    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
+    nearest = (await detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
     assert len(nearest) == 0
 
     # test 6: store and get value
     true_time = get_dht_time() + 1200
-    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
+    assert await me.store("mykey", ["Value", 10], true_time)
 
     initial_peers = random.choice(swarm_maddrs)
-    that_guy = loop.run_until_complete(
-        DHTNode.create(
-            initial_peers=initial_peers,
-            parallel_rpc=parallel_rpc,
-            cache_refresh_before_expiry=False,
-            cache_locally=False,
-        )
+    that_guy = await DHTNode.create(
+        initial_peers=initial_peers,
+        parallel_rpc=parallel_rpc,
+        cache_refresh_before_expiry=False,
+        cache_locally=False,
     )
 
     for node in [me, that_guy]:
-        val, expiration_time = loop.run_until_complete(node.get("mykey"))
+        val, expiration_time = await node.get("mykey")
         assert val == ["Value", 10], "Wrong value"
         assert expiration_time == true_time, f"Wrong time"
 
-    assert loop.run_until_complete(detached_node.get("mykey")) is None
+    assert not await detached_node.get("mykey")
 
     # test 7: bulk store and bulk get
     keys = "foo", "bar", "baz", "zzz"
     values = 3, 2, "batman", [1, 2, 3]
-    store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
+    store_ok = await me.store_many(keys, values, expiration_time=get_dht_time() + 999)
     assert all(store_ok.values()), "failed to store one or more keys"
-    response = loop.run_until_complete(me.get_many(keys[::-1]))
+    response = await me.get_many(keys[::-1])
     for key, value in zip(keys, values):
         assert key in response and response[key][0] == value
 
     # test 8: store dictionaries as values (with sub-keys)
     upper_key, subkey1, subkey2, subkey3 = "ololo", "k1", "k2", "k3"
     now = get_dht_time()
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
+    assert await me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10)
+    assert await me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20)
     for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key))
+        value, time = await node.get(upper_key)
         assert isinstance(value, dict) and time == now + 20
         assert value[subkey1] == (123, now + 10)
         assert value[subkey2] == (456, now + 20)
         assert len(value) == 2
 
-    assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
-    loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
+    assert not await me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10)
+    assert await me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30)
+    assert await me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50)
 
     for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key))
+        value, time = await node.get(upper_key, latest=True)
         assert isinstance(value, dict) and time == now + 50, (value, time)
         assert value[subkey1] == (123, now + 10)
         assert value[subkey2] == (567, now + 30)
@@ -337,7 +156,7 @@ def test_dht_node(
     for proc in processes:
         proc.terminate()
     # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
-    loop.run_until_complete(asyncio.wait([node.shutdown() for node in [me, detached_node, that_guy]]))
+    await asyncio.gather(me.shutdown(), that_guy.shutdown(), detached_node.shutdown())
 
 
 @pytest.mark.forked

+ 163 - 0
tests/test_dht_protocol.py

@@ -0,0 +1,163 @@
+import asyncio
+import multiprocessing as mp
+import random
+import signal
+from typing import List, Sequence, Tuple
+
+import pytest
+from multiaddr import Multiaddr
+
+import hivemind
+from hivemind import P2P, PeerID, get_dht_time, get_logger
+from hivemind.dht import DHTID
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.storage import DictionaryDHTValue
+
+logger = get_logger(__name__)
+
+
+def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
+    return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
+
+
+def run_protocol_listener(
+    dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
+) -> None:
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+
+    p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
+    visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
+
+    protocol = loop.run_until_complete(
+        DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
+    )
+
+    logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
+
+    for peer_id in maddrs_to_peer_ids(initial_peers):
+        loop.run_until_complete(protocol.call_ping(peer_id))
+
+    maddr_conn.send((p2p.peer_id, visible_maddrs))
+
+    async def shutdown():
+        await p2p.shutdown()
+        logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
+        loop.stop()
+
+    loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
+    loop.run_forever()
+
+
+def launch_protocol_listener(
+    initial_peers: Sequence[Multiaddr] = (),
+) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
+    remote_conn, local_conn = mp.Pipe()
+    dht_id = DHTID.generate()
+    process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
+    process.start()
+    peer_id, visible_maddrs = local_conn.recv()
+
+    return dht_id, process, peer_id, visible_maddrs
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dht_protocol():
+    peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
+    peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
+
+    for client_mode in [True, False]:  # note: order matters, this test assumes that first run uses client mode
+        peer_id = DHTID.generate()
+        p2p = await P2P.create(initial_peers=peer1_maddrs)
+        protocol = await DHTProtocol.create(
+            p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
+        )
+        logger.info(f"Self id={protocol.node_id}")
+
+        assert peer1_node_id == await protocol.call_ping(peer1_id)
+
+        key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
+        store_ok = await protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        assert all(store_ok), "DHT rejected a trivial store"
+
+        # peer 1 must know about peer 2
+        (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [key]))[key]
+        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+        (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
+        assert (
+            recv_id == peer2_node_id and recv_peer_id == peer2_id
+        ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
+
+        assert recv_value == value and recv_expiration == expiration, (
+            f"call_find_value expected {value} (expires by {expiration}) "
+            f"but got {recv_value} (expires by {recv_expiration})"
+        )
+
+        # peer 2 must know about peer 1, but not have a *random* nonexistent value
+        dummy_key = DHTID.generate()
+        empty_item, nodes_found_2 = (await protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
+        assert empty_item is None, "Non-existent keys shouldn't have values"
+        (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
+        assert (
+            recv_id == peer1_node_id and recv_peer_id == peer1_id
+        ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
+
+        # cause a non-response by querying a nonexistent peer
+        assert not await protocol.call_find(PeerID.from_base58("fakeid"), [key])
+
+        # store/get a dictionary with sub-keys
+        nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
+        value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
+        assert await protocol.call_store(
+            peer1_id,
+            keys=[nested_key],
+            values=[hivemind.MSGPackSerializer.dumps(value1)],
+            expiration_time=[expiration],
+            subkeys=[subkey1],
+        )
+        assert await protocol.call_store(
+            peer1_id,
+            keys=[nested_key],
+            values=[hivemind.MSGPackSerializer.dumps(value2)],
+            expiration_time=[expiration + 5],
+            subkeys=[subkey2],
+        )
+        (recv_dict, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [nested_key]))[nested_key]
+        assert isinstance(recv_dict, DictionaryDHTValue)
+        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
+        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
+        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
+
+        if not client_mode:
+            await p2p.shutdown()
+
+    peer1_proc.terminate()
+    peer2_proc.terminate()
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_empty_table():
+    """Test RPC methods with empty routing table"""
+    peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
+
+    p2p = await P2P.create(initial_peers=peer_maddrs)
+    protocol = await DHTProtocol.create(
+        p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
+    )
+
+    key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
+
+    empty_item, nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
+    assert empty_item is None and len(nodes_found) == 0
+    assert all(await protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration))
+
+    (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
+    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+    assert len(nodes_found) == 0
+    assert recv_value == value and recv_expiration == expiration
+
+    assert peer_id == await protocol.call_ping(peer_peer_id)
+    assert not await protocol.call_ping(PeerID.from_base58("fakeid"))
+    peer_proc.terminate()

+ 31 - 34
tests/test_moe.py

@@ -3,9 +3,12 @@ import numpy as np
 import pytest
 import torch
 
-import hivemind
-from hivemind.moe.client.expert import DUMMY
-from hivemind.moe.server import background_server, declare_experts, layers
+from hivemind.dht import DHT
+from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
+from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
+from hivemind.moe.server.layers import name_to_block
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 @pytest.mark.forked
@@ -16,11 +19,9 @@ def test_moe():
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
     ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn."
-        )
+        dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
 
         for i in range(3):
             out = dmoe(torch.randn(10, 16))
@@ -35,9 +36,9 @@ def test_no_experts():
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
     ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
-        dmoe = hivemind.RemoteSwitchMixtureOfExperts(
+        dmoe = RemoteSwitchMixtureOfExperts(
             in_features=16,
             grid_size=(4, 4, 4),
             dht=dht,
@@ -74,10 +75,10 @@ def test_call_many(hidden_dim=16):
     ) as (server_endpoint, _):
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = hivemind.RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
+        e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
+        e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
 
-        mask, expert_outputs = hivemind.moe.client.moe._RemoteCallMany.apply(
+        mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
             k_min,
@@ -130,8 +131,8 @@ def test_remote_module_call(hidden_dim=16):
         optim_cls=None,
         no_dht=True,
     ) as (server_endpoint, _):
-        real_expert = hivemind.RemoteExpert("expert.0", server_endpoint)
-        fake_expert = hivemind.RemoteExpert("oiasfjiasjf", server_endpoint)
+        real_expert = RemoteExpert("expert.0", server_endpoint)
+        fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
 
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
@@ -152,12 +153,10 @@ def test_remote_module_call(hidden_dim=16):
 @pytest.mark.forked
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
-    dht = hivemind.DHT(start=True)
+    dht = DHT(start=True)
     assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
 
-    dmoe = hivemind.RemoteMixtureOfExperts(
-        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn."
-    )
+    dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
 
     for i in range(25):
         input = torch.randn(32)
@@ -174,7 +173,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[hivemind.RemoteExpert(uid, "") for uid in all_expert_uids]],
+            [[RemoteExpert(uid, "") for uid in all_expert_uids]],
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
@@ -197,7 +196,7 @@ def test_determinism(hidden_dim=16):
         optim_cls=None,
         no_dht=True,
     ) as (server_endpoint, _):
-        expert = hivemind.RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
+        expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
 
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -212,8 +211,8 @@ def test_determinism(hidden_dim=16):
 @pytest.mark.forked
 def test_compute_expert_scores():
     try:
-        dht = hivemind.DHT(start=True)
-        moe = hivemind.moe.RemoteMixtureOfExperts(
+        dht = DHT(start=True)
+        moe = RemoteMixtureOfExperts(
             dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
         )
         gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
@@ -221,13 +220,11 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
             [
-                hivemind.RemoteExpert(
-                    uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337"
-                )
+                RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
                 for expert_i in range(len(ii[batch_i]))
             ]
             for batch_i in range(len(ii))
-        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
+        ]  # note: these experts do not exist on server, we use them only to test compute_expert_scores
         logits = moe.compute_expert_scores([gx, gy], batch_experts)
         torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
         assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
@@ -247,25 +244,25 @@ def test_client_anomaly_detection():
 
     experts = {}
     for i in range(4):
-        expert = layers.name_to_block["ffn"](HID_DIM)
-        experts[f"expert.{i}"] = hivemind.ExpertBackend(
+        expert = name_to_block["ffn"](HID_DIM)
+        experts[f"expert.{i}"] = ExpertBackend(
             name=f"expert.{i}",
             expert=expert,
             optimizer=torch.optim.Adam(expert.parameters()),
-            args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
-            outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
+            args_schema=(BatchTensorDescriptor(HID_DIM),),
+            outputs_schema=BatchTensorDescriptor(HID_DIM),
             max_batch_size=16,
         )
 
     experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
 
-    dht = hivemind.DHT(start=True)
-    server = hivemind.moe.Server(dht, experts, num_connection_handlers=1)
+    dht = DHT(start=True)
+    server = Server(dht, experts, num_connection_handlers=1)
     server.start()
     try:
         server.ready.wait()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
+        dmoe = RemoteMixtureOfExperts(
             in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
 
@@ -282,7 +279,7 @@ def test_client_anomaly_detection():
         with pytest.raises(ValueError):
             inf_loss.backward()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
+        dmoe = RemoteMixtureOfExperts(
             in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         output = dmoe(input)

+ 385 - 0
tests/test_optimizer.py

@@ -0,0 +1,385 @@
+import ctypes
+import multiprocessing as mp
+import time
+from functools import partial
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import hivemind
+from hivemind.averaging.control import AveragingStage
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.optimizer import Optimizer
+from hivemind.optim.progress_tracker import ProgressTracker
+from hivemind.optim.state_averager import TrainingStateAverager
+from hivemind.utils.crypto import RSAPrivateKey
+
+
+@pytest.mark.forked
+def test_grad_averager():
+    dht1 = hivemind.DHT(start=True)
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager1 = GradientAverager(
+        model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
+    )
+
+    dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager2 = GradientAverager(
+        model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
+    )
+
+    control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
+    control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
+
+    for i in range(10):
+        time.sleep(0.1)
+        if i % 3 == 0:
+            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1.backward()
+            averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
+            model1.zero_grad()
+        else:
+            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2.backward()
+            averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
+            # note: we do not call zero grad here because reuse_grad_buffers=True
+
+    assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
+    peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
+    assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
+    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
+
+    assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
+    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
+
+    averager1.step(control=control1, wait=False)
+    averager2.step(control=control2, wait=False)
+    for step in (control1, control2):
+        step.result()  # wait for all-reduce to finish
+
+    peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
+    peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
+    ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
+    with averager1.use_averaged_gradients():
+        assert torch.allclose(model1.w.grad, ref_average)
+    with averager2.use_averaged_gradients():
+        assert torch.allclose(model2.w.grad, ref_average)
+
+    # after no longer use_averaged_gradients
+    assert not torch.allclose(model1.w.grad, ref_average)
+    assert not torch.allclose(model2.w.grad, ref_average)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
+    [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
+)
+def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    torch.manual_seed(1337)
+    torch.use_deterministic_algorithms(True)
+    # note: use_deterministic_algorithms does not affect further tests because this test is forked
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    extras1 = (torch.randn(2, 2), -torch.rand(1))
+    extras2 = (-torch.randn(2, 2), torch.rand(1))
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        sync_epoch_when_averaging=sync_epoch_when_averaging,
+        average_opt_statistics=("exp_avg_sq",),
+        offload_optimizer=offload_optimizer,
+        reuse_tensors=reuse_tensors,
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, params=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
+    )
+    avgr2 = TrainingStateAverager(
+        dht=dht2, params=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
+    )
+
+    x = torch.ones(2)
+
+    for step in range(20):
+        F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
+        avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
+
+        F.mse_loss(model2(x), -torch.ones(3)).backward()
+        avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
+
+    assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
+    assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
+    assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
+    assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
+
+    stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    assert not torch.allclose(stats1, stats2)
+
+    avgr1.step(increment_epoch=True)
+
+    avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+    avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+
+    avgr1.step(wait_for_delayed_updates=True)
+    avgr2.step(wait_for_delayed_updates=True)
+
+    assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
+    assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert avgr1.local_epoch == 2
+    assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
+
+
+@pytest.mark.forked
+def test_load_state_from_peers():
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.SGD, lr=0.1),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+    )
+
+    avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
+
+    avgr2.local_epoch = 1337
+    model2.weight.data[...] = 42
+    time.sleep(0.1)
+
+    avgr1.load_state_from_peers()
+    assert avgr1.local_epoch == 1337
+    assert torch.all(model1.weight == 42).item()
+    assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
+
+
+@pytest.mark.forked
+def test_progress_tracker():
+    # note to a curious reader: no, you cannot reduce the timings without compromising realism or stability
+    prefix = "my_exp"
+    target_batch_size = 256
+    dht_root = hivemind.DHT(start=True)
+    barrier = mp.Barrier(parties=5)
+    delayed_start_evt = mp.Event()
+    finished_evt = mp.Event()
+    emas = mp.Array(ctypes.c_double, 5)
+
+    def run_worker(index: int, batch_size: int, period: float, **kwargs):
+        dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
+        tracker = ProgressTracker(
+            dht,
+            prefix,
+            target_batch_size,
+            start=True,
+            min_refresh_period=0.1,
+            default_refresh_period=0.2,
+            max_refresh_period=0.5,
+            private_key=RSAPrivateKey(),
+            **kwargs,
+        )
+
+        barrier.wait()
+        if index == 4:
+            delayed_start_evt.wait()
+
+        local_epoch = 2 if index == 4 else 0
+        samples_accumulated = 0
+
+        while True:
+            time.sleep(period)
+            if finished_evt.is_set():
+                break
+
+            samples_accumulated += batch_size
+            tracker.report_local_progress(local_epoch, samples_accumulated)
+
+            if tracker.ready_to_update_epoch:
+                if index == 4 and local_epoch >= 4:
+                    time.sleep(0.5)
+                    break
+
+                with tracker.pause_updates():
+                    local_epoch = tracker.update_epoch(local_epoch + 1)
+                    samples_accumulated = 0
+
+        emas[index] = tracker.performance_ema.samples_per_second
+        tracker.shutdown()
+        dht.shutdown()
+
+    workers = [
+        mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
+        mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
+        mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
+        mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
+    ]
+    for worker in workers:
+        worker.start()
+
+    tracker = ProgressTracker(
+        dht_root,
+        prefix,
+        target_batch_size,
+        start=True,
+        min_refresh_period=0.1,
+        default_refresh_period=0.2,
+        max_refresh_period=0.5,
+    )
+    barrier.wait()
+
+    local_epoch = 0
+    last_timestamp = hivemind.get_dht_time()
+    step_time_deltas = []
+
+    while local_epoch < 6:
+        time.sleep(0.1)
+
+        if tracker.ready_to_update_epoch:
+            with tracker.pause_updates():
+                local_epoch = tracker.update_epoch(local_epoch + 1)
+
+            time_delta = hivemind.get_dht_time() - last_timestamp
+            if local_epoch == 2:
+                delayed_start_evt.set()
+
+            last_timestamp = hivemind.get_dht_time()
+            step_time_deltas.append(time_delta)
+
+    finished_evt.set()
+    for worker in workers:
+        worker.join()
+
+    tracker.shutdown()
+    dht_root.shutdown()
+    assert not tracker.is_alive()
+
+    mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
+    for i in (0, 1, 5):  # Without the 4th worker (the fastest one)
+        assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
+    for i in (2, 3, 4):  # With the 4th worker
+        assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
+    assert emas[1] < emas[2] < emas[3] < emas[4]
+    assert tracker.performance_ema.samples_per_second < 1e-9
+
+
+@pytest.mark.forked
+def test_optimizer(
+    num_peers: int = 1,
+    num_clients: int = 0,
+    target_batch_size: int = 32,
+    total_epochs: int = 3,
+    reuse_grad_buffers: bool = True,
+    delay_grad_averaging: bool = True,
+    delay_optimizer_step: bool = True,
+    average_state_every: int = 1,
+):
+    dht = hivemind.DHT(start=True)
+
+    features = torch.randn(100, 5)
+    targets = features @ torch.randn(5, 1)
+    optimizer = None
+    total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
+        nonlocal optimizer
+        model = nn.Linear(5, 1)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id="test_run",
+            target_batch_size=target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=0.1),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
+            averager_opts=dict(request_timeout=0.5),
+            matchmaking_time=1.0,
+            averaging_timeout=5.0,
+            reuse_grad_buffers=reuse_grad_buffers,
+            delay_grad_averaging=delay_grad_averaging,
+            delay_optimizer_step=delay_optimizer_step,
+            average_state_every=average_state_every,
+            client_mode=client_mode,
+            verbose=False,
+        )
+        optimizer.load_state_from_peers()
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < total_epochs:
+            time.sleep(max(0.0, prev_time + batch_time - time.perf_counter()))
+            batch = torch.randint(0, len(features), (batch_size,))
+
+            loss = F.mse_loss(model(features[batch]), targets[batch])
+            loss.backward()
+
+            optimizer.step()
+
+            total_samples_accumulated.value += batch_size
+
+            if not reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(1.0)
+        optimizer.shutdown()
+        return optimizer
+
+    peers = []
+
+    for index in range(num_peers):
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                kwargs=dict(
+                    batch_size=4 + index,
+                    batch_time=0.3 + 0.2 * index,
+                    client_mode=(index >= num_peers - num_clients),
+                ),
+            )
+        )
+
+    for peer in peers[1:]:
+        peer.start()
+    peers[0].run()
+    for peer in peers[1:]:
+        peer.join()
+
+    assert isinstance(optimizer, Optimizer)
+    assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
+    expected_samples_accumulated = target_batch_size * total_epochs
+    assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
+    assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
+
+    assert not optimizer.state_averager.is_alive()
+    assert not optimizer.grad_averager.is_alive()
+    assert not optimizer.tracker.is_alive()
+    assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()

+ 41 - 4
tests/test_p2p_daemon.py

@@ -9,8 +9,9 @@ import numpy as np
 import pytest
 from multiaddr import Multiaddr
 
-from hivemind.p2p import P2P, P2PHandlerError
-from hivemind.proto import dht_pb2
+from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
+from hivemind.proto import dht_pb2, test_pb2
+from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 
 
@@ -33,6 +34,17 @@ async def test_daemon_killed_on_del():
     assert not is_process_running(child_pid)
 
 
+@pytest.mark.asyncio
+async def test_startup_error_message():
+    with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
+        await P2P.create(
+            initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
+        )
+
+    with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
+        await P2P.create(startup_timeout=0.01)  # Test that startup_timeout works
+
+
 @pytest.mark.parametrize(
     "host_maddrs",
     [
@@ -51,9 +63,9 @@ async def test_transports(host_maddrs: List[Multiaddr]):
     await client.wait_for_at_least_n_peers(1)
 
     peers = await client.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
     peers = await server.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
 
 
 @pytest.mark.asyncio
@@ -71,6 +83,31 @@ async def test_daemon_replica_does_not_affect_primary():
     assert not is_process_running(child_pid)
 
 
+@pytest.mark.asyncio
+async def test_unary_handler_edge_cases():
+    p2p = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p.daemon_listen_maddr)
+
+    async def square_handler(data: test_pb2.TestRequest, context):
+        return test_pb2.TestResponse(number=data.number ** 2)
+
+    await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler
+    with pytest.raises(P2PDaemonError):
+        await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler from replicated p2p
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try dialing yourself
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.call_protobuf_handler(
+            p2p.peer_id, "square", test_pb2.TestRequest(number=41), test_pb2.TestResponse
+        )
+
+
 @pytest.mark.parametrize(
     "should_cancel,replicate",
     [

Деякі файли не було показано, через те що забагато файлів було змінено