Kaynağa Gözat

Merge branch 'master' into server-p2p

Pavel Samygin 3 yıl önce
ebeveyn
işleme
2f4b6fe98e
91 değiştirilmiş dosya ile 5699 ekleme ve 1675 silme
  1. 7 2
      .github/workflows/check-style.yml
  2. 0 1
      .github/workflows/push-docker-image.yml
  3. 37 0
      .github/workflows/run-benchmarks.yml
  4. 9 4
      .github/workflows/run-tests.yml
  5. 46 42
      README.md
  6. 178 77
      benchmarks/benchmark_dht.py
  7. 162 0
      benchmarks/benchmark_optimizer.py
  8. 13 13
      benchmarks/benchmark_throughput.py
  9. 1 1
      benchmarks/benchmark_throughput_p2p.py
  10. BIN
      docs/_static/dht.odp
  11. BIN
      docs/_static/dht.png
  12. 1 1
      docs/conf.py
  13. 3 3
      docs/index.rst
  14. 26 4
      docs/modules/optim.rst
  15. 8 6
      docs/modules/server.rst
  16. 35 32
      docs/user/quickstart.md
  17. 33 32
      examples/albert/README.md
  18. 23 17
      examples/albert/arguments.py
  19. 5 5
      examples/albert/requirements.txt
  20. 98 75
      examples/albert/run_trainer.py
  21. 23 25
      examples/albert/run_training_monitor.py
  22. 1 1
      examples/albert/tokenize_wikitext103.py
  23. 2 2
      examples/albert/utils.py
  24. 5 2
      hivemind/__init__.py
  25. 0 1
      hivemind/averaging/__init__.py
  26. 194 86
      hivemind/averaging/allreduce.py
  27. 192 103
      hivemind/averaging/averager.py
  28. 165 0
      hivemind/averaging/control.py
  29. 21 4
      hivemind/averaging/key_manager.py
  30. 2 2
      hivemind/averaging/load_balancing.py
  31. 75 55
      hivemind/averaging/matchmaking.py
  32. 71 26
      hivemind/averaging/partition.py
  33. 1 44
      hivemind/compression/__init__.py
  34. 8 0
      hivemind/compression/base.py
  35. 2 2
      hivemind/compression/quantization.py
  36. 43 0
      hivemind/compression/serialization.py
  37. 3 322
      hivemind/dht/__init__.py
  38. 324 0
      hivemind/dht/dht.py
  39. 1 1
      hivemind/dht/node.py
  40. 2 2
      hivemind/dht/routing.py
  41. 1 1
      hivemind/dht/schema.py
  42. 2 2
      hivemind/hivemind_cli/run_server.py
  43. 1 0
      hivemind/moe/__init__.py
  44. 5 9
      hivemind/moe/client/expert.py
  45. 6 6
      hivemind/moe/client/moe.py
  46. 1 1
      hivemind/moe/client/switch_moe.py
  47. 4 348
      hivemind/moe/server/__init__.py
  48. 6 7
      hivemind/moe/server/connection_handler.py
  49. 2 2
      hivemind/moe/server/expert_backend.py
  50. 2 73
      hivemind/moe/server/expert_uid.py
  51. 412 0
      hivemind/moe/server/server.py
  52. 3 0
      hivemind/optim/__init__.py
  53. 1 1
      hivemind/optim/adaptive.py
  54. 8 0
      hivemind/optim/base.py
  55. 75 32
      hivemind/optim/collaborative.py
  56. 226 0
      hivemind/optim/grad_averager.py
  57. 125 0
      hivemind/optim/grad_scaler.py
  58. 779 0
      hivemind/optim/optimizer.py
  59. 0 41
      hivemind/optim/performance_ema.py
  60. 363 0
      hivemind/optim/progress_tracker.py
  61. 9 5
      hivemind/optim/simple.py
  62. 723 0
      hivemind/optim/state_averager.py
  63. 1 1
      hivemind/optim/training_averager.py
  64. 26 18
      hivemind/p2p/p2p_daemon.py
  65. 1 1
      hivemind/p2p/p2p_daemon_bindings/control.py
  66. 9 26
      hivemind/p2p/servicer.py
  67. 1 1
      hivemind/proto/averaging.proto
  68. 1 0
      hivemind/utils/__init__.py
  69. 57 11
      hivemind/utils/asyncio.py
  70. 1 1
      hivemind/utils/grpc.py
  71. 1 1
      hivemind/utils/limits.py
  72. 10 3
      hivemind/utils/logging.py
  73. 7 3
      hivemind/utils/mpfuture.py
  74. 70 0
      hivemind/utils/performance_ema.py
  75. 2 2
      hivemind/utils/serializer.py
  76. 54 5
      hivemind/utils/tensor_descr.py
  77. 1 1
      pyproject.toml
  78. 6 4
      requirements-dev.txt
  79. 4 2
      requirements-docs.txt
  80. 5 2
      tests/conftest.py
  81. 11 11
      tests/test_allreduce.py
  82. 213 0
      tests/test_allreduce_fault_tolerance.py
  83. 126 4
      tests/test_averaging.py
  84. 1 1
      tests/test_compression.py
  85. 2 2
      tests/test_dht.py
  86. 31 34
      tests/test_moe.py
  87. 385 0
      tests/test_optimizer.py
  88. 3 3
      tests/test_p2p_daemon.py
  89. 7 7
      tests/test_p2p_daemon_bindings.py
  90. 8 6
      tests/test_p2p_servicer.py
  91. 81 4
      tests/test_util_modules.py

+ 7 - 2
.github/workflows/check-style.yml

@@ -1,6 +1,9 @@
 name: Check style
 
-on: [ push ]
+on:
+  push:
+    branches: [ master ]
+  pull_request:
 
 jobs:
   black:
@@ -10,7 +13,7 @@ jobs:
       - uses: psf/black@stable
         with:
           options: "--check --diff"
-          version: "21.6b0"
+          version: "22.1.0"
   isort:
     runs-on: ubuntu-latest
     steps:
@@ -19,3 +22,5 @@ jobs:
         with:
           python-version: 3.8
       - uses: isort/isort-action@master
+        with:
+          isortVersion: "5.10.1"

+ 0 - 1
.github/workflows/push-docker-image.yml

@@ -8,7 +8,6 @@ on:
   pull_request:
     branches: [ master ]
 
-
 jobs:
   build:
     runs-on: ubuntu-latest

+ 37 - 0
.github/workflows/run-benchmarks.yml

@@ -0,0 +1,37 @@
+name: Benchmarks
+
+on:
+  push:
+    branches: [ master ]
+  pull_request:
+
+jobs:
+  run_benchmarks:
+
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: 3.9
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-3.9-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install .
+      - name: Benchmark
+        run: |
+          cd benchmarks
+          python benchmark_throughput.py --preset minimalistic
+          python benchmark_tensor_compression.py
+          python benchmark_dht.py

+ 9 - 4
.github/workflows/run-tests.yml

@@ -1,7 +1,9 @@
 name: Tests
 
-on: [ push ]
-
+on:
+  push:
+    branches: [ master ]
+  pull_request:
 
 jobs:
   run_tests:
@@ -33,6 +35,7 @@ jobs:
       - name: Test
         run: |
           cd tests
+          export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
           pytest --durations=0 --durations-min=1.0 -v
   build_and_test_p2pd:
     runs-on: ubuntu-latest
@@ -59,6 +62,7 @@ jobs:
       - name: Test
         run: |
           cd tests
+          export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
           pytest -k "p2p" -v
   codecov_in_develop_mode:
 
@@ -82,9 +86,10 @@ jobs:
           pip install -r requirements-dev.txt
       - name: Build hivemind
         run: |
-          pip install -e .
+          pip install -e . --no-use-pep517
       - name: Test
         run: |
-          pytest --cov=hivemind -v tests
+          export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
+          pytest --cov hivemind -v tests
       - name: Upload coverage to Codecov
         uses: codecov/codecov-action@v1

+ 46 - 42
README.md

@@ -1,7 +1,7 @@
 ## Hivemind: decentralized deep learning in PyTorch
 
 [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest)
-[![PyPI version](https://img.shields.io/pypi/v/hivemind.svg)](https://pypi.org/project/hivemind/)
+[![PyPI version](https://img.shields.io/pypi/v/hivemind.svg?color=blue)](https://pypi.org/project/hivemind/)
 [![Discord](https://img.shields.io/static/v1?style=default&label=Discord&logo=discord&message=join)](https://discord.gg/uGugx9zYvN)
 [![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions)
 ![Codecov](https://img.shields.io/codecov/c/github/learning-at-home/hivemind)
@@ -12,6 +12,10 @@ large model on hundreds of computers from different universities, companies, and
 
 ![img](https://i.imgur.com/GPxolxb.gif)
 
+## Live Demo
+
+Check out our NeurIPS 2021 demonstration ["Training Transformers Together"](https://training-transformers-together.github.io/) to see hivemind in action, join an ongoing collaborative experiment, and learn more about the technologies behind it!
+
 ## Key Features
 
 * Distributed training without a master node: Distributed Hash Table allows connecting computers in a decentralized
@@ -23,8 +27,8 @@ large model on hundreds of computers from different universities, companies, and
 * Train neural networks of arbitrary size: parts of their layers are distributed across the participants with the
   Decentralized Mixture-of-Experts ([paper](https://arxiv.org/abs/2002.04013)).
 
-To learn more about the ideas behind this library, see https://learning-at-home.github.io or read
-the [NeurIPS 2020 paper](https://arxiv.org/abs/2002.04013).
+To learn more about the ideas behind this library,
+see the [full list](https://github.com/learning-at-home/hivemind/tree/refer-to-discord-in-docs#citation) of our papers below.
 
 ## Installation
 
@@ -52,7 +56,7 @@ cd hivemind
 pip install .
 ```
 
-If you would like to verify that your installation is working properly, you can install with `pip install -e .[dev]`
+If you would like to verify that your installation is working properly, you can install with `pip install .[dev]`
 instead. Then, you can run the tests with `pytest tests/`.
 
 By default, hivemind uses the precompiled binary of
@@ -65,8 +69,8 @@ of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
 
 - __Linux__ is the default OS for which hivemind is developed and tested. We recommend Ubuntu 18.04+ (64-bit), but
   other 64-bit distros should work as well. Legacy 32-bit is not recommended.
-- __macOS 10.x__ mostly works but requires building hivemind from source, and some edge cases may fail. To ensure full
-  compatibility, we recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
+- __macOS 10.x__ can run hivemind using [Docker](https://docs.docker.com/desktop/mac/install/).
+  We recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
 - __Windows 10+ (experimental)__ can run hivemind
   using [WSL](https://docs.microsoft.com/ru-ru/windows/wsl/install-win10). You can configure WSL to use GPU by
   following sections 1–3 of [this guide](https://docs.nvidia.com/cuda/wsl-user-guide/index.html) by NVIDIA. After
@@ -83,17 +87,17 @@ of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
 * API reference and additional tutorials are available
   at [learning-at-home.readthedocs.io](https://learning-at-home.readthedocs.io)
 
-If you have any questions about installing and using hivemind, you can ask them in
+If you have any questions about installing and using hivemind, feel free to ask them in
 [our Discord chat](https://discord.gg/uGugx9zYvN) or file an [issue](https://github.com/learning-at-home/hivemind/issues).
 
 ## Contributing
 
 Hivemind is currently at the active development stage, and we welcome all contributions. Everything, from bug fixes and
-documentation improvements to entirely new features, is equally appreciated.
+documentation improvements to entirely new features, is appreciated.
 
 If you want to contribute to hivemind but don't know where to start, take a look at the
 unresolved [issues](https://github.com/learning-at-home/hivemind/issues). Open a new issue or
-join [our chat room](https://discord.gg/xC7ucM8j) in case you want to discuss new functionality or report a possible
+join [our chat room](https://discord.gg/uGugx9zYvN) in case you want to discuss new functionality or report a possible
 bug. Bug fixes are always welcome, but new features should be preferably discussed with maintainers beforehand.
 
 If you want to start contributing to the source code of hivemind, please see
@@ -105,9 +109,9 @@ our [guide](https://learning-at-home.readthedocs.io/en/latest/user/contributing.
 
 If you found hivemind or its underlying algorithms useful for your research, please cite the following source:
 
-```
+```bibtex
 @misc{hivemind,
-  author = {Learning@home team},
+  author = {Learning{@}home team},
   title = {{H}ivemind: a {L}ibrary for {D}ecentralized {D}eep {L}earning},
   year = 2020,
   howpublished = {\url{https://github.com/learning-at-home/hivemind}},
@@ -118,17 +122,17 @@ Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired t
 (prototype implementation of hivemind available
 at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
 
-```
+```bibtex
 @inproceedings{ryabinin2020crowdsourced,
- author = {Ryabinin, Max and Gusev, Anton},
- booktitle = {Advances in Neural Information Processing Systems},
- editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
- pages = {3659--3672},
- publisher = {Curran Associates, Inc.},
- title = {Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts},
- url = {https://proceedings.neurips.cc/paper/2020/file/25ddc0f8c9d3e22e03d3076f98d83cb2-Paper.pdf},
- volume = {33},
- year = {2020}
+  author = {Ryabinin, Max and Gusev, Anton},
+  booktitle = {Advances in Neural Information Processing Systems},
+  editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
+  pages = {3659--3672},
+  publisher = {Curran Associates, Inc.},
+  title = {Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts},
+  url = {https://proceedings.neurips.cc/paper/2020/file/25ddc0f8c9d3e22e03d3076f98d83cb2-Paper.pdf},
+  volume = {33},
+  year = {2020}
 }
 ```
 
@@ -137,40 +141,40 @@ at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
 
 ["Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices"](https://arxiv.org/abs/2103.03239)
 
-```
+```bibtex
 @misc{ryabinin2021moshpit,
-      title={Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices}, 
-      author={Max Ryabinin and Eduard Gorbunov and Vsevolod Plokhotnyuk and Gennady Pekhimenko},
-      year={2021},
-      eprint={2103.03239},
-      archivePrefix={arXiv},
-      primaryClass={cs.LG}
+  title = {Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices},
+  author = {Max Ryabinin and Eduard Gorbunov and Vsevolod Plokhotnyuk and Gennady Pekhimenko},
+  year = {2021},
+  eprint = {2103.03239},
+  archivePrefix = {arXiv},
+  primaryClass = {cs.LG}
 }
 ```
 
 ["Distributed Deep Learning in Open Collaborations"](https://arxiv.org/abs/2106.10207)
 
-```
+```bibtex
 @misc{diskin2021distributed,
-      title={Distributed Deep Learning in Open Collaborations}, 
-      author={Michael Diskin and Alexey Bukhtiyarov and Max Ryabinin and Lucile Saulnier and Quentin Lhoest and Anton Sinitsin and Dmitry Popov and Dmitry Pyrkin and Maxim Kashirin and Alexander Borzunov and Albert Villanova del Moral and Denis Mazur and Ilia Kobelev and Yacine Jernite and Thomas Wolf and Gennady Pekhimenko},
-      year={2021},
-      eprint={2106.10207},
-      archivePrefix={arXiv},
-      primaryClass={cs.LG}
+  title = {Distributed Deep Learning in Open Collaborations},
+  author = {Michael Diskin and Alexey Bukhtiyarov and Max Ryabinin and Lucile Saulnier and Quentin Lhoest and Anton Sinitsin and Dmitry Popov and Dmitry Pyrkin and Maxim Kashirin and Alexander Borzunov and Albert Villanova del Moral and Denis Mazur and Ilia Kobelev and Yacine Jernite and Thomas Wolf and Gennady Pekhimenko},
+  year = {2021},
+  eprint = {2106.10207},
+  archivePrefix = {arXiv},
+  primaryClass = {cs.LG}
 }
 ```
 
 ["Secure Distributed Training at Scale"](https://arxiv.org/abs/2106.11257)
 
-```
+```bibtex
 @misc{gorbunov2021secure,
-      title={Secure Distributed Training at Scale}, 
-      author={Eduard Gorbunov and Alexander Borzunov and Michael Diskin and Max Ryabinin},
-      year={2021},
-      eprint={2106.11257},
-      archivePrefix={arXiv},
-      primaryClass={cs.LG}
+  title = {Secure Distributed Training at Scale},
+  author = {Eduard Gorbunov and Alexander Borzunov and Michael Diskin and Max Ryabinin},
+  year = {2021},
+  eprint = {2106.11257},
+  archivePrefix = {arXiv},
+  primaryClass = {cs.LG}
 }
 ```
 

+ 178 - 77
benchmarks/benchmark_dht.py

@@ -1,11 +1,15 @@
 import argparse
+import asyncio
 import random
 import time
+import uuid
+from logging import shutdown
+from typing import Tuple
 
+import numpy as np
 from tqdm import trange
 
 import hivemind
-from hivemind.moe.server import declare_experts, get_experts
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
@@ -13,23 +17,116 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
-def random_endpoint() -> hivemind.Endpoint:
-    return (
-        f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}."
-        f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
-    )
+class NodeKiller:
+    """Auxiliary class that kills dht nodes over a pre-defined schedule"""
+
+    def __init__(self, shutdown_peers: list, shutdown_timestamps: list):
+        self.shutdown_peers = set(shutdown_peers)
+        self.shutdown_timestamps = shutdown_timestamps
+        self.current_iter = 0
+        self.timestamp_iter = 0
+        self.lock = asyncio.Lock()
+
+    async def check_and_kill(self):
+        async with self.lock:
+            if (
+                self.shutdown_timestamps != None
+                and self.timestamp_iter < len(self.shutdown_timestamps)
+                and self.current_iter == self.shutdown_timestamps[self.timestamp_iter]
+            ):
+                shutdown_peer = random.sample(self.shutdown_peers, 1)[0]
+                shutdown_peer.shutdown()
+                self.shutdown_peers.remove(shutdown_peer)
+                self.timestamp_iter += 1
+            self.current_iter += 1
+
+
+async def store_and_get_task(
+    peers: list,
+    total_num_rounds: int,
+    num_store_peers: int,
+    num_get_peers: int,
+    wait_after_iteration: float,
+    delay: float,
+    expiration: float,
+    latest: bool,
+    node_killer: NodeKiller,
+) -> Tuple[list, list, list, list, int, int]:
+    """Iteratively choose random peers to store data onto the dht, then retreive with another random subset of peers"""
+
+    total_stores = total_gets = 0
+    successful_stores = []
+    successful_gets = []
+    store_times = []
+    get_times = []
+
+    for _ in range(total_num_rounds):
+        key = uuid.uuid4().hex
+
+        store_start = time.perf_counter()
+        store_peers = random.sample(peers, min(num_store_peers, len(peers)))
+        store_subkeys = [uuid.uuid4().hex for _ in store_peers]
+        store_values = {subkey: uuid.uuid4().hex for subkey in store_subkeys}
+        store_tasks = [
+            peer.store(
+                key,
+                subkey=subkey,
+                value=store_values[subkey],
+                expiration_time=hivemind.get_dht_time() + expiration,
+                return_future=True,
+            )
+            for peer, subkey in zip(store_peers, store_subkeys)
+        ]
+        store_result = await asyncio.gather(*store_tasks)
+        await node_killer.check_and_kill()
+
+        store_times.append(time.perf_counter() - store_start)
+
+        total_stores += len(store_result)
+        successful_stores_per_iter = sum(store_result)
+        successful_stores.append(successful_stores_per_iter)
+        await asyncio.sleep(delay)
+
+        get_start = time.perf_counter()
+        get_peers = random.sample(peers, min(num_get_peers, len(peers)))
+        get_tasks = [peer.get(key, latest, return_future=True) for peer in get_peers]
+        get_result = await asyncio.gather(*get_tasks)
+        get_times.append(time.perf_counter() - get_start)
+
+        successful_gets_per_iter = 0
+
+        total_gets += len(get_result)
+        for result in get_result:
+            if result != None:
+                attendees, expiration = result
+                if len(attendees.keys()) == successful_stores_per_iter:
+                    get_ok = True
+                    for key in attendees:
+                        if attendees[key][0] != store_values[key]:
+                            get_ok = False
+                            break
+                    successful_gets_per_iter += get_ok
 
+        successful_gets.append(successful_gets_per_iter)
+        await asyncio.sleep(wait_after_iteration)
 
-def benchmark_dht(
+    return store_times, get_times, successful_stores, successful_gets, total_stores, total_gets
+
+
+async def benchmark_dht(
     num_peers: int,
     initial_peers: int,
-    num_experts: int,
-    expert_batch_size: int,
     random_seed: int,
-    wait_after_request: float,
-    wait_before_read: float,
+    num_threads: int,
+    total_num_rounds: int,
+    num_store_peers: int,
+    num_get_peers: int,
+    wait_after_iteration: float,
+    delay: float,
     wait_timeout: float,
     expiration: float,
+    latest: bool,
+    failure_rate: float,
 ):
     random.seed(random_seed)
 
@@ -42,88 +139,92 @@ def benchmark_dht(
         peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
         peers.append(peer)
 
-    store_peer, get_peer = peers[-2:]
-
-    expert_uids = list(
-        set(
-            f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
-            for _ in range(num_experts)
-        )
-    )
-    logger.info(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
-    random.shuffle(expert_uids)
-
-    logger.info(f"Storing experts to dht in batches of {expert_batch_size}...")
-    successful_stores = total_stores = total_store_time = 0
     benchmark_started = time.perf_counter()
-    endpoints = []
-
-    for start in trange(0, num_experts, expert_batch_size):
-        store_start = time.perf_counter()
-        endpoints.append(random_endpoint())
-        store_ok = declare_experts(
-            store_peer, expert_uids[start : start + expert_batch_size], endpoints[-1], expiration=expiration
+    logger.info("Creating store and get tasks...")
+    shutdown_peers = random.sample(peers, min(int(failure_rate * num_peers), num_peers))
+    assert len(shutdown_peers) != len(peers)
+    remaining_peers = list(set(peers) - set(shutdown_peers))
+    shutdown_timestamps = random.sample(
+        range(0, num_threads * total_num_rounds), min(len(shutdown_peers), num_threads * total_num_rounds)
+    )
+    shutdown_timestamps.sort()
+    node_killer = NodeKiller(shutdown_peers, shutdown_timestamps)
+    task_list = [
+        asyncio.create_task(
+            store_and_get_task(
+                remaining_peers,
+                total_num_rounds,
+                num_store_peers,
+                num_get_peers,
+                wait_after_iteration,
+                delay,
+                expiration,
+                latest,
+                node_killer,
+            )
         )
-        successes = store_ok.values()
-        total_store_time += time.perf_counter() - store_start
-
-        total_stores += len(successes)
-        successful_stores += sum(successes)
-        time.sleep(wait_after_request)
+        for _ in trange(num_threads)
+    ]
+
+    store_and_get_result = await asyncio.gather(*task_list)
+    benchmark_total_time = time.perf_counter() - benchmark_started
+    total_store_times = []
+    total_get_times = []
+    total_successful_stores = []
+    total_successful_gets = []
+    total_stores = total_gets = 0
+    for result in store_and_get_result:
+        store_times, get_times, successful_stores, successful_gets, stores, gets = result
+
+        total_store_times.extend(store_times)
+        total_get_times.extend(get_times)
+        total_successful_stores.extend(successful_stores)
+        total_successful_gets.extend(successful_gets)
+        total_stores += stores
+        total_gets += gets
 
+    alive_peers = [peer.is_alive() for peer in peers]
     logger.info(
-        f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})"
+        f"Store wall time (sec.): mean({np.mean(total_store_times):.3f}) "
+        + f"std({np.std(total_store_times, ddof=1):.3f}) max({np.max(total_store_times):.3f})"
     )
-    logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
-    time.sleep(wait_before_read)
-
-    if time.perf_counter() - benchmark_started > expiration:
-        logger.warning("All keys expired before benchmark started getting them. Consider increasing expiration_time")
-
-    successful_gets = total_get_time = 0
-
-    for start in trange(0, len(expert_uids), expert_batch_size):
-        get_start = time.perf_counter()
-        get_result = get_experts(get_peer, expert_uids[start : start + expert_batch_size])
-        total_get_time += time.perf_counter() - get_start
-
-        for i, expert in enumerate(get_result):
-            if (
-                expert is not None
-                and expert.uid == expert_uids[start + i]
-                and expert.endpoint == endpoints[start // expert_batch_size]
-            ):
-                successful_gets += 1
-
-    if time.perf_counter() - benchmark_started > expiration:
-        logger.warning(
-            "keys expired midway during get requests. If that isn't desired, increase expiration_time param"
-        )
-
     logger.info(
-        f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})"
+        f"Get wall time (sec.): mean({np.mean(total_get_times):.3f}) "
+        + f"std({np.std(total_get_times, ddof=1):.3f}) max({np.max(total_get_times):.3f})"
+    )
+    logger.info(f"Average store time per worker: {sum(total_store_times) / num_threads:.3f} sec.")
+    logger.info(f"Average get time per worker: {sum(total_get_times) / num_threads:.3f} sec.")
+    logger.info(f"Total benchmark time: {benchmark_total_time:.5f} sec.")
+    logger.info(
+        "Store success rate: "
+        + f"{sum(total_successful_stores) / total_stores * 100:.1f}% ({sum(total_successful_stores)}/{total_stores})"
+    )
+    logger.info(
+        "Get success rate: "
+        + f"{sum(total_successful_gets) / total_gets * 100:.1f}% ({sum(total_successful_gets)}/{total_gets})"
     )
-    logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
-
-    alive_peers = [peer.is_alive() for peer in peers]
     logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--num_peers", type=int, default=32, required=False)
-    parser.add_argument("--initial_peers", type=int, default=1, required=False)
-    parser.add_argument("--num_experts", type=int, default=256, required=False)
-    parser.add_argument("--expert_batch_size", type=int, default=32, required=False)
-    parser.add_argument("--expiration", type=float, default=300, required=False)
-    parser.add_argument("--wait_after_request", type=float, default=0, required=False)
-    parser.add_argument("--wait_before_read", type=float, default=0, required=False)
+    parser.add_argument("--num_peers", type=int, default=16, required=False)
+    parser.add_argument("--initial_peers", type=int, default=4, required=False)
+    parser.add_argument("--random_seed", type=int, default=30, required=False)
+    parser.add_argument("--num_threads", type=int, default=10, required=False)
+    parser.add_argument("--total_num_rounds", type=int, default=16, required=False)
+    parser.add_argument("--num_store_peers", type=int, default=8, required=False)
+    parser.add_argument("--num_get_peers", type=int, default=8, required=False)
+    parser.add_argument("--wait_after_iteration", type=float, default=0, required=False)
+    parser.add_argument("--delay", type=float, default=0, required=False)
     parser.add_argument("--wait_timeout", type=float, default=5, required=False)
-    parser.add_argument("--random_seed", type=int, default=random.randint(1, 1000))
+    parser.add_argument("--expiration", type=float, default=300, required=False)
+    parser.add_argument("--latest", type=bool, default=True, required=False)
+    parser.add_argument("--failure_rate", type=float, default=0.1, required=False)
     parser.add_argument("--increase_file_limit", action="store_true")
     args = vars(parser.parse_args())
 
     if args.pop("increase_file_limit", False):
         increase_file_limit()
 
-    benchmark_dht(**args)
+    asyncio.run(benchmark_dht(**args))

+ 162 - 0
benchmarks/benchmark_optimizer.py

@@ -0,0 +1,162 @@
+import multiprocessing as mp
+import random
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable
+
+import torch
+import torchvision
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.utils.data import Dataset
+
+import hivemind
+from hivemind.optim.optimizer import Optimizer
+from hivemind.utils.crypto import RSAPrivateKey
+
+
+@dataclass(frozen=True)
+class TrainingArguments:
+    seed: int = 42
+    run_id: str = "my_exp"
+
+    num_peers: int = 8
+    num_clients: int = 3
+    target_batch_size: int = 256
+    reuse_grad_buffers: bool = True
+    delay_grad_averaging: bool = True
+    delay_optimizer_step: bool = True
+    average_state_every: int = 1
+    use_amp: bool = False
+
+    lr_base: float = 0.1
+    lr_gamma: int = 0.1
+    lr_step_size: int = 10
+    max_epoch: int = 25
+
+    batch_size_min: int = 2
+    batch_size_max: int = 16
+    batch_time_min: float = 1.0
+    batch_time_max: float = 4.5
+    batch_time_std: float = 0.5
+
+    matchmaking_time: float = 5.0
+    max_refresh_period: float = 5.0
+    averaging_timeout: float = 15.0
+    winddown_time: float = 5.0
+    verbose: bool = True
+
+    device: str = "cpu"
+    make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
+    make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
+        nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
+    )
+
+
+def benchmark_optimizer(args: TrainingArguments):
+    random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    torch.set_num_threads(1)
+
+    dht = hivemind.DHT(start=True)
+
+    train_dataset = args.make_dataset()
+    num_features = train_dataset.data[0].numel()
+    num_classes = len(train_dataset.classes)
+    X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
+    X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
+    y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
+    del train_dataset
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
+        model = args.make_model(num_features, num_classes).to(args.device)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id=args.run_id,
+            target_batch_size=args.target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=args.lr_base),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
+            matchmaking_time=args.matchmaking_time,
+            averaging_timeout=args.averaging_timeout,
+            reuse_grad_buffers=args.reuse_grad_buffers,
+            delay_grad_averaging=args.delay_grad_averaging,
+            delay_optimizer_step=args.delay_optimizer_step,
+            average_state_every=args.average_state_every,
+            client_mode=client_mode,
+            verbose=verbose,
+        )
+
+        if args.use_amp and args.reuse_grad_buffers:
+            grad_scaler = hivemind.GradScaler()
+        else:
+            # check that hivemind.Optimizer supports regular PyTorch grad scaler as well
+            grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < args.max_epoch:
+            time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
+
+            batch = torch.randint(0, len(X_train), (batch_size,))
+
+            with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
+                loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
+                grad_scaler.scale(loss).backward()
+
+            grad_scaler.unscale_(optimizer)
+
+            if args.use_amp:
+                grad_scaler.step(optimizer)
+            else:
+                optimizer.step()
+
+            grad_scaler.update()
+
+            if not args.reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(args.winddown_time)
+        optimizer.shutdown()
+
+    peers = []
+
+    for index in range(args.num_peers):
+        batch_size = random.randint(args.batch_size_min, args.batch_size_max)
+        batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                daemon=False,
+                kwargs=dict(
+                    batch_size=batch_size,
+                    batch_time=batch_time,
+                    client_mode=(index >= args.num_peers - args.num_clients),
+                    verbose=args.verbose and (index == 0),
+                ),
+            )
+        )
+
+    try:
+        for peer in peers[1:]:
+            peer.start()
+        peers[0].run()
+        for peer in peers[1:]:
+            peer.join()
+    finally:
+        for peer in peers[1:]:
+            peer.kill()
+
+
+if __name__ == "__main__":
+    benchmark_optimizer(TrainingArguments())

+ 13 - 13
benchmarks/benchmark_throughput.py

@@ -6,11 +6,13 @@ import time
 
 import torch
 
-import hivemind
-from hivemind import get_free_port
-from hivemind.moe.server import layers
+from hivemind.moe.client import RemoteExpert
+from hivemind.moe.server import ExpertBackend, Server
+from hivemind.moe.server.layers import name_to_block
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import LOCALHOST, get_free_port
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -32,9 +34,7 @@ def print_device_info(device=None):
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
     torch.set_num_threads(1)
     can_start.wait()
-    experts = [
-        hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
-    ]
+    experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -66,7 +66,7 @@ def benchmark_throughput(
         or not torch.cuda.is_initialized()
         or torch.device(device) == torch.device("cpu")
     )
-    assert expert_cls in layers.name_to_block
+    assert expert_cls in name_to_block
     port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
@@ -105,20 +105,20 @@ def benchmark_throughput(
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         for i in range(num_experts):
-            expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = hivemind.ExpertBackend(
+            expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
+            experts[f"expert{i}"] = ExpertBackend(
                 name=f"expert{i}",
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
-                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
-                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
+                args_schema=(BatchTensorDescriptor(hid_dim),),
+                outputs_schema=BatchTensorDescriptor(hid_dim),
                 max_batch_size=max_batch_size,
             )
         timestamps["created_experts"] = time.perf_counter()
-        server = hivemind.moe.Server(
+        server = Server(
             None,
             experts,
-            listen_on=f"{hivemind.LOCALHOST}:{port}",
+            listen_on=f"{LOCALHOST}:{port}",
             num_connection_handlers=num_handlers,
             device=device,
         )

+ 1 - 1
benchmarks/benchmark_throughput_p2p.py

@@ -250,7 +250,7 @@ if __name__ == "__main__":
             num_clients=1,
             num_handlers=1,
             num_batches_per_client=args.num_batches_per_client,
-            batch_size=512,
+            batch_size=256,
         )
     elif args.preset == "nop":
         benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)

BIN
docs/_static/dht.odp


BIN
docs/_static/dht.png


+ 1 - 1
docs/conf.py

@@ -203,7 +203,7 @@ todo_include_todos = True
 
 
 def setup(app):
-    app.add_stylesheet("fix_rtd.css")
+    app.add_css_file("fix_rtd.css")
     app.add_config_value(
         "recommonmark_config",
         {

+ 3 - 3
docs/index.rst

@@ -9,9 +9,9 @@ of computers, whether you're running a very capable computer or a less reliable
 Learn how to create or join a Hivemind run in the `quickstart tutorial <./user/quickstart.html>`__ or browse the API
 documentation below.
 
-| Hivemind is currently in active development, so expect some adventures. If you encounter any issues, please let us know
-  `on github <https://github.com/learning-at-home/hivemind/issues>`__.
-
+| Hivemind is currently in active development, so expect some adventures. If you have any questions, feel free to ask them
+  in `our Discord chat <https://discord.com/invite/uGugx9zYvN>`_ or
+  file an `issue <https://github.com/learning-at-home/hivemind/issues>`__.
 
 **Table of contents:**
 ~~~~~~~~~~~~~~~~~~~~~~

+ 26 - 4
docs/modules/optim.rst

@@ -1,14 +1,36 @@
 **hivemind.optim**
 ==================
 
-.. automodule:: hivemind.optim
-.. currentmodule:: hivemind.optim
-
 .. raw:: html
 
-  This module contains decentralized optimizers that wrap regular pytorch optimizers to collaboratively train a shared model. Depending on the exact type, optimizer may average model parameters with peers, exchange gradients, or follow a more complicated distributed training strategy.
+  This module contains decentralized optimizers that wrap your regular PyTorch Optimizer to train with peers.
+  Depending on the exact configuration, Optimizer may perform large synchronous updates equivalent,
+  or perform asynchrnous local updates and average model parameters.
+
   <br><br>
 
+.. automodule:: hivemind.optim.optimizer
+.. currentmodule:: hivemind.optim.optimizer
+
+**hivemind.Optimizer**
+----------------------
+
+.. autoclass:: Optimizer
+   :members: step, local_epoch, zero_grad, load_state_from_peers, param_groups, shutdown
+   :member-order: bysource
+
+.. currentmodule:: hivemind.optim.grad_scaler
+.. autoclass:: GradScaler
+   :member-order: bysource
+
+
+**CollaborativeOptimizer**
+--------------------------
+
+
+.. automodule:: hivemind.optim.collaborative
+.. currentmodule:: hivemind.optim
+
 .. autoclass:: CollaborativeOptimizer
    :members: step
    :member-order: bysource

+ 8 - 6
docs/modules/server.rst

@@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo
 The hivemind.moe.server module is organized as follows:
 
 - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
-- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
 - ExpertBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
   that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
+- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
 - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
   and offers those batches to Runtime_ for processing.
 
@@ -25,16 +25,18 @@ The hivemind.moe.server module is organized as follows:
    :members:
    :member-order: bysource
 
-.. _Runtime:
-.. autoclass:: Runtime
-    :members:
-    :member-order: bysource
-
 .. _ExpertBackend:
 .. autoclass:: ExpertBackend
     :members: forward, backward, apply_gradients, get_info, get_pools
     :member-order: bysource
 
+.. currentmodule:: hivemind.moe.server.runtime
+
+.. _Runtime:
+.. autoclass:: Runtime
+    :members:
+    :member-order: bysource
+
 .. currentmodule:: hivemind.moe.server.task_pool
 
 .. _TaskPool:

+ 35 - 32
docs/user/quickstart.md

@@ -11,7 +11,7 @@ You can also install the bleeding edge version from GitHub:
 ```
 git clone https://github.com/learning-at-home/hivemind
 cd hivemind
-pip install -e .
+pip install -e . --no-use-pep517
 ```
  
 ## Decentralized Training
@@ -47,26 +47,27 @@ model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(),
                       nn.Flatten(), nn.Linear(32 * 5 * 5, 10))
 opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
-
 # Create DHT: a decentralized key-value storage shared between peers
 dht = hivemind.DHT(start=True)
 print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])
 
 # Set up a decentralized optimizer that will average with peers in background
-opt = hivemind.optim.DecentralizedOptimizer(
-    opt,                      # wrap the SGD optimizer defined above
-    dht,                      # use a DHT that is connected with other peers
-    average_parameters=True,  # periodically average model weights in opt.step
-    average_gradients=False,  # do not average accumulated gradients
-    prefix='my_cifar_run',    # unique identifier of this collaborative run
-    target_group_size=16,     # maximum concurrent peers for this run
+opt = hivemind.Optimizer(
+    dht=dht,                  # use a DHT that is connected with other peers
+    run_id='my_cifar_run',    # unique identifier of this collaborative run
+    batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
+    target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch 
+    optimizer=opt,            # wrap the SGD optimizer defined above
+    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
+    matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
+    averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
 )
-# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
 
+# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
 with tqdm() as progressbar:
     while True:
-        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256):
+        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=32):
             opt.zero_grad()
             loss = F.cross_entropy(model(x_batch), y_batch)
             loss.backward()
@@ -78,7 +79,7 @@ with tqdm() as progressbar:
 
 
 As you can see, this code is regular PyTorch with one notable exception: it wraps your regular optimizer with a
-`DecentralizedOptimizer`. This optimizer uses `DHT` to find other peers and tries to exchange weights them. When you run
+`hivemind.Optimizer`. This optimizer uses `DHT` to find other peers and tries to exchange parameters them. When you run
 the code (please do so), you will see the following output:
 
 ```shell
@@ -86,7 +87,7 @@ To join the training, use initial_peers = ['/ip4/127.0.0.1/tcp/XXX/p2p/YYY']
 [...] Starting a new averaging round with current parameters.
 ```
 
-This is `DecentralizedOptimizer` telling you that it's looking for peers. Since there are no peers, we'll need to create 
+This is `hivemind.Optimizer` telling you that it's looking for peers. Since there are no peers, we'll need to create 
 them ourselves.
 
 Copy the entire script (or notebook) and modify this line:
@@ -123,26 +124,28 @@ model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(),
 opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
 # Create DHT: a decentralized key-value storage shared between peers
-dht = hivemind.DHT(initial_peers=[COPY_FROM_ANOTHER_PEER_OUTPUTS], start=True)
+dht = hivemind.DHT(initial_peers=[COPY_FROM_OTHER_PEERS_OUTPUTS], start=True)
 print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])
 
 # Set up a decentralized optimizer that will average with peers in background
-opt = hivemind.optim.DecentralizedOptimizer(
-    opt,                      # wrap the SGD optimizer defined above
-    dht,                      # use a DHT that is connected with other peers
-    average_parameters=True,  # periodically average model weights in opt.step
-    average_gradients=False,  # do not average accumulated gradients
-    prefix='my_cifar_run',    # unique identifier of this collaborative run
-    target_group_size=16,     # maximum concurrent peers for this run
+opt = hivemind.Optimizer(
+    dht=dht,                  # use a DHT that is connected with other peers
+    run_id='my_cifar_run',    # unique identifier of this collaborative run
+    batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
+    target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch
+    optimizer=opt,            # wrap the SGD optimizer defined above
+    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
+    matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
+    averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
 )
 
-opt.averager.load_state_from_peers()
+opt.load_state_from_peers()
 
-# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
+# Note: if you intend to use GPU, switch to it only after the optimizer is created
 with tqdm() as progressbar:
     while True:
-        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256):
+        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=32):
             opt.zero_grad()
             loss = F.cross_entropy(model(x_batch), y_batch)
             loss.backward()
@@ -166,22 +169,22 @@ This message means that the optimizer has averaged model parameters with another
 during one of the calls to `opt.step()`. You can start more peers by replicating the same code as the second peer,
 using either the first or second peer as `initial_peers`.
 
-The only issue with this code is that each new peer starts with a different untrained network blends its un-trained
-parameters with other peers, reseting their progress. You can see this effect as a spike increase in training loss
-immediately after new peer joins training. To avoid this problem, the second peer can download the
-current model/optimizer state from an existing peer right before it begins training on minibatches:
+Each new peer starts with an untrained network and must download the latest training state before it can contribute.
+By default, peer will automatically detect that it is out of sync and start ``Downloading parameters from peer <...>``.
+To avoid wasting the first optimizer step, one can manually download the latest model/optimizer state right before it begins training on minibatches:
 ```python
-opt.averager.load_state_from_peers()
+opt.load_state_from_peers()
 ```
 
 Congrats, you've just started a pocket-sized experiment with decentralized deep learning!
 
-However, this is just the bare minimum of what hivemind can do. In [this example](https://github.com/learning-at-home/hivemind/tree/master/examples/albert),
+However, this is only the basics of what hivemind can do. In [this example](https://github.com/learning-at-home/hivemind/tree/master/examples/albert),
 we show how to use a more advanced version of DecentralizedOptimizer to collaboratively train a large Transformer over the internet.
 
 If you want to learn more about each individual component,
 - Learn how to use `hivemind.DHT` using this basic [DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html),
-- Learn the underlying math behind DecentralizedOptimizer in
-  [(Li et al. 2020)](https://arxiv.org/abs/2005.00124) and [(Ryabinin et al. 2021)](https://arxiv.org/abs/2103.03239).
+- Read more on how to use `hivemind.Optimizer` in its [documentation page](https://learning-at-home.readthedocs.io/en/latest/modules/optim.html), 
+- Learn the underlying math behind hivemind.Optimizer in [Diskin et al., (2021)](https://arxiv.org/abs/2106.10207), 
+  [Li et al. (2020)](https://arxiv.org/abs/2005.00124) and [Ryabinin et al. (2021)](https://arxiv.org/abs/2103.03239).
 - Read about setting up Mixture-of-Experts training in [this guide](https://learning-at-home.readthedocs.io/en/latest/user/moe.html),
  

+ 33 - 32
examples/albert/README.md

@@ -9,7 +9,7 @@ using `hivemind.CollaborativeOptimizer` to exchange information between peers.
 
 * Install hivemind: `pip install git+https://github.com/learning-at-home/hivemind.git`
 * Dependencies: `pip install -r requirements.txt`
-* Preprocess data: `python tokenize_wikitext103.py`
+* Preprocess data: `./tokenize_wikitext103.py`
 * Upload the data to a publicly available location or ask volunteers to preprocess it locally
 
 ## Running an experiment
@@ -20,16 +20,16 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,
 
 - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics. If you're unfamiliar with Weights
   & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
-- Run `python run_training_monitor.py --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
-- `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.`
-  due to naming conventions.
-- `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the
-  same project name.
+- Run `./run_training_monitor.py --wandb_project YOUR_WANDB_PROJECT`
+
+  - `YOUR_WANDB_PROJECT` is a name of wandb project used to track training metrics. Multiple experiments can have the
+    same project name.
 
 ```
-$ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
-[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:54] Running a DHT peer. To connect other peers to this one over the Internet, 
+$ ./run_training_monitor.py --wandb_project Demo-run
+Oct 14 16:26:36.083 [INFO] Running a DHT peer. To connect other peers to this one over the Internet,
 use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
+Oct 14 16:26:36.083 [INFO] Full list of visible multiaddresses: ...
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32
 wandb: Syncing run dry-mountain-2
@@ -37,12 +37,12 @@ wandb:  View project at https://wandb.ai/XXX/Demo-run
 wandb:  View run at https://wandb.ai/XXX/Demo-run/runs/YYY
 wandb: Run data is saved locally in /path/to/run/data
 wandb: Run `wandb offline` to turn off syncing.
-[2021/04/19 02:26:41.064][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
-[2021/04/19 02:26:44.068][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
+Oct 14 16:26:41.064 [INFO] Found no active peers: None
+Oct 14 16:26:44.068 [INFO] Found no active peers: None
 ...
-[2021/04/19 02:37:37.246][INFO][__main__.<module>:194] Step #1  loss = 11.05164
-[2021/04/19 02:39:37.441][INFO][__main__.<module>:194] Step #2  loss = 11.03771
-[2021/04/19 02:40:37.541][INFO][__main__.<module>:194] Step #3  loss = 11.02886
+Oct 14 16:37:37.246 [INFO] Step #1  loss = 11.05164
+Oct 14 16:39:37.441 [INFO] Step #2  loss = 11.03771
+Oct 14 16:40:37.541 [INFO] Step #3  loss = 11.02886
 ```
 
 ### GPU trainers
@@ -55,9 +55,7 @@ To join the collaboration with a GPU trainer,
   (see [default paths](./arguments.py#L117-L134) for reference)
 - Run:
   ```bash
-  python run_trainer.py \
-  --experiment_prefix SAME_AS_IN_RUN_TRAINING_MONITOR --initial_peers ONE_OR_MORE_PEERS --seed 42 \
-  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
+  ./run_trainer.py  --initial_peers ONE_OR_MORE_PEERS --per_device_train_batch_size BATCH_SIZE_FOR_YOUR_GPU
   ```
 
   Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
@@ -82,22 +80,26 @@ To join the collaboration with a GPU trainer,
   You may need to change the IP address to a publicly visible one if some of the initial peers are located behind NAT.
   If you have any trouble doing this, consider the ["Using IPFS"](#using-ipfs) section.
 
+  The `BATCH_SIZE_FOR_YOUR_GPU` should be tweaked so that the model fits into your GPU memory.
+  For 1080Ti or 2080Ti gpus, a good initial value is 4. For 8GB GPUs, try batch size 1-2.
+
 See the ["Tips and tricks"](#tips-and-tricks) section for more information on setting up collaborative training.
 
 As the peer begins training, it will periodically report training logs in the following form:
 
 ```
-[...][INFO][...] Collaboration accumulated 448 samples from 17 peers; ETA 18.88 seconds (refresh in 15.73s.)
-[...][INFO][...] Collaboration accumulated 4096 samples from 16 peers; ETA 0.00 seconds (refresh in 0.50s.)
-[...][INFO][optim.collaborative.step:195] Averaged tensors successfully with 17 peers
-[...][INFO][optim.collaborative.step:211] Optimizer step: done!
-06/17/2021 18:58:23 - INFO - __main__ -   Step 0
-06/17/2021 18:58:23 - INFO - __main__ -   Your current contribution: 892 samples
-06/17/2021 18:58:23 - INFO - __main__ -   Local loss: 11.023
-
+Dec 28 00:15:31.482 [INFO] albert accumulated 4056 samples for epoch #0 from 2 peers. ETA 0.75 sec (refresh in 0.50 sec)
+Dec 28 00:15:31.990 [INFO] albert accumulated 4072 samples for epoch #0 from 2 peers. ETA 0.24 sec (refresh in 0.50 sec)
+...
+Dec 28 00:15:32.857 [INFO] Step #1
+Dec 28 00:15:32.857 [INFO] Your current contribution: 2144 samples
+Dec 28 00:15:32.857 [INFO] Performance: 20.924 samples/sec
+Dec 28 00:15:32.857 [INFO] Local loss: 11.06709
+Dec 28 00:15:33.580 [INFO] Averaged gradients with 2 peers
+Dec 28 00:15:38.336 [INFO] Averaged parameters with 2 peers
 ```
 
-__Sanity check:__ a healthy peer will periodically report `Averaged tensors successfully with [N > 1]` peers.
+__Sanity check:__ a healthy peer will periodically report `Averaged gradients/parameters with [N > 1]` peers.
 
 For convenience, you can view (and share!) the learning curves of your collaborative experiments in wandb:
 
@@ -135,7 +137,7 @@ incoming connections (e.g. when in colab or behind a firewall), add `--client_mo
 below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or
 set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
 be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
-recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`. 
+recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`.
 
 The example trainer supports
 multiple GPUs via DataParallel. However, using advanced distributed training strategies (
@@ -155,7 +157,7 @@ collaborative experiment. Here's how to best use them:
 - Most free GPUs are running behind a firewall, which requires you to run trainer with `--client_mode` (see example
   below). Such peers can only exchange gradients if there is at least one non-client-mode peer (GPU server or desktop
   with public IP). We recommend using a few preemptible instances with the cheapest GPU you can find. For example, we
-  tested this code on preemptible 
+  tested this code on preemptible
   [`g4dn.xlarge`](https://aws.amazon.com/blogs/aws/now-available-ec2-instances-g4-with-nvidia-t4-tensor-core-gpus/)
   nodes for around $0.15/h apiece with 8 AWS nodes and up to 61 Colab/Kaggle participants.
 - You can create starter notebooks to make it more convenient for collaborators to join your training
@@ -168,11 +170,10 @@ Here's an example of a full trainer script for Google Colab:
 !pip install transformers datasets sentencepiece torch_optimizer==0.1.0
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !curl -L YOUR_HOSTED_DATA | tar xzf -
-!ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
- --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
- --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
- --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \
- --experiment_prefix EXPERIMENT_NAME_HERE --seed 42
+!ulimit -n 4096 && ./hivemind/examples/albert/run_trainer.py \
+    --initial_peers ONE_OR_MORE_PEERS \
+    --logging_dir ./logs --logging_first_step --output_dir ./outputs --overwrite_output_dir \
+    --client_mode --averaging_expiration 10 --batch_size_lead 300 --gradient_accumulation_steps 1
 ```
 
 ### Using IPFS

+ 23 - 17
examples/albert/arguments.py

@@ -6,8 +6,8 @@ from transformers import TrainingArguments
 
 @dataclass
 class BaseTrainingArguments:
-    experiment_prefix: str = field(
-        metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
+    run_id: str = field(
+        default="albert", metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
     )
     initial_peers: List[str] = field(
         default_factory=list,
@@ -45,12 +45,11 @@ class BaseTrainingArguments:
 
 @dataclass
 class AveragerArguments:
-    averaging_expiration: float = field(
-        default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
-    )
-    averaging_timeout: float = field(
-        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
-    )
+    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
+
+
+@dataclass
+class ProgressTrackerArguments:
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
     )
@@ -66,17 +65,13 @@ class AveragerArguments:
     expected_drift_rate: float = field(
         default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
     )
-    performance_ema_alpha: float = field(
-        default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
-    )
-    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
     metadata_expiration: float = field(
         default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
     )
 
 
 @dataclass
-class CollaborativeOptimizerArguments:
+class OptimizerArguments:
     target_batch_size: int = field(
         default=4096,
         metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
@@ -93,10 +88,16 @@ class CollaborativeOptimizerArguments:
         default=100.0,
         metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
     )
+    averaging_timeout: float = field(
+        default=60.0, metadata={"help": "Give up on averaging step after this many seconds"}
+    )
+    matchmaking_time: float = field(
+        default=5.0, metadata={"help": "When looking for group, wait for requests for at least this many seconds"}
+    )
 
 
 @dataclass
-class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArguments):
+class CollaborationArguments(OptimizerArguments, BaseTrainingArguments):
     statistics_expiration: float = field(
         default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )
@@ -126,7 +127,7 @@ class AlbertTrainingArguments(TrainingArguments):
     gradient_accumulation_steps: int = 2
     seq_length: int = 512
 
-    max_steps: int = 125_000  # please note: this affects both number of steps and learning rate schedule
+    total_steps: int = 125_000  # please note: this only affects the learning rate schedule
     learning_rate: float = 0.00176
     warmup_steps: int = 5000
     adam_epsilon: float = 1e-6
@@ -137,9 +138,14 @@ class AlbertTrainingArguments(TrainingArguments):
     fp16: bool = True
     fp16_opt_level: str = "O2"
     do_train: bool = True
+    do_eval: bool = False
 
+    logging_dir: str = "logs"
+    output_dir: str = "outputs"
     logging_steps: int = 100
+    logging_first_step: bool = True
+    overwrite_output_dir: bool = True
+
     save_total_limit: int = 2
     save_steps: int = 500
-
-    output_dir: str = "outputs"
+    max_steps: int = 10**30  # meant as "peer should compute gradients forever"

+ 5 - 5
examples/albert/requirements.txt

@@ -1,7 +1,7 @@
-transformers>=4.6.0
-datasets>=1.5.0
-torch_optimizer>=0.1.0
-wandb>=0.10.26
+transformers==4.6.0
+datasets==1.5.0
+torch_optimizer==0.1.0
+wandb==0.10.26
 sentencepiece
 requests
-nltk>=3.6.2
+nltk==3.6.7

+ 98 - 75
examples/albert/run_trainer.py

@@ -1,7 +1,8 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
 import os
 import pickle
+import sys
 from dataclasses import asdict
 from pathlib import Path
 
@@ -16,11 +17,17 @@ from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 
-import hivemind
+from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
-from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
+from arguments import (
+    AlbertTrainingArguments,
+    AveragerArguments,
+    CollaborationArguments,
+    DatasetArguments,
+    ProgressTrackerArguments,
+)
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -52,36 +59,6 @@ def get_model(training_args, config, tokenizer):
     return model
 
 
-def get_optimizer_and_scheduler(training_args, model):
-    no_decay = ["bias", "LayerNorm.weight"]
-    optimizer_grouped_parameters = [
-        {
-            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
-            "weight_decay": training_args.weight_decay,
-        },
-        {
-            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
-            "weight_decay": 0.0,
-        },
-    ]
-
-    opt = Lamb(
-        optimizer_grouped_parameters,
-        lr=training_args.learning_rate,
-        betas=(training_args.adam_beta1, training_args.adam_beta2),
-        eps=training_args.adam_epsilon,
-        weight_decay=training_args.weight_decay,
-        clamp_value=training_args.clamp_value,
-        debias=True,
-    )
-
-    scheduler = get_linear_schedule_with_warmup(
-        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
-    )
-
-    return opt, scheduler
-
-
 class CollaborativeCallback(transformers.TrainerCallback):
     """
     This callback monitors and reports collaborative training progress.
@@ -90,8 +67,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
     def __init__(
         self,
-        dht: hivemind.DHT,
-        optimizer: hivemind.CollaborativeOptimizer,
+        dht: DHT,
+        optimizer: Optimizer,
         model: torch.nn.Module,
         local_public_key: bytes,
         statistics_expiration: float,
@@ -99,7 +76,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
     ):
         super().__init__()
         self.model = model
-        self.dht, self.collaborative_optimizer = dht, optimizer
+        self.dht, self.optimizer = dht, optimizer
         self.local_public_key = local_public_key
         self.statistics_expiration = statistics_expiration
         self.last_reported_collaboration_step = -1
@@ -114,7 +91,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
     ):
         logger.info("Loading state from peers")
-        self.collaborative_optimizer.load_state_from_peers()
+        self.optimizer.load_state_from_peers()
 
     def on_step_end(
         self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
@@ -124,40 +101,43 @@ class CollaborativeCallback(transformers.TrainerCallback):
             self.restore_from_backup(self.latest_backup)
             return control
 
+        local_progress = self.optimizer.local_progress
+
         if state.log_history:
             self.loss += state.log_history[-1]["loss"]
             self.steps += 1
-            if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
-                self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
+
+            if self.optimizer.local_epoch != self.last_reported_collaboration_step:
+                self.last_reported_collaboration_step = self.optimizer.local_epoch
                 self.total_samples_processed += self.samples
-                samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
+                samples_per_second = local_progress.samples_per_second
                 statistics = utils.LocalMetrics(
-                    step=self.collaborative_optimizer.local_step,
+                    step=self.optimizer.local_epoch,
                     samples_per_second=samples_per_second,
                     samples_accumulated=self.samples,
                     loss=self.loss,
                     mini_steps=self.steps,
                 )
-                logger.info(f"Step #{self.collaborative_optimizer.local_step}")
+                logger.info(f"Step #{self.optimizer.local_epoch}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
-                logger.info(f"Performance: {samples_per_second} samples per second.")
+                logger.info(f"Performance: {samples_per_second:.3f} samples/sec")
                 if self.steps:
-                    logger.info(f"Local loss: {self.loss / self.steps}")
-                if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
+                    logger.info(f"Local loss: {self.loss / self.steps:.5f}")
+                if self.optimizer.local_epoch % self.backup_every_steps == 0:
                     self.latest_backup = self.backup_state()
 
                 self.loss = 0
                 self.steps = 0
-                if self.collaborative_optimizer.is_synchronized:
+                if self.optimizer.is_synchronized_with_peers():
                     self.dht.store(
-                        key=self.collaborative_optimizer.prefix + "_metrics",
+                        key=self.optimizer.run_id + "_metrics",
                         subkey=self.local_public_key,
                         value=statistics.dict(),
-                        expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                        expiration_time=get_dht_time() + self.statistics_expiration,
                         return_future=True,
                     )
 
-        self.samples = self.collaborative_optimizer.local_samples_accumulated
+        self.samples = local_progress.samples_accumulated
 
         return control
 
@@ -170,19 +150,17 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
     @torch.no_grad()
     def backup_state(self) -> bytes:
-        return pickle.dumps(
-            {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
-        )
+        return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})
 
     @torch.no_grad()
     def restore_from_backup(self, backup: bytes):
         state = pickle.loads(backup)
         self.model.load_state_dict(state["model"])
-        self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
+        self.optimizer.load_state_dict(state["optimizer"])
 
 
 class NoOpScheduler(LRSchedulerBase):
-    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
+    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""
 
     def get_lr(self):
         return [group["lr"] for group in self.optimizer.param_groups]
@@ -202,12 +180,17 @@ class NoOpScheduler(LRSchedulerBase):
 
 
 def main():
-    parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
-    training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
-
+    parser = HfArgumentParser(
+        (
+            AlbertTrainingArguments,
+            DatasetArguments,
+            CollaborationArguments,
+            AveragerArguments,
+            ProgressTrackerArguments,
+        )
+    )
+    training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
     logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
-    if len(collaboration_args.initial_peers) == 0:
-        raise ValueError("Please specify at least one network endpoint in initial peers.")
 
     setup_transformers_logging(training_args.local_rank)
     logger.info(f"Training/evaluation parameters:\n{training_args}")
@@ -216,7 +199,15 @@ def main():
     set_seed(training_args.seed)
 
     config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
-    tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
+    try:
+        tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
+    except OSError:
+        logger.fatal(
+            f"No tokenizer data found in {dataset_args.tokenizer_path}, "
+            f"please run ./tokenize_wikitext103.py before running this"
+        )
+        sys.exit(1)
+
     model = get_model(training_args, config, tokenizer)
     model.to(training_args.device)
 
@@ -224,11 +215,9 @@ def main():
     # This data collator will take care of randomly masking the tokens.
     data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
 
-    opt, scheduler = get_optimizer_and_scheduler(training_args, model)
+    validators, local_public_key = utils.make_validators(collaboration_args.run_id)
 
-    validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
-
-    dht = hivemind.DHT(
+    dht = DHT(
         start=True,
         initial_peers=collaboration_args.initial_peers,
         client_mode=collaboration_args.client_mode,
@@ -246,19 +235,53 @@ def main():
 
     adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
 
-    collaborative_optimizer = hivemind.CollaborativeOptimizer(
-        opt=opt,
+    # We need to make such a lambda function instead of just an optimizer instance
+    # to make hivemind.Optimizer(..., offload_optimizer=True) work
+    opt = lambda params: Lamb(
+        params,
+        lr=training_args.learning_rate,
+        betas=(training_args.adam_beta1, training_args.adam_beta2),
+        eps=training_args.adam_epsilon,
+        weight_decay=training_args.weight_decay,
+        clamp_value=training_args.clamp_value,
+        debias=True,
+    )
+
+    no_decay = ["bias", "LayerNorm.weight"]
+    params = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+            "weight_decay": training_args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+            "weight_decay": 0.0,
+        },
+    ]
+
+    scheduler = lambda opt: get_linear_schedule_with_warmup(
+        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
+    )
+
+    optimizer = Optimizer(
         dht=dht,
-        scheduler=scheduler,
-        prefix=collaboration_args.experiment_prefix,
-        compression=hivemind.Float16Compression(),
-        batch_size_per_step=total_batch_size_per_step,
-        bandwidth=collaboration_args.bandwidth,
+        run_id=collaboration_args.run_id,
         target_batch_size=adjusted_target_batch_size,
+        batch_size_per_step=total_batch_size_per_step,
+        optimizer=opt,
+        params=params,
+        scheduler=scheduler,
+        matchmaking_time=collaboration_args.matchmaking_time,
+        averaging_timeout=collaboration_args.averaging_timeout,
+        offload_optimizer=True,
+        delay_optimizer_step=True,
+        delay_grad_averaging=True,
         client_mode=collaboration_args.client_mode,
+        grad_compression=Float16Compression(),
+        state_averaging_compression=Float16Compression(),
+        averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
+        tracker_opts=asdict(tracker_args),
         verbose=True,
-        start=True,
-        **asdict(averager_args),
     )
 
     class TrainerWithIndependentShuffling(Trainer):
@@ -274,11 +297,11 @@ def main():
         data_collator=data_collator,
         train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
-        optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
+        optimizers=(optimizer, NoOpScheduler(optimizer)),
         callbacks=[
             CollaborativeCallback(
                 dht,
-                collaborative_optimizer,
+                optimizer,
                 model,
                 local_public_key,
                 collaboration_args.statistics_expiration,

+ 23 - 25
examples/albert/run_training_monitor.py

@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
 import time
 from dataclasses import asdict, dataclass, field
@@ -9,13 +9,14 @@ import requests
 import torch
 import wandb
 from torch_optimizer import Lamb
-from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
+from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser, get_linear_schedule_with_warmup
 
 import hivemind
+from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
-from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
+from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -39,6 +40,7 @@ class TrainingMonitorArguments(BaseTrainingArguments):
     wandb_project: Optional[str] = field(
         default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
     )
+    store_checkpoints: bool = field(default=True, metadata={"help": "If False, disables periodic checkpoint saving"})
     save_checkpoint_step_interval: int = field(
         default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
     )
@@ -55,14 +57,13 @@ class TrainingMonitorArguments(BaseTrainingArguments):
     upload_interval: Optional[float] = field(
         default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
     )
-    store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
 
 
 class CheckpointHandler:
     def __init__(
         self,
         monitor_args: TrainingMonitorArguments,
-        collab_optimizer_args: CollaborativeOptimizerArguments,
+        optimizer_args: OptimizerArguments,
         averager_args: AveragerArguments,
         dht: hivemind.DHT,
     ):
@@ -95,17 +96,14 @@ class CheckpointHandler:
             debias=True,
         )
 
-        adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
-
-        self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
-            opt=opt,
+        self.state_averager = TrainingStateAverager(
             dht=dht,
-            prefix=experiment_prefix,
-            compression_type=hivemind.Float16Compression(),
-            bandwidth=collab_optimizer_args.bandwidth,
-            target_batch_size=adjusted_target_batch_size,
-            client_mode=collab_optimizer_args.client_mode,
-            verbose=True,
+            optimizer=opt,
+            scheduler=get_linear_schedule_with_warmup(opt, num_warmup_steps=5000, num_training_steps=125_000),
+            prefix=f"{run_id}_state_averager",
+            state_compression=hivemind.Float16Compression(),
+            bandwidth=optimizer_args.bandwidth,
+            client_mode=optimizer_args.client_mode,
             start=True,
             **asdict(averager_args),
         )
@@ -121,7 +119,7 @@ class CheckpointHandler:
 
     def save_state(self, cur_step):
         logger.info("Saving state from peers")
-        self.collaborative_optimizer.load_state_from_peers()
+        self.state_averager.load_state_from_peers()
         self.previous_step = cur_step
 
     def is_time_to_upload(self):
@@ -134,7 +132,7 @@ class CheckpointHandler:
 
     def upload_checkpoint(self, current_loss):
         logger.info("Saving optimizer")
-        torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
+        torch.save(self.state_averager.optimizer.state_dict(), f"{self.repo_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
         logger.info("Started uploading to Model Hub")
         self.model.push_to_hub(
@@ -146,8 +144,8 @@ class CheckpointHandler:
 
 
 if __name__ == "__main__":
-    parser = HfArgumentParser((TrainingMonitorArguments, CollaborativeOptimizerArguments, AveragerArguments))
-    monitor_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
+    parser = HfArgumentParser((TrainingMonitorArguments, OptimizerArguments, AveragerArguments))
+    monitor_args, optimizer_args, averager_args = parser.parse_args_into_dataclasses()
 
     if monitor_args.use_google_dns:
         request = requests.get("https://api.ipify.org")
@@ -158,8 +156,8 @@ if __name__ == "__main__":
         version = ip_address(address).version
         monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]
 
-    experiment_prefix = monitor_args.experiment_prefix
-    validators, local_public_key = utils.make_validators(experiment_prefix)
+    run_id = monitor_args.run_id
+    validators, local_public_key = utils.make_validators(run_id)
 
     dht = hivemind.DHT(
         start=True,
@@ -176,11 +174,11 @@ if __name__ == "__main__":
         wandb.init(project=monitor_args.wandb_project)
 
     current_step = 0
-    if monitor_args.store_checkpoins:
-        checkpoint_handler = CheckpointHandler(monitor_args, collab_optimizer_args, averager_args, dht)
+    if monitor_args.store_checkpoints:
+        checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)
 
     while True:
-        metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
+        metrics_dict = dht.get(run_id + "_metrics", latest=True)
         if metrics_dict is not None:
             metrics_dict = metrics_dict.value
             metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
@@ -219,7 +217,7 @@ if __name__ == "__main__":
                         }
                     )
 
-                if monitor_args.store_checkpoins:
+                if monitor_args.store_checkpoints:
                     if checkpoint_handler.is_time_to_save_state(current_step):
                         checkpoint_handler.save_state(current_step)
                         if checkpoint_handler.is_time_to_upload():

+ 1 - 1
examples/albert/tokenize_wikitext103.py

@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 """ This script builds a pre-tokenized compressed representation of WikiText-103 using huggingface/datasets """
 import random
 from functools import partial

+ 2 - 2
examples/albert/utils.py

@@ -24,9 +24,9 @@ class MetricSchema(BaseModel):
     metrics: Dict[BytesWithPublicKey, LocalMetrics]
 
 
-def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
+def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
     signature_validator = RSASignatureValidator()
-    validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix), signature_validator]
+    validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
     return validators, signature_validator.local_public_key
 
 

+ 5 - 2
hivemind/__init__.py

@@ -1,4 +1,4 @@
-from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.moe import (
@@ -16,8 +16,11 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
+    GradScaler,
+    Optimizer,
+    TrainingAverager,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 
-__version__ = "1.0.0dev0"
+__version__ = "1.1.0dev0"

+ 0 - 1
hivemind/averaging/__init__.py

@@ -1,2 +1 @@
 from hivemind.averaging.averager import DecentralizedAverager
-from hivemind.averaging.training import TrainingAverager

+ 194 - 86
hivemind/averaging/allreduce.py

@@ -1,15 +1,22 @@
 import asyncio
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
 
 import torch
 
-from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
+from hivemind.averaging.partition import AllreduceException, BannedException, TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
+from hivemind.utils.asyncio import (
+    achain,
+    aiter_with_timeout,
+    amap_in_executor,
+    anext,
+    as_aiter,
+    attach_event_on_finished,
+)
 
 # flavour types
 GroupID = bytes
@@ -37,15 +44,21 @@ class AllReduceRunner(ServicerBase):
     :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
-    :param tensors: local tensors that should be averaged with groupmates
+    :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1)
     :param peer_id: your peer_id, must be included in ordered_peer_ids
     :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
-    :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 x sender_timeout
+    :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
+    :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
+      non-averaging peers receive results only after they finish sending, which helps them avoid
+      throughput issues in case of asymmetric high-latency connections (e.g. ACK compression).
     """
 
     def __init__(
@@ -56,16 +69,23 @@ class AllReduceRunner(ServicerBase):
         prefix: Optional[str],
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
+        weight: Optional[float] = None,
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
-        weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
         **kwargs,
     ):
         self._p2p = p2p
         self.peer_id = p2p.peer_id
         assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
+        if reducer_timeout is not None and (sender_timeout is None or reducer_timeout <= sender_timeout):
+            raise ValueError(
+                "If reducer_timeout is enabled, sender_timeout must be shorter than reducer_timeout. "
+                "Otherwise, there is a chance that reducers will be banned while they await senders."
+            )
 
         if not issubclass(servicer_type, ServicerBase):
             raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
@@ -73,31 +93,42 @@ class AllReduceRunner(ServicerBase):
         self._prefix = prefix
 
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
-        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
+        assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
-        for mode, frac, weight in zip(modes, peer_fractions, weights):
+        for mode, frac in zip(modes, peer_fractions):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
-            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
+        if weight is None:
+            weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
+        self.weight = weight
+
         self._future = asyncio.Future()
 
-        self.sender_peer_ids, self.sender_weights = [], []
-        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
+        self.sender_peer_ids = []
+        for peer_id, mode in zip(self.ordered_peer_ids, modes):
             if mode != AveragingMode.AUX:
                 self.sender_peer_ids.append(peer_id)
-                self.sender_weights.append(weight)
+
+        self.sender_timeout, self.reducer_timeout = sender_timeout, reducer_timeout
+        self.all_senders_started = asyncio.Event()
+        self.banned_senders: Set[PeerID] = set()  # peers that did not send data by next_chunk_timeout
+        self.banlock = asyncio.Lock()
+
+        self.active_senders: Set[PeerID] = set()  # peers that began sending data via rpc_aggregate_part
+        if self.peer_id in self.sender_peer_ids:
+            self.active_senders.add(self.peer_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
 
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
-        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
+        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, return_deltas=True, **kwargs)
         self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             len(self.sender_peer_ids),
-            self.sender_weights,
         )
 
     def __repr__(self):
@@ -116,9 +147,16 @@ class AllReduceRunner(ServicerBase):
     def _get_peer_stub(self, peer: PeerID) -> StubBase:
         return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
 
+    def should_delay_results(self, peer_id: PeerID) -> bool:
+        return self.peer_fractions[self.ordered_peer_ids.index(peer_id)] == 0
+
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
+
+        if self.tensor_part_container.num_parts_by_peer[self.ordered_peer_ids.index(self.peer_id)] != 0:
+            pending_tasks.add(asyncio.create_task(self._handle_missing_senders()))
+
         try:
             if len(self.sender_peer_ids) == 0:
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
@@ -131,6 +169,7 @@ class AllReduceRunner(ServicerBase):
 
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
+
                 self.finalize()
 
             else:  # auxiliary peer
@@ -143,35 +182,69 @@ class AllReduceRunner(ServicerBase):
                 task.cancel()
             raise
 
+        finally:
+            for task in pending_tasks:
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass
+                except Exception as inner_exc:
+                    logger.debug(f"Task {task} failed with {inner_exc}", exc_info=True)
+
+    async def _handle_missing_senders(self):
+        """Detect senders that should have sent tensors for averaging, but did not send anything within timeout"""
+        try:
+            await asyncio.wait_for(self.all_senders_started.wait(), self.sender_timeout)
+        except asyncio.TimeoutError:
+            for peer_id in self.sender_peer_ids:
+                if peer_id not in self.active_senders and peer_id not in self.banned_senders:
+                    await self._ban_sender(peer_id)
+
     async def _communicate_with_peer(self, peer_id: PeerID):
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
         peer_index = self.ordered_peer_ids.index(peer_id)
         if peer_id == self.peer_id:
             sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
-                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                averaged_part = await self.tensor_part_reducer.accumulate_part(
+                    sender_index, part_index, tensor_part, weight=self.weight
+                )
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
-            code = None
-            stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
-            async for part_index, (averaged_part_delta, msg) in aenumerate(
-                amap_in_executor(
-                    lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
-                    stream,
+            try:
+                done_sending = asyncio.Event()
+                inputs_aiter = attach_event_on_finished(self._generate_input_for_peer(peer_index), done_sending)
+                stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter)
+
+                if self.should_delay_results(self.peer_id):
+                    await done_sending.wait()
+
+                part_index = 0
+
+                def _try_deserialize(msg):
+                    if msg.code != averaging_pb2.AVERAGED_PART:
+                        raise AllreduceException(f"{peer_id} sent {averaging_pb2.MessageCode.Name(msg.code)}")
+                    return deserialize_torch_tensor(msg.tensor_part), msg
+
+                async for delta, msg in amap_in_executor(
+                    _try_deserialize,
+                    aiter_with_timeout(stream, self.reducer_timeout),
                     max_prefetch=self.tensor_part_container.prefetch,
-                )
-            ):
-                if code is None:
-                    code = msg.code
-                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
-
-            if code != averaging_pb2.AVERAGED_PART:
-                raise AllreduceException(
-                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
-                    f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                    f", allreduce failed"
-                )
+                ):
+                    self.tensor_part_container.register_processed_part(peer_index, part_index, delta)
+                    part_index += 1
+
+                if part_index != self.tensor_part_container.num_parts_by_peer[peer_index]:
+                    raise AllreduceException(
+                        f"peer {peer_id} sent {part_index} parts, but we expected "
+                        f"{self.tensor_part_container.num_parts_by_peer[peer_index]}"
+                    )
+            except BaseException as e:
+                if isinstance(e, Exception):
+                    logger.debug(f"Caught {repr(e)} when communicating to {peer_id}", exc_info=True)
+                self.tensor_part_container.register_failed_reducer(peer_index)
+                raise
 
     async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
@@ -180,79 +253,116 @@ class AllReduceRunner(ServicerBase):
             code=averaging_pb2.PART_FOR_AVERAGING,
             group_id=self.group_id,
             tensor_part=first_part,
+            weight=self.weight,
         )
         async for part in parts_aiter:
-            yield averaging_pb2.AveragingData(tensor_part=part)
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
 
     async def rpc_aggregate_part(
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
-        request: averaging_pb2.AveragingData = await anext(stream)
-        reason_to_reject = self._check_reasons_to_reject(request)
-        if reason_to_reject:
-            yield reason_to_reject
-            return
+        sender_index = self.sender_peer_ids.index(context.remote_id)
+        self.active_senders.add(context.remote_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
 
-        elif request.code == averaging_pb2.PART_FOR_AVERAGING:
-            try:
-                sender_index = self.sender_peer_ids.index(context.remote_id)
-                async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
-                    yield msg
+        try:
+            request: averaging_pb2.AveragingData = await asyncio.wait_for(anext(stream), self.sender_timeout)
+            reason_to_reject = self._check_reasons_to_reject(request, context)
+            if reason_to_reject:
+                yield reason_to_reject
+                return
+
+            elif request.code == averaging_pb2.PART_FOR_AVERAGING:
+                stream = aiter_with_timeout(achain(as_aiter(request), stream), self.sender_timeout)
+                if not self.should_delay_results(context.remote_id):
+                    async for msg in self._accumulate_parts_streaming(stream, sender_index):
+                        yield msg
+
+                else:
+                    done_receiving = asyncio.Event()
+                    delayed_results = asyncio.Queue()
+
+                    async def _accumulate_parts():
+                        try:
+                            async for msg in self._accumulate_parts_streaming(
+                                attach_event_on_finished(stream, done_receiving), sender_index
+                            ):
+                                delayed_results.put_nowait(msg)
+                        finally:
+                            delayed_results.put_nowait(None)
+
+                    accumulate_task = asyncio.create_task(_accumulate_parts())
+
+                    await done_receiving.wait()
+
+                    while True:
+                        next_result = await delayed_results.get()
+                        if next_result is None:
+                            break
+                        yield next_result
+                    await accumulate_task
 
-            except Exception as e:
-                self.finalize(exception=e)
+            else:
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
-        else:
-            error_code = averaging_pb2.MessageCode.Name(request.code)
-            logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
-            self.finalize(exception=AllreduceException(f"peer {context.remote_id} sent {error_code}."))
-            yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                raise AllreduceException(f"{context.remote_id} sent {averaging_pb2.MessageCode.Name(request.code)}")
+
+        except BaseException as e:
+            await self._ban_sender(context.remote_id)
+            if isinstance(e, Exception):
+                logger.debug(f"Caught {repr(e)} when communicating with {context.remote_id}", exc_info=True)
+                yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+            else:
+                raise  # CancelledError, StopIteration and similar
+
+    async def _ban_sender(self, peer_id: PeerID):
+        async with self.banlock:
+            if peer_id not in self.banned_senders:
+                self.banned_senders.add(peer_id)
+                self.tensor_part_reducer.on_sender_failed(self.sender_peer_ids.index(peer_id))
 
-    def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
+    def _check_reasons_to_reject(
+        self, request: averaging_pb2.AveragingData, context: P2PContext
+    ) -> Optional[averaging_pb2.AveragingData]:
         if request.group_id != self.group_id:
             return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
         elif self._future.cancelled():
             return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
         elif self._future.done():
             return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+        elif context.remote_id not in self.sender_peer_ids:
+            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
 
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
-        loop = asyncio.get_event_loop()
-        async for part_index, (tensor_part, part_compression) in aenumerate(
-            amap_in_executor(
-                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
+        part_index = 0
+        try:
+            loop = asyncio.get_event_loop()
+            async for tensor_part, weight, part_compression in amap_in_executor(
+                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
-            )
-        ):
-            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
-
-            serialized_delta = await loop.run_in_executor(
-                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
-            )
-            yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
-
-    async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        # Coroutines are lazy, so we take the first item to start the couroutine's execution
-        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+            ):
+                try:
+                    averaged_part = await self.tensor_part_reducer.accumulate_part(
+                        sender_index, part_index, tensor_part, weight=weight
+                    )
+                    part_index += 1
+                except BannedException:
+                    logger.debug(f"Sender {sender_index} is already banned")
+                    break  # sender was banned, we no longer need to aggregate it
+
+                serialized_delta = await loop.run_in_executor(
+                    None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
+                )
+                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
+        finally:
+            if part_index != self.tensor_part_reducer.num_parts:
+                await self._ban_sender(self.sender_peer_ids[sender_index])
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
-        pending_tasks = set()
-        if cancel or exception:
-            # propagate error to peers
-            if cancel or isinstance(exception, asyncio.CancelledError):
-                code = averaging_pb2.CANCELLED
-            else:
-                code = averaging_pb2.INTERNAL_ERROR
-            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
-                if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
-                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
-
         if not self._future.done():
             if cancel:
                 logger.debug(f"{self} - cancelled")
@@ -265,7 +375,5 @@ class AllReduceRunner(ServicerBase):
                 self._future.set_result(None)
             self.tensor_part_container.finalize()
             self.tensor_part_reducer.finalize()
-            return pending_tasks
         else:
-            logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
-            return pending_tasks
+            logger.debug(f"{self} - attempted to finalize allreduce that is already finished: {self._future}")

+ 192 - 103
hivemind/averaging/averager.py

@@ -7,6 +7,7 @@ import contextlib
 import ctypes
 import multiprocessing as mp
 import os
+import random
 import threading
 import weakref
 from dataclasses import asdict
@@ -16,6 +17,7 @@ import numpy as np
 import torch
 
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
+from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
@@ -28,10 +30,20 @@ from hivemind.compression import (
     serialize_torch_tensor,
 )
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
+from hivemind.utils.asyncio import (
+    achain,
+    afirst,
+    aiter_with_timeout,
+    anext,
+    as_aiter,
+    azip,
+    enter_asynchronously,
+    switch_to_uvloop,
+)
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -54,16 +66,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param prefix: a shared prefix for all group keys
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
-    :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
-      note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
+    :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
-    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
-    :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
+    :note: request_timeout must be smaller than min_matchmaking_time to avoid potential deadlocks.
     :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
@@ -75,6 +85,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
       with averager.allow_state_sharing = True / False
+    :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
+    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+    :param next_chunk_timeout: during all-reduce and load_state_from_peers, if peer does not send next data chunk in
+      this number of seconds, consider it failed and proceed with remaining peers. default: no timeout
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 * sender_timeout
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
 
     Example:
@@ -92,6 +110,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
+    _state_updated: asyncio.Event
+    _p2p: P2P
     serializer = MSGPackSerializer
 
     def __init__(
@@ -101,14 +121,18 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         *,
         start: bool,
         prefix: str,
-        target_group_size: int,
+        target_group_size: Optional[int] = None,
         min_group_size: int = 2,
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
-        request_timeout: float = 3,
+        averaging_expiration: Optional[float] = None,
+        min_matchmaking_time: float = 5.0,
+        request_timeout: float = 3.0,
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
+        next_chunk_timeout: Optional[float] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
         compression: CompressionBase = NoCompression(),
         state_compression: CompressionBase = NoCompression(),
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
@@ -116,6 +140,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         min_vector_size: int = 0,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
+        declare_state_period: float = 30,
         client_mode: Optional[bool] = None,
         daemon: bool = True,
         shutdown_timeout: float = 5,
@@ -124,17 +149,25 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert bandwidth is None or (
             bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
         ), "bandwidth must be a non-negative float32"
-        if not is_power_of_two(target_group_size):
-            logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
+        if averaging_expiration is not None:
+            logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
+            assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
+            min_matchmaking_time = averaging_expiration
+
         super().__init__()
         self.dht = dht
         self.prefix = prefix
 
         if client_mode is None:
             client_mode = dht.client_mode
+        if sender_timeout is None:
+            sender_timeout = next_chunk_timeout
+        if reducer_timeout is None:
+            reducer_timeout = 2 * sender_timeout if sender_timeout is not None else None
+
         self.client_mode = client_mode
 
         self._parent_pid = os.getpid()
@@ -148,13 +181,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         self._averaged_tensors = tuple(averaged_tensors)
         self.lock_averaged_tensors = mp.Lock()
-        self.last_updated: DHTExpiration = -float("inf")
         for tensor in self._averaged_tensors:
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             tensor.share_memory_()
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.shutdown_timeout = shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
         self.bandwidth = bandwidth
 
         self.matchmaking_kwargs = dict(
@@ -163,13 +196,15 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             initial_group_bits=initial_group_bits,
             target_group_size=target_group_size,
             min_group_size=min_group_size,
-            averaging_expiration=averaging_expiration,
             request_timeout=request_timeout,
+            min_matchmaking_time=min_matchmaking_time,
         )
         self.allreduce_kwargs = dict(
             compression=compression,
             part_size_bytes=part_size_bytes,
             min_vector_size=min_vector_size,
+            sender_timeout=sender_timeout,
+            reducer_timeout=reducer_timeout,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -177,9 +212,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
+        self._state_sharing_priority = mp.Value(ctypes.c_double, 0)
+
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
+        self.declare_state_period = declare_state_period
         self.state_compression = state_compression
         self.tensor_infos = tensor_infos
 
@@ -202,9 +240,29 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     @allow_state_sharing.setter
     def allow_state_sharing(self, value: bool):
         if value and self.client_mode:
-            raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
+            raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state")
         else:
-            self._allow_state_sharing.value = value
+            old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
+            if value != old_value:
+                self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
+
+    @property
+    def state_sharing_priority(self) -> float:
+        """Others will preferentially downloading state from peers with highest priority."""
+        return float(self._state_sharing_priority.value)
+
+    @state_sharing_priority.setter
+    def state_sharing_priority(self, value: float):
+        if value and self.client_mode:
+            raise ValueError("State sharing priority is unused: averager in client mode cannot share its state")
+        else:
+            old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
+            if self.allow_state_sharing and value != old_value:
+                self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
+
+    async def _trigger_declare_load_state(self):
+        # note: previously tried to set mp.Event instead of this. Awaiting it in executor caused degradation in py39
+        self._state_updated.set()
 
     @property
     def peer_id(self) -> PeerID:
@@ -238,7 +296,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if not self.client_mode:
                     await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
                 else:
-                    logger.debug(f"The averager is running in client mode.")
+                    logger.debug("The averager is running in client mode")
 
                 self._matchmaking = Matchmaking(
                     self._p2p,
@@ -250,6 +308,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
 
+                self._state_updated = asyncio.Event()
                 self._pending_group_assembled = asyncio.Event()
                 self._pending_group_assembled.set()
             except Exception as e:
@@ -294,20 +353,20 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def shutdown(self) -> None:
         """Shut down the averager process"""
         if self.is_alive():
-            self._outer_pipe.send(("_shutdown", [None], {}))  # shut down the daemon process
+            self._outer_pipe.send(("_shutdown", [self.shutdown_timeout], {}))  # shut down the daemon process
             self._inner_pipe.send(("_SHUTDOWN", None))  # shut down background thread in master
             self.join(self.shutdown_timeout)
             if self.is_alive():
-                logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
+                logger.warning("Averager did not shut down within the grace period; terminating it the hard way")
                 self.terminate()
         else:
             logger.exception("Averager shutdown has no effect: the process is already not alive")
 
-    async def _shutdown(self, timeout: Optional[float] = None) -> None:
+    async def _shutdown(self, timeout: Optional[float]) -> None:
         remaining_tasks = set()
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
-        await asyncio.gather(*remaining_tasks)
+        await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout)
 
     def __del__(self):
         if self._parent_pid == os.getpid() and self.is_alive():
@@ -316,67 +375,96 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def step(
         self,
         gather: Optional[GatheredData] = None,
+        scheduled_time: Optional[DHTExpiration] = None,
         weight: Optional[float] = None,
         timeout: Optional[float] = None,
         allow_retries: bool = True,
+        require_trigger: bool = False,
         wait: bool = True,
-    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
         :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
           (this operation is known as all-gather). The gathered data will be available as the output of this function.
+        :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment.
+          By default, schedule all-reduce current time plus min_matchmaking_time seconds
         :param weight: averaging weight for this peer, int or float, must be strictly positive
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
           within the specified timeout
-        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
-        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
+        :param require_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
+        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failed
+        :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         if self.mode == AveragingMode.AUX and weight is not None:
-            logger.warning("Averager is running in auxiliary mode, weight is unused.")
+            logger.warning("Averager is running in auxiliary mode, weight is unused")
+        if scheduled_time is None:
+            scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
+        deadline = get_dht_time() + timeout if timeout is not None else float("inf")
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
-
-        future = MPFuture()
-        gather_binary = self.serializer.dumps(
-            gather
-        )  # serialize here to avoid loading modules in the averager process
-        self._outer_pipe.send(
-            (
-                "_step",
-                [],
-                dict(
-                    future=future,
-                    gather_binary=gather_binary,
-                    weight=weight,
-                    allow_retries=allow_retries,
-                    timeout=timeout,
-                ),
-            )
+        assert not (wait and require_trigger), "Non-asynchronous step cannot wait for trigger (use wait=False)"
+        assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
+
+        user_data_for_gather = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
+        data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, user_data_for_gather])
+        step = StepControl(
+            scheduled_time=scheduled_time,
+            deadline=deadline,
+            allow_retries=allow_retries,
+            weight=weight,
+            data_for_gather=data_for_gather,
         )
-        return future.result() if wait else future
 
-    async def _step(
-        self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
-    ):
-        start_time = get_dht_time()
+        future_for_init = MPFuture()
+        self._outer_pipe.send(("_step", [], dict(step=step, future_for_init=future_for_init)))
+        step.attach(*future_for_init.result())
 
+        if not require_trigger:
+            step.allow_allreduce()
+        return step.result() if wait else step
+
+    async def _step(self, *, step: StepControl, future_for_init: MPFuture):
         try:
-            while not future.done():
+            trigger, cancel = MPFuture(), MPFuture()
+            step.attach(trigger, cancel)
+            future_for_init.set_result((trigger, cancel))
+
+            async def find_peers_or_notify_cancel():
+                group_info = await self._matchmaking.look_for_group(step)
+                if not step.triggered:
+                    step.stage = AveragingStage.AWAITING_TRIGGER
+                    await step.wait_for_trigger()
+                return group_info
+
+            while not step.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(
-                        timeout=timeout, data_for_gather=data_for_gather
-                    )
+                    step.stage = AveragingStage.LOOKING_FOR_GROUP
+                    matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
+                    check_cancel_task = asyncio.create_task(step.wait_for_cancel())
+
+                    await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
+                    if step.cancelled():
+                        matchmaking_task.cancel()
+                        raise asyncio.CancelledError()
+                    else:
+                        check_cancel_task.cancel()
+
+                    group_info = await matchmaking_task
+
                     if group_info is None:
-                        raise AllreduceException("Averaging step failed: could not find a group.")
+                        raise AllreduceException("Averaging step failed: could not find a group")
 
-                    future.set_result(
+                    step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                    step.set_result(
                         await asyncio.wait_for(
-                            self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
+                            self._run_allreduce(
+                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
+                            ),
                             timeout=self._allreduce_timeout,
                         )
                     )
@@ -390,21 +478,25 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
                     P2PHandlerError,
+                    DispatchFailure,
+                    ControlFailure,
                 ) as e:
-                    time_elapsed = get_dht_time() - start_time
-                    if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                        logger.exception(f"Averager caught {repr(e)}")
-                        future.set_exception(e)
+                    if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
+                        if not step.cancelled():
+                            logger.exception(e)
+                        if not step.done():
+                            step.set_exception(e)
                     else:
-                        logger.warning(f"Averager caught {repr(e)}, retrying")
+                        logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
 
         except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
+            if not step.done():
+                step.set_exception(e)
             raise
         finally:
-            if not future.done():
-                future.set_exception(
+            step.stage = AveragingStage.FINISHED
+            if not step.done():
+                step.set_exception(
                     RuntimeError(
                         "Internal sanity check failed: averager.step left future pending."
                         " Please report this to hivemind issues."
@@ -414,8 +506,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
-            weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
             modes = tuple(map(AveragingMode, mode_ids))
 
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
@@ -426,7 +518,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
             )
 
-            async with self.get_tensors_async() as local_tensors:
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
                 allreduce = AllReduceRunner(
                     p2p=self._p2p,
                     servicer_type=type(self),
@@ -435,26 +527,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
-                    weights=weights,
                     gathered=user_gathered,
                     modes=modes,
                     **kwargs,
                 )
 
                 with self.register_allreduce_group(group_info.group_id, allreduce):
-
-                    # actually run all-reduce
-                    averaging_outputs = [output async for output in allreduce]
-
                     if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        assert len(local_tensors) == len(self._averaged_tensors)
-                        for tensor, update in zip(local_tensors, averaging_outputs):
+                        iter_results = allreduce.run()
+                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
+                            # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
-                        self.last_updated = get_dht_time()
+                        self._state_updated.set()
+
+                    else:
+                        async for _ in allreduce:  # trigger all-reduce by iterating
+                            raise ValueError("aux peers should not receive averaged tensors")
 
                 return allreduce.gathered
         except BaseException as e:
-            logger.exception(e)
+            if isinstance(e, Exception):
+                logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
     @contextlib.contextmanager
@@ -477,16 +570,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
-        self.last_updated = get_dht_time()
-
-    @contextlib.asynccontextmanager
-    async def get_tensors_async(self) -> Sequence[torch.Tensor]:
-        """Like get_tensors, but uses an asynchronous contextmanager"""
-        try:
-            await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
-            yield self._averaged_tensors
-        finally:
-            self.lock_averaged_tensors.release()
 
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
@@ -515,21 +598,31 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     async def _declare_for_download_periodically(self):
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
+        sharing_was_allowed = self.allow_state_sharing
         while True:
-            if self.allow_state_sharing:
+            expiration_time = get_dht_time() + self.declare_state_period
+            if self.allow_state_sharing or sharing_was_allowed:
+                # notify either if sharing is allowed or if it was just switched off (to overwrite previous message)
                 asyncio.create_task(
                     asyncio.wait_for(
                         self.dht.store(
                             download_key,
                             subkey=self.peer_id.to_bytes(),
-                            value=self.last_updated,
-                            expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
+                            value=self.state_sharing_priority if self.allow_state_sharing else None,
+                            expiration_time=expiration_time,
                             return_future=True,
                         ),
-                        timeout=self._matchmaking.averaging_expiration,
+                        timeout=expiration_time - get_dht_time(),
                     )
                 )
-            await asyncio.sleep(self._matchmaking.averaging_expiration)
+                sharing_was_allowed = self.allow_state_sharing
+
+            # report again either in state_declare_period or after the field was changed by the user
+            self._state_updated.clear()
+            try:
+                await asyncio.wait_for(self._state_updated.wait(), timeout=max(0.0, expiration_time - get_dht_time()))
+            except asyncio.TimeoutError:
+                pass
 
     async def rpc_download_state(
         self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
@@ -584,21 +677,23 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
         future = MPFuture()
-        self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
+        self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
         return future.result(timeout=timeout) if wait else future
 
-    async def _load_state_from_peers(self, future: MPFuture):
+    async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
+        if timeout is not None:
+            timeout = self.next_chunk_timeout if self.next_chunk_timeout is not None else self.request_timeout
         try:
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority = {
-                PeerID(peer_id): float(info.value)
+                PeerID(peer_id): (float(info.value), random.random())  # using randomness as a tie breaker
                 for peer_id, info in peer_priority.items()
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
             }
 
             if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
-                logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
+                logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}")
                 future.set_result(None)
                 return
 
@@ -608,10 +703,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     logger.info(f"Downloading parameters from peer {peer}")
                     try:
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
-                        stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
+                        stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
 
-                        async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
+                        async for message in aiter_with_timeout(stream, timeout=timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:
@@ -623,12 +718,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
 
                         if not metadata:
-                            logger.debug(f"Peer {peer} did not send its state.")
+                            logger.debug(f"Peer {peer} did not send its state")
                             continue
 
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
-                        self.last_updated = get_dht_time()
                         return
                     except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
@@ -668,11 +762,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 future.set_exception(e)
 
 
-def is_power_of_two(n):
-    """Check whether n is a power of 2"""
-    return (n != 0) and (n & (n - 1) == 0)
-
-
 def _background_thread_fetch_current_state(
     serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
 ):

+ 165 - 0
hivemind/averaging/control.py

@@ -0,0 +1,165 @@
+import os
+import struct
+from enum import Enum
+from typing import Optional
+
+import numpy as np
+import torch
+
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class AveragingStage(Enum):
+    IDLE = 0  # still initializing
+    LOOKING_FOR_GROUP = 1  # running decentralized matchmaking, can't run allreduce yet
+    AWAITING_TRIGGER = 2  # waiting for user to set the trigger that allows running allreduce
+    RUNNING_ALLREDUCE = 3  # exchanging tensors with groupmates
+    FINISHED = 4  # either done or failed with exception
+
+
+class StepControl(MPFuture):
+    """
+    An auxiliary data structure that allows user to control stages and track progress in a single averaging step
+
+    :param scheduled_time: estimated time when averaging should begin. Will be used for scheduling
+    :param deadline: if averaging is still in progress at this time, it should be stopped due to TimeoutError
+    :param allow_retries: if True, allow running matchmaking and all-reduce again if previous attempt fails
+    :param weight: averaging weight, can be changed afterwards
+    :param data_for_gather: send this data to all peers in the next group and gather it from groupmates
+    """
+
+    # indices for the shared buffer
+    _SCHEDULED_TIME, _WEIGHT, _STAGE, _BEGAN_ALLREDUCE = slice(0, 8), slice(8, 16), 16, 17
+
+    def __init__(
+        self,
+        scheduled_time: DHTExpiration,
+        deadline: float,
+        allow_retries: bool,
+        weight: float,
+        data_for_gather: bytes,
+    ):
+        super().__init__()
+        self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
+        self._trigger: Optional[MPFuture] = None
+        self._cancel: Optional[MPFuture] = None
+
+        # Buffer contents:
+        # scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
+        self._shared_buffer = torch.zeros([18], dtype=torch.uint8).share_memory_()
+        self.stage = AveragingStage.IDLE
+        self.scheduled_time = scheduled_time
+        self.weight = weight
+        self.began_allreduce = False
+
+    def attach(self, trigger: MPFuture, cancel: MPFuture):
+        assert self._trigger is None and self._cancel is None, "Futures are already attached"
+        self._trigger, self._cancel = trigger, cancel
+
+    def allow_allreduce(self):
+        """Allow averager to begin all-reduce when it finds a group. Meant to be triggered by user."""
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        if self._trigger.done():
+            logger.warning("Trigger is already set")
+        else:
+            self._trigger.set_result(None)
+
+    async def wait_for_trigger(self):
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        await self._trigger
+
+    @property
+    def triggered(self) -> bool:
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        return self._trigger.done()
+
+    @property
+    def scheduled_time(self) -> DHTExpiration:
+        return struct.unpack("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data)[0]
+
+    @scheduled_time.setter
+    def scheduled_time(self, scheduled_time):
+        if self.began_allreduce:
+            logger.warning("Changing scheduled time has no effect after all-reduce has already started")
+        if scheduled_time >= self.deadline:
+            logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout")
+        struct.pack_into("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data, 0, float(scheduled_time))
+
+    @property
+    def weight(self) -> float:
+        return struct.unpack("d", self._shared_buffer[StepControl._WEIGHT].numpy().data)[0]
+
+    @weight.setter
+    def weight(self, weight: float):
+        assert weight >= 0 and np.isfinite(weight)
+        if self.began_allreduce:
+            logger.warning("Changing weights has no effect after all-reduce has already started")
+        struct.pack_into("d", self._shared_buffer[StepControl._WEIGHT].numpy().data, 0, float(weight))
+
+    @property
+    def stage(self) -> AveragingStage:
+        return AveragingStage(self._shared_buffer[StepControl._STAGE].item())
+
+    @stage.setter
+    def stage(self, stage: AveragingStage):
+        if stage == AveragingStage.RUNNING_ALLREDUCE:
+            self.began_allreduce = True
+        self._shared_buffer[StepControl._STAGE] = stage.value
+
+    @property
+    def began_allreduce(self) -> bool:
+        return bool(self._shared_buffer[StepControl._BEGAN_ALLREDUCE].item())
+
+    @began_allreduce.setter
+    def began_allreduce(self, value: bool):
+        self._shared_buffer[StepControl._BEGAN_ALLREDUCE] = int(value)
+
+    @property
+    def data_for_gather(self) -> bytes:
+        return self._data_for_gather
+
+    @property
+    def deadline(self) -> DHTExpiration:
+        return self._deadline
+
+    def get_timeout(self) -> Optional[DHTExpiration]:
+        return max(0.0, self.deadline - get_dht_time())
+
+    @property
+    def allow_retries(self) -> bool:
+        return self._allow_retries
+
+    def __getstate__(self):
+        return dict(
+            super().__getstate__(),
+            _trigger=self._trigger,
+            _cancel=self._cancel,
+            _shared_buffer=self._shared_buffer,
+            immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
+        )
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
+        self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
+
+    def __del__(self):
+        if os.getpid() == self._origin_pid and not self.triggered:
+            logger.warning(
+                "Deleted an averaging StepControl, but the step was not triggered. This may cause other "
+                "peers to fail an averaging round via TimeoutError."
+            )
+        super().__del__()
+
+    def cancel(self) -> bool:
+        if self._trigger is not None:
+            self._trigger.cancel()
+        if self._cancel is not None:
+            self._cancel.set_result(None)
+        return super().cancel()
+
+    async def wait_for_cancel(self):
+        """Await for step to be cancelled by the user. Should be called from insider the averager."""
+        await self._cancel

+ 21 - 4
hivemind/averaging/key_manager.py

@@ -11,6 +11,7 @@ from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get
 
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
+DEFAULT_NUM_BUCKETS = 256
 logger = get_logger(__name__)
 
 
@@ -29,9 +30,12 @@ class GroupKeyManager:
         dht: DHT,
         prefix: str,
         initial_group_bits: str,
-        target_group_size: int,
+        target_group_size: Optional[int],
     ):
         assert all(bit in "01" for bit in initial_group_bits)
+        if target_group_size is not None and not is_power_of_two(target_group_size):
+            logger.warning("It is recommended to set target_group_size to a power of 2")
+
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
         self.target_group_size = target_group_size
         self.peer_id = dht.peer_id
@@ -76,7 +80,7 @@ class GroupKeyManager:
         assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
         result = await self.dht.get(group_key, latest=True, return_future=True)
         if result is None or not isinstance(result.value, dict):
-            logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
+            logger.debug(f"Allreduce group not found: {group_key}, creating new group")
             return []
         averagers = []
         for key, looking_for_group in result.value.items():
@@ -92,8 +96,11 @@ class GroupKeyManager:
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
         index = group_info.peer_ids.index(self.peer_id)
-        generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
-        nbits = int(np.ceil(np.log2(self.target_group_size)))
+        num_buckets = self.target_group_size
+        if num_buckets is None:
+            num_buckets = next_power_of_two(group_info.group_size)
+        generalized_index = rng.sample(range(num_buckets), group_info.group_size)[index]
+        nbits = int(np.ceil(np.log2(num_buckets)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
@@ -101,3 +108,13 @@ class GroupKeyManager:
     async def update_key_on_not_enough_peers(self):
         """this function is triggered whenever averager fails to assemble group within timeout"""
         pass  # to be implemented in subclasses
+
+
+def is_power_of_two(n):
+    """Check whether n is a power of 2"""
+    return (n != 0) and (n & (n - 1) == 0)
+
+
+def next_power_of_two(n):
+    """Round n up to the nearest power of 2"""
+    return 1 if n == 0 else 2 ** (n - 1).bit_length()

+ 2 - 2
hivemind/averaging/load_balancing.py

@@ -65,7 +65,7 @@ def optimize_parts_lp(vector_size: int, bandwidths: np.ndarray, min_size: int =
     # the constraints below are tuples (A, b) such that Ax <= b
     nonnegative_weights = -np.eye(group_size, num_variables, dtype=c.dtype), np.zeros(group_size, c.dtype)
     weights_sum_to_one = c[None, :] - 1.0, np.array([-1.0])
-    coeff_per_variable = (group_size - 2.0) / np.maximum(bandwidths, 10 ** -LOAD_BALANCING_LP_DECIMALS)
+    coeff_per_variable = (group_size - 2.0) / np.maximum(bandwidths, 10**-LOAD_BALANCING_LP_DECIMALS)
     coeff_matrix_minus_xi = np.hstack([np.diag(coeff_per_variable), -np.ones((group_size, 1), c.dtype)])
     xi_is_maximum = coeff_matrix_minus_xi[is_nonzero], -1.0 / bandwidths[is_nonzero]
     force_max_weights = np.eye(group_size, M=num_variables, dtype=c.dtype), is_nonzero.astype(c.dtype)
@@ -80,7 +80,7 @@ def optimize_parts_lp(vector_size: int, bandwidths: np.ndarray, min_size: int =
             peer_scores[peer_scores < min_size / float(vector_size)] = 0.0
         peer_scores = np.round(peer_scores, LOAD_BALANCING_LP_DECIMALS)
     else:
-        logger.error(f"Failed to solve load-balancing for bandwidths {bandwidths}.")
+        logger.error(f"Failed to solve load-balancing for bandwidths {bandwidths}")
         peer_scores = np.ones(group_size, c.dtype)
 
     return peer_scores[np.argsort(permutation)]

+ 75 - 55
hivemind/averaging/matchmaking.py

@@ -9,10 +9,12 @@ import random
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
+from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils.asyncio import anext, cancel_and_wait
@@ -41,18 +43,18 @@ class Matchmaking:
         *,
         servicer_type: Type[ServicerBase],
         prefix: str,
-        target_group_size: int,
+        target_group_size: Optional[int],
         min_group_size: int,
+        min_matchmaking_time: float,
         request_timeout: float,
         client_mode: bool,
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
     ):
         assert "." not in prefix, "group prefix must be a string without ."
-        if request_timeout is None or request_timeout >= averaging_expiration:
+        if request_timeout is None or request_timeout >= min_matchmaking_time:
             logger.warning(
-                "It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
-                "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
+                "It is recommended to use request_timeout smaller than min_matchmaking_time. Otherwise,"
+                " matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
             )
 
         super().__init__()
@@ -67,7 +69,7 @@ class Matchmaking:
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
-        self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.min_matchmaking_time, self.request_timeout = min_matchmaking_time, request_timeout
         self.client_mode = client_mode
 
         self.lock_looking_for_group = asyncio.Lock()
@@ -78,8 +80,18 @@ class Matchmaking:
 
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
-        self.data_for_gather: Optional[bytes] = None
+        self.potential_leaders = PotentialLeaders(self.peer_id, min_matchmaking_time, target_group_size)
+        self.step_control: Optional[StepControl] = None
+
+    @contextlib.asynccontextmanager
+    async def looking_for_group(self, step_control: StepControl):
+        async with self.lock_looking_for_group:
+            assert self.step_control is None
+            try:
+                self.step_control = step_control
+                yield
+            finally:
+                self.step_control = None
 
     @property
     def is_looking_for_group(self):
@@ -98,10 +110,9 @@ class Matchmaking:
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
 
-    async def look_for_group(self, *, data_for_gather: bytes, timeout: Optional[float] = None) -> Optional[GroupInfo]:
+    async def look_for_group(self, step: StepControl) -> Optional[GroupInfo]:
         """
-        :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates
-        :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
+        :param step: step parameters and user control structure for the current step
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         """
@@ -110,11 +121,10 @@ class Matchmaking:
                 "Another look_for_group is already in progress. The current run will be scheduled after"
                 " the existing group is either assembled or disbanded."
             )
-        async with self.lock_looking_for_group:
-            self.data_for_gather = data_for_gather
-            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
+        async with self.looking_for_group(step):
+            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
             try:
-                return await asyncio.wait_for(self.assembled_group, timeout=timeout)
+                return await asyncio.wait_for(self.assembled_group, timeout=step.get_timeout())
             except asyncio.TimeoutError:
                 return None
 
@@ -136,16 +146,16 @@ class Matchmaking:
                 # note: the code above ensures that we send all followers away before creating new future
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
-                self.data_for_gather = None
 
-    async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
+    async def _request_join_potential_leaders(self, step: StepControl) -> GroupInfo:
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
-        async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
+        assert self.is_looking_for_group
+        async with self.potential_leaders.begin_search(step, self.group_key_manager, declare=not self.client_mode):
             while True:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
 
-                    group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
+                    group = await self._request_join_group(next_leader)
                     if group is not None:
                         return group
 
@@ -166,33 +176,32 @@ class Matchmaking:
                         self.assembled_group.set_exception(e)
                     raise e
 
-    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def _request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
         """
         :param leader: request this peer to be your leader for allreduce
-        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
         :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
         :note: this function does not guarantee that your group leader is the same as :leader: parameter
           The originally specified leader can disband group and redirect us to a different leader
         """
         assert self.is_looking_for_group and self.current_leader is None
-        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
+        stream: Optional[AsyncIterator[averaging_pb2.MessageFromLeader]] = None
         try:
             async with self.lock_request_join_group:
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
-
-                stream = leader_stub.rpc_join_group(
+                request_expiration_time = self.get_request_expiration_time()
+                stream = await leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                         schema_hash=self.schema_hash,
-                        expiration=expiration_time,
+                        expiration=request_expiration_time,
                         client_mode=self.client_mode,
-                        gather=self.data_for_gather,
+                        gather=self.step_control.data_for_gather,
                         group_key=self.group_key_manager.current_key,
                     )
                 )
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
-                    logger.debug(f"{self.peer_id} - joining the group of {leader}; waiting for peers")
+                    logger.debug(f"{self.peer_id} - joining the group of {leader}, waiting for peers")
                     self.current_leader = leader
                     self.was_accepted_to_group.set()
                     if len(self.current_followers) > 0:
@@ -204,7 +213,7 @@ class Matchmaking:
                 return None
 
             async with self.potential_leaders.pause_search():
-                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
+                time_to_expiration = max(0.0, request_expiration_time - get_dht_time())
                 message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
@@ -217,8 +226,11 @@ class Matchmaking:
                     if suggested_leader != self.peer_id:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
-                        await stream.aclose()
-                        return await self.request_join_group(suggested_leader, expiration_time)
+                        try:
+                            await stream.aclose()
+                        except RuntimeError as e:
+                            logger.debug(e, exc_info=True)
+                        return await self._request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 return None
 
@@ -227,15 +239,26 @@ class Matchmaking:
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
-        except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.exception(f"{self} - failed to request potential leader {leader}:")
+        except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
+            logger.debug(f"{self} - failed to request potential leader {leader}:", exc_info=True)
             return None
 
         finally:
             self.was_accepted_to_group.clear()
             self.current_leader = None
             if stream is not None:
-                await stream.aclose()
+                try:
+                    await stream.aclose()
+                except RuntimeError as e:
+                    logger.debug(e, exc_info=True)
+
+    def get_request_expiration_time(self) -> float:
+        """Returns the averager's current expiration time, which is used to send join requests to leaders"""
+        if isfinite(self.potential_leaders.declared_expiration_time):
+            return self.potential_leaders.declared_expiration_time
+        else:
+            scheduled_time = max(self.step_control.scheduled_time, get_dht_time() + self.min_matchmaking_time)
+            return min(scheduled_time, self.potential_leaders.search_end_time)
 
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
@@ -251,7 +274,11 @@ class Matchmaking:
                 self.current_followers[context.remote_id] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
-                if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
+                if (
+                    self.target_group_size is not None
+                    and len(self.current_followers) + 1 >= self.target_group_size
+                    and not self.assembled_group.done()
+                ):
                     # outcome 1: we have assembled a full group and are ready for allreduce
                     await self.leader_assemble_group()
 
@@ -337,7 +364,7 @@ class Matchmaking:
             )
         elif context.remote_id == self.peer_id or context.remote_id in self.current_followers:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_PEER_ID)
-        elif len(self.current_followers) + 1 >= self.target_group_size:
+        elif self.target_group_size is not None and len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
         else:
             return None
@@ -352,11 +379,11 @@ class Matchmaking:
         random.shuffle(ordered_peer_ids)
 
         gathered = tuple(
-            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            self.step_control.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
             for peer_id in ordered_peer_ids
         )
 
-        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers.")
+        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers")
         group_info = GroupInfo(group_id, tuple(ordered_peer_ids), gathered)
         await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
         self.assembled_group.set_result(group_info)
@@ -373,7 +400,7 @@ class Matchmaking:
         assert self.peer_id in ordered_peer_ids, "Leader sent us group_peer_ids that does not contain us!"
         assert len(ordered_peer_ids) == len(msg.gathered)
 
-        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}.")
+        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}")
         group_info = GroupInfo(group_id, tuple(ordered_peer_ids), tuple(msg.gathered))
         await self.group_key_manager.update_key_on_group_assembled(group_info)
         self.assembled_group.set_result(group_info)
@@ -388,8 +415,8 @@ class Matchmaking:
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
 
-    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
+    def __init__(self, peer_id: PeerID, min_matchmaking_time: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.min_matchmaking_time = peer_id, min_matchmaking_time
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
@@ -401,13 +428,13 @@ class PotentialLeaders:
         self.search_end_time = float("inf")
 
     @contextlib.asynccontextmanager
-    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
+    async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
         async with self.lock_search:
             self.running.set()
-            self.search_end_time = get_dht_time() + timeout if timeout is not None else float("inf")
+            self.search_end_time = step.deadline if step.deadline is not None else float("inf")
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             if declare:
-                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
 
             try:
                 yield self
@@ -467,20 +494,12 @@ class PotentialLeaders:
             self.past_attempts.add((maybe_next_leader, entry.expiration_time))
             return maybe_next_leader
 
-    @property
-    def request_expiration_time(self) -> float:
-        """this averager's current expiration time - used to send join requests to leaders"""
-        if isfinite(self.declared_expiration_time):
-            return self.declared_expiration_time
-        else:
-            return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-
     async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
         DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
         while get_dht_time() < self.search_end_time:
             new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
             self.max_assured_time = max(
-                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+                self.max_assured_time, get_dht_time() + self.min_matchmaking_time - DISCREPANCY
             )
 
             self.leader_queue.clear()
@@ -499,13 +518,14 @@ class PotentialLeaders:
             )
             self.update_triggered.clear()
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
+    async def _declare_averager_periodically(self, step: StepControl, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
             try:
                 while True:
                     await self.running.wait()
-
-                    new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                    new_expiration_time = float(
+                        min(max(step.scheduled_time, get_dht_time() + self.min_matchmaking_time), self.search_end_time)
+                    )
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()

+ 71 - 26
hivemind/averaging/partition.py

@@ -10,21 +10,24 @@ import torch
 
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import amap_in_executor
+from hivemind.utils import amap_in_executor, as_aiter, get_logger
 
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 16
+DEFAULT_PART_SIZE_BYTES = 2**19
+logger = get_logger(__name__)
 
 
 class TensorPartContainer:
     """
     Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
+
     :param tensors: local tensors to be split and aggregated
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param compression: optionally compress tensors with this compression algorithm before sending them to peers
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
+    :param return_deltas: if True, output tensors are differences (aggregated tensor - local tensor)
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     """
 
@@ -35,7 +38,8 @@ class TensorPartContainer:
         compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
-        prefetch: int = 5,
+        return_deltas: bool = True,
+        prefetch: int = 1,
     ):
         if tensor_infos is None:
             tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
@@ -43,6 +47,8 @@ class TensorPartContainer:
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
         self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
         self.total_size = sum(tensor.numel() for tensor in tensors)
+        self.failed_size = 0
+        self.return_deltas = return_deltas
         self.prefetch = prefetch
 
         self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
@@ -91,7 +97,6 @@ class TensorPartContainer:
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
         input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
-        self._input_parts_by_peer[peer_index].clear()
         return input_parts
 
     @torch.no_grad()
@@ -99,13 +104,9 @@ class TensorPartContainer:
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
-
-        async def _aiterate_parts():
-            for _ in range(self.num_parts_by_peer[peer_index]):
-                yield self._input_parts_by_peer[peer_index].popleft()
-
+        parts_aiter = as_aiter(*self._input_parts_by_peer[peer_index])
         async for serialized_part in amap_in_executor(
-            lambda x_and_info: self.compression.compress(*x_and_info), _aiterate_parts(), max_prefetch=self.prefetch
+            lambda x_and_info: self.compression.compress(*x_and_info), parts_aiter, max_prefetch=self.prefetch
         ):
             yield serialized_part
 
@@ -123,6 +124,16 @@ class TensorPartContainer:
         self._outputs_registered_by_peer[peer_index] += 1
         self._output_part_available[peer_index].set()
 
+    def register_failed_reducer(self, peer_index: int):
+        """
+        a given peer failed to aggregate a certain part, use our local part instead, keep track of failed parts
+        """
+        for part_index in range(self._outputs_registered_by_peer[peer_index], self.num_parts_by_peer[peer_index]):
+            part_and_info = self._input_parts_by_peer[peer_index][part_index]
+            part_result_or_delta = torch.zeros_like(part_and_info[0]) if self.return_deltas else part_and_info[0]
+            self.register_processed_part(peer_index, part_index, part_result_or_delta)
+            self.failed_size += part_result_or_delta.numel()
+
     async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
         """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
         assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
@@ -139,7 +150,7 @@ class TensorPartContainer:
                     self._output_part_available[peer_index].clear()
                     await self._output_part_available[peer_index].wait()
                     if self.finished.is_set():
-                        raise AllreduceException("All-reduce was terminated during iteration.")
+                        raise AllreduceException("All-reduce was terminated during iteration")
 
                 tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
                 num_parts_processed += 1
@@ -155,9 +166,11 @@ class TensorPartContainer:
         if not self.finished.is_set():
             for peer_index in range(self.group_size):
                 self._inputs_consumed_by_peer[peer_index] = True
+                self._output_part_available[peer_index].set()
                 self._input_parts_by_peer[peer_index].clear()
                 self._output_parts_by_peer[peer_index].clear()
-                self._output_part_available[peer_index].set()
+            if self.failed_size != 0:
+                logger.warning(f"Averaging: received {(1. - self.failed_size / self.total_size) * 100:.1f}% results")
             self._outputs_consumed = True
             self.finished.set()
 
@@ -167,26 +180,27 @@ class TensorPartReducer:
     Auxiliary data structure responsible for running asynchronous all-reduce
     :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
     :param num_senders: total number of peers in a given all-reduce group that will send gradients
-    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
     :note: even if local peer is not sending data, local parts will be used for shape information
     """
 
-    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int):
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
-        self.weights = tuple(weights or (1 for _ in range(num_senders)))
-        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
-        assert all(isinstance(weight, (int, float)) for weight in self.weights)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.accumulator = None  # this will contain the sum of current tensor part from group peers
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
         self.finished = asyncio.Event()
+
+        self.num_parts_received = [0 for _ in range(self.num_senders)]
+        self.sender_failed_after = [float("inf") for _ in range(self.num_senders)]
+        self.num_current_senders = self.num_senders
+
         self.reset_accumulators()
 
     def reset_accumulators(self):
         """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
-        assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
+        assert self.current_part_accumulated_from == self.num_current_senders or self.current_part_index == -1
         if self.current_part_index >= self.num_parts - 1:
             self.finalize()
             return
@@ -194,32 +208,53 @@ class TensorPartReducer:
         self.current_part_index += 1
         self.current_part_accumulated_from = 0
         self.current_part_future = asyncio.Future()
+        self.num_current_senders = sum(
+            self.current_part_index < failed_index for failed_index in self.sender_failed_after
+        )
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.denominator = 0.0
 
-    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+    async def accumulate_part(
+        self, sender_index: int, part_index: int, tensor_part: torch.Tensor, weight: float = 1.0
+    ) -> torch.Tensor:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
+        self.num_parts_received[sender_index] += 1
 
         while part_index > self.current_part_index:
             # wait for previous parts to finish processing ...
             await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
             if self.finished.is_set():
                 raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
+
+        if self.sender_failed_after[sender_index] != float("inf"):
+            raise BannedException(f"sender {sender_index} was banned in background")
         assert part_index == self.current_part_index
 
         current_part_future = self.current_part_future
 
-        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
-        self.denominator += self.weights[sender_index]
-        self.current_part_accumulated_from += 1
+        if part_index < self.sender_failed_after[sender_index]:
+            self.accumulator.add_(tensor_part, alpha=weight)
+            self.current_part_accumulated_from += 1
+            self.denominator += weight
+            self.check_current_part_finished()
+        return await current_part_future
 
-        assert self.current_part_accumulated_from <= self.num_senders
-        if self.current_part_accumulated_from == self.num_senders:
-            current_part_future.set_result(self.accumulator.div_(self.denominator))
+    def on_sender_failed(self, sender_index: int):
+        """Exclude that sender's data for averaging any parts that it did not submit yet."""
+        self.sender_failed_after[sender_index] = self.num_parts_received[sender_index]
+        if self.finished.is_set():
+            return
+        if self.current_part_index == self.num_parts_received[sender_index]:
+            self.num_current_senders -= 1
+            self.check_current_part_finished()
+
+    def check_current_part_finished(self):
+        assert self.current_part_accumulated_from <= self.num_current_senders
+        if self.current_part_accumulated_from == self.num_current_senders:
+            self.current_part_future.set_result(self.accumulator.div_(self.denominator))
             self.reset_accumulators()
-        return await current_part_future
 
     def finalize(self):
         if not self.finished.is_set():
@@ -228,9 +263,19 @@ class TensorPartReducer:
                 del self.accumulator
             self.finished.set()
 
+            if self.num_parts != 0 and self.num_senders != 0:
+                parts_expected = self.num_parts * self.num_senders
+                parts_received = sum(self.num_parts_received)
+                if parts_expected != parts_received:
+                    logger.warning(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
+
     def __del__(self):
         self.finalize()
 
 
 class AllreduceException(Exception):
     """A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error)"""
+
+
+class BannedException(AllreduceException):
+    """An exception that indicates that a given sender was banned and will no longer be aggregated"""

+ 1 - 44
hivemind/compression/__init__.py

@@ -2,51 +2,8 @@
 Compression strategies that reduce the network communication in .averaging, .optim and .moe
 """
 
-import warnings
-from typing import Dict, Optional
-
-import torch
-
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
-from hivemind.proto import runtime_pb2
-
-warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
-
-
-BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
-    NONE=NoCompression(),
-    FLOAT16=Float16Compression(),
-    MEANSTD_16BIT=ScaledFloat16Compression(),
-    QUANTILE_8BIT=Quantile8BitQuantization(),
-    UNIFORM_8BIT=Uniform8BitQuantization(),
-)
-
-for key in runtime_pb2.CompressionType.keys():
-    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer."
-    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
-    assert (
-        runtime_pb2.CompressionType.Name(actual_compression_type) == key
-    ), f"Compression strategy for {key} has inconsistent type"
-
-
-def serialize_torch_tensor(
-    tensor: torch.Tensor,
-    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-    info: Optional[CompressionInfo] = None,
-    allow_inplace: bool = False,
-    **kwargs,
-) -> runtime_pb2.Tensor:
-    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
-    assert tensor.device == torch.device("cpu")
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
-    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
-    return compression.compress(tensor, info, allow_inplace)
-
-
-def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-    """Restore a pytorch tensor from a protobuf message"""
-    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
-    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
+from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor

+ 8 - 0
hivemind/compression/base.py

@@ -1,4 +1,5 @@
 import dataclasses
+import warnings
 from abc import ABC, abstractmethod
 from enum import Enum, auto
 from typing import Any, Optional
@@ -9,6 +10,10 @@ import torch
 from hivemind.proto import runtime_pb2
 from hivemind.utils.tensor_descr import TensorDescriptor
 
+# While converting read-only NumPy arrays into PyTorch tensors, we don't make extra copies for efficiency
+warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
+
+
 Key = Any
 
 
@@ -65,6 +70,9 @@ class CompressionBase(ABC):
         """Estimate the compression ratio without doing the actual compression; lower ratio = better compression"""
         ...
 
+    def __repr__(self):
+        return f"hivemind.{self.__class__.__name__}()"
+
 
 class NoCompression(CompressionBase):
     """A dummy compression strategy that preserves the original tensor as is."""

+ 2 - 2
hivemind/compression/quantization.py

@@ -48,7 +48,7 @@ class Quantization(CompressionBase, ABC):
 
     @property
     def n_bins(self):
-        return 2 ** self.n_bits
+        return 2**self.n_bits
 
 
 class Uniform8BitQuantization(Quantization):
@@ -94,7 +94,7 @@ def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
     return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
 
 
-def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
+def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10**5) -> np.ndarray:
     """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
     if not array.data.c_contiguous and array.data.f_contiguous:
         array = array.T

+ 43 - 0
hivemind/compression/serialization.py

@@ -0,0 +1,43 @@
+from typing import Dict, Optional
+
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression
+from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
+from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.proto import runtime_pb2
+
+BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+    NONE=NoCompression(),
+    FLOAT16=Float16Compression(),
+    MEANSTD_16BIT=ScaledFloat16Compression(),
+    QUANTILE_8BIT=Quantile8BitQuantization(),
+    UNIFORM_8BIT=Uniform8BitQuantization(),
+)
+
+for key in runtime_pb2.CompressionType.keys():
+    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
+    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+    assert (
+        runtime_pb2.CompressionType.Name(actual_compression_type) == key
+    ), f"Compression strategy for {key} has inconsistent type"
+
+
+def serialize_torch_tensor(
+    tensor: torch.Tensor,
+    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+    info: Optional[CompressionInfo] = None,
+    allow_inplace: bool = False,
+    **kwargs,
+) -> runtime_pb2.Tensor:
+    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
+    assert tensor.device == torch.device("cpu")
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
+    return compression.compress(tensor, info, allow_inplace)
+
+
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    """Restore a pytorch tensor from a protobuf message"""
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)

+ 3 - 322
hivemind/dht/__init__.py

@@ -4,7 +4,7 @@ Hivemind DHT is based on Kademlia [1] with added support for improved bulk store
 
 The code is organized as follows:
 
- * **class DHT (__init__.py)** - high-level class for model training. Runs DHTNode in a background process.
+ * **class DHT (dht.py)** - high-level class for model training. Runs DHTNode in a background process.
  * **class DHTNode (node.py)** - an asyncio implementation of dht server, stores AND gets keys.
  * **class DHTProtocol (protocol.py)** - an RPC protocol to request data from dht nodes.
  * **async def traverse_dht (traverse.py)** - a search algorithm that crawls DHT peers.
@@ -12,327 +12,8 @@ The code is organized as follows:
 - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
 - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 """
-from __future__ import annotations
-
-import asyncio
-import multiprocessing as mp
-import os
-from functools import partial
-from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
-
-from multiaddr import Multiaddr
 
+from hivemind.dht.dht import DHT
 from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
-from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
-from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_logger, switch_to_uvloop
-
-logger = get_logger(__name__)
-
-ReturnType = TypeVar("ReturnType")
-
-
-class DHT(mp.Process):
-    """
-    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
-    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
-    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
-
-    :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
-    :param start: if True, automatically starts the background process on creation. Otherwise await manual start
-    :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
-    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
-      (but no more than one per key)
-    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
-    :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
-      The validators will be combined using the CompositeValidator class. It merges them when possible
-      (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
-    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
-    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
-    :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
-    """
-
-    _node: DHTNode
-
-    def __init__(
-        self,
-        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        *,
-        start: bool,
-        p2p: Optional[P2P] = None,
-        daemon: bool = True,
-        num_workers: int = DEFAULT_NUM_WORKERS,
-        record_validators: Iterable[RecordValidatorBase] = (),
-        shutdown_timeout: float = 3,
-        await_ready: bool = True,
-        **kwargs,
-    ):
-        self._parent_pid = os.getpid()
-        super().__init__()
-
-        if not (
-            initial_peers is None
-            or (
-                isinstance(initial_peers, Sequence)
-                and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
-            )
-        ):
-            raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
-        self.initial_peers = initial_peers
-        self.kwargs = kwargs
-        self.num_workers = num_workers
-
-        self._record_validator = CompositeValidator(record_validators)
-        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
-        self.shutdown_timeout = shutdown_timeout
-        self._ready = MPFuture()
-        self.daemon = daemon
-
-        # These values will be fetched from the child process when requested
-        self._peer_id = None
-        self._client_mode = None
-        self._p2p_replica = None
-
-        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
-
-        if start:
-            self.run_in_background(await_ready=await_ready)
-
-    def run(self) -> None:
-        """Serve DHT forever. This function will not return until DHT node is shut down"""
-
-        loop = switch_to_uvloop()
-        pipe_semaphore = asyncio.Semaphore(value=0)
-        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
-
-        async def _run():
-            try:
-                if self._daemon_listen_maddr is not None:
-                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
-                else:
-                    replicated_p2p = None
-
-                self._node = await DHTNode.create(
-                    initial_peers=self.initial_peers,
-                    num_workers=self.num_workers,
-                    record_validator=self._record_validator,
-                    p2p=replicated_p2p,
-                    **self.kwargs,
-                )
-            except Exception as e:
-                # Loglevel is DEBUG since normally the exception is propagated to the caller
-                logger.debug(e, exc_info=True)
-                self._ready.set_exception(e)
-                return
-            self._ready.set_result(None)
-
-            while True:
-                try:
-                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
-                except asyncio.TimeoutError:
-                    pass
-                if not self._inner_pipe.poll():
-                    continue
-                try:
-                    method, args, kwargs = self._inner_pipe.recv()
-                except (OSError, ConnectionError, RuntimeError) as e:
-                    logger.exception(e)
-                    await asyncio.sleep(self._node.protocol.wait_timeout)
-                    continue
-                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                if method == "_shutdown":
-                    await task
-                    break
-
-        loop.run_until_complete(_run())
-
-    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
-        """
-        Starts DHT in a background process. if await_ready, this method will wait until background dht
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready:
-            self.wait_until_ready(timeout)
-
-    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
-        self._ready.result(timeout=timeout)
-
-    def shutdown(self) -> None:
-        """Shut down a running dht process"""
-        if self.is_alive():
-            self._outer_pipe.send(("_shutdown", [], {}))
-            self.join(self.shutdown_timeout)
-            if self.is_alive():
-                logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
-                self.terminate()
-
-    async def _shutdown(self):
-        await self._node.shutdown()
-
-    def get(
-        self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
-    ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
-        """
-        Search for a key across DHT and return either first or latest entry (if found).
-        :param key: same key as in node.store(...)
-        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
-        :returns: (value, expiration time); if value was not found, returns None
-        """
-        future = MPFuture()
-        self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
-        return future if return_future else future.result()
-
-    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
-        try:
-            result = await self._node.get(key, latest=latest, **kwargs)
-            if not future.done():
-                future.set_result(result)
-        except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
-            raise
-
-    def store(
-        self,
-        key: DHTKey,
-        value: DHTValue,
-        expiration_time: DHTExpiration,
-        subkey: Optional[Subkey] = None,
-        return_future: bool = False,
-        **kwargs,
-    ) -> Union[bool, MPFuture]:
-        """
-        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
-
-        :param key: msgpack-serializable key to be associated with value until expiration.
-        :param value: msgpack-serializable value to be stored under a given key until expiration.
-        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
-        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: True if store succeeds, False if it fails (due to no response or newer value)
-        """
-        future = MPFuture()
-        self._outer_pipe.send(
-            (
-                "_store",
-                [],
-                dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
-            )
-        )
-        return future if return_future else future.result()
-
-    async def _store(
-        self,
-        key: DHTKey,
-        value: DHTValue,
-        expiration_time: DHTExpiration,
-        subkey: Optional[Subkey],
-        future: MPFuture,
-        **kwargs,
-    ):
-        try:
-            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
-            if not future.done():
-                future.set_result(result)
-        except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
-            raise
-
-    def run_coroutine(
-        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
-    ) -> Union[ReturnType, MPFuture[ReturnType]]:
-        """
-        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
-         for running custom functions DHT for special cases (e.g. declare experts, beam search)
-
-        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: coroutine outputs or MPFuture for these outputs
-        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
-          DHT fields made by this coroutine will not be accessible from the host process.
-        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
-          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
-        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
-        """
-        future = MPFuture()
-        self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
-        return future if return_future else future.result()
-
-    async def _run_coroutine(
-        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
-    ):
-        try:
-            future.set_result(await coro(self, self._node))
-        except BaseException as e:
-            logger.exception("Caught an exception when running a coroutine:")
-            future.set_exception(e)
-
-    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
-        if not self._ready.done():
-            raise RuntimeError(
-                "Can't append new validators before the DHT process has started. "
-                "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
-            )
-
-        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
-
-    @staticmethod
-    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
-        node.protocol.record_validator.extend(record_validators)
-
-    @property
-    def peer_id(self) -> PeerID:
-        if self._peer_id is None:
-            self._peer_id = self.run_coroutine(DHT._get_peer_id)
-        return self._peer_id
-
-    @staticmethod
-    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
-        return node.peer_id
-
-    @property
-    def client_mode(self) -> bool:
-        if self._client_mode is None:
-            self._client_mode = self.run_coroutine(DHT._get_client_mode)
-        return self._client_mode
-
-    @staticmethod
-    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
-        return node.protocol.client_mode
-
-    def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
-        """
-        Get multiaddrs of the current DHT node that should be accessible by other peers.
-
-        :param latest: ask the P2P daemon to refresh the visible multiaddrs
-        """
-
-        return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
-
-    @staticmethod
-    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
-        return await node.get_visible_maddrs(latest=latest)
-
-    async def replicate_p2p(self) -> P2P:
-        """
-        Get a replica of a P2P instance used in the DHT process internally.
-        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
-        """
-
-        if self._p2p_replica is None:
-            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
-            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
-        return self._p2p_replica
-
-    @staticmethod
-    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
-        return node.p2p.daemon_listen_maddr
-
-    def __del__(self):
-        if self._parent_pid == os.getpid() and self.is_alive():
-            self.shutdown()

+ 324 - 0
hivemind/dht/dht.py

@@ -0,0 +1,324 @@
+from __future__ import annotations
+
+import asyncio
+import multiprocessing as mp
+import os
+from functools import partial
+from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
+
+from multiaddr import Multiaddr
+
+from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
+from hivemind.dht.routing import DHTKey, DHTValue, Subkey
+from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
+from hivemind.p2p import P2P, PeerID
+from hivemind.utils import MPFuture, get_logger, switch_to_uvloop
+from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration
+
+logger = get_logger(__name__)
+ReturnType = TypeVar("ReturnType")
+
+
+class DHT(mp.Process):
+    """
+    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
+    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
+    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
+
+    :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
+    :param start: if True, automatically starts the background process on creation. Otherwise await manual start
+    :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
+    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
+      (but no more than one per key)
+    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
+    :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
+      The validators will be combined using the CompositeValidator class. It merges them when possible
+      (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
+    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
+    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
+    :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
+    """
+
+    _node: DHTNode
+
+    def __init__(
+        self,
+        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
+        *,
+        start: bool,
+        p2p: Optional[P2P] = None,
+        daemon: bool = True,
+        num_workers: int = DEFAULT_NUM_WORKERS,
+        record_validators: Iterable[RecordValidatorBase] = (),
+        shutdown_timeout: float = 3,
+        await_ready: bool = True,
+        **kwargs,
+    ):
+        self._parent_pid = os.getpid()
+        super().__init__()
+
+        if not (
+            initial_peers is None
+            or (
+                isinstance(initial_peers, Sequence)
+                and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
+            )
+        ):
+            raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
+        self.initial_peers = initial_peers
+        self.kwargs = kwargs
+        self.num_workers = num_workers
+
+        self._record_validator = CompositeValidator(record_validators)
+        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
+        self.shutdown_timeout = shutdown_timeout
+        self._ready = MPFuture()
+        self.daemon = daemon
+
+        # These values will be fetched from the child process when requested
+        self._peer_id = None
+        self._client_mode = None
+        self._p2p_replica = None
+
+        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
+
+        if start:
+            self.run_in_background(await_ready=await_ready)
+
+    def run(self) -> None:
+        """Serve DHT forever. This function will not return until DHT node is shut down"""
+
+        loop = switch_to_uvloop()
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                if self._daemon_listen_maddr is not None:
+                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                else:
+                    replicated_p2p = None
+
+                self._node = await DHTNode.create(
+                    initial_peers=self.initial_peers,
+                    num_workers=self.num_workers,
+                    record_validator=self._record_validator,
+                    p2p=replicated_p2p,
+                    **self.kwargs,
+                )
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
+                try:
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
+                except asyncio.TimeoutError:
+                    pass
+                if not self._inner_pipe.poll():
+                    continue
+                try:
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._node.protocol.wait_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
+
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
+        """
+        Starts DHT in a background process. if await_ready, this method will wait until background dht
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
+
+    def shutdown(self) -> None:
+        """Shut down a running dht process"""
+        if self.is_alive():
+            self._outer_pipe.send(("_shutdown", [], {}))
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning("DHT did not shut down within the grace period; terminating it the hard way")
+                self.terminate()
+
+    async def _shutdown(self):
+        await self._node.shutdown()
+
+    def get(
+        self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
+    ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
+        """
+        Search for a key across DHT and return either first or latest entry (if found).
+        :param key: same key as in node.store(...)
+        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
+        :returns: (value, expiration time); if value was not found, returns None
+        """
+        future = MPFuture()
+        self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
+        try:
+            result = await self._node.get(key, latest=latest, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey] = None,
+        return_future: bool = False,
+        **kwargs,
+    ) -> Union[bool, MPFuture]:
+        """
+        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
+
+        :param key: msgpack-serializable key to be associated with value until expiration.
+        :param value: msgpack-serializable value to be stored under a given key until expiration.
+        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
+        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: True if store succeeds, False if it fails (due to no response or newer value)
+        """
+        future = MPFuture()
+        self._outer_pipe.send(
+            (
+                "_store",
+                [],
+                dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
+            )
+        )
+        return future if return_future else future.result()
+
+    async def _store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey],
+        future: MPFuture,
+        **kwargs,
+    ):
+        try:
+            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
+    ) -> Union[ReturnType, MPFuture[ReturnType]]:
+        """
+        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
+         for running custom functions DHT for special cases (e.g. declare experts, beam search)
+
+        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: coroutine outputs or MPFuture for these outputs
+        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
+          DHT fields made by this coroutine will not be accessible from the host process.
+        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
+          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
+        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
+        """
+        future = MPFuture()
+        self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
+        return future if return_future else future.result()
+
+    async def _run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
+    ):
+        try:
+            future.set_result(await coro(self, self._node))
+        except BaseException as e:
+            logger.exception("Caught an exception when running a coroutine:")
+            future.set_exception(e)
+
+    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
+        if not self._ready.done():
+            raise RuntimeError(
+                "Can't append new validators before the DHT process has started. "
+                "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
+            )
+
+        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
+
+    @staticmethod
+    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
+        node.protocol.record_validator.extend(record_validators)
+
+    @property
+    def peer_id(self) -> PeerID:
+        if self._peer_id is None:
+            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+        return self._peer_id
+
+    @staticmethod
+    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
+        return node.peer_id
+
+    @property
+    def client_mode(self) -> bool:
+        if self._client_mode is None:
+            self._client_mode = self.run_coroutine(DHT._get_client_mode)
+        return self._client_mode
+
+    @staticmethod
+    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
+        return node.protocol.client_mode
+
+    def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
+        """
+        Get multiaddrs of the current DHT node that should be accessible by other peers.
+
+        :param latest: ask the P2P daemon to refresh the visible multiaddrs
+        """
+
+        return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
+
+    @staticmethod
+    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
+        return await node.get_visible_maddrs(latest=latest)
+
+    async def replicate_p2p(self) -> P2P:
+        """
+        Get a replica of a P2P instance used in the DHT process internally.
+        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
+        """
+
+        if self._p2p_replica is None:
+            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
+            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
+        return self._p2p_replica
+
+    @staticmethod
+    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
+        return node.p2p.daemon_listen_maddr
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 1 - 1
hivemind/dht/node.py

@@ -717,7 +717,7 @@ class DHTNode:
         """Add key to a refresh queue, refresh at :refresh_time: or later"""
         if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
             self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
-            logger.debug("Spawned cache refresh task.")
+            logger.debug("Spawned cache refresh task")
         earliest_key, earliest_item = self.cache_refresh_queue.top()
         if earliest_item is None or refresh_time < earliest_item.expiration_time:
             self.cache_refresh_evt.set()  # if we new element is now earliest, notify the cache queue

+ 2 - 2
hivemind/dht/routing.py

@@ -10,7 +10,7 @@ from itertools import chain
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.p2p import PeerID
-from hivemind.utils import MSGPackSerializer, get_dht_time
+from hivemind.utils import DHTExpiration, MSGPackSerializer, get_dht_time
 
 DHTKey = Subkey = DHTValue = Any
 BinaryDHTID = BinaryDHTValue = bytes
@@ -217,7 +217,7 @@ class KBucket:
 
     def __delitem__(self, node_id: DHTID):
         if not (node_id in self.nodes_to_peer_id or node_id in self.replacement_nodes):
-            raise KeyError(f"KBucket does not contain node id={node_id}.")
+            raise KeyError(f"KBucket does not contain node id={node_id}")
 
         if node_id in self.replacement_nodes:
             del self.replacement_nodes[node_id]

+ 1 - 1
hivemind/dht/schema.py

@@ -18,7 +18,7 @@ class SchemaValidator(RecordValidatorBase):
     This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
     """
 
-    def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool = True, prefix: Optional[str] = None):
+    def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None):
         """
         :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
 

+ 2 - 2
hivemind/hivemind_cli/run_server.py

@@ -4,7 +4,7 @@ from pathlib import Path
 import configargparse
 import torch
 
-from hivemind.moe.server import Server
+from hivemind.moe import Server
 from hivemind.moe.server.layers import schedule_name_to_scheduler
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
@@ -28,7 +28,7 @@ def main():
                         help="specify the exact list of expert uids to create. Use either this or num_experts"
                              " and expert_pattern, not both")
     parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
-                        help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
+                        help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'")
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
 
     parser.add_argument('--num_handlers', type=int, default=None, required=False,

+ 1 - 0
hivemind/moe/__init__.py

@@ -3,6 +3,7 @@ from hivemind.moe.server import (
     ConnectionHandler,
     ExpertBackend,
     Server,
+    background_server,
     declare_experts,
     get_experts,
     register_expert_class,

+ 5 - 9
hivemind/moe/client/expert.py

@@ -1,8 +1,8 @@
-import pickle
 from concurrent.futures import Future
 from queue import Queue
 from threading import Thread
 from typing import Any, Awaitable, Dict, Optional, Tuple
+from typing import Any, Dict, Optional, Tuple
 
 import torch
 import torch.nn as nn
@@ -12,7 +12,7 @@ import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import asingle, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.utils import MSGPackSerializer, asingle, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
@@ -68,7 +68,7 @@ class RemoteExpert(nn.Module):
     def info(self):
         if self._info is None:
             outputs = _RemoteModuleCall.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
-            self._info = pickle.loads(outputs.serialized_info)
+            self._info = MSGPackSerializer.loads(outputs.serialized_info)
         return self._info
 
     def extra_repr(self):
@@ -134,9 +134,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ]
 
         outputs = cls.run_coroutine(
-            asingle(
-                stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
-            ),
+            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
         )
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
@@ -155,9 +153,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ]
 
         grad_inputs = cls.run_coroutine(
-            asingle(
-                ctx.stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
-            ),
+            ctx.stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
         )
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]

+ 6 - 6
hivemind/moe/client/moe.py

@@ -9,8 +9,8 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
@@ -48,7 +48,7 @@ class RemoteMixtureOfExperts(nn.Module):
         *,
         in_features,
         grid_size: Tuple[int, ...],
-        dht: hivemind.DHT,
+        dht: DHT,
         uid_prefix: str,
         k_best: int,
         k_min: int = 1,
@@ -238,14 +238,14 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies
         )
         if len(responded_inds) < k_min:
-            raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout.")
+            raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout")
 
         if not isinstance(info["outputs_schema"], tuple):
             outputs_schema = (info["outputs_schema"],)
         else:
             outputs_schema = info["outputs_schema"]
         outputs = nested_map(
-            lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
+            lambda descriptor: descriptor.make_zeros(num_samples, max_experts, device=flat_inputs[0].device),
             outputs_schema,
         )
 
@@ -330,7 +330,7 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies
         )
         if len(survivor_inds) < backward_k_min:
-            raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout.")
+            raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout")
 
         # assemble responses
         batch_inds, expert_inds = map(
@@ -341,7 +341,7 @@ class _RemoteCallMany(torch.autograd.Function):
         # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
 
         grad_inputs = nested_map(
-            lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
+            lambda descr: descr.make_zeros(num_samples, device=flat_grad_outputs[0].device),
             list(nested_flatten(info["forward_schema"])),
         )
 

+ 1 - 1
hivemind/moe/client/switch_moe.py

@@ -150,7 +150,7 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
         # for each grid dimension, sum across all indices for a dimension. Optimizing this leads to uniform allocation
         balancing_loss = torch.stack(
             [
-                torch.mean(dim_softmax.mean(0) * dim_utilization) * (dim_size ** 2)
+                torch.mean(dim_softmax.mean(0) * dim_utilization) * dim_size**2
                 for dim_softmax, dim_utilization, dim_size in zip(
                     grid_softmax, self.grid_utilization, self.beam_search.grid_size
                 )

+ 4 - 348
hivemind/moe/server/__init__.py

@@ -1,349 +1,5 @@
-from __future__ import annotations
-
-import multiprocessing as mp
-import multiprocessing.synchronize
-import threading
-from contextlib import contextmanager
-from functools import partial
-from pathlib import Path
-from typing import Dict, List, Optional, Tuple
-
-import torch
-from multiaddr import Multiaddr
-
-import hivemind
-from hivemind.dht import DHT
-from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
-from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
+from hivemind.moe.server.dht_handler import declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
-from hivemind.moe.server.layers import (
-    add_custom_models_from_file,
-    name_to_block,
-    name_to_input,
-    register_expert_class,
-    schedule_name_to_scheduler,
-)
-from hivemind.moe.server.runtime import Runtime
-from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
-
-logger = get_logger(__name__)
-
-
-class Server(threading.Thread):
-    """
-    Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
-    After creation, a server should be started: see Server.run or Server.run_in_background.
-
-    A working server does 3 things:
-     - processes incoming forward/backward requests via Runtime (created by the server)
-     - publishes updates to expert status every :update_period: seconds
-     - follows orders from HivemindController - if it exists
-
-    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
-    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
-    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
-    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
-        if too small for normal functioning, we recommend 4 handlers per expert backend.
-    :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
-        if dht is None, this parameter is ignored.
-    :param start: if True, the server will immediately start as a background thread and returns control after server
-        is ready (see .ready below)
-    """
-
-    def __init__(
-        self,
-        dht: Optional[DHT],
-        expert_backends: Dict[str, ExpertBackend],
-        num_connection_handlers: int = 1,
-        update_period: int = 30,
-        start=False,
-        checkpoint_dir=None,
-        **kwargs,
-    ):
-        super().__init__()
-        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-
-        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(1)]
-        if checkpoint_dir is not None:
-            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
-        else:
-            self.checkpoint_saver = None
-        self.runtime = Runtime(self.experts, **kwargs)
-
-        if self.dht and self.experts:
-            self.dht_handler_thread = DHTHandlerThread(
-                experts=self.experts,
-                dht=self.dht,
-                peer_id=self.dht.peer_id,
-                update_period=self.update_period,
-                daemon=True,
-            )
-
-        if start:
-            self.run_in_background(await_ready=True)
-
-    @classmethod
-    def create(
-        cls,
-        num_experts: int = None,
-        expert_uids: str = None,
-        expert_pattern: str = None,
-        expert_cls="ffn",
-        hidden_dim=1024,
-        optim_cls=torch.optim.Adam,
-        scheduler: str = "none",
-        num_warmup_steps=None,
-        num_total_steps=None,
-        clip_grad_norm=None,
-        num_handlers=None,
-        min_batch_size=1,
-        max_batch_size=4096,
-        device=None,
-        no_dht=False,
-        initial_peers=(),
-        checkpoint_dir: Optional[Path] = None,
-        compression=CompressionType.NONE,
-        stats_report_interval: Optional[int] = None,
-        custom_module_path=None,
-        *,
-        start: bool,
-    ) -> Server:
-        """
-        Instantiate a server with several identical experts. See argparse comments below for details
-        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
-        :param num_experts: run this many identical experts
-        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-           means "sample random experts between myprefix.0.0 and myprefix.255.255;
-        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
-        :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
-        :param hidden_dim: main dimension for expert_cls
-        :param num_handlers: server will use this many parallel processes to handle incoming requests
-        :param min_batch_size: total num examples in the same batch will be greater than this value
-        :param max_batch_size: total num examples in the same batch will not exceed this value
-        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
-
-        :param optim_cls: uses this optimizer to train all experts
-        :param scheduler: if not `none`, the name of the expert LR scheduler
-        :param num_warmup_steps: the number of warmup steps for LR schedule
-        :param num_total_steps: the total number of steps for LR schedule
-        :param clip_grad_norm: maximum gradient norm used for clipping
-
-        :param no_dht: if specified, the server will not be attached to a dht
-        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
-
-        :param checkpoint_dir: directory to save and load expert checkpoints
-
-        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
-            hosted on this server. For a more fine-grained compression, start server in python and specify compression
-            for each BatchTensorProto in ExpertBackend for the respective experts.
-
-        :param start: if True, starts server right away and returns when server is ready for requests
-        :param stats_report_interval: interval between two reports of batch processing performance statistics
-        """
-        if custom_module_path is not None:
-            add_custom_models_from_file(custom_module_path)
-        assert expert_cls in name_to_block
-
-        if no_dht:
-            dht = None
-        else:
-            dht = hivemind.DHT(initial_peers=initial_peers, start=True)
-            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
-            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
-
-        assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
-            num_experts is not None and expert_uids is None
-        ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
-
-        if expert_uids is None:
-            if checkpoint_dir is not None:
-                assert is_directory(checkpoint_dir)
-                expert_uids = [
-                    child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
-                ]
-                total_experts_in_checkpoint = len(expert_uids)
-                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
-
-                if total_experts_in_checkpoint > num_experts:
-                    raise ValueError(
-                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
-                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
-                    )
-            else:
-                expert_uids = []
-
-            uids_to_generate = num_experts - len(expert_uids)
-            if uids_to_generate > 0:
-                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
-                expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
-
-        num_experts = len(expert_uids)
-        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
-        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
-
-        sample_input = name_to_input[expert_cls](3, hidden_dim)
-        if isinstance(sample_input, tuple):
-            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
-        else:
-            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
-
-        scheduler = schedule_name_to_scheduler[scheduler]
-
-        # initialize experts
-        experts = {}
-        for expert_uid in expert_uids:
-            expert = name_to_block[expert_cls](hidden_dim)
-            experts[expert_uid] = hivemind.ExpertBackend(
-                name=expert_uid,
-                expert=expert,
-                args_schema=args_schema,
-                optimizer=optim_cls(expert.parameters()),
-                scheduler=scheduler,
-                num_warmup_steps=num_warmup_steps,
-                num_total_steps=num_total_steps,
-                clip_grad_norm=clip_grad_norm,
-                min_batch_size=min_batch_size,
-                max_batch_size=max_batch_size,
-            )
-
-        if checkpoint_dir is not None:
-            load_experts(experts, checkpoint_dir)
-
-        return cls(
-            dht,
-            experts,
-            num_connection_handlers=num_handlers,
-            device=device,
-            checkpoint_dir=checkpoint_dir,
-            stats_report_interval=stats_report_interval,
-            start=start,
-        )
-
-    def run(self):
-        """
-        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
-        runs Runtime (self.runtime) to process incoming requests.
-        """
-        logger.info(f"Server started with {len(self.experts)} experts:")
-        for expert_name, backend in self.experts.items():
-            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
-
-        if self.dht:
-            if not self.dht.is_alive():
-                self.dht.run_in_background(await_ready=True)
-
-            if self.experts:
-                self.dht_handler_thread.start()
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.start()
-
-        for process in self.conn_handlers:
-            if not process.is_alive():
-                process.start()
-            process.ready.result()
-
-        try:
-            self.runtime.run()
-        finally:
-            self.shutdown()
-
-    def run_in_background(self, await_ready=True, timeout=None):
-        """
-        Starts Server in a background thread. if await_ready, this method will wait until background server
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
-
-    @property
-    def ready(self) -> mp.synchronize.Event:
-        """
-        An event (multiprocessing.Event) that is set when the server is ready to process requests.
-
-        Example
-        =======
-        >>> server.start()
-        >>> server.ready.wait(timeout=10)
-        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
-        """
-        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
-
-    def shutdown(self):
-        """
-        Gracefully terminate the server, process-safe.
-        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
-        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
-        """
-        self.ready.clear()
-
-        for process in self.conn_handlers:
-            process.terminate()
-            process.join()
-        logger.debug("Connection handlers terminated")
-
-        if self.dht and self.experts:
-            self.dht_handler_thread.stop.set()
-            self.dht_handler_thread.join()
-
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
-
-        if self.dht is not None:
-            self.dht.shutdown()
-            self.dht.join()
-
-        logger.debug(f"Shutting down runtime")
-
-        self.runtime.shutdown()
-        logger.info("Server shutdown succesfully")
-
-
-@contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, List[Multiaddr]]:
-    """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
-    pipe, runners_pipe = mp.Pipe(duplex=True)
-    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
-    try:
-        runner.start()
-        # once the server is ready, runner will send us
-        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
-        start_ok, data = pipe.recv()
-        if start_ok:
-            yield data
-            pipe.send("SHUTDOWN")  # on exit from context, send shutdown signal
-        else:
-            raise RuntimeError(f"Server failed to start: {data}")
-    finally:
-        runner.join(timeout=shutdown_timeout)
-        if runner.is_alive():
-            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
-            runner.kill()
-            logger.info("Server terminated.")
-
-
-def _server_runner(pipe, *args, **kwargs):
-    try:
-        server = Server.create(*args, start=True, **kwargs)
-    except Exception as e:
-        logger.exception(f"Encountered an exception when starting a server: {e}")
-        pipe.send((False, f"{type(e).__name__} {e}"))
-        return
-
-    try:
-        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
-        pipe.send((True, (server.listen_on, dht_maddrs)))
-        pipe.recv()  # wait for shutdown signal
-
-    finally:
-        logger.info("Shutting down server...")
-        server.shutdown()
-        server.join()
-        logger.info("Server shut down.")
+from hivemind.moe.server.layers import register_expert_class
+from hivemind.moe.server.server import Server, background_server
+from hivemind.moe.server.connection_handler import ConnectionHandler

+ 6 - 7
hivemind/moe/server/connection_handler.py

@@ -1,6 +1,5 @@
 import asyncio
 import multiprocessing as mp
-import pickle
 from typing import AsyncIterator, Dict
 
 import torch
@@ -10,7 +9,7 @@ from hivemind.dht import DHT
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import MPFuture, as_aiter, get_logger, nested_flatten
+from hivemind.utils import MSGPackSerializer, MPFuture, as_aiter, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 
 logger = get_logger(__name__)
@@ -54,11 +53,11 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             logger.debug("Caught KeyboardInterrupt, shutting down")
 
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
-        return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
+        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
 
     async def rpc_forward(
         self, request: runtime_pb2.ExpertRequest, context: P2PContext
-    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+    ) -> runtime_pb2.ExpertResponse:
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
@@ -67,15 +66,15 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
         ]
 
-        yield runtime_pb2.ExpertResponse(tensors=serialized_response)
+        return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
     async def rpc_backward(
         self, request: runtime_pb2.ExpertRequest, context: P2PContext
-    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+    ) -> runtime_pb2.ExpertResponse:
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
         serialized_response = [
             serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
             for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
         ]
-        yield runtime_pb2.ExpertResponse(tensors=serialized_response)
+        return runtime_pb2.ExpertResponse(tensors=serialized_response)

+ 2 - 2
hivemind/moe/server/expert_backend.py

@@ -74,8 +74,8 @@ class ExpertBackend:
 
         if outputs_schema is None:
             # run expert once to get outputs schema
-            dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
-            dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
+            dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
+            dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 

+ 2 - 73
hivemind/moe/server/expert_uid.py

@@ -1,12 +1,7 @@
-import random
 import re
-from typing import List, NamedTuple, Optional, Tuple, Union
+from typing import NamedTuple, Tuple, Union
 
-import hivemind
-from hivemind.dht import DHT
-from hivemind.utils import Endpoint, get_logger
-
-logger = get_logger(__name__)
+from hivemind.utils import Endpoint
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
@@ -32,69 +27,3 @@ def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPref
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
     pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
     return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
-
-
-def generate_uids_from_pattern(
-    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
-) -> List[str]:
-    """
-    Sample experts from a given pattern, remove duplicates.
-    :param num_experts: sample this many unique expert uids
-    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-     means "sample random experts between myprefix.0.0 and myprefix.255.255;
-    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
-    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
-    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
-     a small chance of sampling duplicate expert uids.
-    """
-    remaining_attempts = attempts_per_expert * num_experts
-    found_uids, attempted_uids = list(), set()
-
-    def _generate_uid():
-        if expert_pattern is None:
-            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
-
-        uid = []
-        for block in expert_pattern.split(UID_DELIMITER):
-            try:
-                if "[" not in block and "]" not in block:
-                    uid.append(block)
-                elif block.startswith("[") and block.endswith("]") and ":" in block:
-                    slice_start, slice_end = map(int, block[1:-1].split(":"))
-                    uid.append(str(random.randint(slice_start, slice_end - 1)))
-                else:
-                    raise ValueError("Block must be either fixed or a range [from:to]")
-            except KeyboardInterrupt:
-                raise
-            except Exception as e:
-                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
-        return UID_DELIMITER.join(uid)
-
-    while remaining_attempts > 0 and len(found_uids) < num_experts:
-
-        # 1. sample new expert uids at random
-        new_uids = []
-        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
-            new_uid = _generate_uid()
-            remaining_attempts -= 1
-            if new_uid not in attempted_uids:
-                attempted_uids.add(new_uid)
-                new_uids.append(new_uid)
-
-        # 2. look into DHT (if given) and remove duplicates
-        if dht:
-            existing_expert_uids = {
-                found_expert.uid
-                for found_expert in hivemind.moe.server.get_experts(dht, new_uids)
-                if found_expert is not None
-            }
-            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
-
-        found_uids += new_uids
-
-    if len(found_uids) != num_experts:
-        logger.warning(
-            f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
-            f"{attempts_per_expert * num_experts} attempts"
-        )
-    return found_uids

+ 412 - 0
hivemind/moe/server/server.py

@@ -0,0 +1,412 @@
+from __future__ import annotations
+
+import multiprocessing as mp
+import threading
+from contextlib import contextmanager
+from functools import partial
+from pathlib import Path
+import random
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from multiaddr import Multiaddr
+
+from hivemind.dht import DHT
+from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.moe.server.layers import (
+    add_custom_models_from_file,
+    name_to_block,
+    name_to_input,
+    schedule_name_to_scheduler,
+)
+from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.logging import get_logger
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
+from hivemind.utils import Endpoint
+
+logger = get_logger(__name__)
+
+
+class Server(threading.Thread):
+    """
+    Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
+    After creation, a server should be started: see Server.run or Server.run_in_background.
+
+    A working server does 3 things:
+     - processes incoming forward/backward requests via Runtime (created by the server)
+     - publishes updates to expert status every :update_period: seconds
+     - follows orders from HivemindController - if it exists
+
+    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
+     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
+    :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
+    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
+    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
+        if too small for normal functioning, we recommend 4 handlers per expert backend.
+    :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
+        if dht is None, this parameter is ignored.
+    :param start: if True, the server will immediately start as a background thread and returns control after server
+        is ready (see .ready below)
+    """
+
+    def __init__(
+        self,
+        dht: Optional[DHT],
+        expert_backends: Dict[str, ExpertBackend],
+        num_connection_handlers: int = 1,
+        update_period: int = 30,
+        start=False,
+        checkpoint_dir=None,
+        **kwargs,
+    ):
+        super().__init__()
+        self.dht, self.experts, self.update_period = dht, expert_backends, update_period
+
+        self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(1)]
+        if checkpoint_dir is not None:
+            self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
+        else:
+            self.checkpoint_saver = None
+        self.runtime = Runtime(self.experts, **kwargs)
+
+        if self.dht and self.experts:
+            self.dht_handler_thread = DHTHandlerThread(
+                experts=self.experts,
+                dht=self.dht,
+                peer_id=self.dht.peer_id,
+                update_period=self.update_period,
+                daemon=True,
+            )
+
+        if start:
+            self.run_in_background(await_ready=True)
+
+    @classmethod
+    def create(
+        cls,
+        num_experts: int = None,
+        expert_uids: str = None,
+        expert_pattern: str = None,
+        expert_cls="ffn",
+        hidden_dim=1024,
+        optim_cls=torch.optim.Adam,
+        scheduler: str = "none",
+        num_warmup_steps=None,
+        num_total_steps=None,
+        clip_grad_norm=None,
+        num_handlers=None,
+        min_batch_size=1,
+        max_batch_size=4096,
+        device=None,
+        no_dht=False,
+        initial_peers=(),
+        checkpoint_dir: Optional[Path] = None,
+        compression=CompressionType.NONE,
+        stats_report_interval: Optional[int] = None,
+        custom_module_path=None,
+        *,
+        start: bool,
+    ) -> Server:
+        """
+        Instantiate a server with several identical experts. See argparse comments below for details
+        :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
+        :param num_experts: run this many identical experts
+        :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+           means "sample random experts between myprefix.0.0 and myprefix.255.255;
+        :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
+        :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
+        :param hidden_dim: main dimension for expert_cls
+        :param num_handlers: server will use this many parallel processes to handle incoming requests
+        :param min_batch_size: total num examples in the same batch will be greater than this value
+        :param max_batch_size: total num examples in the same batch will not exceed this value
+        :param device: all experts will use this device in torch notation; default: cuda if available else cpu
+
+        :param optim_cls: uses this optimizer to train all experts
+        :param scheduler: if not `none`, the name of the expert LR scheduler
+        :param num_warmup_steps: the number of warmup steps for LR schedule
+        :param num_total_steps: the total number of steps for LR schedule
+        :param clip_grad_norm: maximum gradient norm used for clipping
+
+        :param no_dht: if specified, the server will not be attached to a dht
+        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
+
+        :param checkpoint_dir: directory to save and load expert checkpoints
+
+        :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
+            hosted on this server. For a more fine-grained compression, start server in python and specify compression
+            for each BatchTensorProto in ExpertBackend for the respective experts.
+
+        :param start: if True, starts server right away and returns when server is ready for requests
+        :param stats_report_interval: interval between two reports of batch processing performance statistics
+        """
+        if custom_module_path is not None:
+            add_custom_models_from_file(custom_module_path)
+        assert expert_cls in name_to_block
+
+        if no_dht:
+            dht = None
+        else:
+            dht = DHT(initial_peers=initial_peers, start=True)
+            visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
+            logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+
+        assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
+            num_experts is not None and expert_uids is None
+        ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
+
+        if expert_uids is None:
+            if checkpoint_dir is not None:
+                assert is_directory(checkpoint_dir)
+                expert_uids = [
+                    child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
+                ]
+                total_experts_in_checkpoint = len(expert_uids)
+                logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
+
+                if total_experts_in_checkpoint > num_experts:
+                    raise ValueError(
+                        f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
+                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
+                    )
+            else:
+                expert_uids = []
+
+            uids_to_generate = num_experts - len(expert_uids)
+            if uids_to_generate > 0:
+                logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
+                expert_uids.extend(_generate_uids(uids_to_generate, expert_pattern, dht))
+
+        num_experts = len(expert_uids)
+        num_handlers = num_handlers if num_handlers is not None else num_experts * 8
+        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+
+        sample_input = name_to_input[expert_cls](3, hidden_dim)
+        if isinstance(sample_input, tuple):
+            args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
+        else:
+            args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
+
+        scheduler = schedule_name_to_scheduler[scheduler]
+
+        # initialize experts
+        experts = {}
+        for expert_uid in expert_uids:
+            expert = name_to_block[expert_cls](hidden_dim)
+            experts[expert_uid] = ExpertBackend(
+                name=expert_uid,
+                expert=expert,
+                args_schema=args_schema,
+                optimizer=optim_cls(expert.parameters()),
+                scheduler=scheduler,
+                num_warmup_steps=num_warmup_steps,
+                num_total_steps=num_total_steps,
+                clip_grad_norm=clip_grad_norm,
+                min_batch_size=min_batch_size,
+                max_batch_size=max_batch_size,
+            )
+
+        if checkpoint_dir is not None:
+            load_experts(experts, checkpoint_dir)
+
+        return cls(
+            dht,
+            experts,
+            num_connection_handlers=num_handlers,
+            device=device,
+            checkpoint_dir=checkpoint_dir,
+            stats_report_interval=stats_report_interval,
+            start=start,
+        )
+
+    def run(self):
+        """
+        Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
+        runs Runtime (self.runtime) to process incoming requests.
+        """
+        logger.info(f"Server started with {len(self.experts)} experts:")
+        for expert_name, backend in self.experts.items():
+            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
+
+        if self.dht:
+            if not self.dht.is_alive():
+                self.dht.run_in_background(await_ready=True)
+
+            if self.experts:
+                self.dht_handler_thread.start()
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.start()
+
+        for process in self.conn_handlers:
+            if not process.is_alive():
+                process.start()
+            process.ready.result()
+
+        try:
+            self.runtime.run()
+        finally:
+            self.shutdown()
+
+    def run_in_background(self, await_ready=True, timeout=None):
+        """
+        Starts Server in a background thread. if await_ready, this method will wait until background server
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready and not self.ready.wait(timeout=timeout):
+            raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
+
+    @property
+    def ready(self) -> mp.synchronize.Event:
+        """
+        An event (multiprocessing.Event) that is set when the server is ready to process requests.
+
+        Example
+        =======
+        >>> server.start()
+        >>> server.ready.wait(timeout=10)
+        >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
+        """
+        return self.runtime.ready  # mp.Event that is true if self is ready to process batches
+
+    def shutdown(self):
+        """
+        Gracefully terminate the server, process-safe.
+        Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
+        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
+        """
+        self.ready.clear()
+
+        for process in self.conn_handlers:
+            process.terminate()
+            process.join()
+        logger.debug("Connection handlers terminated")
+
+        if self.dht and self.experts:
+            self.dht_handler_thread.stop.set()
+            self.dht_handler_thread.join()
+
+        if self.checkpoint_saver is not None:
+            self.checkpoint_saver.stop.set()
+            self.checkpoint_saver.join()
+
+        if self.dht is not None:
+            self.dht.shutdown()
+            self.dht.join()
+
+        logger.debug(f"Shutting down runtime")
+
+        self.runtime.shutdown()
+        logger.info("Server shutdown succesfully")
+
+
+@contextmanager
+def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[Endpoint, List[Multiaddr]]:
+    """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
+    pipe, runners_pipe = mp.Pipe(duplex=True)
+    runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
+    try:
+        runner.start()
+        # once the server is ready, runner will send us
+        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
+        start_ok, data = pipe.recv()
+        if start_ok:
+            yield data
+            pipe.send("SHUTDOWN")  # on exit from context, send shutdown signal
+        else:
+            raise RuntimeError(f"Server failed to start: {data}")
+    finally:
+        runner.join(timeout=shutdown_timeout)
+        if runner.is_alive():
+            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
+            runner.kill()
+            logger.info("Server terminated.")
+
+
+def _server_runner(pipe, *args, **kwargs):
+    try:
+        server = Server.create(*args, start=True, **kwargs)
+    except Exception as e:
+        logger.exception(f"Encountered an exception when starting a server: {e}")
+        pipe.send((False, f"{type(e).__name__} {e}"))
+        return
+
+    try:
+        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
+        pipe.send((True, (server.listen_on, dht_maddrs)))
+        pipe.recv()  # wait for shutdown signal
+
+    finally:
+        logger.info("Shutting down server...")
+        server.shutdown()
+        server.join()
+        logger.info("Server shut down.")
+
+def _generate_uids(
+    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
+) -> List[str]:
+    """
+    Sample experts from a given pattern, remove duplicates.
+    :param num_experts: sample this many unique expert uids
+    :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
+     means "sample random experts between myprefix.0.0 and myprefix.255.255;
+    :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
+    :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
+    :note: this method is not strictly process-safe. If several servers run it concurrently, they have
+     a small chance of sampling duplicate expert uids.
+    """
+    remaining_attempts = attempts_per_expert * num_experts
+    found_uids, attempted_uids = list(), set()
+
+    def _generate_uid():
+        if expert_pattern is None:
+            return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
+
+        uid = []
+        for block in expert_pattern.split(UID_DELIMITER):
+            try:
+                if "[" not in block and "]" not in block:
+                    uid.append(block)
+                elif block.startswith("[") and block.endswith("]") and ":" in block:
+                    slice_start, slice_end = map(int, block[1:-1].split(":"))
+                    uid.append(str(random.randint(slice_start, slice_end - 1)))
+                else:
+                    raise ValueError("Block must be either fixed or a range [from:to]")
+            except KeyboardInterrupt:
+                raise
+            except Exception as e:
+                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
+        return UID_DELIMITER.join(uid)
+
+    while remaining_attempts > 0 and len(found_uids) < num_experts:
+
+        # 1. sample new expert uids at random
+        new_uids = []
+        while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
+            new_uid = _generate_uid()
+            remaining_attempts -= 1
+            if new_uid not in attempted_uids:
+                attempted_uids.add(new_uid)
+                new_uids.append(new_uid)
+
+        # 2. look into DHT (if given) and remove duplicates
+        if dht is not None:
+            existing_expert_uids = {
+                found_expert.uid for found_expert in get_experts(dht, new_uids) if found_expert is not None
+            }
+            new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
+
+        found_uids += new_uids
+
+    if len(found_uids) != num_experts:
+        logger.warning(
+            f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
+            f"{attempts_per_expert * num_experts} attempts"
+        )
+    return found_uids

+ 3 - 0
hivemind/optim/__init__.py

@@ -1,4 +1,7 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
+from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
+from hivemind.optim.training_averager import TrainingAverager

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 import torch.optim
 
-from hivemind import TrainingAverager
 from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.training_averager import TrainingAverager
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 8 - 0
hivemind/optim/base.py

@@ -1,3 +1,5 @@
+from warnings import warn
+
 import torch
 
 from hivemind.dht import DHT
@@ -8,6 +10,12 @@ class DecentralizedOptimizerBase(torch.optim.Optimizer):
 
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
         self.opt, self.dht = opt, dht
+        warn(
+            "DecentralizedOptimizerBase and its subclasses have been deprecated and will be removed "
+            "in hivemind 1.1.0. Use hivemind.Optimizer instead",
+            FutureWarning,
+            stacklevel=2,
+        )
 
     @property
     def state(self):

+ 75 - 32
hivemind/optim/collaborative.py

@@ -9,13 +9,14 @@ import numpy as np
 import torch
 from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 
-from hivemind.averaging.training import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.utils import Endpoint, get_dht_time, get_logger
+from hivemind.optim.grad_scaler import HivemindGradScaler
+from hivemind.optim.training_averager import TrainingAverager
+from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils.performance_ema import PerformanceEMA
 
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
@@ -56,11 +57,15 @@ class TrainingProgressSchema(BaseModel):
 
 class CollaborativeOptimizer(DecentralizedOptimizerBase):
     """
-    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
+    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers.
 
     These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
     Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
 
+    :note: **For new projects, please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
+      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and many advanced ones.
+      CollaborativeOptimizer will still be supported for a while, but it will be deprecated in v1.1.0.
+
     :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
 
       * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
@@ -147,6 +152,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
 
+        self._step_supports_amp_scaling = self.reuse_grad_buffers  # enable custom execution with torch GradScaler
+
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
@@ -197,6 +204,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 try:
                     self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     break
+                except KeyboardInterrupt:
+                    raise
                 except BaseException as e:
                     logger.exception(f"Failed to load state from peers: {e}, retrying ...")
                     continue
@@ -205,13 +214,26 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.reset_accumulated_grads_()
             self.update_scheduler()
 
-    def step(self, batch_size: Optional[int] = None, **kwargs):
+    def state_dict(self) -> dict:
+        state_dict = super().state_dict()
+        state_dict["state"]["collaborative_step"] = self.local_step
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "collaborative_step" in state_dict["state"]:
+            self.averager.local_step = state_dict["state"].pop("collaborative_step")
+        return super().load_state_dict(state_dict)
+
+    def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
 
         :param batch_size: optional override for batch_size_per_step from init
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
+        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
+            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler)")
         if self.batch_size_per_step is None:
             if batch_size is None:
                 raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
@@ -220,12 +242,19 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
         if not self.is_synchronized and not self.is_within_tolerance:
-            logger.log(self.status_loglevel, "Peer is out of sync.")
+            logger.log(self.status_loglevel, "Peer is out of sync")
             self.load_state_from_peers()
             return
         elif not self.is_synchronized and self.is_within_tolerance:
             self.averager.local_step = self.collaboration_state.optimizer_step
-            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
+            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}")
+
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.local_samples_accumulated = self.local_steps_accumulated = 0
+            self.reset_accumulated_grads_()
+            self.should_report_progress.set()
+            return
 
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
             logger.warning(
@@ -238,7 +267,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_updates_accumulated += 1
-            self.performance_ema.update(num_processed=batch_size)
+            self.performance_ema.update(task_size=batch_size)
             self.should_report_progress.set()
 
         if not self.collaboration_state.ready_for_step:
@@ -251,6 +280,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.unscale_(self)
+
             current_step, group_info = self.averager.local_step, None
 
             if self.collaboration_state.num_peers > 1:
@@ -271,15 +304,20 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                                 logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
 
                 except BaseException as e:
-                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
+                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
 
             else:
                 logger.log(
                     self.status_loglevel,
-                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
+                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s)",
                 )
 
-            self.opt.step()
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.step(self)
+            else:
+                self.opt.step()
+
             self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_updates_accumulated = 0
             self.collaboration_state.register_step(current_step + 1)
@@ -287,6 +325,13 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.collaboration_state_updated.set()
             self.update_scheduler()
 
+            if grad_scaler is not None:
+                with grad_scaler.running_global_step():
+                    assert grad_scaler.update()
+
+            if not self.averager.client_mode:
+                self.averager.state_sharing_priority = self.local_step
+
         logger.log(self.status_loglevel, f"Optimizer step: done!")
 
         return group_info
@@ -320,7 +365,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                         else:
                             logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
             except BaseException as e:
-                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
+                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
 
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
@@ -344,38 +389,36 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         """local gradient accumulators"""
         if self.reuse_grad_buffers:
             yield from self._grad_buffers()
-        elif self._grads is None:
-            with torch.no_grad():
-                self._grads = [
-                    torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
-                ]
+            return
+
+        if self._grads is None:
+            self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
         yield from self._grads
 
     @torch.no_grad()
     def accumulate_grads_(self, batch_size: int):
         """add current gradients to grad accumulators (if any)"""
         if self.reuse_grad_buffers:
-            return  # user is responsible for accumulating gradients in .grad buffers
-        alpha = float(batch_size) / self.batch_size_per_step
-        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-            grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+            # user is responsible for accumulating gradients in .grad buffers
+            assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
+        else:
+            alpha = float(batch_size) / self.batch_size_per_step
+            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
 
     @torch.no_grad()
     def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
-        if self.reuse_grad_buffers:
-            return
-        for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-            grad_buf[...] = grad_acc.to(grad_buf.device)
-            if scale_by is not None:
+        if not self.reuse_grad_buffers:
+            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
+                grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
+        if scale_by is not None:
+            for grad_buf in self._grad_buffers():
                 grad_buf.mul_(scale_by)
 
     @torch.no_grad()
     def reset_accumulated_grads_(self):
-        if self.reuse_grad_buffers:
-            self.opt.zero_grad()
-        else:
-            for grad_buf in self.accumulated_grads():
-                grad_buf.zero_()
+        for grad_buf in self.accumulated_grads():
+            grad_buf.zero_()
 
     def report_training_progress(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next step"""
@@ -509,7 +552,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             value=None,
             expiration_time=get_dht_time() + self.metadata_expiration,
         )
-        logger.debug(f"{self.__class__.__name__} is shut down.")
+        logger.debug(f"{self.__class__.__name__} is shut down")
 
     def __del__(self):
         self.shutdown()

+ 226 - 0
hivemind/optim/grad_averager.py

@@ -0,0 +1,226 @@
+import contextlib
+from typing import Iterable, Iterator, Optional
+
+import torch
+
+import hivemind
+from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
+from hivemind.utils import DHTExpiration, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class GradientAverager(DecentralizedAverager):
+    """
+    An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
+    GradientAverager is meant to be used within hivemind.Optimizer, but it can be used standalone (see example below).
+
+    GradientAverager manages three sets of buffers:
+    (1) model gradients - the gradients associated with local model parameters by PyTorch (param.grad).
+        These tensors are typically stored on device and updated by torch autograd
+    (2) gradient accumulators - an [optional] set of buffers where local gradients are accumulated.
+      - note: if reuse_grad_buffers is True, the averager will use gradients from parameters as local accumulators,
+        which reduces RAM usage but requires the user to avoid calling zero_grad / clip_grad manually
+    (3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
+
+
+    Example:
+
+    >>> model = SuchModelMuchLayers()
+    >>> opt = torch.optim.Adam(model.parameters())
+    >>> grad_averager = GradientAverager(model.parameters(), dht=hivemind.DHT(...))
+    >>> next_step_time = hivemind.get_dht_time() + 60   # runs global steps every 60 seconds
+    >>> next_step_control = None
+    >>> while True:
+    >>>    # accumulate as many gradients as you can before next_step_time
+    >>>    loss = compute_loss(model, batch_size=32)
+    >>>    loss.backward()
+    >>>    grad_averager.accumulate_grads_(batch_size=32)
+    >>>    # [optional] next step in 5 seconds, start looking for peers in advance
+    >>>    if next_step_time - hivemind.get_dht_time() <= 5
+    >>>        next_step_control = grad_averager.schedule_step(scheduled_time=next_step_time)
+    >>>    # aggregate gradients and perform optimizer step
+    >>>    if hivemind.get_dht_time() >= next_step_time:
+    >>>        grad_averager.step(control=next_step_control)
+    >>>        with grad_averager.use_averaged_gradients():  # this will fill param.grads with aggregated gradients
+    >>>            opt.step()  # update model parameters using averaged gradients
+    >>>        grad_averager.reset_accumulated_grads_()  # prepare for next step
+    >>>        next_step_time = hivemind.get_dht_time() + 60
+    >>>        next_step_control = None
+
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        *,
+        dht: hivemind.DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        **kwargs,
+    ):
+        if reuse_grad_buffers and accumulate_grads_on is not None:
+            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
+        client_mode = client_mode if client_mode is not None else dht.client_mode
+        self.parameters = tuple(parameters)
+        self.reuse_grad_buffers = reuse_grad_buffers
+        self.warn = warn
+        self.local_samples_accumulated = 0
+        self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        self._local_accumulators = None
+        if not reuse_grad_buffers:
+            self._local_accumulators = tuple(
+                torch.zeros_like(grad, device=accumulate_grads_on) for grad in self._grads_from_parameters()
+            )
+        self._accumulators_used_in_step = False
+        self._new_averaged_grads = False
+
+        with torch.no_grad():
+            averaged_grads = tuple(
+                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+            )
+        super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
+
+    def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
+        """gradient buffers associated with parameters"""
+        for param in self.parameters:
+            if param.grad is None:
+                param.grad = torch.zeros_like(param)
+            yield param.grad
+
+    @torch.no_grad()
+    def _grad_accumulators(self) -> Iterator[torch.Tensor]:
+        """averager-based gradient accumulators"""
+        assert (self._local_accumulators is None) == self.reuse_grad_buffers
+        yield from self._grads_from_parameters() if self.reuse_grad_buffers else self._local_accumulators
+
+    @torch.no_grad()
+    def accumulate_grads_(self, batch_size: int):
+        """add current gradients to local grad accumulators (if used)"""
+        if self._accumulators_used_in_step and self.warn:
+            logger.warning(
+                "[warn=True] Gradient accumulators were not reset since the last averaging round. Please "
+                "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)"
+            )
+            self._accumulators_used_in_step = False  # warn once per round
+        if self._anchor_batch_size is None:
+            # remember the first batch size to correctly re-scale gradients if subsequent batches have a different size
+            self._anchor_batch_size = batch_size
+        self.local_samples_accumulated += batch_size
+        self.local_times_accumulated += 1
+        if self.reuse_grad_buffers:
+            pass  # user is responsible for accumulating gradients in .grad buffers
+        else:
+            alpha = float(batch_size) / self._anchor_batch_size
+            for grad_buf, grad_acc in zip(self._grads_from_parameters(), self._grad_accumulators()):
+                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
+    def step(
+        self,
+        weight: Optional[float] = None,
+        reset_accumulators: bool = True,
+        control: Optional[StepControl] = None,
+        timeout: Optional[float] = None,
+        wait: bool = True,
+        **kwargs,
+    ):
+        """
+        Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators
+
+        :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
+        :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
+        :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
+        :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True)
+        :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
+        """
+        if control is None:
+            control = self.schedule_step(timeout=timeout, **kwargs)
+        elif len(kwargs) > 0:
+            raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect")
+        assert not control.triggered, f"This {type(control)} instance was already used"
+        if self._new_averaged_grads and self.warn:
+            logger.warning(
+                "[warn=True] Starting new averaging round, but previous round results were not used. "
+                "This may be a sign of incorrect optimizer behavior"
+            )
+
+        self.load_accumulators_into_averager_()
+        self._accumulators_used_in_step = True
+        self._new_averaged_grads = True
+
+        control.weight = self.local_samples_accumulated if weight is None else weight
+        if reset_accumulators:
+            self.reset_accumulated_grads_()
+        control.allow_allreduce()
+
+        return control.result(timeout) if wait else control
+
+    @torch.no_grad()
+    def load_accumulators_into_averager_(self):
+        """load locally accumulated gradients into the averager for aggregation"""
+        # divide locally accumulated gradients by the number of times they were accumulated
+        grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
+        with self.get_tensors() as averaged_grads:
+            for grad_acc, averaged_grad in zip(self._grad_accumulators(), averaged_grads):
+                averaged_grad.copy_(grad_acc, non_blocking=True).mul_(grad_scale)
+
+    @torch.no_grad()
+    def reset_accumulated_grads_(self):
+        """reset averager-internal gradient accumulators and the denominator"""
+        self._accumulators_used_in_step = False
+        self.local_samples_accumulated = self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        for grad_buf in self._grad_accumulators():
+            grad_buf.zero_()
+
+    @contextlib.contextmanager
+    @torch.no_grad()
+    def use_averaged_gradients(self):
+        """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
+        self._new_averaged_grads = False
+        with self.get_tensors() as averaged_grads:
+            assert len(averaged_grads) == len(self.parameters)
+            try:
+                old_grads = [param.grad for param in self.parameters]
+                for param, new_grad in zip(self.parameters, averaged_grads):
+                    param.grad = new_grad
+                yield averaged_grads
+            finally:
+                for param, old_grad in zip(self.parameters, old_grads):
+                    param.grad = old_grad
+
+    def notify_used_averaged_gradients(self):
+        """Notify averager that the results of a previous averaging round are accounted for"""
+        self._new_averaged_grads = False

+ 125 - 0
hivemind/optim/grad_scaler.py

@@ -0,0 +1,125 @@
+import contextlib
+import threading
+from copy import deepcopy
+from typing import Dict, Optional
+
+import torch
+from torch.cuda.amp import GradScaler as TorchGradScaler
+from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
+from torch.optim import Optimizer as TorchOptimizer
+
+import hivemind
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class GradScaler(TorchGradScaler):
+    """
+    A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.
+
+    :note: if not using reuse_grad_buffers=True, one can and *should* train normally without this class, e.g. using
+      standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.
+
+    hivemind.GradScaler makes 3 modifications to the regular PyTorch AMP:
+
+    - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
+    - limit increasing gradient scale to only immediately after global optimizer steps
+    - allow training with some or master parameters in float16
+
+    :note: The above modiffications will be enabled automatically. One can (and should) use hivemind.GradScaler exactly
+      as regular ``torch.amp.GradScaler``.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._is_running_global_step = False
+        self._is_ready_to_update = False
+        self._inner_optimizer_states = {}
+        self._optimizer_states_to_reset = set()
+        self._lock = threading.RLock()
+
+    @contextlib.contextmanager
+    def running_global_step(self):
+        with self._lock:
+            was_running, self._is_running_global_step = self._is_running_global_step, True
+            try:
+                yield
+            finally:
+                self._is_running_global_step = was_running
+
+    def unscale_(self, optimizer: TorchOptimizer) -> bool:
+        with self._lock:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            if self._is_running_global_step:
+                super().unscale_(optimizer)
+                self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                # note: we store unscaled optimizer state in a separate dict and not in _per_optimizer_states in order
+                # to avoid an edge case where full DPU peer encounters overflow in local gradients while averaging
+                # offloaded gradients (i.e. after global unscale but before global step). Due to overflow, next call to
+                # .update on user side would reset *all* optimizer states and cause .step to unscale gradients twice.
+                # Offloaded optimizer is not affected by overflow in on-device gradients and should not be reset.
+                return True
+            else:
+                self._check_inf_per_device(optimizer)
+                self._optimizer_states_to_reset.add(id(optimizer))
+                return False
+
+    def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
+        if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
+            # ^-- invoked privately within hivemind optimizer
+            inner_optimizer = optimizer
+            with self._lock:
+                if self._is_ready_to_update:
+                    logger.warning("Please call grad_scaler.update() after each step")
+
+                inner_optimizer_state = self._inner_optimizer_states.pop(id(inner_optimizer), None)
+                if inner_optimizer_state is not None:
+                    self._per_optimizer_states[id(inner_optimizer)] = inner_optimizer_state
+                assert (
+                    self._per_optimizer_states[id(inner_optimizer)]["stage"] == OptState.UNSCALED
+                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step"
+                if self.are_grads_finite(inner_optimizer, use_cached=True):
+                    super().step(inner_optimizer, *args, **kwargs)
+                else:
+                    logger.warning("Skipping global step due to gradient over/underflow")
+                self._is_ready_to_update = True
+                return True
+        else:
+            super().step(optimizer)
+            self._optimizer_states_to_reset.add(id(optimizer))
+            return False
+
+    def update(self, new_scale: Optional[float] = None) -> bool:
+        with self._lock:
+            total_infs = 0
+            for optimizer_state in self._per_optimizer_states.values():
+                total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
+
+            if self._is_ready_to_update or total_infs != 0:
+                # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
+                super().update(new_scale)
+                self._is_ready_to_update = False
+                return True
+            else:
+                for opt_id in self._optimizer_states_to_reset:
+                    self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
+                self._optimizer_states_to_reset.clear()
+                return False
+
+    def _unscale_grads_(
+        self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
+    ) -> Dict[torch.device, torch.Tensor]:
+        # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
+        # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
+        return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
+
+    def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
+        opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
+        return not sum(v.item() for v in opt_dict.values())
+
+
+class HivemindGradScaler(GradScaler):
+    def __init__(self, *args, **kwargs):
+        logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
+        super().__init__(*args, **kwargs)

+ 779 - 0
hivemind/optim/optimizer.py

@@ -0,0 +1,779 @@
+from __future__ import annotations
+
+import logging
+import os
+import time
+from functools import partial
+from typing import Callable, Optional, Sequence, Union
+
+import torch
+
+from hivemind.averaging.control import AveragingStage, StepControl
+from hivemind.compression import CompressionBase, NoCompression
+from hivemind.dht import DHT
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
+from hivemind.optim.state_averager import (
+    LRSchedulerBase,
+    OptimizerFactory,
+    Parameters,
+    ParamGroups,
+    SchedulerFactory,
+    TorchOptimizer,
+    TrainingStateAverager,
+)
+from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class Optimizer(torch.optim.Optimizer):
+    """
+    hivemind.Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
+
+    By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
+    There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
+    or even fully asynchronous (use_local_updates=True).
+
+    :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
+
+    >>> model = transformers.AutoModel("albert-xxlarge-v2")
+    >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
+    >>> opt = hivemind.Optimizer(dht=dht, run_id="run_42", batch_size_per_step=4, target_batch_size=4096,
+    >>>                          params=model.parameters(), optimizer=lambda params: torch.optim.Adam(params))
+    >>> while True:
+    >>>     loss = compute_loss_on_batch(model, batch_size=4)
+    >>>     opt.zero_grad()
+    >>>     loss.backward()
+    >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
+
+    By default, peers will perform the following steps:
+
+     * accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
+     * after peers collectively accumulate target_batch_size, average gradients with peers and perform optimizer step;
+     * if your peer lags behind the rest of the swarm, it will download parameters and optimizer state from others;
+
+    Unlike regular training, your device may join midway through training, when other peers already made some progress.
+    For this reason, any learning rate schedulers, curriculum and other **time-dependent features should be based on**
+    ``optimizer.local_epoch`` (and not the number ot calls to opt.step). Otherwise, peers that joined training late
+    may end up having different learning rates. To do so automatically, specify ``scheduler=...`` parameter below.
+
+    :What is an epoch?: Optimizer uses the term ``epoch`` to describe intervals between synchronizations. One epoch
+      coresponds to processing certain number of training samples (``target_batch_size``) in total across all peers.
+      Like in PyTorch LR Scheduler, **epoch does not necessarily correspond to a full pass over the training data.**
+      At the end of epoch, peers perform synchronous actions such as averaging gradients for a global optimizer update,
+      updating the learning rate scheduler or simply averaging parameters (if using local updates).
+      The purpose of this is to ensure that changing the number of peers does not require changing hyperparameters.
+      For instance, if the number of peers doubles, they will run all-reduce more frequently to adjust for faster training.
+
+    :Configuration guide: This guide will help you set up your first collaborative training run. It covers the most
+      important basic options, but ignores features that require significant changes to the training code.
+
+    >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=IF_BEHIND_FIREWALL_OR_VERY_UNRELIABLE, start=True)
+    >>> opt = hivemind.Optimizer(
+    >>>    dht=dht, run_id="a_unique_name_that_every_participant_will_see_when_training",
+    >>>    batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER, target_batch_size=LARGE_GLOBAL_BATCH,
+    >>>    # ^--- Each global optimzier step will use gradients from 1x-1.1x of target_batch_size (due to latency);
+    >>>    # It is recommended to train with very large batch sizes to reduce the % of time spent on communication.
+    >>>
+    >>>    params=params, optimizer=lambda params: AnyPyTorchOptimizer(params, **hyperparams_for_target_batch_size),
+    >>>    # tune learning rate for your target_batch_size. Here's a good reference: https://arxiv.org/abs/1904.00962
+    >>>    scheduler=lambda opt: AnyPyTorchScheduler(opt, **hyperparams_for_target_batch_size),
+    >>>    # scheduler.step will be called automatically each time when peers collectively accumulate target_batch_size
+    >>>
+    >>>    offload_optimizer=True,  # saves GPU memory, but increases RAM usage; Generally a good practice to use this.
+    >>>    delay_grad_averaging=OPTIONAL, delay_optimizer_step=OPTIONAL, # train faster, but with 1 round of staleness;
+    >>>    # setting both to True is equivalent to Delayed Parameter Updates (see https://arxiv.org/abs/2101.06840)
+    >>>
+    >>>    grad_compression=hivemind.Float16Compression(),  state_averaging_compression=hivemind.Float16Compression(),
+    >>>    # ^-- it is usually fine to use pure 16-bit or even lower precision during communication with no precaution;
+    >>>    # See hivemind/examples/albert for an working example of mixed 8/16-bit compression.
+    >>>
+    >>>    matchmaking_time=15.0, # 3-5s for small local runs, 10-15s for training over the internet or with many peers
+    >>>    averaging_timeout=60.0,  # around of 2x the actual time it takes to run all-reduce
+    >>>    verbose=True  # periodically report the training progress to the console (e.g. "Averaged with N peers")
+    >>> )  # and you're done!
+
+
+    :param dht: a running hivemind.DHT instance connected to other peers.
+    :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
+      **Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
+      Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
+      ``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
+      individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
+
+    :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch.
+      The actual batch may be *slightly* larger due asynchrony (e.g. peers submit more gradients in the last second).
+    :param batch_size_per_step: you should accumulate gradients over this many samples between calls to optimizer.step.
+
+    :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params).
+    :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer.
+      **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging require
+      and require the callable and will not work if hivemind.optimizer is created with a pre-existing PyTorch Optimizer.
+    :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
+      The learning rate scheduler will adjust learning rate based on global epoch, not the number of
+      local calls to optimizer.step; this is required to keep different peers synchronized.
+
+    :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds.
+      Increase if you see "averaged gradients with N peers" where N is below 0.9x the real siee on >=25% of epochs.
+      When training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes.
+    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
+      Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
+      Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
+    :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
+    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
+      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
+
+    :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
+    :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
+    :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
+
+    :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
+      if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
+
+    :param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
+      This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
+      The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
+      hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
+      gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
+
+    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
+      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
+      Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
+
+    :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
+    :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
+
+    :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
+    :param load_state_compression: compression strategy for loading state from peers, default = no compression
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+
+    :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
+    :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
+    :param performance_ema_alpha: moving average alpha in ProgressTracker, TrainingStateAverager and Optimizer
+    :param verbose: if True, report internal events such as accumilating gradients and running background tasks
+
+    :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
+      is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.
+
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: DHT,
+        run_id: str,
+        target_batch_size: int,
+        batch_size_per_step: Optional[int] = None,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        params: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        matchmaking_time: Optional[float] = 15.0,
+        averaging_timeout: Optional[float] = 60.0,
+        allreduce_timeout: Optional[float] = None,
+        next_chunk_timeout: Optional[float] = None,
+        load_state_timeout: float = 600.0,
+        reuse_grad_buffers: bool = False,
+        offload_optimizer: Optional[bool] = None,
+        delay_optimizer_step: Optional[bool] = None,
+        delay_grad_averaging: bool = False,
+        delay_state_averaging: bool = True,
+        average_state_every: int = 1,
+        use_local_updates: bool = False,
+        client_mode: bool = None,
+        auxiliary: bool = False,
+        grad_compression: CompressionBase = NoCompression(),
+        state_averaging_compression: CompressionBase = NoCompression(),
+        load_state_compression: CompressionBase = NoCompression(),
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        averager_opts: Optional[dict] = None,
+        tracker_opts: Optional[dict] = None,
+        performance_ema_alpha: float = 0.1,
+        shutdown_timeout: float = 5,
+        verbose: bool = False,
+    ):
+        self._parent_pid = os.getpid()
+
+        client_mode = client_mode if client_mode is None else dht.client_mode
+        delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
+        offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
+        next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
+        assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
+        assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
+        assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
+        if callable(optimizer) and params is not None:
+            if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
+                raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
+        elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
+            if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
+                raise ValueError(
+                    "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
+                    "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
+                )
+        else:
+            raise ValueError(
+                "Please initialize the optimizer in one of the following two ways:\n"
+                "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
+                "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
+            )
+        if use_local_updates:
+            assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
+            assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+
+        self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
+        self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
+        self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
+        self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
+        self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
+
+        self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
+        self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
+
+        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
+        self.scheduled_grads: Optional[StepControl] = None
+        self.scheduled_state: Optional[StepControl] = None
+
+        self.tracker = self._make_progress_tracker(
+            target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
+        )
+        self.state_averager = self._make_state_averager(
+            optimizer=optimizer,
+            params=params,
+            scheduler=scheduler,
+            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
+            compression=state_averaging_compression,
+            state_compression=load_state_compression,
+            average_opt_statistics=average_opt_statistics,
+            performance_ema_alpha=performance_ema_alpha,
+            extra_tensors=extra_tensors,
+            **averager_opts or {},
+        )
+        if not use_local_updates:
+            self.grad_averager = self._make_gradient_averager(
+                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+            )
+        else:
+            self.grad_averager = None
+
+        self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
+        self._schema_hash = self._compute_schema_hash()
+
+        self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
+        # used for pre-scheduling the averaging round in state_averager
+
+        self._step_supports_amp_scaling = reuse_grad_buffers
+        # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
+        # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
+
+    def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
+        return TrainingStateAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_state_averager",
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.allreduce_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            offload_optimizer=self.offload_optimizer,
+            custom_gradients=self.offload_optimizer,
+            status_loglevel=self.status_loglevel,
+            next_chunk_timeout=self.next_chunk_timeout,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+
+    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+        assert hasattr(self, "state_averager"), "must initialize state averager first"
+        grad_averager = GradientAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_grad_averager",
+            parameters=self.state_averager.main_parameters,
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.allreduce_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            next_chunk_timeout=self.next_chunk_timeout,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+        if self.offload_optimizer:
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad = averaged_grad
+        return grad_averager
+
+    def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
+        return ProgressTracker(
+            dht=self.dht,
+            prefix=self.run_id,
+            target_batch_size=target_batch_size,
+            client_mode=self.client_mode,
+            status_loglevel=self.status_loglevel,
+            start=True,
+            **kwargs,
+        )
+
+    def _compute_schema_hash(self) -> int:
+        optimized_param_groups = self.state_averager.optimizer.param_groups
+        optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+        param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
+
+        # offloaded optimizer requires that gradient tensors are reused between iterations
+        grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
+        return hash((grad_ids, param_shapes))
+
+    def is_alive(self) -> bool:
+        return self.state_averager.is_alive()
+
+    @property
+    def local_epoch(self) -> int:
+        """
+        This worker's current epoch, kept synchronized with peers. If peer's local_epoch lags behind others, it will
+        automatically re-synchronize by downloading state from another peer.
+        An epoch corresponds to accumulating target_batch_size across all active devices.
+        """
+        return self.state_averager.local_epoch
+
+    @property
+    def local_progress(self) -> LocalTrainingProgress:
+        return self.tracker.local_progress
+
+    @property
+    def use_local_updates(self) -> bool:
+        return self.grad_averager is None
+
+    @property
+    def use_gradient_averaging(self) -> bool:
+        return self.grad_averager is not None
+
+    def step(
+        self,
+        closure: Optional[Callable[[], torch.Tensor]] = None,
+        batch_size: Optional[int] = None,
+        grad_scaler: Optional[GradScaler] = None,
+    ):
+        """
+        Update training progress after accumulating another local batch size. Depending on the configuration, this will
+        report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
+
+        :param closure: A closure that reevaluates the model and returns the loss.
+        :param batch_size: optional override for batch_size_per_step from init.
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler.
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+        if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
+            raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
+        if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
+            raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
+        if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
+            raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
+        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
+
+        # if delayed updates finished before step, apply these updates; otherwise do nothing
+        self.state_averager.step(apply_delayed_updates=True)
+
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        if not self.auxiliary and self._should_load_state_from_peers():
+            logger.log(self.status_loglevel, "Peer is out of sync")
+            self.load_state_from_peers()
+            return loss  # local gradients were computed with out-of-sync parameters, must start over
+
+        if self.use_gradient_averaging:
+            # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
+            if not self.auxiliary:
+                grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
+                if not grads_are_valid:
+                    return loss  # local gradients were reset due to overflow, must start over
+
+            self._maybe_schedule_gradient_averaging()
+            self._maybe_schedule_state_averaging()
+
+        else:
+            # use_local_updates=True: update parameters on every step independently of other peers
+            if not self.auxiliary:
+                if grad_scaler is not None:
+                    with grad_scaler.running_global_step():
+                        assert grad_scaler.unscale_(self)
+
+                new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
+                self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
+                self._maybe_schedule_state_averaging()
+
+                self.state_averager.step(
+                    increment_epoch=False,
+                    optimizer_step=True,
+                    delay_optimizer_step=self.delay_optimizer_step,
+                    grad_scaler=grad_scaler,
+                )
+
+        if self.tracker.ready_to_update_epoch:
+            self._update_global_epoch(grad_scaler)
+
+        return loss
+
+    def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
+        """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
+        assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
+        _epoch_start_time = time.perf_counter()
+
+        with self.tracker.pause_updates():
+            wait_for_trigger = None
+
+            if self.use_gradient_averaging:
+                logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
+                if self.delay_optimizer_step:
+                    self.state_averager.step(wait_for_delayed_updates=True)
+
+                began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
+                if not began_averaging_gradients:
+                    # failed to start gradient averaging due to an internal error
+                    self.grad_averager.load_accumulators_into_averager_()
+                elif self.delay_grad_averaging:
+                    # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
+                    wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
+                else:
+                    # delay_grad_averaging=False, average gradients immediately
+                    self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
+
+            next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+            swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
+            should_average_state = (
+                swarm_not_empty
+                and next_epoch % self.average_state_every == 0
+                and not self.state_averager.averaging_in_progress
+            )
+
+            if should_average_state and self.scheduled_state is not None:
+                if self.scheduled_state.triggered or self.scheduled_state.done():
+                    logger.log(
+                        self.status_loglevel,
+                        f"Not using pre-scheduled group for state averaging because it"
+                        f"was already used elsewhere: {self.scheduled_state}",
+                    )
+                    self.scheduled_state = None
+                self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
+
+            self.state_averager.step(
+                increment_epoch=True,
+                wait_for_trigger=wait_for_trigger,
+                optimizer_step=should_perform_optimizer_step,
+                delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
+                grad_scaler=grad_scaler,
+                averaging_round=should_average_state,
+                delay_averaging=self.delay_state_averaging and not self.auxiliary,
+                averaging_control=self.scheduled_state if should_average_state else None,
+                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
+            )
+
+            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
+                self.scheduled_state.cancel()
+            self.scheduled_state = None
+
+            self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+            self._should_check_synchronization_on_update = True
+            # the above line ensures that peers check for *strict* synchronization once per epoch
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging and not self.auxiliary:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+            logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}")
+
+    def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
+        """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
+        if grad_scaler is not None:
+            with grad_scaler.running_global_step():
+                assert grad_scaler.unscale_(self)
+
+        began_averaging_gradients = False
+        if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
+            logger.log(
+                self.status_loglevel,
+                f"Not using pre-scheduled group for state averaging because it"
+                f"was already used elsewhere: {self.scheduled_state}",
+            )
+            self.scheduled_grads = None
+
+        elif self.tracker.global_progress.num_peers > 1:
+            try:
+                self.scheduled_grads = self.grad_averager.step(
+                    control=self.scheduled_grads, reset_accumulators=True, wait=False
+                )
+                began_averaging_gradients = True
+            except BaseException as e:
+                logger.exception(e)
+
+        if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
+            if self.tracker.global_progress.num_peers > 1:
+                logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
+                self._tag_along_with_zero_weight(self.scheduled_grads)
+            else:
+                logger.log(self.status_loglevel, f"Skipping pre-scheduled averaging round: there are no other peers")
+                self._load_local_gradients_into_optimizer()
+                self.scheduled_grads.cancel()
+            self.scheduled_grads = None
+        return began_averaging_gradients
+
+    def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
+        """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
+        assert not self.use_local_updates and not self.auxiliary
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
+            return False
+
+        self.grad_averager.accumulate_grads_(batch_size)
+        self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
+        return True
+
+    def _maybe_schedule_gradient_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
+        assert self.use_gradient_averaging
+        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
+            if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
+                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
+                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec")
+                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
+
+    def _maybe_schedule_state_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+        if next_epoch % self.average_state_every != 0:
+            return  # averaging is not performed at this epoch
+        if self.state_averager.averaging_in_progress:
+            return  # previous run is still in progress
+        if self.delay_before_state_averaging.num_updates == 0:
+            return  # not enough data to accurately pre-schedule
+
+        estimated_time = self.tracker.estimated_next_update_time
+        estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
+        estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
+        eta_seconds_to_averaging = estimated_time - get_dht_time()
+
+        if eta_seconds_to_averaging <= self.matchmaking_time:
+            if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
+                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
+                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f} sec")
+                self.scheduled_state = self.state_averager.schedule_step(
+                    gather=next_epoch, timeout=self.averaging_timeout
+                )
+
+    def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
+        """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
+        assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
+        averaged_gradients = False
+
+        try:
+            if maybe_step_control is not None:
+                group_info = maybe_step_control.result(self.averaging_timeout)
+                logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                self._load_averaged_gradients_into_optimizer_()
+                averaged_gradients = True
+            else:
+                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
+        except BaseException as e:
+            logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
+
+        if not averaged_gradients:
+            self._load_local_gradients_into_optimizer()
+
+    def _load_averaged_gradients_into_optimizer_(self):
+        """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
+        assert self.use_gradient_averaging
+
+        if self.offload_optimizer:
+            pass  # averaged gradients are already baked into optimizer, see _make_gradient_averager
+        else:
+            # copy averaged gradients into optimizer .grad buffers
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad.copy_(averaged_grad, non_blocking=True)
+
+        self.grad_averager.notify_used_averaged_gradients()
+
+    def _load_local_gradients_into_optimizer(self):
+        """Fallback to using local gradients in the optimizer (instead of averaged gradients)"""
+        logger.log(self.status_loglevel, f"Proceeding with local gradients")
+        self.grad_averager.load_accumulators_into_averager_()
+        # note: we load gradients into grad_averager even though there is only one peer because of two reasons:
+        # - if offload_optimizer, then we must load gradients onto the CPU gradient buffers used by the optimizer
+        # - if not offload_optimizer, we must un-scale gradients (divide them by the number of accumulation steps)
+        self._load_averaged_gradients_into_optimizer_()
+
+    def zero_grad(self, set_to_none: bool = False):
+        """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
+        if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
+            raise ValueError(
+                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
+                f"call zero_grad manually. Gradients will be refreshed internally"
+            )
+        for param_group in self.param_groups:
+            for param in param_group["params"]:
+                if param.grad is None:
+                    pass
+                elif set_to_none:
+                    param.grad = None
+                else:
+                    param.grad.zero_()
+
+    def _should_load_state_from_peers(self) -> bool:
+        """
+        If true, peer will discard local progress and attempt to download state from peers.
+        This method allows peer to continue training in two cases:
+         - peer is on the same epoch as other collaborators - keep training normally
+         - peer was on the same epoch and accumulated some grads, but some collaborators
+             have just transitioned to the next epoch - this peer should also transition.
+
+        :note: The latter case occurs due to the lack of network synchrony: the first peer that
+        detects enough samples will transition to the next step and start counting samples anew.
+        Some other peers may take time before they check with DHT and observe that
+          - the global epoch is technically one epoch ahead of the current one and
+          - the remaining (non-transitioned) peers no longer have target_batch_size between them
+        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
+        """
+        if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
+            self._should_check_synchronization_on_update = False
+            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
+        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
+
+    def is_synchronized_with_peers(self) -> bool:
+        """Checks whether the current peer is up-to-date with others in terms of the epoch (step) number."""
+        return self.local_epoch >= self.tracker.global_epoch - 1
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to load the newest collaboration state from other peers within the same run_id.
+
+        If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
+        """
+        # note: we tag along for the next all-reduce because the run may have already started and cancelling it
+        # will cause peers to restart matchmaking and may  stall the entire collaboration for a few seconds.
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self._tag_along_with_zero_weight(self.scheduled_grads)
+            self.scheduled_grads = None
+        self.state_averager.step(wait_for_delayed_updates=True)
+
+        with self.tracker.pause_updates():
+            while True:
+                try:
+                    self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    break
+                except KeyboardInterrupt:
+                    raise
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
+                    continue
+
+            if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
+                logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}")
+                self.state_averager.local_epoch = self.tracker.global_epoch
+
+            self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+    def state_dict(self) -> dict:
+        state_dict = self.state_averager.optimizer.state_dict()
+        state_dict["state"]["local_epoch"] = self.local_epoch
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "local_epoch" in state_dict["state"]:
+            self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
+        return self.state_averager.optimizer.load_state_dict(state_dict)
+
+    @property
+    def state(self):
+        return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
+
+    @property
+    def opt(self) -> TorchOptimizer:
+        return self.state_averager.optimizer
+
+    @property
+    def param_groups(self) -> ParamGroups:
+        next_index = 0
+        param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
+        for param_group in param_groups:
+            num_params = len(param_group["params"])
+            main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
+            param_group["params"] = main_params_for_group
+            next_index += num_params
+        assert next_index == len(self.state_averager.main_parameters)
+        return param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise ValueError(
+            f"{self.__class__.__name__} does not support calling add_param_group after creation. "
+            f"Please provide all parameter groups at init"
+        )
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
+
+    def _tag_along_with_zero_weight(self, control: StepControl):
+        """Wait for a running averaging round to finish with zero weight."""
+        if not control.triggered:
+            control.weight = 0
+            control.allow_allreduce()
+        if not control.done():
+            try:
+                control.result(self.averaging_timeout)
+            except BaseException as e:
+                logger.exception(e)
+                if not control.done():
+                    control.cancel()
+
+    def shutdown(self):
+        logger.log(self.status_loglevel, "Sending goodbye to peers...")
+        self.tracker.shutdown(self.shutdown_timeout)
+        self.state_averager.step(wait_for_delayed_updates=True)
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
+                else:
+                    self._tag_along_with_zero_weight(scheduled_round)
+
+        logger.log(self.status_loglevel, "Shutting down averagers...")
+        self.state_averager.shutdown()
+        if self.use_gradient_averaging:
+            self.grad_averager.shutdown()
+        logger.log(self.status_loglevel, f"{self.__class__.__name__} is shut down")
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 0 - 41
hivemind/optim/performance_ema.py

@@ -1,41 +0,0 @@
-from contextlib import contextmanager
-
-from hivemind.utils import get_dht_time
-
-
-class PerformanceEMA:
-    """
-    A running estimate of performance (operations/sec) using adjusted exponential moving average
-    :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
-    """
-
-    def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
-        self.alpha, self.eps, self.num_updates = alpha, eps, 0
-        self.ema_seconds_per_sample, self.samples_per_second = 0, eps
-        self.timestamp = get_dht_time()
-        self.paused = False
-
-    def update(self, num_processed: int) -> float:
-        """
-        :param num_processed: how many items were processed since last call
-        :returns: current estimate of performance (samples per second), but at most
-        """
-        assert not self.paused, "PerformanceEMA is currently paused"
-        assert num_processed > 0, f"Can't register processing {num_processed} samples"
-        self.timestamp, old_timestamp = get_dht_time(), self.timestamp
-        seconds_per_sample = max(0, self.timestamp - old_timestamp) / num_processed
-        self.ema_seconds_per_sample = self.alpha * seconds_per_sample + (1 - self.alpha) * self.ema_seconds_per_sample
-        self.num_updates += 1
-        adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
-        self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
-        return self.samples_per_second
-
-    @contextmanager
-    def pause(self):
-        """While inside this context, EMA will not count the time passed towards the performance estimate"""
-        self.paused, was_paused = True, self.paused
-        try:
-            yield
-        finally:
-            self.timestamp = get_dht_time()
-            self.paused = was_paused

+ 363 - 0
hivemind/optim/progress_tracker.py

@@ -0,0 +1,363 @@
+import asyncio
+import contextlib
+import logging
+import threading
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+import numpy as np
+from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
+
+from hivemind.dht import DHT
+from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
+from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
+from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.performance_ema import PerformanceEMA
+
+logger = get_logger(__name__)
+
+
+@dataclass(frozen=False)
+class GlobalTrainingProgress:
+    epoch: int
+    samples_accumulated: int
+    target_batch_size: int
+    num_peers: int
+    num_clients: int
+    eta_next_epoch: float
+    next_fetch_time: float
+
+
+class LocalTrainingProgress(BaseModel):
+    peer_id: bytes
+    epoch: conint(ge=0, strict=True)
+    samples_accumulated: conint(ge=0, strict=True)
+    samples_per_second: confloat(ge=0.0, strict=True)
+    time: StrictFloat
+    client_mode: StrictBool
+
+
+class TrainingProgressSchema(BaseModel):
+    progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]
+
+
+class ProgressTracker(threading.Thread):
+    """
+    Auxiliary class that keeps track of local & global training progress, measured in epochs.
+    An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
+    Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.
+
+    :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
+    :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
+    :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
+    :param expected_drift_peers: assume that this many new peers can join between epochs
+    :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between epochs
+    :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
+      refresh the collaboration-wide statistics (to avoid missing the moment when peers transition to the next epoch)
+    :param performance_ema_alpha: smoothing value used to estimate this peer's performance (samples per second)
+    :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
+
+    Example:
+
+    >>> tracker = ProgressTracker(hivemind.DHT(...), prefix="my_experiment_with_several_peers", target_batch_size=100)
+    >>> local_epoch, local_samples = 0, 0
+    >>> while True:
+    >>>     accumulate_gradients(batch_size=32)
+    >>>     local_samples += 32
+    >>>     tracker.report_local_progress(local_epoch, local_samples)
+    >>>     if local_epoch < tracker.global_progress.epoch:
+    >>>         download_state_from_peers()  # if peer is out of sync, synchronize it with the swarm
+    >>>     if tracker.accumulated_enough_samples:
+    >>>         with tracker.pause_updates():
+    >>>             aggregate_gradients_with_peers()
+    >>>             update_model_parameters()
+    >>>             local_epoch = tracker.update_epoch(local_epoch + 1)
+    >>>             local_samples = 0
+    """
+
+    def __init__(
+        self,
+        dht: DHT,
+        prefix: str,
+        target_batch_size: int,
+        *,
+        client_mode: Optional[bool] = None,
+        min_refresh_period: float = 0.5,
+        max_refresh_period: float = 10,
+        default_refresh_period: float = 3,
+        expected_drift_peers: float = 3,
+        expected_drift_rate: float = 0.2,
+        performance_ema_alpha: float = 0.1,
+        metadata_expiration: float = 60.0,
+        status_loglevel: int = logging.DEBUG,
+        private_key: Optional[RSAPrivateKey] = None,
+        daemon: bool = True,
+        start: bool,
+    ):
+        client_mode = client_mode if client_mode is not None else dht.client_mode
+        self.dht, self.prefix, self.client_mode = dht, prefix, client_mode
+        self.training_progress_key = f"{self.prefix}_progress"
+        self.target_batch_size = target_batch_size
+        self.min_refresh_period, self.max_refresh_period = min_refresh_period, max_refresh_period
+        self.default_refresh_period = default_refresh_period
+        self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
+        self.status_loglevel = status_loglevel
+        self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
+        self.metadata_expiration = metadata_expiration
+
+        signature_validator = RSASignatureValidator(private_key)
+        self._local_public_key = signature_validator.local_public_key
+        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
+
+        # report the collaboration progress periodically or in background
+        self.local_progress = self._get_local_progress(local_epoch=0, samples_accumulated=0)
+        metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
+        self.global_progress = self._parse_swarm_progress_data(metadata)
+        self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
+        self.should_report_progress, self.fetched_global_progress_this_epoch = threading.Event(), threading.Event()
+        self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
+        super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
+        if start:
+            self.start()
+
+    @property
+    def global_epoch(self) -> int:
+        return self.global_progress.epoch
+
+    @property
+    def ready_to_update_epoch(self) -> bool:
+        """Whether or not this peer can increment epoch right away."""
+        return (
+            self.global_epoch > self.local_progress.epoch
+            or self.global_progress.samples_accumulated >= self.target_batch_size
+            or get_dht_time() >= self.global_progress.eta_next_epoch
+        )
+
+    @property
+    def estimated_next_update_time(self) -> DHTExpiration:
+        """Estimate (absolute) time when this peer should increment epoch"""
+        if self.ready_to_update_epoch:
+            return get_dht_time()
+        return self.global_progress.eta_next_epoch
+
+    def _get_local_progress(self, local_epoch: int, samples_accumulated: int):
+        return LocalTrainingProgress(
+            peer_id=self.dht.peer_id.to_bytes(),
+            epoch=local_epoch,
+            samples_accumulated=samples_accumulated,
+            samples_per_second=self.performance_ema.samples_per_second,
+            time=get_dht_time(),
+            client_mode=self.client_mode,
+        )
+
+    def report_local_progress(self, local_epoch: int, samples_accumulated: int, update_global_samples: bool = True):
+        """Update the number of locally accumulated samples and notify to other peers about this."""
+        extra_samples = samples_accumulated - self.local_progress.samples_accumulated
+        if update_global_samples and local_epoch == self.local_progress.epoch == self.global_progress.epoch:
+            self.global_progress.samples_accumulated += extra_samples
+            # note: the above line can decrease the number of samples, e.g. if forced to reset due to overflow
+
+        if extra_samples > 0:
+            self.performance_ema.update(task_size=extra_samples)
+            logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
+        else:
+            logger.debug("Resetting performance timestamp to current time (progress was reset)")
+            self.performance_ema.reset_timer()
+
+        self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
+        self.should_report_progress.set()
+
+    @contextlib.contextmanager
+    def pause_updates(self):
+        """Temporarily stop progress tracker from updating global training state"""
+        with self.lock_global_progress, self.performance_ema.pause():
+            yield
+
+    def update_epoch(self, new_epoch: Optional[int] = None) -> int:
+        """Update the local epoch, reset the number of sample accumulated, reset local progress, return new epoch"""
+        assert self.lock_global_progress.locked(), "ProgressTracker must be paused when incrementing epoch"
+        if new_epoch is None:
+            new_epoch = self.local_progress.epoch + 1
+        if new_epoch > self.global_progress.epoch:
+            self.global_progress.epoch = new_epoch
+            self.global_progress.samples_accumulated = 0
+            self.global_progress.eta_next_epoch = float("inf")
+        self.report_local_progress(new_epoch, samples_accumulated=0)
+        self.fetched_global_progress_this_epoch.clear()
+        return new_epoch
+
+    def run(self):
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+        loop.run_until_complete(asyncio.gather(self._progress_reporter(), self._progress_fetcher()))
+        self.shutdown_complete.set()
+
+    async def _progress_reporter(self):
+        """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
+        last_report_time = -float("inf")
+        last_report_epoch = -float("inf")
+        store_task = None
+        try:
+            while not self.shutdown_triggered.is_set():
+                wait_timeout = max(0.0, last_report_time - get_dht_time() + self.metadata_expiration / 2)
+                logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
+                await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
+                if self.should_report_progress.is_set():
+                    logger.debug(f"Progress update triggered by report_local_progress")
+                    self.should_report_progress.clear()
+                else:
+                    logger.debug(f"Progress update triggered by metadata_expiration")
+
+                local_progress = self.local_progress
+                last_report_time = get_dht_time()
+                if local_progress.samples_accumulated > 0:
+                    last_report_epoch = self.global_epoch
+
+                if last_report_epoch >= self.global_epoch - 1:
+                    # report progress if peer is synchronized and actively reporting samples. Do not report aux peers.
+                    store_task = asyncio.create_task(
+                        asyncio.wait_for(
+                            self.dht.store(
+                                key=self.training_progress_key,
+                                subkey=self._local_public_key,
+                                value=local_progress.dict(),
+                                expiration_time=last_report_time + self.metadata_expiration,
+                                return_future=True,
+                            ),
+                            timeout=self.metadata_expiration,
+                        )
+                    )
+        finally:
+            logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}")
+            if store_task is not None:
+                store_task.cancel()
+
+    async def _progress_fetcher(self):
+        """
+        Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
+        """
+        loop = asyncio.get_event_loop()
+        shutdown_checker = asyncio.create_task(
+            asyncio.wait_for(loop.run_in_executor(None, self.shutdown_triggered.wait), None)
+        )
+
+        async def _fetch_progress_unless_shutdown_triggered():
+            """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
+            getter = asyncio.create_task(
+                asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
+            )
+            await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
+            if self.shutdown_triggered.is_set():
+                return
+            return await getter
+
+        try:
+            while not self.shutdown_triggered.is_set():
+                time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
+                state_updated_externally = await loop.run_in_executor(
+                    None, self.global_state_updated.wait, time_to_next_update
+                )
+                if state_updated_externally:
+                    self.global_state_updated.clear()
+                    continue
+
+                async with enter_asynchronously(self.lock_global_progress):
+                    maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
+                    if self.shutdown_triggered.is_set():
+                        break
+                    metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
+                    self.global_progress = self._parse_swarm_progress_data(metadata)
+                    self.fetched_global_progress_this_epoch.set()
+
+        finally:
+            logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}")
+
+    def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
+        """Read performance statistics reported by peers, estimate progress towards next batch"""
+        current_time = get_dht_time()
+
+        if not isinstance(metadata, dict) or len(metadata) == 0:
+            logger.log(self.status_loglevel, f"Found no active peers: {metadata}")
+            samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated)
+            local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second
+
+            return GlobalTrainingProgress(
+                self.local_progress.epoch,
+                self.local_progress.samples_accumulated,
+                self.target_batch_size,
+                num_peers=0,
+                num_clients=0,
+                eta_next_epoch=current_time + local_eta_next_epoch,
+                next_fetch_time=current_time + self.default_refresh_period,
+            )
+
+        valid_peer_entries = [
+            LocalTrainingProgress.parse_obj(peer_state.value)
+            for peer_state in metadata.values()
+            if peer_state.value is not None
+        ]
+
+        num_peers = len(valid_peer_entries)
+        num_clients = sum(peer.client_mode for peer in valid_peer_entries)
+
+        global_epoch = self.local_progress.epoch
+        for peer in valid_peer_entries:
+            if not peer.client_mode:
+                global_epoch = max(global_epoch, peer.epoch)
+
+        total_samples_accumulated = estimated_current_samples = 0
+        total_samples_per_second = self.performance_ema.eps
+
+        for peer in valid_peer_entries:
+            total_samples_per_second += peer.samples_per_second
+            if peer.epoch == global_epoch:
+                total_samples_accumulated += peer.samples_accumulated
+                estimated_current_samples += (
+                    peer.samples_accumulated + max(0.0, current_time - peer.time) * peer.samples_per_second
+                )
+            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
+            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
+
+        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
+        estimated_time_to_next_epoch = max(0, estimated_samples_remaining) / total_samples_per_second
+
+        expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
+        time_to_next_fetch = float(
+            np.clip(
+                a=estimated_time_to_next_epoch * num_peers / expected_max_peers,
+                a_min=self.min_refresh_period,
+                a_max=self.max_refresh_period,
+            )
+        )
+        logger.log(
+            self.status_loglevel,
+            f"{self.prefix} accumulated {total_samples_accumulated} samples for epoch #{global_epoch} from "
+            f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
+        )
+        return GlobalTrainingProgress(
+            global_epoch,
+            total_samples_accumulated,
+            target_batch_size=self.target_batch_size,
+            num_peers=num_peers,
+            num_clients=num_clients,
+            eta_next_epoch=current_time + estimated_time_to_next_epoch,
+            next_fetch_time=current_time + time_to_next_fetch,
+        )
+
+    def shutdown(self, timeout: Optional[float] = None):
+        """Permanently disable all tracking activity"""
+        self.shutdown_triggered.set()
+        self.should_report_progress.set()
+        self.global_state_updated.set()
+        self.shutdown_complete.wait(timeout)
+        self.dht.store(
+            self.training_progress_key,
+            subkey=self._local_public_key,
+            value=None,
+            expiration_time=get_dht_time() + self.metadata_expiration,
+            return_future=True,
+        )
+
+    def __del__(self):
+        if self.is_alive():
+            self.shutdown()

+ 9 - 5
hivemind/optim/simple.py

@@ -4,9 +4,9 @@ from typing import Optional, Sequence, Tuple
 
 import torch
 
-from hivemind.averaging import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
 
 logger = get_logger(__name__)
@@ -86,6 +86,10 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             if self.local_step % self.averaging_step_period == 0:
                 self.update_event.set()
             self.averager.pending_updates_done.wait()
+
+            if not self.averager.client_mode:
+                self.averager.state_sharing_priority = get_dht_time()
+
             return loss
         finally:
             self.lock_parameters.acquire()
@@ -127,16 +131,16 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                 time.sleep(time_to_nearest_interval)
 
             if verbose:
-                logger.info(f"Starting a new averaging round with current parameters.")
+                logger.info(f"Starting a new averaging round with current parameters")
             try:
                 group_info = averager.step(lock_parameters, **kwargs)
                 if verbose:
                     if group_info is not None:
-                        logger.info(f"Finished averaging round in with {len(group_info)} peers.")
+                        logger.info(f"Finished averaging round in with {len(group_info)} peers")
                     else:
-                        logger.warning(f"Averaging round failed: could not find group.")
+                        logger.warning(f"Averaging round failed: could not find group")
             except Exception as e:
-                logger.error(f"Averaging round failed: caught {e}.")
+                logger.error(f"Averaging round failed: caught {e}")
 
 
 class DecentralizedSGD(DecentralizedOptimizer):

+ 723 - 0
hivemind/optim/state_averager.py

@@ -0,0 +1,723 @@
+""" An extension of averager that supports common optimization use cases. """
+import logging
+import threading
+import time
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import nullcontext
+from itertools import chain
+from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
+
+import torch
+
+import hivemind
+from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
+
+logger = get_logger(__name__)
+
+
+Parameters = Iterable[torch.Tensor]
+ParamGroups = Iterable[Dict[str, Any]]
+TorchOptimizer = torch.optim.Optimizer
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
+SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
+
+
+class TrainingStateAverager(DecentralizedAverager):
+    """
+    An auxiliary class that holds peer's training state, including model parameters, optimizer statistics, scheduler
+    and any other variables that define the local training state (e.g. batchnorm moving averages).
+    TrainingStateAveraager is intended to keep these parameters weakly synchronized across the swarm.
+
+    The intended use is to call .step(optimizer_step=..., averaging_round=...) periodically, e.g. after every batch.
+    If peer gets out of sync with the swarm, one should call state_averager.load_state_from_peers() to re-synchronize.
+
+    Example:
+
+    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, params=model.parameters(), ...)
+    >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
+    >>> avgr.load_state_from_peers()
+    >>> for i, batch in enumerate(training_dataloader):
+    >>>     loss = compute_loss(model, batch)
+    >>>     loss.backward()
+    >>>     avgr.step(optimizer_step=i % 10 == 0, averaging_round=is_it_time_for_averaging(), delay_averaging=True)
+
+    :note: when using delay_averaging or delay_optimizer_step, calling optimizer directly is not recommended because
+      it may overlap with delayed updates from a background thread with unpredictable results. Instead, please call
+      TrainingStateAverager.step(..., optimizer_step=True)
+
+    :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
+    :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
+    :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
+    :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
+    :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
+      state tensors. If False, user must make sure that all tensors are pre-initialized at init.
+      By default, initialize optimizer unless it already has some state tensors to begin with.
+    :param offload_optimizer: if True, create optimizer on top of averaged parameters which may save device memory.
+    :param custom_gradients: if True, do *not* automatically load local gradients into the offloaded optimizer.
+      This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
+    :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
+      For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+      Defaults to True if offload_optimizer, False otherwise.
+    :param delta_rule_averaging: if True, averaging will use delta rule to allow running local optimizer steps
+      while averaging. Delta rule: `state_tensor := state_tensor + averaging_result - state_tensor_before_averaging`
+    :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
+    :param parameter_names: optionally provide parameter names in the same order as in params
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+    :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
+    :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: hivemind.DHT,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        params: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        initialize_optimizer: Optional[bool] = None,
+        offload_optimizer: bool = False,
+        custom_gradients: bool = False,
+        reuse_tensors: Optional[bool] = None,
+        delta_rule_averaging: bool = False,
+        performance_ema_alpha: float = 0.1,
+        sync_epoch_when_averaging: bool = False,
+        parameter_names: Optional[Sequence[str]] = None,
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        status_loglevel: int = logging.DEBUG,
+        **kwargs,
+    ):
+        average_opt_statistics = tuple(average_opt_statistics)
+        assert all(isinstance(key, str) for key in average_opt_statistics)
+        if reuse_tensors is None:
+            reuse_tensors = offload_optimizer and not delta_rule_averaging
+        if custom_gradients and not offload_optimizer:
+            logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+        if reuse_tensors and delta_rule_averaging:
+            raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive")
+
+        param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
+
+        self.status_loglevel = status_loglevel
+        self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients
+        self.reuse_tensors, self.delta_rule_averaging = reuse_tensors, delta_rule_averaging
+        self._old_tensors: Optional[Sequence[torch.Tensor]] = None  # for delta rule
+
+        self.main_parameters, self.parameter_names = main_parameters, parameter_names
+        self._averaged_parameters = self._make_averaged_parameters(main_parameters)
+        self.optimizer, self.scheduler = self._init_components(
+            param_groups, optimizer, scheduler, initialize_optimizer
+        )
+        self.opt_keys_for_averaging, self.extra_tensors = average_opt_statistics, extra_tensors
+        self.sync_epoch_when_averaging = sync_epoch_when_averaging
+        self.local_epoch = 0
+
+        self.delay_before_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        self.step_executor = ThreadPoolExecutor(max_workers=2 if self.delta_rule_averaging else 1)
+        self.finished_optimizer_step = threading.Event()
+        self.finished_averaging_round = threading.Event()
+        self.lock_optimizer = threading.Lock()
+        self.lock_averaging = threading.Lock()
+        self.pending_updates = set()
+
+        super().__init__(
+            dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
+        )
+
+    @staticmethod
+    def _check_params(
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        param_groups: Optional[Union[Parameters, ParamGroups]],
+        parameter_names: Optional[Sequence[str]],
+    ) -> Tuple[ParamGroups, Sequence[torch.Tensor], Sequence[str]]:
+        """Get and verify parameters, groups and names"""
+        if param_groups is None:
+            assert hasattr(optimizer, "param_groups"), "Must provide param_groups or an optimizer with .param_groups"
+            param_groups = optimizer.param_groups
+        param_groups = tuple(param_groups)
+        if all(isinstance(p, torch.Tensor) for p in param_groups):
+            param_groups = (dict(params=param_groups),)
+        for group in param_groups:
+            assert isinstance(group, dict) and group.get("params") is not None
+            assert all(isinstance(p, torch.Tensor) for p in group["params"])
+        parameters = tuple(chain(*(group["params"] for group in param_groups)))
+        if parameter_names is None:
+            parameter_names = tuple(i for i in range(len(parameters)))
+        parameter_names = tuple(nested_flatten(parameter_names))
+        assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
+        assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
+        params_with_grad = sum(p.numel() for p in parameters if p.requires_grad)
+        params_no_grad = sum(p.numel() for p in parameters if not p.requires_grad)
+        if params_no_grad >= params_with_grad:
+            logger.warning(
+                "The majority of parameters have requires_grad=False, but they are still synchronized"
+                " with peers. If these parameters are frozen (not updated), please do not feed them into "
+                "the optimizer at all in order to avoid communication overhead. Proceeding anyway."
+            )
+
+        return param_groups, parameters, parameter_names
+
+    def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
+        """Initialize averaged parameters based on the optimizer and averaging mode"""
+        return tuple(self._make_host_tensor(param, force_copy=self.offload_optimizer) for param in main_parameters)
+
+    def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
+        """Create a new tensor for averaging or reuse the existing one"""
+        if self.reuse_tensors and not force_copy:
+            if source_tensor.device != torch.device("cpu"):
+                raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU")
+            if not source_tensor.is_shared():
+                source_tensor.share_memory_()
+            return source_tensor
+        else:
+            averaged_tensor = source_tensor.detach().to(device="cpu", dtype=torch.float32, copy=True)
+            return averaged_tensor.share_memory_().requires_grad_(source_tensor.requires_grad)
+
+    def _init_components(
+        self,
+        param_groups: ParamGroups,
+        optimizer_or_factory: Union[TorchOptimizer, OptimizerFactory],
+        scheduler_or_factory: Optional[Union[LRSchedulerBase, SchedulerFactory]],
+        initialize_optimizer: Optional[bool],
+    ) -> Tuple[TorchOptimizer, Optional[LRSchedulerBase]]:
+        """Get optimizer and scheduler by either instantiating user-provided factory or using pre-instantiated ones"""
+        assert hasattr(self, "_averaged_parameters"), "Internal error: must initialize averaged parameters first"
+        optimizer_is_factory = callable(optimizer_or_factory) and not isinstance(optimizer_or_factory, TorchOptimizer)
+        scheduler_is_factory = callable(scheduler_or_factory) and not isinstance(scheduler_or_factory, LRSchedulerBase)
+        if optimizer_is_factory and not scheduler_is_factory and scheduler_or_factory is not None:
+            raise ValueError("If optimizer is created internally, scheduler must also be initialized internally")
+        if self.offload_optimizer and not optimizer_is_factory:
+            raise ValueError("Using offload_optimizer requires creating optimizer inside hivemind")
+
+        # create optimizer
+        if optimizer_is_factory:
+            if self.offload_optimizer:
+                if self.reuse_tensors:
+                    parameters_for_optimizer = self._averaged_parameters
+                else:
+                    parameters_for_optimizer = tuple(
+                        tensor.detach().clone().requires_grad_(tensor.requires_grad)
+                        for tensor in self._averaged_parameters
+                    )
+
+                next_index = 0
+                param_groups_for_optimizer = []
+                for param_group in param_groups:
+                    num_params = len(param_group["params"])
+                    averaged_params_for_group = parameters_for_optimizer[next_index : next_index + num_params]
+                    param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
+                    next_index += num_params
+                assert next_index == len(parameters_for_optimizer)
+
+                for param in parameters_for_optimizer:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
+            else:
+                param_groups_for_optimizer = param_groups
+            optimizer = optimizer_or_factory(param_groups_for_optimizer)
+        else:
+            optimizer = optimizer_or_factory
+
+        # optionally initialize optimizer state dict
+        if initialize_optimizer is None:
+            initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
+            logger.log(
+                self.status_loglevel,
+                "Initializing optimizer manually since it has no tensors in state dict. "
+                "To override this, provide initialize_optimizer=False",
+            )
+
+        if initialize_optimizer:
+            initialize_optimizer_state_(optimizer)  # note: this will run one optimizer step!
+
+        # create LR scheduler
+        if scheduler_is_factory:
+            assert callable(scheduler_or_factory)
+            scheduler = scheduler_or_factory(optimizer)
+        else:
+            scheduler = scheduler_or_factory
+
+        # verify optimizer and scheduler
+        assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
+        if self.reuse_tensors:
+            for param_group in optimizer.param_groups:
+                for param in param_group["params"]:
+                    assert param.is_shared()
+        assert isinstance(scheduler, (LRSchedulerBase, type(None)))
+        if scheduler is not None:
+            assert scheduler.optimizer == optimizer
+        return optimizer, scheduler
+
+    def _local_tensors(self) -> Iterator[torch.Tensor]:
+        """Iterate local trainer's tensors that should be averaged with peers"""
+        for param_group in self.optimizer.param_groups:
+            yield from param_group["params"]
+        for stats in self.opt_keys_for_averaging:
+            for param_group in self.optimizer.param_groups:
+                for param in param_group["params"]:
+                    yield self.optimizer.state[param][stats]
+        yield from self.extra_tensors
+
+    @torch.no_grad()
+    def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
+        """Create or reuse a tuple of all averaged tensors, including parameters, optimizer statistics and extras"""
+        assert hasattr(self, "optimizer"), "Optimizer should already be initialized by this point"
+        assert hasattr(self, "_averaged_parameters"), "Should initialize _averaged_parameters first"
+        assert not hasattr(self, "_averaged_tensors"), "Averager is already initialized"
+        assert all(isinstance(key, str) for key in self.opt_keys_for_averaging)
+
+        local_tensors = tuple(self._local_tensors())
+        local_non_parameters = local_tensors[len(self._averaged_parameters) :]
+        averaged_tensors = tuple(map(torch.Tensor.detach, self._averaged_parameters))
+        averaged_non_parameters = tuple(map(self._make_host_tensor, local_non_parameters))
+        averaged_tensors = tuple(chain(averaged_tensors, averaged_non_parameters))
+
+        assert len(averaged_tensors) == len(local_tensors)
+        for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+            assert local_tensor.shape == averaged_tensor.shape
+            if averaged_tensor.grad is not None:
+                logger.log(self.status_loglevel, "setting gradients for averaged tensor to None")
+
+        return averaged_tensors
+
+    def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
+        """Get CompressionInfo for each state tensor, accounting for its role and specification"""
+        tensor_infos = []
+        for param, param_name in zip(self.main_parameters, self.parameter_names):
+            tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
+        for stats_name in self.opt_keys_for_averaging:
+            opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+            assert len(opt_parameters) == len(self.parameter_names)
+            for param, param_name in zip(opt_parameters, self.parameter_names):
+                tensor_infos.append(
+                    CompressionInfo.from_tensor(
+                        self.optimizer.state[param][stats_name],
+                        key=(param_name, stats_name),
+                        role=TensorRole.OPTIMIZER,
+                    )
+                )
+        for i, extra_tensor in enumerate(self.extra_tensors):
+            tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
+        return tuple(tensor_infos)
+
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into TrainingStateAverager.step to use pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
+    def step(
+        self,
+        wait_for_delayed_updates: bool = None,
+        apply_delayed_updates: bool = True,
+        increment_epoch: bool = False,
+        optimizer_step: bool = False,
+        zero_grad: bool = False,
+        delay_optimizer_step: bool = False,
+        averaging_round: bool = False,
+        delay_averaging: Optional[bool] = None,
+        averaging_control: Optional[StepControl] = None,
+        wait_for_trigger: Optional[Callable[[], Any]] = None,
+        grad_scaler: Optional[GradScaler] = None,
+        averaging_opts: Optional[Dict[str, Any]] = None,
+    ):
+        """
+        Perform one or several possible actions, depending on the specified keyword args.
+        The actions will be performed in the same order as specified below:
+
+        :param wait_for_delayed_updates: if there are background averaging rounds, wait for them to finish
+          by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
+        :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
+        :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :note: if specified, it is guaranteed that epoch is incremented immediately regardless of other options
+        :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
+        :param zero_grad: if True, reset local gradients after performing optimizer step
+        :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
+        :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
+        :param delay_averaging: if True, perform averaging in background and apply results in a future step
+          by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param averaging_control: if specified, use this as a pre-scheduled averaging round. Should require_trigger.
+        :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
+        :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
+        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
+        :param averaging_opts: a dict of keyword arguments forwarded into averaging round
+        """
+        if delay_averaging is None:
+            delay_averaging = delay_optimizer_step
+        should_wait = averaging_round or optimizer_step or zero_grad if self.delta_rule_averaging else averaging_round
+        if wait_for_delayed_updates is None:
+            wait_for_delayed_updates = should_wait
+        if should_wait and not (wait_for_delayed_updates and apply_delayed_updates):
+            raise ValueError("Should wait for background operation to finish before scheduling new one")
+        assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
+        if delay_optimizer_step:
+            assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
+            assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
+        if averaging_opts and not averaging_round:
+            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
+        if averaging_control is not None:
+            assert averaging_round, "averaging_control is unused if averaging_round is not performed"
+        if wait_for_trigger is not None:
+            assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
+            if not (self.reuse_tensors or self.custom_gradients):
+                # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
+                # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
+                # the possibly different gradients when wait_for_trigger has finished).
+                raise ValueError(
+                    "wait_for_trigger is a low-level option that requires manual gradient manipulation. "
+                    "If you know what you're doing, please refer to the comments in the source code for details"
+                )
+        output = None
+
+        if wait_for_delayed_updates:
+            for pending_update in self.pending_updates:
+                try:
+                    timeout = (averaging_opts or {}).get("averaging_timeout", self._allreduce_timeout)
+                    logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                    output = pending_update.result(timeout)
+                except BaseException:
+                    # exception will be reported below
+                    if not pending_update.done():
+                        pending_update.cancel()
+
+        # remove finished updates, log any exceptions
+        finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
+        self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
+        for finished_update in finished_updates:
+            if finished_update.cancelled() or finished_update.exception():
+                logger.log(self.status_loglevel, f"Background update failed: {finished_update}")
+
+        if apply_delayed_updates:
+            if self.finished_averaging_round.is_set():
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                if self.offload_optimizer and not self.finished_optimizer_step.is_set():
+                    self._apply_optimizer_parameters_()
+                logger.log(self.status_loglevel, "Received parameters from background averaging round")
+                self.finished_averaging_round.clear()
+
+            if self.finished_optimizer_step.is_set():
+                if self.offload_optimizer:
+                    self._apply_optimizer_parameters_()
+                logger.debug("Received parameters from background optimizer step")
+                self.finished_optimizer_step.clear()
+
+        if increment_epoch:
+            self.local_epoch += 1
+
+        if optimizer_step or zero_grad or averaging_round:
+            if self.offload_optimizer and not self.custom_gradients:
+                self._load_local_grads_into_optimizer_()
+
+            pending_update = self.step_executor.submit(
+                self._do,
+                wait_for_trigger,
+                optimizer_step,
+                zero_grad,
+                averaging_round,
+                averaging_control,
+                grad_scaler,
+                **averaging_opts or {},
+            )
+            self.pending_updates.add(pending_update)
+
+            should_await_optimizer = (optimizer_step or zero_grad) and not delay_optimizer_step
+            should_await_averaging = averaging_round and not delay_averaging
+
+            if should_await_optimizer:
+                self.finished_optimizer_step.wait()
+                self.finished_optimizer_step.clear()
+                if self.offload_optimizer and not should_await_averaging:
+                    self._apply_optimizer_parameters_()
+                logger.debug("Finished optimizer step")
+
+            if should_await_averaging:
+                self.finished_averaging_round.wait()
+                self.finished_averaging_round.clear()
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                if self.offload_optimizer:
+                    self._apply_optimizer_parameters_()
+                logger.log(self.status_loglevel, "Finished averaging round")
+
+            async_averaging = averaging_round and delay_averaging
+            async_optimizer = (optimizer_step or zero_grad) and delay_optimizer_step
+
+            if not (async_averaging or async_optimizer):
+                try:
+                    output = pending_update.result()
+                finally:
+                    self.pending_updates.remove(pending_update)
+
+        return output
+
+    def _do(
+        self,
+        wait_for_trigger: Optional[Callable[[], Any]],
+        optimizer_step: bool,
+        zero_grad: bool,
+        averaging_round: bool,
+        averaging_control: Optional[StepControl],
+        grad_scaler: Optional[GradScaler],
+        timeout: Optional[float] = None,
+        **kwargs,
+    ):
+        """
+        Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
+        This method is meant to be called in the background executor.
+        """
+        if averaging_control is not None and (averaging_control.triggered or averaging_control.done()):
+            logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {averaging_control}")
+            averaging_control = None
+
+        start_time = time.perf_counter()
+        began_running = False
+
+        try:
+            if averaging_round and averaging_control is None:
+                averaging_control = super().step(
+                    gather=self.local_epoch,
+                    require_trigger=True,
+                    timeout=timeout,
+                    wait=False,
+                    **kwargs,
+                )
+
+            if wait_for_trigger is not None:
+                wait_for_trigger()
+            began_running = True
+
+            with self.lock_optimizer:
+                if optimizer_step:
+                    with self.lock_averaged_tensors if self.reuse_tensors else nullcontext():
+                        logger.debug(f"Running optimizer step")
+                        if grad_scaler is None:
+                            self.optimizer.step()
+                        else:
+                            with grad_scaler.running_global_step():
+                                assert grad_scaler.step(self.optimizer)
+
+                if zero_grad:
+                    logger.debug(f"Running zero grad")
+                    self.optimizer.zero_grad()
+                    if self.offload_optimizer:
+                        for parameter in self.main_parameters:
+                            if parameter.grad is not None:
+                                parameter.grad.zero_()
+
+                self._update_scheduler()
+                self.finished_optimizer_step.set()
+
+            if averaging_round:
+                with self.lock_averaging:
+                    if not self.reuse_tensors:
+                        self._load_local_tensors_into_averager_()
+                    if self.delta_rule_averaging:
+                        # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
+                        with torch.no_grad(), self.get_tensors() as averaged_tensors:
+                            self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
+
+                    self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
+                    try:
+                        averaging_control.allow_allreduce()
+                        gathered = averaging_control.result(timeout=timeout)
+                        logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                    except BaseException as e:
+                        logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                        gathered = {}
+
+                    self.finished_averaging_round.set()
+
+                if self.sync_epoch_when_averaging:
+                    old_epoch = self.local_epoch
+                    for peer_epoch in gathered.values():
+                        self.local_epoch = max(self.local_epoch, peer_epoch)
+                    if self.local_epoch != old_epoch:
+                        logger.log(self.status_loglevel, f"Found peer with newer epoch ({self.local_epoch})")
+                        self._update_scheduler()
+
+        except Exception as e:
+            if not began_running:
+                logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
+            logger.exception(e)
+            if averaging_control is not None and not averaging_control.done():
+                logger.error(f"Cancelled scheduled state averaging round")
+                averaging_control.cancel()
+            self.finished_optimizer_step.set()
+            self.finished_averaging_round.set()
+
+    @torch.no_grad()
+    def _load_local_grads_into_optimizer_(self):
+        """Copy local gradients into the gradient buffers of the offloaded optimizer"""
+        assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
+        opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        for main_param, opt_param in zip(self.main_parameters, opt_parameters):
+            if main_param.grad is not None:
+                opt_param.grad.copy_(main_param.grad, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_optimizer_parameters_(self):
+        """Copy parameters from offloaded optimizer to the main model"""
+        assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
+        offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training"
+        for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
+            main_param.copy_(offloaded_param, non_blocking=True)
+
+    @torch.no_grad()
+    def _load_local_tensors_into_averager_(self):
+        """Copy local tensors into the averaging buffers"""
+        assert not self.reuse_tensors, "No need to load tensors into averager: both tensors share the same memory"
+        with self.get_tensors() as averaged_tensors:
+            for local_tensor, averaged_tensor in zip(self._local_tensors(), averaged_tensors):
+                averaged_tensor.copy_(local_tensor, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_averaging_results_(self):
+        """Copy averaged tensors into their respective local tensors"""
+        assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        if self.delta_rule_averaging and self._old_tensors is None:
+            logger.warning("Using delta_rule_averaging, but old tensors were not found. Averaging may have failed")
+        with self.get_tensors() as averaged_tensors:
+            local_tensors = list(self._local_tensors())
+            assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
+            if not self.delta_rule_averaging or self._old_tensors is None:
+                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                    local_tensor.copy_(averaged_tensor, non_blocking=True)
+            else:
+                assert len(self._old_tensors) == len(local_tensors)
+                for local_tensor, new_tensor, old_tensor in zip(local_tensors, averaged_tensors, self._old_tensors):
+                    delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
+                    local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
+
+    @property
+    def averaging_in_progress(self) -> bool:
+        return self.lock_averaging.locked()
+
+    def get_current_state(self):
+        """
+        Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
+        :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+        """
+        with torch.no_grad(), self.lock_averaged_tensors:
+            optimized_parameters = tuple(
+                param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
+            )
+            parameter_infos = [
+                CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+                for param, key in zip(optimized_parameters, self.parameter_names)
+            ]
+            extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            extra_infos = [
+                CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+                for i, extra_tensor in enumerate(extra_tensors)
+            ]
+            optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.optimizer)
+            optimizer_infos = [
+                CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
+                for i, opt_tensor in enumerate(optimizer_tensors)
+            ]
+
+        metadata = dict(
+            epoch=self.local_epoch, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata
+        )
+        all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
+        return metadata, all_tensors, all_tensor_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        opt_parameters = tuple(param for param_group in self.optimizer.param_groups for param in param_group["params"])
+        main_parameters_and_extras = tuple(chain(opt_parameters, self.extra_tensors))
+        num_parameters_and_extras = len(main_parameters_and_extras)
+
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        if (not isinstance(metadata.get("epoch"), int)) or metadata["epoch"] < self.local_epoch:
+            logger.warning("Cowardly refusing to load state from peer: peer's epoch is behind our local epoch")
+            return
+
+        loaded_parameters_and_extras = flat_tensors[:num_parameters_and_extras]
+        loaded_opt_tensors = flat_tensors[num_parameters_and_extras:]
+        if num_parameters_and_extras != len(loaded_parameters_and_extras):
+            logger.error("Failed to load state from peer, received parameters, extras or metadata")
+            return
+
+        with torch.no_grad(), self.lock_averaged_tensors:
+            try:
+                load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
+            except StopIteration:
+                logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
+                return
+
+            for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
+                local_param.copy_(loaded_param, non_blocking=True)
+
+        if self.offload_optimizer:
+            self._apply_optimizer_parameters_()
+        if not self.reuse_tensors:
+            self._load_local_tensors_into_averager_()
+
+        self.local_epoch = metadata["epoch"]
+        self._update_scheduler()
+
+    def _update_scheduler(self):
+        """Increase the scheduler state until it becomes synchronized with local epoch"""
+        if self.scheduler:
+            while self.scheduler._step_count <= self.local_epoch:
+                self.scheduler.step()
+
+
+def initialize_optimizer_state_(opt: torch.optim.Optimizer):
+    """Initialize optimizer statistics by running a virtual optimizer step with zero gradients"""
+    flat_params = tuple(param for group in opt.param_groups for param in group["params"])
+    old_grads = []
+    for param in flat_params:
+        old_grads.append(param.grad)
+        param.grad = torch.zeros_like(param)
+    opt.step()
+    for param, old_grad in zip(flat_params, old_grads):
+        param.grad = old_grad
+
+
+def dump_optimizer_state(opt: torch.optim.Optimizer):
+    """Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers"""
+    with torch.no_grad():
+        flat_metadata, flat_tensors = [], []
+        for elem in nested_flatten(opt.state_dict()):
+            if isinstance(elem, torch.Tensor):
+                flat_metadata.append(dict(type="tensor", index=len(flat_tensors)))
+                flat_tensors.append(elem.cpu())
+            else:
+                flat_metadata.append(dict(type="value", value=elem))
+        return flat_metadata, flat_tensors
+
+
+def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
+    """Load a state obtained by dump_optimizer_state back into the optimizer"""
+    flat_optimizer_state = []
+    for elem in flat_metadata:
+        if elem.get("type") == "tensor" and isinstance(elem.get("index"), int):
+            flat_optimizer_state.append(flat_tensors[elem["index"]])
+        elif elem.get("type") == "value" and "value" in elem:
+            flat_optimizer_state.append(elem["value"])
+    return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))

+ 1 - 1
hivemind/averaging/training.py → hivemind/optim/training_averager.py

@@ -101,7 +101,7 @@ class TrainingAverager(DecentralizedAverager):
                 self.pending_updates_done.clear()
                 with data_lock, self.get_tensors() as averaged_tensors:
                     if len(averaged_tensors) != len(local_tensors):
-                        raise RuntimeError("The number of optimized parameters should not change.")
+                        raise RuntimeError("The number of optimized parameters should not change")
 
                     if use_old_local_tensors:
                         # since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents

+ 26 - 18
hivemind/p2p/p2p_daemon.py

@@ -140,6 +140,11 @@ class P2P:
         socket_uid = secrets.token_urlsafe(8)
         self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pd-{socket_uid}.sock")
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
+        if announce_maddrs is not None:
+            for addr in announce_maddrs:
+                addr = Multiaddr(addr)
+                if ("tcp" in addr and addr["tcp"] == "0") or ("udp" in addr and addr["udp"] == "0"):
+                    raise ValueError("Please specify an explicit port in announce_maddrs: port 0 is not supported")
 
         need_bootstrap = bool(initial_peers) or use_ipfs
         process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})
@@ -260,7 +265,7 @@ class P2P:
         return self._daemon_listen_maddr
 
     @staticmethod
-    async def send_raw_data(data: bytes, writer: asyncio.StreamWriter, *, chunk_size: int = 2 ** 16) -> None:
+    async def send_raw_data(data: bytes, writer: asyncio.StreamWriter, *, chunk_size: int = 2**16) -> None:
         writer.write(len(data).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER))
         data = memoryview(data)
         for offset in range(0, len(data), chunk_size):
@@ -386,22 +391,25 @@ class P2P:
                 await P2P.send_protobuf(request, writer)
             await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
 
-        with closing(writer):
-            writing_task = asyncio.create_task(_write_to_stream())
-            try:
-                while True:
-                    try:
-                        response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
-                    except asyncio.IncompleteReadError:  # Connection is closed
-                        break
+        async def _read_from_stream() -> AsyncIterator[Message]:
+            with closing(writer):
+                try:
+                    while True:
+                        try:
+                            response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
+                        except asyncio.IncompleteReadError:  # Connection is closed
+                            break
 
-                    if err is not None:
-                        raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
-                    yield response
+                        if err is not None:
+                            raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
+                        yield response
+
+                    await writing_task
+                finally:
+                    writing_task.cancel()
 
-                await writing_task
-            finally:
-                writing_task.cancel()
+        writing_task = asyncio.create_task(_write_to_stream())
+        return _read_from_stream()
 
     async def add_protobuf_handler(
         self,
@@ -476,7 +484,7 @@ class P2P:
         if not isinstance(input, AsyncIterableABC):
             return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
 
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
+        responses = await self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
 
     async def _call_unary_protobuf_handler(
@@ -490,7 +498,7 @@ class P2P:
         response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
         return output_protobuf_type.FromString(response)
 
-    def iterate_protobuf_handler(
+    async def iterate_protobuf_handler(
         self,
         peer_id: PeerID,
         name: str,
@@ -498,7 +506,7 @@ class P2P:
         output_protobuf_type: Type[Message],
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else as_aiter(input)
-        return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+        return await self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
 
     def _start_listening(self) -> None:
         async def listen() -> None:

+ 1 - 1
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -26,7 +26,7 @@ SUPPORT_CONN_PROTOCOLS = (
 SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
 logger = get_logger(__name__)
 
-DEFAULT_MAX_MSG_SIZE = 4 * 1024 ** 2
+DEFAULT_MAX_MSG_SIZE = 4 * 1024**2
 
 
 def parse_conn_protocol(maddr: Multiaddr) -> int:

+ 9 - 26
hivemind/p2p/servicer.py

@@ -86,38 +86,21 @@ class ServicerBase:
     @classmethod
     def _make_rpc_caller(cls, handler: RPCHandler):
         input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
+        output_type = AsyncIterator[handler.response_type] if handler.stream_output else handler.response_type
 
         # This method will be added to a new Stub type (a subclass of StubBase)
-        if handler.stream_output:
-
-            def caller(
-                self: StubBase, input: input_type, timeout: None = None
-            ) -> AsyncIterator[handler.response_type]:
-                if timeout is not None:
-                    raise ValueError("Timeouts for handlers returning streams are not supported")
-
-                return self._p2p.iterate_protobuf_handler(
-                    self._peer,
-                    cls._get_handle_name(self._namespace, handler.method_name),
-                    input,
-                    handler.response_type,
-                )
-
-        else:
-
-            async def caller(
-                self: StubBase, input: input_type, timeout: Optional[float] = None
-            ) -> handler.response_type:
+        async def caller(self: StubBase, input: input_type, timeout: Optional[float] = None) -> output_type:
+            handle_name = cls._get_handle_name(self._namespace, handler.method_name)
+            if not handler.stream_output:
                 return await asyncio.wait_for(
-                    self._p2p.call_protobuf_handler(
-                        self._peer,
-                        cls._get_handle_name(self._namespace, handler.method_name),
-                        input,
-                        handler.response_type,
-                    ),
+                    self._p2p.call_protobuf_handler(self._peer, handle_name, input, handler.response_type),
                     timeout=timeout,
                 )
 
+            if timeout is not None:
+                raise ValueError("Timeouts for handlers returning streams are not supported")
+            return await self._p2p.iterate_protobuf_handler(self._peer, handle_name, input, handler.response_type)
+
         caller.__name__ = handler.method_name
         return caller
 

+ 1 - 1
hivemind/proto/averaging.proto

@@ -45,7 +45,7 @@ message AveragingData {
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
-  bytes metadata = 5;       // reserved user-extendable metadata
+  double weight = 5;        // peers will be averaged in proportion to these weights
 }
 
 message DownloadRequest {}

+ 1 - 0
hivemind/utils/__init__.py

@@ -5,6 +5,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
+from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 57 - 11
hivemind/utils/asyncio.py

@@ -1,7 +1,8 @@
 import asyncio
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
+from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
 
 import uvloop
 
@@ -105,7 +106,7 @@ async def cancel_and_wait(awaitable: Awaitable) -> bool:
 async def amap_in_executor(
     func: Callable[..., T],
     *iterables: AsyncIterable,
-    max_prefetch: Optional[int] = None,
+    max_prefetch: int = 1,
     executor: Optional[ThreadPoolExecutor] = None,
 ) -> AsyncIterator[T]:
     """iterate from an async iterable in a background thread, yield results to async iterable"""
@@ -113,9 +114,14 @@ async def amap_in_executor(
     queue = asyncio.Queue(max_prefetch)
 
     async def _put_items():
-        async for args in azip(*iterables):
-            await queue.put(loop.run_in_executor(executor, func, *args))
-        await queue.put(None)
+        try:
+            async for args in azip(*iterables):
+                await queue.put(loop.run_in_executor(executor, func, *args))
+            await queue.put(None)
+        except Exception as e:
+            future = asyncio.Future()
+            future.set_exception(e)
+            await queue.put(future)
 
     task = asyncio.create_task(_put_items())
     try:
@@ -123,13 +129,23 @@ async def amap_in_executor(
         while future is not None:
             yield await future
             future = await queue.get()
-        await task
     finally:
-        if not task.done():
-            task.cancel()
-
-
-async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> AsyncIterator[T]:
+        awaitables = [task]
+        while not queue.empty():
+            future = queue.get_nowait()
+            if future is not None:
+                awaitables.append(future)
+        for coro in awaitables:
+            coro.cancel()
+            try:
+                await coro
+            except BaseException as e:
+                if isinstance(e, Exception):
+                    logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
+                # note: we do not reraise here because it is already in the finally clause
+
+
+async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: Optional[float]) -> AsyncIterator[T]:
     """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
     # based on https://stackoverflow.com/a/50245879
     iterator = iterable.__aiter__()
@@ -138,3 +154,33 @@ async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> Asyn
             yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
         except StopAsyncIteration:
             break
+
+
+async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Event()) -> AsyncIterator[T]:
+    """Iterate over an async iterable and set an event when the iteration has stopped, failed or terminated"""
+    try:
+        async for item in iterable:
+            yield item
+    finally:
+        event.set()
+
+
+class _AsyncContextWrapper(AbstractAsyncContextManager):
+    """Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
+
+    def __init__(self, context: AbstractContextManager):
+        self._context = context
+
+    async def __aenter__(self):
+        loop = asyncio.get_event_loop()
+        return await loop.run_in_executor(None, self._context.__enter__)
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        return self._context.__exit__(exc_type, exc_value, traceback)
+
+
+@asynccontextmanager
+async def enter_asynchronously(context: AbstractContextManager):
+    """Wrap a non-async context so that it can be entered asynchronously"""
+    async with _AsyncContextWrapper(context) as ret_value:
+        yield ret_value

+ 1 - 1
hivemind/utils/grpc.py

@@ -175,7 +175,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
 
-STREAMING_CHUNK_SIZE_BYTES = 2 ** 16
+STREAMING_CHUNK_SIZE_BYTES = 2**16
 
 
 def split_for_streaming(

+ 1 - 1
hivemind/utils/limits.py

@@ -3,7 +3,7 @@ from hivemind.utils.logging import get_logger
 logger = get_logger(__name__)
 
 
-def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
+def increase_file_limit(new_soft=2**15, new_hard=2**15):
     """Increase the maximum number of open files. On Linux, this allows spawning more processes/threads."""
     try:
         import resource  # local import to avoid ImportError for Windows users

+ 10 - 3
hivemind/utils/logging.py

@@ -15,6 +15,9 @@ if _env_colors is not None:
 else:
     use_colors = sys.stderr.isatty()
 
+_env_log_caller = os.getenv("HIVEMIND_ALWAYS_LOG_CALLER")
+always_log_caller = _env_log_caller is not None and _env_log_caller.lower() == "true"
+
 
 class HandlerMode(Enum):
     NOWHERE = 0
@@ -65,8 +68,12 @@ class CustomFormatter(logging.Formatter):
             record.created = record.origin_created
             record.msecs = (record.created - int(record.created)) * 1000
 
-        if not hasattr(record, "caller"):
-            record.caller = f"{record.name}.{record.funcName}:{record.lineno}"
+        if record.levelno != logging.INFO or always_log_caller:
+            if not hasattr(record, "caller"):
+                record.caller = f"{record.name}.{record.funcName}:{record.lineno}"
+            record.caller_block = f" [{TextStyle.BOLD}{record.caller}{TextStyle.RESET}]"
+        else:
+            record.caller_block = ""
 
         # Aliases for the format argument
         record.levelcolor = self._LEVEL_TO_COLOR[record.levelno]
@@ -84,7 +91,7 @@ def _initialize_if_necessary():
             return
 
         formatter = CustomFormatter(
-            fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
+            fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}]{caller_block} {message}",
             style="{",
             datefmt="%b %d %H:%M:%S",
         )

+ 7 - 3
hivemind/utils/mpfuture.py

@@ -18,6 +18,8 @@ from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
 
+torch.multiprocessing.set_sharing_strategy(os.environ.get("HIVEMIND_MEMORY_SHARING_STRATEGY", "file_system"))
+
 # flavour types
 ResultType = TypeVar("ResultType")
 PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection
@@ -138,7 +140,9 @@ class MPFuture(base.Future, Generic[ResultType]):
         async def _event_setter():
             self._aio_event.set()
 
-        if self._loop.is_running() and running_loop == self._loop:
+        if self._loop.is_closed():
+            return  # do nothing, the loop is already closed
+        elif self._loop.is_running() and running_loop == self._loop:
             asyncio.create_task(_event_setter())
         elif self._loop.is_running() and running_loop != self._loop:
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
@@ -201,8 +205,8 @@ class MPFuture(base.Future, Generic[ResultType]):
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
                 self._sender_pipe.send((self._uid, update_type, payload))
-        except (ConnectionError, BrokenPipeError, EOFError) as e:
-            logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
+        except (ConnectionError, BrokenPipeError, EOFError, OSError) as e:
+            logger.debug(f"No updates were sent: pipe to origin process was broken ({e})", exc_info=True)
 
     def set_result(self, result: ResultType):
         if os.getpid() == self._origin_pid:

+ 70 - 0
hivemind/utils/performance_ema.py

@@ -0,0 +1,70 @@
+import time
+from contextlib import contextmanager
+from threading import Lock
+from typing import Optional
+
+
+class PerformanceEMA:
+    """
+    A running estimate of performance (operations/sec) using adjusted exponential moving average
+    :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
+    """
+
+    def __init__(self, alpha: float = 0.1, eps: float = 1e-20, paused: bool = False):
+        self.alpha, self.eps, self.num_updates = alpha, eps, 0
+        self.ema_seconds_per_sample, self.samples_per_second = 0, eps
+        self.timestamp = time.perf_counter()
+        self.paused = paused
+        self.lock = Lock()
+
+    def update(self, task_size: float, interval: Optional[float] = None) -> float:
+        """
+        :param task_size: how many items were processed since last call
+        :param interval: optionally provide the time delta it took to process this task
+        :returns: current estimate of performance (samples per second), but at most
+        """
+        assert task_size > 0, f"Can't register processing {task_size} samples"
+        if not self.paused:
+            self.timestamp, old_timestamp = time.perf_counter(), self.timestamp
+            interval = interval if interval is not None else self.timestamp - old_timestamp
+        else:
+            assert interval is not None, "If PerformanceEMA is paused, please specify the time interval"
+        self.ema_seconds_per_sample = (
+            self.alpha * interval / task_size + (1 - self.alpha) * self.ema_seconds_per_sample
+        )
+        self.num_updates += 1
+        adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
+        self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
+        return self.samples_per_second
+
+    def reset_timer(self):
+        """Reset the time since the last update so that the next task performance is counted from current time"""
+        self.timestamp = time.perf_counter()
+
+    @contextmanager
+    def pause(self):
+        """While inside this context, EMA will not count the time passed towards the performance estimate"""
+        self.paused, was_paused = True, self.paused
+        try:
+            yield
+        finally:
+            self.paused = was_paused
+            self.reset_timer()
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
+
+    @contextmanager
+    def update_threadsafe(self, task_size: float):
+        """
+        Update the EMA throughput of a code that runs inside the context manager, supports multiple concurrent threads.
+
+        :param task_size: how many items were processed since last call
+        """
+        start_timestamp = time.perf_counter()
+        yield
+        with self.lock:
+            self.update(task_size, interval=time.perf_counter() - max(start_timestamp, self.timestamp))
+            # note: we define interval as such to support two distinct scenarios:
+            # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context
+            # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls

+ 2 - 2
hivemind/utils/serializer.py

@@ -35,7 +35,7 @@ class MSGPackSerializer(SerializerBase):
                 getattr(wrapped_type, "unpackb", None)
             ), f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
             if type_code in cls._ext_type_codes:
-                logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting.")
+                logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting")
             cls._ext_type_codes[type_code], cls._ext_types[wrapped_type] = wrapped_type, type_code
             return wrapped_type
 
@@ -60,7 +60,7 @@ class MSGPackSerializer(SerializerBase):
         elif type_code == cls._TUPLE_EXT_TYPE_CODE:
             return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False))
 
-        logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
+        logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is")
         return data
 
     @classmethod

+ 54 - 5
hivemind/utils/tensor_descr.py

@@ -8,6 +8,7 @@ import numpy as np
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.serializer import MSGPackSerializer
 
 DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
@@ -45,13 +46,25 @@ class TensorDescriptor(DescriptorBase):
             tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
         )
 
-    def make_empty(self, **kwargs):
+    def make_zeros(self, **kwargs):
         properties = asdict(self)
         properties.update(kwargs)
         properties.pop("compression")
-        return torch.empty(**properties)
+        return torch.zeros(**properties)
 
 
+def _str_to_torch_type(name: str, torch_type: type):
+    try:
+        value = getattr(torch, name.split(".")[-1])
+    except AttributeError:
+        raise ValueError(f"Invalid dtype: torch has no attribute {name}")
+    if not isinstance(value, torch_type):
+        raise ValueError(f"Invalid dtype: expected {torch_type}, got: {type(value)}")
+
+    return value
+
+
+@MSGPackSerializer.ext_serializable(0x51)
 @dataclass(repr=True, frozen=True)
 class BatchTensorDescriptor(TensorDescriptor):
     """torch.Tensor with a variable 0-th dimension, used to describe batched data"""
@@ -70,12 +83,48 @@ class BatchTensorDescriptor(TensorDescriptor):
             device=tensor.device,
             requires_grad=tensor.requires_grad,
             pin_memory=_safe_check_pinned(tensor),
-            compression=compression if tensor.is_floating_point() else CompressionType.NONE
+            compression=compression if tensor.is_floating_point() else CompressionType.NONE,
         )
 
-    def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
+    def make_zeros(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
-        return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
+        return super().make_zeros(size=(*batch_size, *self.shape[1:]), **kwargs)
+
+    def packb(self) -> bytes:
+        obj_dict = asdict(self)
+
+        obj_dict["dtype"] = str(self.dtype) if self.dtype is not None else None
+        obj_dict["layout"] = str(self.layout) if self.layout is not None else None
+
+        device = obj_dict.pop("device")
+        device_type, device_index = (device.type, device.index) if device is not None else (None, None)
+        obj_dict.update(
+            device_type=device_type,
+            device_index=device_index,
+        )
+
+        return MSGPackSerializer.dumps(obj_dict)
+
+    @classmethod
+    def unpackb(cls, raw: bytes) -> BatchTensorDescriptor:
+        obj_dict = MSGPackSerializer.loads(raw)
+
+        if obj_dict["dtype"] is not None:
+            obj_dict["dtype"] = _str_to_torch_type(obj_dict["dtype"], torch.dtype)
+
+        if obj_dict["layout"] is not None:
+            obj_dict["layout"] = _str_to_torch_type(obj_dict["layout"], torch.layout)
+
+        if obj_dict["device_type"] is not None:
+            obj_dict["device"] = torch.device(obj_dict["device_type"], obj_dict["device_index"])
+        else:
+            obj_dict["device"] = None
+
+        del obj_dict["device_type"], obj_dict["device_index"]
+
+        size = obj_dict.pop("size")[1:]
+
+        return BatchTensorDescriptor(*size, **obj_dict)
 
 
 def _safe_check_pinned(tensor: torch.Tensor) -> bool:

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.black]
 line-length = 119
-required-version = "21.6b0"
+required-version = "22.1.0"
 
 [tool.isort]
 profile = "black"

+ 6 - 4
requirements-dev.txt

@@ -1,9 +1,11 @@
-pytest
+pytest==6.2.5  # see https://github.com/pytest-dev/pytest/issues/9621
 pytest-forked
-pytest-asyncio
+pytest-asyncio==0.16.0
 pytest-cov
+coverage==6.0.2  # see https://github.com/pytest-dev/pytest-cov/issues/520
 tqdm
 scikit-learn
-black==21.6b0
-isort
+torchvision
+black==22.1.0
+isort==5.10.1
 psutil

+ 4 - 2
requirements-docs.txt

@@ -1,2 +1,4 @@
-recommonmark
-sphinx_rtd_theme
+recommonmark==0.5.0
+sphinx_rtd_theme==0.4.3
+docutils==0.16
+sphinx==4.2.0

+ 5 - 2
tests/conftest.py

@@ -1,13 +1,13 @@
 import asyncio
 import gc
-import multiprocessing as mp
 from contextlib import suppress
 
 import psutil
 import pytest
 
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.mpfuture import MPFuture, SharedBytes
+from hivemind.utils.mpfuture import MPFuture
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -33,6 +33,9 @@ def event_loop():
 def cleanup_children():
     yield
 
+    with RSAPrivateKey._process_wide_key_lock:
+        RSAPrivateKey._process_wide_key = None
+
     gc.collect()  # Call .__del__() for removed objects
 
     children = psutil.Process().children(recursive=True)

+ 11 - 11
tests/test_allreduce.py

@@ -33,7 +33,7 @@ async def test_partitioning():
 
     # note: this test does _not_ use parameterization to reuse sampled tensors
     for num_tensors in 1, 3, 5:
-        for part_size_bytes in 31337, 2 ** 20, 10 ** 10:
+        for part_size_bytes in 31337, 2**20, 10**10:
             for weights in [(1, 1), (0.333, 0.1667, 0.5003), (1.0, 0.0), [0.0, 0.4, 0.6, 0.0]]:
                 tensors = random.choices(all_tensors, k=num_tensors)
                 partition = TensorPartContainer(tensors, weights, part_size_bytes=part_size_bytes)
@@ -157,16 +157,16 @@ NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 @pytest.mark.parametrize(
     "peer_modes, averaging_weights, peer_fractions, part_size_bytes",
     [
-        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1), 2 ** 20),
-        ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1), 2 ** 20),
-        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
-        ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
-        ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4), 2 ** 20),
-        ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1), 2 ** 20),
-        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 2 ** 20),
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1), 2**20),
+        ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1), 2**20),
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0), 2**20),
+        ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0), 2**20),
+        ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4), 2**20),
+        ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1), 2**20),
+        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 2**20),
         ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 256),
         ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 19),
-        ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4), 2 ** 20),
+        ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4), 2**20),
     ],
 )
 @pytest.mark.forked
@@ -187,7 +187,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
     allreduce_protocols = []
-    for p2p in p2ps:
+    for i, p2p in enumerate(p2ps):
         allreduce_protocol = AllReduceRunner(
             p2p=p2p,
             servicer_type=AllReduceRunner,
@@ -197,7 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             modes=peer_modes,
-            weights=averaging_weights,
+            weight=averaging_weights[i],
             part_size_bytes=part_size_bytes,
         )
         await allreduce_protocol.add_p2p_handlers(p2p)

+ 213 - 0
tests/test_allreduce_fault_tolerance.py

@@ -0,0 +1,213 @@
+from __future__ import annotations
+
+import asyncio
+from enum import Enum, auto
+from typing import AsyncIterator
+
+import pytest
+import torch
+
+import hivemind
+from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.averaging.averager import *
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import MatchmakingException
+from hivemind.proto import averaging_pb2
+from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class Fault(Enum):
+    NONE = auto()
+    FAIL_BEFORE = auto()
+    FAIL_SENDING = auto()
+    SLOW_SENDING = auto()
+    FAIL_REDUCING = auto()
+    SLOW_REDUCING = auto()
+    CANCEL = auto()
+
+
+class FaultyAverager(hivemind.DecentralizedAverager):
+    def __init__(self, *args, fault: Fault = Fault.NONE, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            if self.fault == Fault.FAIL_BEFORE:
+                raise Exception("Oops, I failed!")
+
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
+                allreduce = FaultyAllReduceRunner(
+                    p2p=self._p2p,
+                    servicer_type=type(self),
+                    prefix=self.prefix,
+                    group_id=group_info.group_id,
+                    tensors=local_tensors,
+                    ordered_peer_ids=group_info.peer_ids,
+                    peer_fractions=peer_fractions,
+                    gathered=user_gathered,
+                    modes=modes,
+                    fault=self.fault,
+                    **kwargs,
+                )
+
+                with self.register_allreduce_group(group_info.group_id, allreduce):
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                            # all-reduce is performed asynchronously while iterating
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self._state_updated.set()
+
+                    else:
+                        async for _ in allreduce:  # trigger all-reduce by iterating
+                            raise ValueError("aux peers should not receive averaged tensors")
+
+                return allreduce.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+
+class FaultyAllReduceRunner(AllReduceRunner):
+    def __init__(self, *args, fault: Fault, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def rpc_aggregate_part(self, stream, context) -> AsyncIterator[averaging_pb2.AveragingData]:
+        if self.fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING):
+            async for i, message in aenumerate(super().rpc_aggregate_part(stream, context)):
+                yield message
+                if i == 2:
+                    if self.fault == Fault.FAIL_SENDING:
+                        yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                        break
+                    else:
+                        await asyncio.sleep(10)
+
+        elif self.fault == Fault.CANCEL:
+            yield averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        else:
+            async for message in super().rpc_aggregate_part(stream, context):
+                yield message
+
+    async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
+        parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
+
+        first_part = await anext(parts_aiter)
+        yield averaging_pb2.AveragingData(
+            code=averaging_pb2.PART_FOR_AVERAGING,
+            group_id=self.group_id,
+            tensor_part=first_part,
+            weight=self.weight,
+        )
+        if self.fault in (Fault.FAIL_SENDING, Fault.SLOW_SENDING):
+            last_reducer_index = self.group_size - 1 - (self.tensor_part_container.num_parts_by_peer[-1] == 0)
+            if peer_index == last_reducer_index:
+                if self.fault == Fault.FAIL_SENDING:
+                    raise Exception("Oops, I failed!")
+                else:
+                    await asyncio.sleep(10)
+        async for part in parts_aiter:
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "fault0, fault1",
+    [
+        (Fault.NONE, Fault.FAIL_BEFORE),
+        (Fault.FAIL_BEFORE, Fault.FAIL_BEFORE),
+        (Fault.SLOW_SENDING, Fault.FAIL_SENDING),
+        (Fault.FAIL_SENDING, Fault.FAIL_BEFORE),
+        (Fault.SLOW_REDUCING, Fault.FAIL_SENDING),
+        (Fault.FAIL_REDUCING, Fault.FAIL_REDUCING),
+        (Fault.NONE, Fault.CANCEL),
+    ],
+)
+def test_fault_tolerance(fault0: Fault, fault1: Fault):
+    def _make_tensors():
+        return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]
+
+    dht = hivemind.DHT(start=True)
+
+    averagers = []
+    for i in range(5):
+        averager = FaultyAverager(
+            _make_tensors(),
+            hivemind.DHT(initial_peers=dht.get_visible_maddrs(), start=True),
+            prefix="test",
+            request_timeout=0.3,
+            min_matchmaking_time=1.0,
+            next_chunk_timeout=0.5,
+            allreduce_timeout=5,
+            part_size_bytes=2**16,
+            client_mode=(i == 1),
+            start=True,
+            fault=fault0 if i == 0 else fault1 if i == 1 else Fault.NONE,
+        )
+        averagers.append(averager)
+
+    ref_numerators = [0, 0, 0, 0]
+    ref_denominator = 0
+
+    for averager in averagers:
+        if averager.fault not in (Fault.FAIL_BEFORE, Fault.CANCEL):
+            with averager.get_tensors() as tensors:
+                for i, tensor in enumerate(tensors):
+                    ref_numerators[i] = ref_numerators[i] + tensor.clone()
+                ref_denominator += 1
+
+    ref_tensors = [ref_numerator / ref_denominator for ref_numerator in ref_numerators]
+    flat_ref = torch.cat(list(map(torch.flatten, ref_tensors)))
+
+    flat_local_tensors = []
+    for averager in averagers:
+        with averager.get_tensors() as tensors:
+            flat_local_tensors.append(torch.cat(list(map(torch.flatten, tensors))))
+
+    futures = [averager.step(timeout=5, wait=False, allow_retries=False) for averager in averagers]
+    for i, averager in enumerate(averagers):
+        if averager.fault == Fault.CANCEL:
+            futures[i].cancel()
+
+    for future in futures[2:]:
+        assert future.result()
+
+    for averager, prev_local_tensors in zip(averagers[2:], flat_local_tensors[2:]):
+        with averager.get_tensors() as tensors:
+            flat_tensors = torch.cat(list(map(torch.flatten, tensors)))
+
+        diff_with_reference = abs(flat_ref - flat_tensors)
+
+        if all(fault == (Fault.FAIL_SENDING, Fault.SLOW_SENDING) for fault in (fault0, fault1)):
+            assert fault0 != Fault.FAIL_REDUCING and fault1 != Fault.FAIL_REDUCING
+            assert diff_with_reference[: len(diff_with_reference) // 2].max() < 1e-5
+        elif all(fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING) for fault in (fault0, fault1)):
+            diff_to_reference = abs(flat_ref - flat_tensors)
+            diff_to_local = abs(prev_local_tensors - flat_tensors)
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+            assert torch.all(torch.minimum(diff_to_reference, diff_to_local) < 1e-5).item()
+        elif any(fault == Fault.CANCEL for fault in (fault0, fault1)):
+            pass  # late cancel may result in an arbitrary mix of averaging results with and without the cancelled peer
+        elif fault0 == Fault.NONE:  # only peer1 in client mode may have failed
+            assert diff_with_reference.max() < 1e-5
+        else:
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+
+    for averager in averagers:
+        averager.shutdown()

+ 126 - 4
tests/test_averaging.py

@@ -1,4 +1,5 @@
 import random
+import time
 
 import numpy as np
 import pytest
@@ -7,6 +8,7 @@ import torch
 import hivemind
 import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.partition import AllreduceException
@@ -281,7 +283,7 @@ def test_load_balancing():
         load_balance_peers(100, (0, 0, 0))
 
     for i in range(10):
-        vector_size = np.random.randint(1, 1024 ** 3)
+        vector_size = np.random.randint(1, 1024**3)
         num_peers = np.random.randint(1, 256)
         scale = 1e-9 + np.random.rand() * 1e5
         bandwidths = np.random.rand(num_peers) * scale + 1e-6
@@ -370,7 +372,6 @@ def test_load_state_from_peers():
         target_group_size=2,
     )
 
-    dht_instances[1].get("demo-run.all_averagers")
     averager2 = TestAverager(
         [torch.randn(3), torch.rand(5)],
         dht=dht_instances[1],
@@ -379,6 +380,8 @@ def test_load_state_from_peers():
         target_group_size=2,
     )
 
+    time.sleep(0.5)
+
     assert num_calls == 0
     got_metadata, got_tensors = averager2.load_state_from_peers()
     assert num_calls == 1
@@ -397,7 +400,9 @@ def test_load_state_from_peers():
 
     averager1.allow_state_sharing = False
     assert averager2.load_state_from_peers() is None
+
     averager1.allow_state_sharing = True
+    time.sleep(0.5)
     got_metadata, got_tensors = averager2.load_state_from_peers()
     assert num_calls == 3
     assert got_metadata == super_metadata
@@ -406,6 +411,47 @@ def test_load_state_from_peers():
         instance.shutdown()
 
 
+@pytest.mark.forked
+def test_load_state_priority():
+    dht_instances = launch_dht_instances(4)
+
+    averagers = []
+    for i in range(4):
+        averager = hivemind.DecentralizedAverager(
+            [torch.randn(3), torch.rand(5), torch.tensor([i], dtype=torch.float32)],
+            dht=dht_instances[i],
+            start=True,
+            prefix="demo-run",
+            target_group_size=2,
+            allow_state_sharing=i != 1,
+        )
+        averager.state_sharing_priority = 5 - abs(2 - i)
+        averagers.append(averager)
+
+    time.sleep(0.5)
+    metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 2
+
+    metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 3
+
+    averagers[0].state_sharing_priority = 10
+    time.sleep(0.2)
+
+    metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 0
+
+    averagers[1].allow_state_sharing = False
+    averagers[2].allow_state_sharing = False
+    metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 3
+
+    for averager in averagers:
+        averager.shutdown()
+    for dht in dht_instances:
+        dht.shutdown()
+
+
 @pytest.mark.forked
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
@@ -420,6 +466,82 @@ def test_getset_bits():
     assert averager.get_group_bits() == "00101011101010"
 
 
+@pytest.mark.forked
+def test_averaging_trigger():
+    averagers = tuple(
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            min_matchmaking_time=0.5,
+            request_timeout=0.3,
+            prefix="mygroup",
+            initial_group_bits="",
+            start=True,
+        )
+        for dht in launch_dht_instances(4)
+    )
+
+    controls = []
+    for i, averager in enumerate(averagers):
+        controls.append(
+            averager.step(
+                wait=False,
+                scheduled_time=hivemind.get_dht_time() + 0.5,
+                weight=1.0,
+                require_trigger=i in (1, 2),
+            )
+        )
+
+    time.sleep(0.6)
+
+    c0, c1, c2, c3 = controls
+    assert not any(c.done() for c in controls)
+    assert c0.stage == AveragingStage.RUNNING_ALLREDUCE
+    assert c1.stage == AveragingStage.AWAITING_TRIGGER
+    assert c2.stage == AveragingStage.AWAITING_TRIGGER
+    assert c3.stage == AveragingStage.RUNNING_ALLREDUCE
+
+    c1.allow_allreduce()
+    c2.allow_allreduce()
+    time.sleep(0.5)
+    assert all(c.stage == AveragingStage.FINISHED for c in controls)
+    assert all(c.done() for c in controls)
+
+    # check that setting trigger twice does not raise error
+    c0.allow_allreduce()
+
+
+@pytest.mark.forked
+def test_averaging_cancel():
+    averagers = tuple(
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            min_matchmaking_time=0.5,
+            request_timeout=0.3,
+            client_mode=(i % 2 == 0),
+            prefix="mygroup",
+            start=True,
+        )
+        for i, dht in enumerate(launch_dht_instances(4))
+    )
+
+    step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]
+
+    time.sleep(0.1)
+    step_controls[0].cancel()
+    step_controls[1].cancel()
+
+    for i, control in enumerate(step_controls):
+        if i in (0, 1):
+            assert control.cancelled()
+        else:
+            assert control.result() is not None and len(control.result()) == 2
+
+    for averager in averagers:
+        averager.shutdown()
+
+
 @pytest.mark.forked
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)
@@ -433,7 +555,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.averaging.TrainingAverager(
+    averager1 = hivemind.TrainingAverager(
         opt1,
         average_gradients=True,
         average_parameters=True,
@@ -444,7 +566,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.averaging.TrainingAverager(
+    averager2 = hivemind.TrainingAverager(
         opt2,
         average_gradients=True,
         average_parameters=True,

+ 1 - 1
tests/test_compression.py

@@ -53,7 +53,7 @@ def test_serialize_tensor():
         assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
 
     tensor = torch.randn(512, 12288)
-    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
+    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
         _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
 
     _check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)

+ 2 - 2
tests/test_dht.py

@@ -72,7 +72,7 @@ async def dummy_dht_coro_stateful(self, node):
 
 async def dummy_dht_coro_long(self, node):
     await asyncio.sleep(0.25)
-    return self._x_dummy ** 2
+    return self._x_dummy**2
 
 
 async def dummy_dht_coro_for_cancel(self, node):
@@ -94,7 +94,7 @@ def test_run_coroutine():
     assert dht.run_coroutine(dummy_dht_coro_stateful) == 125
     assert dht.run_coroutine(dummy_dht_coro_stateful) == 126
     assert not hasattr(dht, "_x_dummy")
-    assert bg_task.result() == 126 ** 2
+    assert bg_task.result() == 126**2
 
     future = dht.run_coroutine(dummy_dht_coro_for_cancel, return_future=True)
     time.sleep(0.25)

+ 31 - 34
tests/test_moe.py

@@ -3,9 +3,12 @@ import numpy as np
 import pytest
 import torch
 
-import hivemind
-from hivemind.moe.client.expert import DUMMY
-from hivemind.moe.server import background_server, declare_experts, layers
+from hivemind.dht import DHT
+from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
+from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
+from hivemind.moe.server.layers import name_to_block
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 @pytest.mark.forked
@@ -16,11 +19,9 @@ def test_moe():
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
     ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn."
-        )
+        dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
 
         for i in range(3):
             out = dmoe(torch.randn(10, 16))
@@ -35,9 +36,9 @@ def test_no_experts():
     with background_server(
         expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
     ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
-        dmoe = hivemind.RemoteSwitchMixtureOfExperts(
+        dmoe = RemoteSwitchMixtureOfExperts(
             in_features=16,
             grid_size=(4, 4, 4),
             dht=dht,
@@ -74,10 +75,10 @@ def test_call_many(hidden_dim=16):
     ) as (server_endpoint, _):
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = hivemind.RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
+        e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
+        e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
 
-        mask, expert_outputs = hivemind.moe.client.moe._RemoteCallMany.apply(
+        mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
             k_min,
@@ -130,8 +131,8 @@ def test_remote_module_call(hidden_dim=16):
         optim_cls=None,
         no_dht=True,
     ) as (server_endpoint, _):
-        real_expert = hivemind.RemoteExpert("expert.0", server_endpoint)
-        fake_expert = hivemind.RemoteExpert("oiasfjiasjf", server_endpoint)
+        real_expert = RemoteExpert("expert.0", server_endpoint)
+        fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
 
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
@@ -152,12 +153,10 @@ def test_remote_module_call(hidden_dim=16):
 @pytest.mark.forked
 def test_beam_search_correctness():
     all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
-    dht = hivemind.DHT(start=True)
+    dht = DHT(start=True)
     assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
 
-    dmoe = hivemind.RemoteMixtureOfExperts(
-        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn."
-    )
+    dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
 
     for i in range(25):
         input = torch.randn(32)
@@ -174,7 +173,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[hivemind.RemoteExpert(uid, "") for uid in all_expert_uids]],
+            [[RemoteExpert(uid, "") for uid in all_expert_uids]],
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
@@ -197,7 +196,7 @@ def test_determinism(hidden_dim=16):
         optim_cls=None,
         no_dht=True,
     ) as (server_endpoint, _):
-        expert = hivemind.RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
+        expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
 
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -212,8 +211,8 @@ def test_determinism(hidden_dim=16):
 @pytest.mark.forked
 def test_compute_expert_scores():
     try:
-        dht = hivemind.DHT(start=True)
-        moe = hivemind.moe.RemoteMixtureOfExperts(
+        dht = DHT(start=True)
+        moe = RemoteMixtureOfExperts(
             dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
         )
         gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
@@ -221,13 +220,11 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
             [
-                hivemind.RemoteExpert(
-                    uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337"
-                )
+                RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
                 for expert_i in range(len(ii[batch_i]))
             ]
             for batch_i in range(len(ii))
-        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
+        ]  # note: these experts do not exist on server, we use them only to test compute_expert_scores
         logits = moe.compute_expert_scores([gx, gy], batch_experts)
         torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
         assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
@@ -247,25 +244,25 @@ def test_client_anomaly_detection():
 
     experts = {}
     for i in range(4):
-        expert = layers.name_to_block["ffn"](HID_DIM)
-        experts[f"expert.{i}"] = hivemind.ExpertBackend(
+        expert = name_to_block["ffn"](HID_DIM)
+        experts[f"expert.{i}"] = ExpertBackend(
             name=f"expert.{i}",
             expert=expert,
             optimizer=torch.optim.Adam(expert.parameters()),
-            args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
-            outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
+            args_schema=(BatchTensorDescriptor(HID_DIM),),
+            outputs_schema=BatchTensorDescriptor(HID_DIM),
             max_batch_size=16,
         )
 
     experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
 
-    dht = hivemind.DHT(start=True)
-    server = hivemind.moe.Server(dht, experts, num_connection_handlers=1)
+    dht = DHT(start=True)
+    server = Server(dht, experts, num_connection_handlers=1)
     server.start()
     try:
         server.ready.wait()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
+        dmoe = RemoteMixtureOfExperts(
             in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
         )
 
@@ -282,7 +279,7 @@ def test_client_anomaly_detection():
         with pytest.raises(ValueError):
             inf_loss.backward()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(
+        dmoe = RemoteMixtureOfExperts(
             in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
         )
         output = dmoe(input)

+ 385 - 0
tests/test_optimizer.py

@@ -0,0 +1,385 @@
+import ctypes
+import multiprocessing as mp
+import time
+from functools import partial
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import hivemind
+from hivemind.averaging.control import AveragingStage
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.optimizer import Optimizer
+from hivemind.optim.progress_tracker import ProgressTracker
+from hivemind.optim.state_averager import TrainingStateAverager
+from hivemind.utils.crypto import RSAPrivateKey
+
+
+@pytest.mark.forked
+def test_grad_averager():
+    dht1 = hivemind.DHT(start=True)
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager1 = GradientAverager(
+        model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
+    )
+
+    dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager2 = GradientAverager(
+        model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
+    )
+
+    control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
+    control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
+
+    for i in range(10):
+        time.sleep(0.1)
+        if i % 3 == 0:
+            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1.backward()
+            averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
+            model1.zero_grad()
+        else:
+            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2.backward()
+            averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
+            # note: we do not call zero grad here because reuse_grad_buffers=True
+
+    assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
+    peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
+    assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
+    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
+
+    assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
+    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
+
+    averager1.step(control=control1, wait=False)
+    averager2.step(control=control2, wait=False)
+    for step in (control1, control2):
+        step.result()  # wait for all-reduce to finish
+
+    peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
+    peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
+    ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
+    with averager1.use_averaged_gradients():
+        assert torch.allclose(model1.w.grad, ref_average)
+    with averager2.use_averaged_gradients():
+        assert torch.allclose(model2.w.grad, ref_average)
+
+    # after no longer use_averaged_gradients
+    assert not torch.allclose(model1.w.grad, ref_average)
+    assert not torch.allclose(model2.w.grad, ref_average)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
+    [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
+)
+def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    torch.manual_seed(1337)
+    torch.use_deterministic_algorithms(True)
+    # note: use_deterministic_algorithms does not affect further tests because this test is forked
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    extras1 = (torch.randn(2, 2), -torch.rand(1))
+    extras2 = (-torch.randn(2, 2), torch.rand(1))
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        sync_epoch_when_averaging=sync_epoch_when_averaging,
+        average_opt_statistics=("exp_avg_sq",),
+        offload_optimizer=offload_optimizer,
+        reuse_tensors=reuse_tensors,
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, params=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
+    )
+    avgr2 = TrainingStateAverager(
+        dht=dht2, params=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
+    )
+
+    x = torch.ones(2)
+
+    for step in range(20):
+        F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
+        avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
+
+        F.mse_loss(model2(x), -torch.ones(3)).backward()
+        avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
+
+    assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
+    assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
+    assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
+    assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
+
+    stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    assert not torch.allclose(stats1, stats2)
+
+    avgr1.step(increment_epoch=True)
+
+    avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+    avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+
+    avgr1.step(wait_for_delayed_updates=True)
+    avgr2.step(wait_for_delayed_updates=True)
+
+    assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
+    assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert avgr1.local_epoch == 2
+    assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
+
+
+@pytest.mark.forked
+def test_load_state_from_peers():
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.SGD, lr=0.1),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+    )
+
+    avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
+
+    avgr2.local_epoch = 1337
+    model2.weight.data[...] = 42
+    time.sleep(0.1)
+
+    avgr1.load_state_from_peers()
+    assert avgr1.local_epoch == 1337
+    assert torch.all(model1.weight == 42).item()
+    assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
+
+
+@pytest.mark.forked
+def test_progress_tracker():
+    # note to a curious reader: no, you cannot reduce the timings without compromising realism or stability
+    prefix = "my_exp"
+    target_batch_size = 256
+    dht_root = hivemind.DHT(start=True)
+    barrier = mp.Barrier(parties=5)
+    delayed_start_evt = mp.Event()
+    finished_evt = mp.Event()
+    emas = mp.Array(ctypes.c_double, 5)
+
+    def run_worker(index: int, batch_size: int, period: float, **kwargs):
+        dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
+        tracker = ProgressTracker(
+            dht,
+            prefix,
+            target_batch_size,
+            start=True,
+            min_refresh_period=0.1,
+            default_refresh_period=0.2,
+            max_refresh_period=0.5,
+            private_key=RSAPrivateKey(),
+            **kwargs,
+        )
+
+        barrier.wait()
+        if index == 4:
+            delayed_start_evt.wait()
+
+        local_epoch = 2 if index == 4 else 0
+        samples_accumulated = 0
+
+        while True:
+            time.sleep(period)
+            if finished_evt.is_set():
+                break
+
+            samples_accumulated += batch_size
+            tracker.report_local_progress(local_epoch, samples_accumulated)
+
+            if tracker.ready_to_update_epoch:
+                if index == 4 and local_epoch >= 4:
+                    time.sleep(0.5)
+                    break
+
+                with tracker.pause_updates():
+                    local_epoch = tracker.update_epoch(local_epoch + 1)
+                    samples_accumulated = 0
+
+        emas[index] = tracker.performance_ema.samples_per_second
+        tracker.shutdown()
+        dht.shutdown()
+
+    workers = [
+        mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
+        mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
+        mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
+        mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
+    ]
+    for worker in workers:
+        worker.start()
+
+    tracker = ProgressTracker(
+        dht_root,
+        prefix,
+        target_batch_size,
+        start=True,
+        min_refresh_period=0.1,
+        default_refresh_period=0.2,
+        max_refresh_period=0.5,
+    )
+    barrier.wait()
+
+    local_epoch = 0
+    last_timestamp = hivemind.get_dht_time()
+    step_time_deltas = []
+
+    while local_epoch < 6:
+        time.sleep(0.1)
+
+        if tracker.ready_to_update_epoch:
+            with tracker.pause_updates():
+                local_epoch = tracker.update_epoch(local_epoch + 1)
+
+            time_delta = hivemind.get_dht_time() - last_timestamp
+            if local_epoch == 2:
+                delayed_start_evt.set()
+
+            last_timestamp = hivemind.get_dht_time()
+            step_time_deltas.append(time_delta)
+
+    finished_evt.set()
+    for worker in workers:
+        worker.join()
+
+    tracker.shutdown()
+    dht_root.shutdown()
+    assert not tracker.is_alive()
+
+    mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
+    for i in (0, 1, 5):  # Without the 4th worker (the fastest one)
+        assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
+    for i in (2, 3, 4):  # With the 4th worker
+        assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
+    assert emas[1] < emas[2] < emas[3] < emas[4]
+    assert tracker.performance_ema.samples_per_second < 1e-9
+
+
+@pytest.mark.forked
+def test_optimizer(
+    num_peers: int = 1,
+    num_clients: int = 0,
+    target_batch_size: int = 32,
+    total_epochs: int = 3,
+    reuse_grad_buffers: bool = True,
+    delay_grad_averaging: bool = True,
+    delay_optimizer_step: bool = True,
+    average_state_every: int = 1,
+):
+    dht = hivemind.DHT(start=True)
+
+    features = torch.randn(100, 5)
+    targets = features @ torch.randn(5, 1)
+    optimizer = None
+    total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
+        nonlocal optimizer
+        model = nn.Linear(5, 1)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id="test_run",
+            target_batch_size=target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=0.1),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
+            averager_opts=dict(request_timeout=0.5),
+            matchmaking_time=1.0,
+            averaging_timeout=5.0,
+            reuse_grad_buffers=reuse_grad_buffers,
+            delay_grad_averaging=delay_grad_averaging,
+            delay_optimizer_step=delay_optimizer_step,
+            average_state_every=average_state_every,
+            client_mode=client_mode,
+            verbose=False,
+        )
+        optimizer.load_state_from_peers()
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < total_epochs:
+            time.sleep(max(0.0, prev_time + batch_time - time.perf_counter()))
+            batch = torch.randint(0, len(features), (batch_size,))
+
+            loss = F.mse_loss(model(features[batch]), targets[batch])
+            loss.backward()
+
+            optimizer.step()
+
+            total_samples_accumulated.value += batch_size
+
+            if not reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(1.0)
+        optimizer.shutdown()
+        return optimizer
+
+    peers = []
+
+    for index in range(num_peers):
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                kwargs=dict(
+                    batch_size=4 + index,
+                    batch_time=0.3 + 0.2 * index,
+                    client_mode=(index >= num_peers - num_clients),
+                ),
+            )
+        )
+
+    for peer in peers[1:]:
+        peer.start()
+    peers[0].run()
+    for peer in peers[1:]:
+        peer.join()
+
+    assert isinstance(optimizer, Optimizer)
+    assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
+    expected_samples_accumulated = target_batch_size * total_epochs
+    assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
+    assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
+
+    assert not optimizer.state_averager.is_alive()
+    assert not optimizer.grad_averager.is_alive()
+    assert not optimizer.tracker.is_alive()
+    assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()

+ 3 - 3
tests/test_p2p_daemon.py

@@ -89,7 +89,7 @@ async def test_unary_handler_edge_cases():
     p2p_replica = await P2P.replicate(p2p.daemon_listen_maddr)
 
     async def square_handler(data: test_pb2.TestRequest, context):
-        return test_pb2.TestResponse(number=data.number ** 2)
+        return test_pb2.TestResponse(number=data.number**2)
 
     await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
 
@@ -202,7 +202,7 @@ async def handle_square_stream(_, reader: asyncio.StreamReader, writer: asyncio.
             except asyncio.IncompleteReadError:
                 break
 
-            result = x ** 2
+            result = x**2
 
             await P2P.send_raw_data(MSGPackSerializer.dumps(result), writer)
 
@@ -215,7 +215,7 @@ async def validate_square_stream(reader: asyncio.StreamReader, writer: asyncio.S
             await P2P.send_raw_data(MSGPackSerializer.dumps(x), writer)
             result = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
 
-            assert result == x ** 2
+            assert result == x**2
 
 
 @pytest.mark.asyncio

+ 7 - 7
tests/test_p2p_daemon_bindings.py

@@ -38,15 +38,15 @@ PAIRS_INT_SERIALIZED_VALID = (
     (0, b"\x00"),
     (1, b"\x01"),
     (128, b"\x80\x01"),
-    (2 ** 32, b"\x80\x80\x80\x80\x10"),
-    (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
+    (2**32, b"\x80\x80\x80\x80\x10"),
+    (2**64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
 )
 
 PAIRS_INT_SERIALIZED_OVERFLOW = (
-    (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
-    (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (2**64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (2**64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
     (
-        2 ** 128,
+        2**128,
         b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04",
     ),
 )
@@ -94,7 +94,7 @@ async def test_write_unsigned_varint_overflow(integer):
         await write_unsigned_varint(s, integer)
 
 
-@pytest.mark.parametrize("integer", (-1, -(2 ** 32), -(2 ** 64), -(2 ** 128)))
+@pytest.mark.parametrize("integer", (-1, -(2**32), -(2**64), -(2**128)))
 @pytest.mark.asyncio
 async def test_write_unsigned_varint_negative(integer):
     s = MockWriter()
@@ -125,7 +125,7 @@ async def test_read_write_unsigned_varint_max_bits_edge(max_bits):
     Test edge cases with different `max_bits`
     """
     for i in range(-3, 0):
-        integer = i + (2 ** max_bits)
+        integer = i + 2**max_bits
         s = MockReaderWriter()
         await write_unsigned_varint(s, integer, max_bits=max_bits)
         s.seek(0, 0)

+ 8 - 6
tests/test_p2p_servicer.py

@@ -21,7 +21,7 @@ async def server_client():
 async def test_unary_unary(server_client):
     class ExampleServicer(ServicerBase):
         async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
-            return test_pb2.TestResponse(number=request.number ** 2)
+            return test_pb2.TestResponse(number=request.number**2)
 
     server, client = server_client
     servicer = ExampleServicer()
@@ -68,8 +68,9 @@ async def test_unary_stream(server_client):
     await servicer.add_p2p_handlers(server)
     stub = ExampleServicer.get_stub(client, server.peer_id)
 
+    stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
     i = 0
-    async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
+    async for item in stream:
         assert item == test_pb2.TestResponse(number=i)
         i += 1
     assert i == 10
@@ -82,8 +83,8 @@ async def test_stream_stream(server_client):
             self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
         ) -> AsyncIterator[test_pb2.TestResponse]:
             async for item in stream:
-                yield test_pb2.TestResponse(number=item.number ** 2)
-                yield test_pb2.TestResponse(number=item.number ** 3)
+                yield test_pb2.TestResponse(number=item.number**2)
+                yield test_pb2.TestResponse(number=item.number**3)
 
     server, client = server_client
     servicer = ExampleServicer()
@@ -94,8 +95,9 @@ async def test_stream_stream(server_client):
         for i in range(10):
             yield test_pb2.TestRequest(number=i)
 
+    stream = await stub.rpc_powers(generate_requests())
     i = 0
-    async for item in stub.rpc_powers(generate_requests()):
+    async for item in stream:
         if i % 2 == 0:
             assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
         else:
@@ -140,7 +142,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
         writer.close()
     elif cancel_reason == "close_generator":
         stub = ExampleServicer.get_stub(client, server.peer_id)
-        iter = stub.rpc_wait(test_pb2.TestRequest(number=10))
+        iter = await stub.rpc_wait(test_pb2.TestRequest(number=10))
 
         assert await anext(iter) == test_pb2.TestResponse(number=11)
         await asyncio.sleep(0.25)

+ 81 - 4
tests/test_util_modules.py

@@ -3,6 +3,7 @@ import concurrent.futures
 import multiprocessing as mp
 import random
 import time
+from concurrent.futures import ThreadPoolExecutor
 
 import numpy as np
 import pytest
@@ -13,7 +14,7 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
+from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils.asyncio import (
     achain,
     aenumerate,
@@ -23,10 +24,13 @@ from hivemind.utils.asyncio import (
     anext,
     as_aiter,
     asingle,
+    attach_event_on_finished,
     azip,
     cancel_and_wait,
+    enter_asynchronously,
 )
 from hivemind.utils.mpfuture import InvalidStateError
+from hivemind.utils.performance_ema import PerformanceEMA
 
 
 @pytest.mark.forked
@@ -309,6 +313,8 @@ def test_many_futures():
     p.start()
 
     some_fork_futures = receiver.recv()
+
+    time.sleep(0.1)  # giving enough time for the futures to be destroyed
     assert len(hivemind.MPFuture._active_futures) == 700
 
     for future in some_fork_futures:
@@ -319,6 +325,7 @@ def test_many_futures():
     evt.set()
     for future in main_futures:
         future.cancel()
+    time.sleep(0.1)  # giving enough time for the futures to be destroyed
     assert len(hivemind.MPFuture._active_futures) == 0
     p.join()
 
@@ -390,7 +397,7 @@ def test_split_parts():
     chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
     assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000))
 
-    chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10 ** 9))
+    chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10**9))
     assert len(chunks3) == 1
 
     compressed_tensor_part = serialize_torch_tensor(tensor, CompressionType.FLOAT16, allow_inplace=False)
@@ -433,8 +440,8 @@ async def test_asyncio_utils():
     assert res == list(range(len(res)))
 
     num_steps = 0
-    async for elem in amap_in_executor(lambda x: x ** 2, as_aiter(*range(100)), max_prefetch=5):
-        assert elem == num_steps ** 2
+    async for elem in amap_in_executor(lambda x: x**2, as_aiter(*range(100)), max_prefetch=5):
+        assert elem == num_steps**2
         num_steps += 1
     assert num_steps == 100
 
@@ -490,6 +497,18 @@ async def test_asyncio_utils():
 
     assert num_steps == 2
 
+    event = asyncio.Event()
+    async for i in attach_event_on_finished(iterate_with_delays([0, 0, 0, 0, 0]), event):
+        assert not event.is_set()
+    assert event.is_set()
+
+    event = asyncio.Event()
+    sleepy_aiter = iterate_with_delays([0.1, 0.1, 0.3, 0.1, 0.1])
+    with pytest.raises(asyncio.TimeoutError):
+        async for _ in attach_event_on_finished(aiter_with_timeout(sleepy_aiter, timeout=0.2), event):
+            assert not event.is_set()
+    assert event.is_set()
+
 
 @pytest.mark.asyncio
 async def test_cancel_and_wait():
@@ -521,3 +540,61 @@ async def test_cancel_and_wait():
     await asyncio.sleep(0.05)
     assert not await cancel_and_wait(task_with_result)
     assert not await cancel_and_wait(task_with_error)
+
+
+@pytest.mark.asyncio
+async def test_async_context():
+    lock = mp.Lock()
+
+    async def coro1():
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.2)
+
+    async def coro2():
+        await asyncio.sleep(0.1)
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.1)
+
+    await asyncio.wait_for(asyncio.gather(coro1(), coro2()), timeout=0.5)
+    # running this without enter_asynchronously would deadlock the event loop
+
+
+def test_batch_tensor_descriptor_msgpack():
+    tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
+    tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))
+
+    assert (
+        tensor_descr.size == tensor_descr_roundtrip.size
+        and tensor_descr.dtype == tensor_descr_roundtrip.dtype
+        and tensor_descr.layout == tensor_descr_roundtrip.layout
+        and tensor_descr.device == tensor_descr_roundtrip.device
+        and tensor_descr.requires_grad == tensor_descr_roundtrip.requires_grad
+        and tensor_descr.pin_memory == tensor_descr.pin_memory
+        and tensor_descr.compression == tensor_descr.compression
+    )
+
+
+@pytest.mark.parametrize("max_workers", [1, 2, 10])
+def test_performance_ema_threadsafe(
+    max_workers: int,
+    interval: float = 0.01,
+    num_updates: int = 100,
+    alpha: float = 0.05,
+    bias_power: float = 0.7,
+    tolerance: float = 0.05,
+):
+    def run_task(ema):
+        task_size = random.randint(1, 4)
+        with ema.update_threadsafe(task_size):
+            time.sleep(task_size * interval * (0.9 + 0.2 * random.random()))
+            return task_size
+
+    with ThreadPoolExecutor(max_workers) as pool:
+        ema = PerformanceEMA(alpha=alpha)
+        start_time = time.perf_counter()
+        futures = [pool.submit(run_task, ema) for i in range(num_updates)]
+        total_size = sum(future.result() for future in futures)
+        end_time = time.perf_counter()
+        target = total_size / (end_time - start_time)
+        assert ema.samples_per_second >= (1 - tolerance) * target * max_workers ** (bias_power - 1)
+        assert ema.samples_per_second <= (1 + tolerance) * target