Преглед изворни кода

Merge branch 'master' into unary-handlers

Denis Mazur пре 4 година
родитељ
комит
b058e6e6f8
84 измењених фајлова са 1043 додато и 985 уклоњено
  1. 21 0
      .github/workflows/check-style.yml
  2. 0 13
      .github/workflows/check_style.yml
  3. 3 3
      .github/workflows/run-tests.yml
  4. 3 2
      CONTRIBUTING.md
  5. 32 27
      README.md
  6. 2 3
      benchmarks/benchmark_averaging.py
  7. 1 2
      benchmarks/benchmark_tensor_compression.py
  8. 2 2
      benchmarks/benchmark_throughput.py
  9. 1 2
      docs/conf.py
  10. 1 1
      docs/user/dht.md
  11. 1 5
      examples/albert/arguments.py
  12. 3 3
      examples/albert/run_trainer.py
  13. 2 2
      examples/albert/run_training_monitor.py
  14. 0 1
      examples/albert/utils.py
  15. 5 5
      hivemind/__init__.py
  16. 88 82
      hivemind/averaging/allreduce.py
  17. 46 99
      hivemind/averaging/averager.py
  18. 6 6
      hivemind/averaging/group_info.py
  19. 25 20
      hivemind/averaging/key_manager.py
  20. 2 1
      hivemind/averaging/load_balancing.py
  21. 91 82
      hivemind/averaging/matchmaking.py
  22. 6 7
      hivemind/averaging/partition.py
  23. 3 3
      hivemind/averaging/training.py
  24. 55 9
      hivemind/dht/__init__.py
  25. 0 1
      hivemind/dht/crypto.py
  26. 10 6
      hivemind/dht/node.py
  27. 8 8
      hivemind/dht/protocol.py
  28. 2 1
      hivemind/dht/routing.py
  29. 1 1
      hivemind/dht/storage.py
  30. 1 1
      hivemind/dht/traverse.py
  31. 2 2
      hivemind/hivemind_cli/run_server.py
  32. 1 1
      hivemind/moe/__init__.py
  33. 11 11
      hivemind/moe/client/beam_search.py
  34. 3 3
      hivemind/moe/client/expert.py
  35. 6 6
      hivemind/moe/client/moe.py
  36. 3 3
      hivemind/moe/client/switch_moe.py
  37. 12 7
      hivemind/moe/server/__init__.py
  38. 3 3
      hivemind/moe/server/connection_handler.py
  39. 7 7
      hivemind/moe/server/dht_handler.py
  40. 3 3
      hivemind/moe/server/expert_backend.py
  41. 1 1
      hivemind/moe/server/expert_uid.py
  42. 1 1
      hivemind/moe/server/layers/custom_experts.py
  43. 1 1
      hivemind/moe/server/runtime.py
  44. 3 3
      hivemind/moe/server/task_pool.py
  45. 1 1
      hivemind/optim/__init__.py
  46. 1 1
      hivemind/optim/adaptive.py
  47. 4 4
      hivemind/optim/collaborative.py
  48. 3 3
      hivemind/optim/simple.py
  49. 1 1
      hivemind/p2p/__init__.py
  50. 59 88
      hivemind/p2p/p2p_daemon.py
  51. 54 27
      hivemind/p2p/servicer.py
  52. 8 14
      hivemind/proto/averaging.proto
  53. 2 2
      hivemind/proto/dht.proto
  54. 3 3
      hivemind/utils/__init__.py
  55. 15 4
      hivemind/utils/asyncio.py
  56. 2 3
      hivemind/utils/compression.py
  57. 2 2
      hivemind/utils/grpc.py
  58. 137 161
      hivemind/utils/mpfuture.py
  59. 1 2
      hivemind/utils/networking.py
  60. 1 1
      hivemind/utils/serializer.py
  61. 1 1
      hivemind/utils/tensor_descr.py
  62. 2 1
      hivemind/utils/timed_storage.py
  63. 7 0
      pyproject.toml
  64. 1 1
      requirements-dev.txt
  65. 22 2
      tests/conftest.py
  66. 28 47
      tests/test_allreduce.py
  67. 1 2
      tests/test_auth.py
  68. 89 82
      tests/test_averaging.py
  69. 4 4
      tests/test_dht.py
  70. 2 2
      tests/test_dht_crypto.py
  71. 2 2
      tests/test_dht_experts.py
  72. 24 11
      tests/test_dht_node.py
  73. 2 2
      tests/test_dht_storage.py
  74. 1 1
      tests/test_dht_validation.py
  75. 1 1
      tests/test_expert_backend.py
  76. 1 2
      tests/test_moe.py
  77. 25 13
      tests/test_p2p_daemon.py
  78. 2 1
      tests/test_p2p_daemon_bindings.py
  79. 17 13
      tests/test_p2p_servicer.py
  80. 2 2
      tests/test_routing.py
  81. 3 2
      tests/test_training.py
  82. 8 11
      tests/test_util_modules.py
  83. 21 9
      tests/test_utils/dht_swarms.py
  84. 5 6
      tests/test_utils/p2p_daemon.py

+ 21 - 0
.github/workflows/check-style.yml

@@ -0,0 +1,21 @@
+name: Check style
+
+on: [ push ]
+
+jobs:
+  black:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: psf/black@stable
+        with:
+          options: "--check --diff"
+          version: "21.6b0"
+  isort:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: actions/setup-python@v2
+        with:
+          python-version: 3.8
+      - uses: isort/isort-action@master

+ 0 - 13
.github/workflows/check_style.yml

@@ -1,13 +0,0 @@
-name: Check style
-
-on: [ push ]
-
-jobs:
-  black:
-    runs-on: ubuntu-latest
-    steps:
-      - uses: actions/checkout@v2
-      - uses: psf/black@stable
-        with:
-          options: "--check"
-          version: "21.6b0"

+ 3 - 3
.github/workflows/run-tests.yml

@@ -33,7 +33,7 @@ jobs:
       - name: Test
       - name: Test
         run: |
         run: |
           cd tests
           cd tests
-          pytest --durations=0 --durations-min=1.0
+          pytest --durations=0 --durations-min=1.0 -v
 
 
   build_and_test_p2pd:
   build_and_test_p2pd:
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
@@ -60,7 +60,7 @@ jobs:
       - name: Test
       - name: Test
         run: |
         run: |
           cd tests
           cd tests
-          pytest -k "p2p" 
+          pytest -k "p2p" -v
 
 
   codecov_in_develop_mode:
   codecov_in_develop_mode:
 
 
@@ -87,6 +87,6 @@ jobs:
           pip install -e .
           pip install -e .
       - name: Test
       - name: Test
         run: |
         run: |
-          pytest --cov=hivemind tests
+          pytest --cov=hivemind -v tests
       - name: Upload coverage to Codecov
       - name: Upload coverage to Codecov
         uses: codecov/codecov-action@v1
         uses: codecov/codecov-action@v1

+ 3 - 2
CONTRIBUTING.md

@@ -34,10 +34,11 @@ with the following rules:
 
 
 ## Code style
 ## Code style
 
 
-* We use [black](https://github.com/psf/black) for code formatting. Before submitting a PR, make sure to install and
-  run `black .` in the root of the repository.
 * The code must follow [PEP8](https://www.python.org/dev/peps/pep-0008/) unless absolutely necessary. Also, each line
 * The code must follow [PEP8](https://www.python.org/dev/peps/pep-0008/) unless absolutely necessary. Also, each line
   cannot be longer than 119 characters.
   cannot be longer than 119 characters.
+* We use [black](https://github.com/psf/black) for code formatting and [isort](https://github.com/PyCQA/isort) for 
+  import sorting. Before submitting a PR, make sure to install and run `black .` and `isort .` in the root of the
+  repository.
 * We highly encourage the use of [typing](https://docs.python.org/3/library/typing.html) where applicable.
 * We highly encourage the use of [typing](https://docs.python.org/3/library/typing.html) where applicable.
 * Use `get_logger` from `hivemind.utils.logging` to log any information instead of `print`ing directly to standard
 * Use `get_logger` from `hivemind.utils.logging` to log any information instead of `print`ing directly to standard
   output/error streams.
   output/error streams.

+ 32 - 27
README.md

@@ -2,7 +2,7 @@
 
 
 [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest)
 [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest)
 [![PyPI version](https://img.shields.io/pypi/v/hivemind.svg)](https://pypi.org/project/hivemind/)
 [![PyPI version](https://img.shields.io/pypi/v/hivemind.svg)](https://pypi.org/project/hivemind/)
-[![Discord](https://img.shields.io/static/v1?style=default&label=Discord&logo=discord&message=join)](https://discord.gg/xC7ucM8j)
+[![Discord](https://img.shields.io/static/v1?style=default&label=Discord&logo=discord&message=join)](https://discord.gg/uGugx9zYvN)
 [![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions)
 [![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions)
 ![Codecov](https://img.shields.io/codecov/c/github/learning-at-home/hivemind)
 ![Codecov](https://img.shields.io/codecov/c/github/learning-at-home/hivemind)
 [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
 [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
@@ -18,21 +18,21 @@ large model on hundreds of computers from different universities, companies, and
   network.
   network.
 * Fault-tolerant backpropagation: forward and backward passes succeed even if some nodes are unresponsive or take too
 * Fault-tolerant backpropagation: forward and backward passes succeed even if some nodes are unresponsive or take too
   long to respond.
   long to respond.
-* Decentralized parameter averaging: iteratively aggregate updates from multiple
-  workers without the need to synchronize across the entire network ([paper](https://arxiv.org/abs/2103.03239)).
+* Decentralized parameter averaging: iteratively aggregate updates from multiple workers without the need to
+  synchronize across the entire network ([paper](https://arxiv.org/abs/2103.03239)).
 * Train neural networks of arbitrary size: parts of their layers are distributed across the participants with the
 * Train neural networks of arbitrary size: parts of their layers are distributed across the participants with the
-  decentralized mixture-of-experts ([paper](https://arxiv.org/abs/2002.04013)).
+  Decentralized Mixture-of-Experts ([paper](https://arxiv.org/abs/2002.04013)).
 
 
 To learn more about the ideas behind this library, see https://learning-at-home.github.io or read
 To learn more about the ideas behind this library, see https://learning-at-home.github.io or read
 the [NeurIPS 2020 paper](https://arxiv.org/abs/2002.04013).
 the [NeurIPS 2020 paper](https://arxiv.org/abs/2002.04013).
 
 
 ## Installation
 ## Installation
 
 
-Before installing, make sure that your environment has Python 3.7+ 
-and [PyTorch](https://pytorch.org/get-started/locally/#start-locally) 1.6.0 or newer.
-You can install them either natively or with [Anaconda](https://www.anaconda.com/products/individual).
+Before installing, make sure that your environment has Python 3.7+
+and [PyTorch](https://pytorch.org/get-started/locally/#start-locally) 1.6.0 or newer. They can be installed either
+natively or with [Anaconda](https://www.anaconda.com/products/individual).
 
 
-You can install [the latest release](https://pypi.org/project/hivemind) with pip or build hivemind from source.
+You can get [the latest release](https://pypi.org/project/hivemind) with pip or build hivemind from source.
 
 
 ### With pip
 ### With pip
 
 
@@ -62,24 +62,29 @@ Before running the compilation, please ensure that your machine has a recent ver
 of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
 of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
 
 
 ### System requirements
 ### System requirements
-- __Linux__ is the default OS for which hivemind is developed and tested. We recommend Ubuntu 18.04+ (64-bit),
-  but other 64-bit distros should work as well. Legacy 32-bit is not recommended.
-- __macOS 10.x__ mostly works but requires building hivemind from source, and some edge cases may fail.
-  To ensure full compatibility, we recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
-- __Windows 10+ (experimental)__ can run hivemind using [WSL](https://docs.microsoft.com/ru-ru/windows/wsl/install-win10).
-  You can configure WSL to use GPU following [this guide](https://docs.nvidia.com/cuda/wsl-user-guide/index.html) by NVIDIA.
-  After the CUDA toolkit is installed you can simply follow the instructions above to install with pip or from source.
+
+- __Linux__ is the default OS for which hivemind is developed and tested. We recommend Ubuntu 18.04+ (64-bit), but
+  other 64-bit distros should work as well. Legacy 32-bit is not recommended.
+- __macOS 10.x__ mostly works but requires building hivemind from source, and some edge cases may fail. To ensure full
+  compatibility, we recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
+- __Windows 10+ (experimental)__ can run hivemind
+  using [WSL](https://docs.microsoft.com/ru-ru/windows/wsl/install-win10). You can configure WSL to use GPU by
+  following sections 1–3 of [this guide](https://docs.nvidia.com/cuda/wsl-user-guide/index.html) by NVIDIA. After
+  that, you can simply follow the instructions above to install with pip or from source.
 
 
 ## Documentation
 ## Documentation
 
 
-* The [quickstart tutorial](https://learning-at-home.readthedocs.io/en/latest/user/quickstart.html) walks through installation
-  and a training a simple neural network with several peers.  
+* The [quickstart tutorial](https://learning-at-home.readthedocs.io/en/latest/user/quickstart.html) walks through
+  installation and a training a simple neural network with several peers.
 * [examples/albert](https://github.com/learning-at-home/hivemind/tree/master/examples/albert) contains the starter kit
 * [examples/albert](https://github.com/learning-at-home/hivemind/tree/master/examples/albert) contains the starter kit
   and instructions for training a Transformer masked language model collaboratively.
   and instructions for training a Transformer masked language model collaboratively.
-* API reference and additional tutorials are available at [learning-at-home.readthedocs.io](https://learning-at-home.readthedocs.io)
+* The [Mixture-of-Experts tutorial](https://learning-at-home.readthedocs.io/en/latest/user/moe.html)
+  covers the usage of Decentralized Mixture-of-Experts layers.
+* API reference and additional tutorials are available
+  at [learning-at-home.readthedocs.io](https://learning-at-home.readthedocs.io)
 
 
-If you have any questions about installing and using hivemind, you can ask them in 
-[our Discord chat](https://discord.gg/xC7ucM8j) or file an [issue](https://github.com/learning-at-home/hivemind/issues).
+If you have any questions about installing and using hivemind, you can ask them in
+[our Discord chat](https://discord.gg/uGugx9zYvN) or file an [issue](https://github.com/learning-at-home/hivemind/issues).
 
 
 ## Contributing
 ## Contributing
 
 
@@ -88,9 +93,8 @@ documentation improvements to entirely new features, is equally appreciated.
 
 
 If you want to contribute to hivemind but don't know where to start, take a look at the
 If you want to contribute to hivemind but don't know where to start, take a look at the
 unresolved [issues](https://github.com/learning-at-home/hivemind/issues). Open a new issue or
 unresolved [issues](https://github.com/learning-at-home/hivemind/issues). Open a new issue or
-join [our chat room](https://discord.gg/xC7ucM8j) in case you want to discuss new functionality or
-report a possible bug. Bug fixes are always welcome, but new features should be preferably discussed with maintainers
-beforehand.
+join [our chat room](https://discord.gg/xC7ucM8j) in case you want to discuss new functionality or report a possible
+bug. Bug fixes are always welcome, but new features should be preferably discussed with maintainers beforehand.
 
 
 If you want to start contributing to the source code of hivemind, please see
 If you want to start contributing to the source code of hivemind, please see
 the [contributing guidelines](https://github.com/learning-at-home/hivemind/blob/master/CONTRIBUTING.md) first. To learn
 the [contributing guidelines](https://github.com/learning-at-home/hivemind/blob/master/CONTRIBUTING.md) first. To learn
@@ -99,7 +103,7 @@ our [guide](https://learning-at-home.readthedocs.io/en/latest/user/contributing.
 
 
 ## Citation
 ## Citation
 
 
-If you found hivemind or its underlying algorithms useful for your research, please cite the relevant papers:
+If you found hivemind or its underlying algorithms useful for your research, please cite the following source:
 
 
 ```
 ```
 @misc{hivemind,
 @misc{hivemind,
@@ -111,7 +115,8 @@ If you found hivemind or its underlying algorithms useful for your research, ple
 ```
 ```
 
 
 Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired the creation of this library
 Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired the creation of this library
-(prototype implementation of hivemind available at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
+(prototype implementation of hivemind available
+at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
 
 
 ```
 ```
 @inproceedings{ryabinin2020crowdsourced,
 @inproceedings{ryabinin2020crowdsourced,
@@ -171,5 +176,5 @@ Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired t
 
 
 </details>
 </details>
 
 
-We also maintain a list of [related projects and
-acknowledgements](https://learning-at-home.readthedocs.io/en/latest/user/acknowledgements.html).
+We also maintain a list
+of [related projects and acknowledgements](https://learning-at-home.readthedocs.io/en/latest/user/acknowledgements.html).

+ 2 - 3
benchmarks/benchmark_averaging.py

@@ -57,11 +57,10 @@ def benchmark_averaging(
         dht = hivemind.DHT(initial_peers=initial_peers, start=True)
         dht = hivemind.DHT(initial_peers=initial_peers, start=True)
         initial_bits = bin(index % num_groups)[2:].rjust(nbits, "0")
         initial_bits = bin(index % num_groups)[2:].rjust(nbits, "0")
         averager = hivemind.averaging.DecentralizedAverager(
         averager = hivemind.averaging.DecentralizedAverager(
-            peer_tensors[i],
+            peer_tensors[index],
             dht,
             dht,
             prefix="my_tensor",
             prefix="my_tensor",
             initial_group_bits=initial_bits,
             initial_group_bits=initial_bits,
-            listen_on=f"{LOCALHOST}:*",
             compression_type=runtime_pb2.CompressionType.FLOAT16,
             compression_type=runtime_pb2.CompressionType.FLOAT16,
             target_group_size=target_group_size,
             target_group_size=target_group_size,
             averaging_expiration=averaging_expiration,
             averaging_expiration=averaging_expiration,
@@ -71,7 +70,7 @@ def benchmark_averaging(
         processes.update({dht, averager})
         processes.update({dht, averager})
 
 
         logger.info(
         logger.info(
-            f"Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}"
+            f"Averager {index}: started with peer id {averager.peer_id}, group_bits: {averager.get_group_bits()}"
         )
         )
         for step in range(num_rounds):
         for step in range(num_rounds):
             try:
             try:

+ 1 - 2
benchmarks/benchmark_tensor_compression.py

@@ -4,10 +4,9 @@ import time
 import torch
 import torch
 
 
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 

+ 2 - 2
benchmarks/benchmark_throughput.py

@@ -7,7 +7,7 @@ import time
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-from hivemind import find_open_port
+from hivemind import get_free_port
 from hivemind.moe.server import layers
 from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
@@ -66,7 +66,7 @@ def benchmark_throughput(
         or torch.device(device) == torch.device("cpu")
         or torch.device(device) == torch.device("cpu")
     )
     )
     assert expert_cls in layers.name_to_block
     assert expert_cls in layers.name_to_block
-    port = port or find_open_port()
+    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
     benchmarking_failed = mp.Event()

+ 1 - 2
docs/conf.py

@@ -17,9 +17,8 @@
 # sys.path.insert(0, os.path.abspath('.'))
 # sys.path.insert(0, os.path.abspath('.'))
 import sys
 import sys
 
 
-from recommonmark.transform import AutoStructify
 from recommonmark.parser import CommonMarkParser
 from recommonmark.parser import CommonMarkParser
-
+from recommonmark.transform import AutoStructify
 
 
 # -- Project information -----------------------------------------------------
 # -- Project information -----------------------------------------------------
 src_path = "../hivemind"
 src_path = "../hivemind"

+ 1 - 1
docs/user/dht.md

@@ -18,7 +18,7 @@ dht2 = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
 ```
 ```
 
 
 Note that `initial_peers` contains the address of the first DHT node.
 Note that `initial_peers` contains the address of the first DHT node.
-This implies that the resulting node will have shared key-value with the first node, __as well as any other
+This implies that the new node will share the key-value data with the first node, __as well as any other
 nodes connected to it.__ When the two nodes are connected, subsequent peers can use any one of them (or both)
 nodes connected to it.__ When the two nodes are connected, subsequent peers can use any one of them (or both)
 as `initial_peers` to connect to the shared "dictionary".
 as `initial_peers` to connect to the shared "dictionary".
 
 

+ 1 - 5
examples/albert/arguments.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
-from typing import Optional, List
+from typing import List, Optional
 
 
 from transformers import TrainingArguments
 from transformers import TrainingArguments
 
 
@@ -45,10 +45,6 @@ class AveragerArguments:
     averaging_timeout: float = field(
     averaging_timeout: float = field(
         default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
         default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
     )
     )
-    listen_on: str = field(
-        default="[::]:*",
-        metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"},
-    )
     min_refresh_period: float = field(
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
     )
     )

+ 3 - 3
examples/albert/run_trainer.py

@@ -11,8 +11,8 @@ import transformers
 from datasets import load_from_disk
 from datasets import load_from_disk
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 from torch_optimizer import Lamb
 from torch_optimizer import Lamb
-from transformers import set_seed, HfArgumentParser, TrainingArguments, DataCollatorForLanguageModeling
-from transformers.models.albert import AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining
+from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
+from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
 from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 from transformers.trainer_utils import is_main_process
@@ -21,7 +21,7 @@ import hivemind
 from hivemind.utils.compression import CompressionType
 from hivemind.utils.compression import CompressionType
 
 
 import utils
 import utils
-from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
+from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

+ 2 - 2
examples/albert/run_training_monitor.py

@@ -10,13 +10,13 @@ import requests
 import torch
 import torch
 import wandb
 import wandb
 from torch_optimizer import Lamb
 from torch_optimizer import Lamb
-from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
+from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 
 import hivemind
 import hivemind
 from hivemind.utils.compression import CompressionType
 from hivemind.utils.compression import CompressionType
 
 
 import utils
 import utils
-from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
+from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 

+ 0 - 1
examples/albert/utils.py

@@ -9,7 +9,6 @@ from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
 from hivemind.dht.validation import RecordValidatorBase
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 

+ 5 - 5
hivemind/__init__.py

@@ -2,21 +2,21 @@ from hivemind.averaging import DecentralizedAverager, TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe import (
 from hivemind.moe import (
     ExpertBackend,
     ExpertBackend,
-    Server,
-    register_expert_class,
     RemoteExpert,
     RemoteExpert,
     RemoteMixtureOfExperts,
     RemoteMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,
+    Server,
+    register_expert_class,
 )
 )
 from hivemind.optim import (
 from hivemind.optim import (
     CollaborativeAdaptiveOptimizer,
     CollaborativeAdaptiveOptimizer,
-    DecentralizedOptimizerBase,
     CollaborativeOptimizer,
     CollaborativeOptimizer,
+    DecentralizedAdam,
     DecentralizedOptimizer,
     DecentralizedOptimizer,
+    DecentralizedOptimizerBase,
     DecentralizedSGD,
     DecentralizedSGD,
-    DecentralizedAdam,
 )
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = "0.9.10"
+__version__ = "1.0.0.dev0"

+ 88 - 82
hivemind/averaging/allreduce.py

@@ -1,15 +1,15 @@
 import asyncio
 import asyncio
-from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
 from enum import Enum
 from enum import Enum
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 
 
-import grpc
 import torch
 import torch
 
 
-from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
-from hivemind.utils import Endpoint, get_logger, ChannelCache
-from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.proto import averaging_pb2_grpc, averaging_pb2
+from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
+from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
+from hivemind.proto import averaging_pb2
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 
 # flavour types
 # flavour types
 GroupID = bytes
 GroupID = bytes
@@ -22,19 +22,27 @@ class AveragingMode(Enum):
     AUX = 2
     AUX = 2
 
 
 
 
-class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
+class AllReduceRunner(ServicerBase):
     """
     """
-    An internal class that runs butterfly AllReduce in a predefined group of averagers
+    An internal class that runs butterfly AllReduce in a predefined group of averagers.
+
+    This class inherits hivemind.p2p.ServicerBase, so it can be used as an RPCServicer for testing purposes without
+    creating a full DecentralizedAverager.
 
 
     :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
     :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
+    :param p2p: a hivemind.p2p.P2P instance used for communication with other peers
+    :param servicer_type: a hivemind.p2p.ServicerBase subclass whose RPC signatures are used
+      when requesting other peers. Typically, it is DecentralizedAverager, its derivative,
+      or AllReduceRunner itself (for testing purposes).
+    :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param group_id: unique identifier of this specific all-reduce run
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
-    :param endpoint: your endpoint, must be included in ordered_group_endpoints
-    :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    :param peer_id: your peer_id, must be included in ordered_peer_ids
+    :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
-    :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
+    :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
     :param gathered: additional user-defined data collected from this group
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
@@ -43,73 +51,83 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     def __init__(
     def __init__(
         self,
         self,
         *,
         *,
+        p2p: P2P,
+        servicer_type: Type[ServicerBase],
+        prefix: Optional[str],
         group_id: GroupID,
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
-        endpoint: Endpoint,
-        ordered_group_endpoints: Sequence[Endpoint],
+        ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         peer_fractions: Tuple[float, ...],
         weights: Optional[Sequence[float]] = None,
         weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
-        gathered: Optional[Dict[Endpoint, Any]] = None,
+        gathered: Optional[Dict[PeerID, Any]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
-        assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+        self._p2p = p2p
+        self.peer_id = p2p.peer_id
+        assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
+
+        if not issubclass(servicer_type, ServicerBase):
+            raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
+        self._servicer_type = servicer_type
+        self._prefix = prefix
+
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
+        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
         for mode, frac, weight in zip(modes, peer_fractions, weights):
         for mode, frac, weight in zip(modes, peer_fractions, weights):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
             assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
 
-        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
 
         self._future = asyncio.Future()
         self._future = asyncio.Future()
 
 
-        self.sender_endpoints, self.sender_weights = [], []
-        for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
+        self.sender_peer_ids, self.sender_weights = [], []
+        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
             if mode != AveragingMode.AUX:
             if mode != AveragingMode.AUX:
-                self.sender_endpoints.append(endpoint)
+                self.sender_peer_ids.append(peer_id)
                 self.sender_weights.append(weight)
                 self.sender_weights.append(weight)
 
 
-        endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
+        peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
-        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
+        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.tensor_part_reducer = TensorPartReducer(
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             tuple(part.shape for part in self.parts_for_local_averaging),
-            len(self.sender_endpoints),
+            len(self.sender_peer_ids),
             self.sender_weights,
             self.sender_weights,
         )
         )
 
 
     def __repr__(self):
     def __repr__(self):
-        return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
+        return f"{self.__class__.__name__}({self.peer_id}, group_size={self.group_size})"
 
 
     def __aiter__(self):
     def __aiter__(self):
         return self.run()
         return self.run()
 
 
-    def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.ordered_group_endpoints
+    def __contains__(self, peer_id: PeerID):
+        return peer_id in self.ordered_peer_ids
 
 
     @property
     @property
     def group_size(self):
     def group_size(self):
-        return len(self.ordered_group_endpoints)
+        return len(self.ordered_peer_ids)
 
 
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+    def _get_peer_stub(self, peer: PeerID) -> StubBase:
+        return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
 
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
         pending_tasks = set()
         try:
         try:
-            if len(self.sender_endpoints) == 0:
+            if len(self.sender_peer_ids) == 0:
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
                 self.finalize()
                 self.finalize()
 
 
-            elif self.endpoint in self.sender_endpoints:
-                for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
+            elif self.peer_id in self.sender_peer_ids:
+                for peer_id, parts in zip(self.ordered_peer_ids, self.tensor_part_container.num_parts_by_peer):
                     if parts != 0:
                     if parts != 0:
-                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
+                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(peer_id)))
 
 
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
@@ -125,57 +143,45 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 task.cancel()
                 task.cancel()
             raise
             raise
 
 
-    async def _communicate_with_peer(self, peer_endpoint: Endpoint):
+    async def _communicate_with_peer(self, peer_id: PeerID):
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
-        peer_index = self.ordered_group_endpoints.index(peer_endpoint)
-        if peer_endpoint == self.endpoint:
-            sender_index = self.sender_endpoints.index(peer_endpoint)
+        peer_index = self.ordered_peer_ids.index(peer_id)
+        if peer_id == self.peer_id:
+            sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
                 averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
                 averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
 
         else:
         else:
             loop = asyncio.get_event_loop()
             loop = asyncio.get_event_loop()
-            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-            write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))
-
-            try:
-                code = None
-                async for part_index, msg in aenumerate(stream):
-                    if code is None:
-                        code = msg.code
-                    averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
-                    self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
-                await write_task
-
-                if code != averaging_pb2.AVERAGED_PART:
-                    raise AllreduceException(
-                        f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
-                        f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                        f", allreduce failed"
-                    )
-            finally:
-                if not write_task.done():
-                    write_task.cancel()
-
-    async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
+            code = None
+            stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
+            async for part_index, msg in aenumerate(stream):
+                if code is None:
+                    code = msg.code
+                averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
+                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
+
+            if code != averaging_pb2.AVERAGED_PART:
+                raise AllreduceException(
+                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
+                    f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
+                    f", allreduce failed"
+                )
+
+    async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
         first_part = await anext(parts_aiter)
         first_part = await anext(parts_aiter)
-        await stream.write(
-            averaging_pb2.AveragingData(
-                code=averaging_pb2.PART_FOR_AVERAGING,
-                group_id=self.group_id,
-                endpoint=self.endpoint,
-                tensor_part=first_part,
-            )
+        yield averaging_pb2.AveragingData(
+            code=averaging_pb2.PART_FOR_AVERAGING,
+            group_id=self.group_id,
+            tensor_part=first_part,
         )
         )
         async for part in parts_aiter:
         async for part in parts_aiter:
-            await stream.write(averaging_pb2.AveragingData(tensor_part=part))
-
-        await stream.done_writing()
+            yield averaging_pb2.AveragingData(tensor_part=part)
 
 
     async def rpc_aggregate_part(
     async def rpc_aggregate_part(
-        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         request: averaging_pb2.AveragingData = await anext(stream)
         request: averaging_pb2.AveragingData = await anext(stream)
@@ -186,7 +192,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
             try:
-                sender_index = self.sender_endpoints.index(request.endpoint)
+                sender_index = self.sender_peer_ids.index(context.remote_id)
                 async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
                 async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
                     yield msg
                     yield msg
 
 
@@ -195,8 +201,8 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
             error_code = averaging_pb2.MessageCode.Name(request.code)
-            logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
-            self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
+            self.finalize(exception=AllreduceException(f"peer {context.remote_id} sent {error_code}."))
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
 
     def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
     def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
@@ -223,10 +229,10 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             )
             )
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
 
-    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
-        await stream.done_writing()
+    async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
+        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
+        # In case of reporting the error, we expect the response stream to contain exactly one item
+        await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
 
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
@@ -239,9 +245,9 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             else:
             else:
                 code = averaging_pb2.INTERNAL_ERROR
                 code = averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
-                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
-                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_endpoint, code)))
+            for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
+                if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
+                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
 
 
         if not self._future.done():
         if not self._future.done():
             if cancel:
             if cancel:

+ 46 - 99
hivemind/averaging/averager.py

@@ -8,42 +8,36 @@ import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import os
 import os
 import threading
 import threading
-import uuid
 import weakref
 import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
 from dataclasses import asdict
-from ipaddress import ip_address
-from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 
 
-import grpc
 import numpy as np
 import numpy as np
 import torch
 import torch
-from grpc._cython.cygrpc import InternalError
 
 
-from hivemind.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
+from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor
-from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
-from hivemind.utils.networking import choose_ip_address, strip_port, Hostname
+from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.proto import averaging_pb2, runtime_pb2
+from hivemind.utils import MPFuture, TensorDescriptor, get_logger
+from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
-from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
+from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 
 
 # flavour types
 # flavour types
-StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 GatheredData = Any
 GatheredData = Any
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
+class DecentralizedAverager(mp.Process, ServicerBase):
     """
     """
-
     Parameter averaging service. A trainer can run this service in background to periodically average his parameters
     Parameter averaging service. A trainer can run this service in background to periodically average his parameters
     with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
     with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
     group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
     group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
@@ -67,14 +61,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           By default, the averager is assumed to have the average bandwidth of his group.
           If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
           If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
-    :param client_mode: if False (default), this averager will accept incoming requests from other peers
-            if True, the averager will only join existing groups where at least one peer has client_mode=False
-    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
-    :param announced_host: visible IP address the averager will announce for external connections from other peers.
-          If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
-    :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
-          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
-    :param kwargs: extra parameters forwarded to grpc.aio.server
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+          if True, the averager will only join existing groups where at least one peer has client_mode=False.
+          By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param auxiliary: if this flag is specified, averager.step will only assist others without sending
     :param auxiliary: if this flag is specified, averager.step will only assist others without sending
           local tensors for averaging
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
@@ -96,7 +85,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     _matchmaking: Matchmaking
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
     _pending_group_assembled: asyncio.Event
-    _server: grpc.aio.Server
     serializer = MSGPackSerializer
     serializer = MSGPackSerializer
 
 
     def __init__(
     def __init__(
@@ -119,13 +107,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         min_vector_size: int = 0,
         min_vector_size: int = 0,
         auxiliary: bool = False,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
         allow_state_sharing: Optional[bool] = None,
-        client_mode: bool = False,
-        listen_on: Endpoint = "0.0.0.0:*",
+        client_mode: Optional[bool] = None,
         daemon: bool = True,
         daemon: bool = True,
-        announced_host: Optional[str] = None,
-        channel_options: Sequence[Tuple[str, Any]] = (),
         shutdown_timeout: float = 5,
         shutdown_timeout: float = 5,
-        **kwargs,
     ):
     ):
         assert "." not in prefix, "group prefix must be a string without trailing '.'"
         assert "." not in prefix, "group prefix must be a string without trailing '.'"
         assert bandwidth is None or (
         assert bandwidth is None or (
@@ -138,7 +122,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
-        self.client_mode, self.listen_on, self.kwargs = client_mode, listen_on, kwargs
+        self.prefix = prefix
+
+        if client_mode is None:
+            client_mode = dht.client_mode
+        self.client_mode = client_mode
+
         self._parent_pid = os.getpid()
         self._parent_pid = os.getpid()
         if self.client_mode:
         if self.client_mode:
             self.mode = AveragingMode.CLIENT
             self.mode = AveragingMode.CLIENT
@@ -146,11 +135,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self.mode = AveragingMode.AUX
             self.mode = AveragingMode.AUX
         else:
         else:
             self.mode = AveragingMode.NODE
             self.mode = AveragingMode.NODE
-
-        if announced_host is None:
-            announced_host = self._choose_announced_host()
-        self.announced_host = announced_host
-        self.channel_options = channel_options
         self.daemon = daemon
         self.daemon = daemon
 
 
         self._averaged_tensors = tuple(averaged_tensors)
         self._averaged_tensors = tuple(averaged_tensors)
@@ -165,6 +149,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.bandwidth = bandwidth
         self.bandwidth = bandwidth
 
 
         self.matchmaking_kwargs = dict(
         self.matchmaking_kwargs = dict(
+            servicer_type=type(self),
             prefix=prefix,
             prefix=prefix,
             initial_group_bits=initial_group_bits,
             initial_group_bits=initial_group_bits,
             target_group_size=target_group_size,
             target_group_size=target_group_size,
@@ -179,17 +164,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
 
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
-        self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
 
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
         if allow_state_sharing is None:
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
         self.allow_state_sharing = allow_state_sharing
 
 
-        self._averager_endpoint: Optional[Endpoint] = None
-        if self.client_mode:
-            self._averager_endpoint = f"client::{uuid.uuid4()}"
-
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
         background_fetcher = threading.Thread(
@@ -201,22 +181,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
 
 
-    def _choose_announced_host(self) -> Hostname:
-        announced_host = strip_port(self.listen_on).strip("[]")  # Stripping square brackets for IPv6
-        if ip_address(announced_host) not in [ip_address("0.0.0.0"), ip_address("::")]:
-            return announced_host
-
-        maddrs = self.dht.get_visible_maddrs()
-        announced_host = choose_ip_address(maddrs)
-        logger.info(
-            f"Choosing IP {announced_host} as endpoint for DecentralizedAverager " f"from visible multiaddrs {maddrs}"
-        )
-        return announced_host
-
-    @property
-    def port(self) -> Optional[Port]:
-        return self._port.value if self._port.value != 0 else None
-
     @property
     @property
     def allow_state_sharing(self) -> bool:
     def allow_state_sharing(self) -> bool:
         """if set to True, other peers can download this peer's state"""
         """if set to True, other peers can download this peer's state"""
@@ -230,15 +194,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self._allow_state_sharing.value = value
             self._allow_state_sharing.value = value
 
 
     @property
     @property
-    def endpoint(self) -> Optional[Endpoint]:
-        if self._averager_endpoint is None and not self.client_mode:
-            assert self.port is not None, "Averager is not running yet"
-            self._averager_endpoint = f"{self.announced_host}:{self.port}"
-            logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
-        return self._averager_endpoint
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}({self.endpoint})"
+    def peer_id(self) -> PeerID:
+        return self.dht.peer_id
 
 
     def run(self):
     def run(self):
         """
         """
@@ -257,20 +214,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
 
             async def _run():
             async def _run():
-                grpc.aio.init_grpc_aio()
-
+                self._p2p = await self.dht.replicate_p2p()
                 if not self.client_mode:
                 if not self.client_mode:
-                    self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
-                    found_port = self._server.add_insecure_port(self.listen_on)
-                    assert found_port != 0, f"Failed to listen to {self.listen_on}"
-                    self._port.value = found_port
-                    await self._server.start()
+                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
                 else:
                 else:
                     logger.debug(f"The averager is running in client mode.")
                     logger.debug(f"The averager is running in client mode.")
 
 
                 self._matchmaking = Matchmaking(
                 self._matchmaking = Matchmaking(
-                    self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
+                    self._p2p,
+                    self.schema_hash,
+                    self.dht,
+                    client_mode=self.client_mode,
+                    **self.matchmaking_kwargs,
                 )
                 )
                 if not self.client_mode:
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
                     asyncio.create_task(self._declare_for_download_periodically())
@@ -313,8 +268,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         remaining_tasks = set()
         remaining_tasks = set()
         for group in self._running_groups.values():
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
             remaining_tasks.update(group.finalize(cancel=True))
-        if not self.client_mode:
-            remaining_tasks.add(self._server.stop(timeout))
         await asyncio.gather(*remaining_tasks)
         await asyncio.gather(*remaining_tasks)
 
 
     def __del__(self):
     def __del__(self):
@@ -328,7 +281,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         timeout: Optional[float] = None,
         timeout: Optional[float] = None,
         allow_retries: bool = True,
         allow_retries: bool = True,
         wait: bool = True,
         wait: bool = True,
-    ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
         """
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
 
@@ -394,11 +347,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     MatchmakingException,
                     MatchmakingException,
                     AssertionError,
                     AssertionError,
                     StopAsyncIteration,
                     StopAsyncIteration,
-                    InternalError,
                     asyncio.CancelledError,
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
                     asyncio.InvalidStateError,
-                    grpc.RpcError,
-                    grpc.aio.AioRpcError,
+                    P2PHandlerError,
                 ) as e:
                 ) as e:
                     time_elapsed = get_dht_time() - start_time
                     time_elapsed = get_dht_time() - start_time
                     if not allow_retries or (timeout is not None and timeout < time_elapsed):
                     if not allow_retries or (timeout is not None and timeout < time_elapsed):
@@ -424,7 +375,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
             weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
             modes = tuple(map(AveragingMode, mode_ids))
 
 
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
@@ -437,10 +388,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
             async with self.get_tensors_async() as local_tensors:
             async with self.get_tensors_async() as local_tensors:
                 allreduce = AllReduceRunner(
                 allreduce = AllReduceRunner(
+                    p2p=self._p2p,
+                    servicer_type=type(self),
+                    prefix=self.prefix,
                     group_id=group_info.group_id,
                     group_id=group_info.group_id,
                     tensors=local_tensors,
                     tensors=local_tensors,
-                    endpoint=self.endpoint,
-                    ordered_group_endpoints=group_info.endpoints,
+                    ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
                     weights=weights,
                     weights=weights,
                     gathered=user_gathered,
                     gathered=user_gathered,
@@ -453,7 +406,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     # actually run all-reduce
                     # actually run all-reduce
                     averaging_outputs = [output async for output in allreduce]
                     averaging_outputs = [output async for output in allreduce]
 
 
-                    if modes[group_info.endpoints.index(self.endpoint)] != AveragingMode.AUX:
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                         assert len(local_tensors) == len(self._averaged_tensors)
                         assert len(local_tensors) == len(self._averaged_tensors)
                         for tensor, update in zip(local_tensors, averaging_outputs):
                         for tensor, update in zip(local_tensors, averaging_outputs):
                             tensor.add_(update, alpha=self._averaging_alpha)
                             tensor.add_(update, alpha=self._averaging_alpha)
@@ -496,14 +449,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self.lock_averaged_tensors.release()
             self.lock_averaged_tensors.release()
 
 
     async def rpc_join_group(
     async def rpc_join_group(
-        self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         async for response in self._matchmaking.rpc_join_group(request, context):
         async for response in self._matchmaking.rpc_join_group(request, context):
             yield response
             yield response
 
 
     async def rpc_aggregate_part(
     async def rpc_aggregate_part(
-        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a groupmate sends us a part of his tensor; we should average it with other peers and return the result"""
         """a groupmate sends us a part of his tensor; we should average it with other peers and return the result"""
         request = await anext(stream)
         request = await anext(stream)
@@ -528,7 +481,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     asyncio.wait_for(
                     asyncio.wait_for(
                         self.dht.store(
                         self.dht.store(
                             download_key,
                             download_key,
-                            subkey=self.endpoint,
+                            subkey=self.peer_id.to_bytes(),
                             value=self.last_updated,
                             value=self.last_updated,
                             expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
                             expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
                             return_future=True,
                             return_future=True,
@@ -539,7 +492,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             await asyncio.sleep(self._matchmaking.averaging_expiration)
             await asyncio.sleep(self._matchmaking.averaging_expiration)
 
 
     async def rpc_download_state(
     async def rpc_download_state(
-        self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
+        self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
     ) -> AsyncIterator[averaging_pb2.DownloadData]:
     ) -> AsyncIterator[averaging_pb2.DownloadData]:
         """
         """
         Get the up-to-date trainer state from a peer.
         Get the up-to-date trainer state from a peer.
@@ -594,8 +547,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             key_manager = self._matchmaking.group_key_manager
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority = {
             peer_priority = {
-                peer: float(info.value)
-                for peer, info in peer_priority.items()
+                PeerID(peer_id): float(info.value)
+                for peer_id, info in peer_priority.items()
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
             }
             }
 
 
@@ -606,13 +559,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
             metadata = None
             metadata = None
             for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
             for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
-                if peer != self.endpoint:
+                if peer != self.peer_id:
                     logger.info(f"Downloading parameters from peer {peer}")
                     logger.info(f"Downloading parameters from peer {peer}")
-                    stream = None
                     try:
                     try:
-                        stub = ChannelCache.get_stub(
-                            peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True, options=self.channel_options
-                        )
+                        stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         current_tensor_parts, tensors = [], []
                         async for message in stream:
                         async for message in stream:
@@ -636,9 +586,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                         return
                         return
                     except BaseException as e:
                     except BaseException as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
-                    finally:
-                        if stream is not None:
-                            await stream.code()
 
 
         finally:
         finally:
             if not future.done():
             if not future.done():

+ 6 - 6
hivemind/averaging/group_info.py

@@ -1,7 +1,7 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Tuple
 from typing import Tuple
 
 
-from hivemind.utils import Endpoint
+from hivemind.p2p import PeerID
 
 
 
 
 @dataclass(frozen=True)
 @dataclass(frozen=True)
@@ -9,12 +9,12 @@ class GroupInfo:
     """A group of peers assembled through decentralized matchmaking"""
     """A group of peers assembled through decentralized matchmaking"""
 
 
     group_id: bytes  # random unique bytestring that describes the current group, generated by group leader
     group_id: bytes  # random unique bytestring that describes the current group, generated by group leader
-    endpoints: Tuple[Endpoint, ...]  # an ordered sequence of endpoints of each groupmate
-    gathered: Tuple[bytes, ...]  # binary metadata gathered from all peers by leader, same order as endpoints
+    peer_ids: Tuple[PeerID, ...]  # an ordered sequence of peer_ids of each groupmate
+    gathered: Tuple[bytes, ...]  # binary metadata gathered from all peers by leader, same order as peer_ids
 
 
     @property
     @property
     def group_size(self):
     def group_size(self):
-        return len(self.endpoints)
+        return len(self.peer_ids)
 
 
-    def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.endpoints
+    def __contains__(self, peer_id: PeerID):
+        return peer_id in self.peer_ids

+ 25 - 20
hivemind/averaging/key_manager.py

@@ -1,13 +1,14 @@
 import asyncio
 import asyncio
-import re
 import random
 import random
-from typing import Optional, List, Tuple
+import re
+from typing import List, Optional, Tuple
 
 
 import numpy as np
 import numpy as np
 
 
-from hivemind.dht import DHT
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
-from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
+from hivemind.dht import DHT
+from hivemind.p2p import PeerID
+from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
 
 
 GroupKey = str
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
@@ -29,7 +30,6 @@ class GroupKeyManager:
     def __init__(
     def __init__(
         self,
         self,
         dht: DHT,
         dht: DHT,
-        endpoint: Endpoint,
         prefix: str,
         prefix: str,
         initial_group_bits: Optional[str],
         initial_group_bits: Optional[str],
         target_group_size: int,
         target_group_size: int,
@@ -43,7 +43,8 @@ class GroupKeyManager:
             search_result = dht.get(f"{prefix}.0b", latest=True)
             search_result = dht.get(f"{prefix}.0b", latest=True)
             initial_group_nbits = self.get_suggested_nbits(search_result) or 0
             initial_group_nbits = self.get_suggested_nbits(search_result) or 0
             initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
             initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
-        self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
+        self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
+        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.excessive_size = excessive_size or target_group_size * 3
         self.excessive_size = excessive_size or target_group_size * 3
@@ -55,13 +56,13 @@ class GroupKeyManager:
         return f"{self.prefix}.0b{self.group_bits}"
         return f"{self.prefix}.0b{self.group_bits}"
 
 
     async def declare_averager(
     async def declare_averager(
-        self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float, looking_for_group: bool = True
+        self, group_key: GroupKey, peer_id: PeerID, expiration_time: float, looking_for_group: bool = True
     ) -> bool:
     ) -> bool:
         """
         """
         Add (or remove) the averager to a given allreduce bucket
         Add (or remove) the averager to a given allreduce bucket
 
 
         :param group_key: allreduce group key, e.g. my_averager.0b011011101
         :param group_key: allreduce group key, e.g. my_averager.0b011011101
-        :param endpoint: averager public endpoint for incoming requests
+        :param peer_id: averager public peer_id for incoming requests
         :param expiration_time: intent to run allreduce before this timestamp
         :param expiration_time: intent to run allreduce before this timestamp
         :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
         :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
           If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
           If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
@@ -72,20 +73,20 @@ class GroupKeyManager:
         expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
         expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
         return await self.dht.store(
         return await self.dht.store(
             key=group_key,
             key=group_key,
-            subkey=endpoint,
+            subkey=peer_id.to_bytes(),
             value=looking_for_group,
             value=looking_for_group,
             expiration_time=expiration_time,
             expiration_time=expiration_time,
             return_future=True,
             return_future=True,
         )
         )
 
 
-    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[Endpoint, DHTExpiration]]:
+    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[PeerID, DHTExpiration]]:
         """
         """
         Find and return averagers that were declared with a given all-reduce key
         Find and return averagers that were declared with a given all-reduce key
 
 
         :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
         :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
         :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
         :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
             if False, return all averagers under a given group_key regardless of value
             if False, return all averagers under a given group_key regardless of value
-        :return: endpoints and expirations of every matching averager
+        :return: peer_ids and expirations of every matching averager
         """
         """
         assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
         assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
         result = await self.dht.get(group_key, latest=True, return_future=True)
         result = await self.dht.get(group_key, latest=True, return_future=True)
@@ -93,11 +94,15 @@ class GroupKeyManager:
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
             return []
         averagers = [
         averagers = [
-            (key, entry.expiration_time)
-            for key, entry in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)
+            (PeerID(key), looking_for_group.expiration_time)
+            for key, looking_for_group in result.value.items()
+            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
         ]
         ]
-        num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
+        num_active_averagers = sum(
+            1
+            for key, looking_for_group in result.value.items()
+            if key != self.RESERVED_KEY_FOR_NBITS and looking_for_group.value
+        )
 
 
         suggested_nbits = self.get_suggested_nbits(result)
         suggested_nbits = self.get_suggested_nbits(result)
         if (
         if (
@@ -106,10 +111,10 @@ class GroupKeyManager:
             and suggested_nbits != self.suggested_nbits
             and suggested_nbits != self.suggested_nbits
         ):
         ):
             self.suggested_nbits = suggested_nbits
             self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
+            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
         elif num_active_averagers >= self.excessive_size:
         elif num_active_averagers >= self.excessive_size:
             self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
             self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
         return averagers
         return averagers
 
 
     async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
     async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
@@ -136,12 +141,12 @@ class GroupKeyManager:
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
         rng = random.Random(group_info.group_id)
-        index = group_info.endpoints.index(self.endpoint)
+        index = group_info.peer_ids.index(self.peer_id)
         generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
         generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
         nbits = int(np.ceil(np.log2(self.target_group_size)))
         nbits = int(np.ceil(np.log2(self.target_group_size)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
-        logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
+        logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
 
         if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
         if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
             asyncio.create_task(self.notify_stragglers())
             asyncio.create_task(self.notify_stragglers())
@@ -156,7 +161,7 @@ class GroupKeyManager:
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
         prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
         prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
         if self.group_bits != prev_nbits:
         if self.group_bits != prev_nbits:
-            logger.warning(f"{self.endpoint} - switching to {len(self.group_bits)}-bit keys")
+            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
         self.suggested_nbits = None
         self.suggested_nbits = None
 
 
     async def notify_stragglers(self):
     async def notify_stragglers(self):

+ 2 - 1
hivemind/averaging/load_balancing.py

@@ -1,4 +1,5 @@
-from typing import Sequence, Optional, Tuple
+from typing import Optional, Sequence, Tuple
+
 import numpy as np
 import numpy as np
 import scipy.optimize
 import scipy.optimize
 
 

+ 91 - 82
hivemind/averaging/matchmaking.py

@@ -2,27 +2,25 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+import asyncio
+import concurrent.futures
 import contextlib
 import contextlib
 import random
 import random
 from math import isfinite
 from math import isfinite
-from typing import Optional, AsyncIterator, Set, Tuple, Dict
-import concurrent.futures
-import asyncio
-
-import grpc
-import grpc._cython.cygrpc
+from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
 
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
-from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
+from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.dht import DHT, DHTID, DHTExpiration
-from hivemind.utils import get_logger, Endpoint, timed_storage, TimedStorage, get_dht_time
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc
-from hivemind.utils.grpc import ChannelCache
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.proto import averaging_pb2
+from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
+from hivemind.utils.asyncio import anext
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
+class Matchmaking:
     f"""
     f"""
     An internal class that is used to form groups of averages for running allreduce
     An internal class that is used to form groups of averages for running allreduce
     See DecentralizedAverager docstring for the detailed description of all parameters
     See DecentralizedAverager docstring for the detailed description of all parameters
@@ -37,10 +35,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        endpoint: Endpoint,
+        p2p: P2P,
         schema_hash: bytes,
         schema_hash: bytes,
         dht: DHT,
         dht: DHT,
         *,
         *,
+        servicer_type: Type[ServicerBase],
         prefix: str,
         prefix: str,
         target_group_size: int,
         target_group_size: int,
         min_group_size: int,
         min_group_size: int,
@@ -57,8 +56,16 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             )
             )
 
 
         super().__init__()
         super().__init__()
-        self.endpoint, self.schema_hash = endpoint, schema_hash
-        self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
+        self._p2p = p2p
+
+        if not issubclass(servicer_type, ServicerBase):
+            raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
+        self._servicer_type = servicer_type
+        self._prefix = prefix
+
+        self.peer_id = p2p.peer_id
+        self.schema_hash = schema_hash
+        self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.client_mode = client_mode
         self.client_mode = client_mode
@@ -69,9 +76,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         self.was_accepted_to_group = asyncio.Event()
         self.was_accepted_to_group = asyncio.Event()
         self.assembled_group = asyncio.Future()
         self.assembled_group = asyncio.Future()
 
 
-        self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
-        self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(endpoint, averaging_expiration, target_group_size)
+        self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
+        self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
+        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
         self.data_for_gather: Optional[bytes] = None
         self.data_for_gather: Optional[bytes] = None
 
 
     @property
     @property
@@ -87,7 +94,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 lfg_status += f" leading {len(self.current_followers)} followers,"
                 lfg_status += f" leading {len(self.current_followers)} followers,"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         return (
         return (
-            f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}"
+            f"{self.__class__.__name__}(peer_id={self.peer_id}, schema={schema_hash_repr}, {lfg_status}"
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
         )
 
 
@@ -160,7 +167,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                         self.assembled_group.set_exception(e)
                         self.assembled_group.set_exception(e)
                     raise e
                     raise e
 
 
-    async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
         """
         """
         :param leader: request this peer to be your leader for allreduce
         :param leader: request this peer to be your leader for allreduce
         :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
         :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
@@ -169,23 +176,24 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
           The originally specified leader can disband group and redirect us to a different leader
           The originally specified leader can disband group and redirect us to a different leader
         """
         """
         assert self.is_looking_for_group and self.current_leader is None
         assert self.is_looking_for_group and self.current_leader is None
-        call: Optional[grpc.aio.UnaryStreamCall] = None
+        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
-                leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
-                call = leader_stub.rpc_join_group(
+                leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
+
+                stream = leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                     averaging_pb2.JoinRequest(
-                        endpoint=self.endpoint,
                         schema_hash=self.schema_hash,
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,
                         expiration=expiration_time,
                         client_mode=self.client_mode,
                         client_mode=self.client_mode,
                         gather=self.data_for_gather,
                         gather=self.data_for_gather,
+                        group_key=self.group_key_manager.current_key,
                     )
                     )
-                )
-                message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
+                ).__aiter__()
+                message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
 
                 if message.code == averaging_pb2.ACCEPTED:
                 if message.code == averaging_pb2.ACCEPTED:
-                    logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
+                    logger.debug(f"{self.peer_id} - joining the group of {leader}; waiting for peers")
                     self.current_leader = leader
                     self.current_leader = leader
                     self.was_accepted_to_group.set()
                     self.was_accepted_to_group.set()
                     if len(self.current_followers) > 0:
                     if len(self.current_followers) > 0:
@@ -193,56 +201,55 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
             if message.code != averaging_pb2.ACCEPTED:
             if message.code != averaging_pb2.ACCEPTED:
                 code = averaging_pb2.MessageCode.Name(message.code)
                 code = averaging_pb2.MessageCode.Name(message.code)
-                logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
+                logger.debug(f"{self.peer_id} - requested {leader} to be my leader, but got rejected with {code}")
                 return None
                 return None
 
 
             async with self.potential_leaders.pause_search():
             async with self.potential_leaders.pause_search():
                 time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
                 time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
-                message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)
+                message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                     async with self.lock_request_join_group:
                     async with self.lock_request_join_group:
                         return await self.follower_assemble_group(leader, message)
                         return await self.follower_assemble_group(leader, message)
 
 
             if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
             if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
-                if message.suggested_leader and message.suggested_leader != self.endpoint:
-                    logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
-                    self.current_leader = None
-                    call.cancel()
-                    return await self.request_join_group(message.suggested_leader, expiration_time)
-                else:
-                    logger.debug(f"{self} - leader disbanded group")
-                    return None
+                if message.suggested_leader:
+                    suggested_leader = PeerID(message.suggested_leader)
+                    if suggested_leader != self.peer_id:
+                        logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
+                        self.current_leader = None
+                        await stream.aclose()
+                        return await self.request_join_group(suggested_leader, expiration_time)
+                logger.debug(f"{self} - leader disbanded group")
+                return None
 
 
             logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
             logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
             return None
             return None
         except asyncio.TimeoutError:
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
-            if call is not None:
-                call.cancel()
             return None
             return None
-        except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
+        except (P2PHandlerError, StopAsyncIteration) as e:
             logger.error(f"{self} - failed to request potential leader {leader}: {e}")
             logger.error(f"{self} - failed to request potential leader {leader}: {e}")
             return None
             return None
 
 
         finally:
         finally:
             self.was_accepted_to_group.clear()
             self.was_accepted_to_group.clear()
             self.current_leader = None
             self.current_leader = None
-            if call is not None:
-                await call.code()
+            if stream is not None:
+                await stream.aclose()
 
 
     async def rpc_join_group(
     async def rpc_join_group(
-        self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
-                reason_to_reject = self._check_reasons_to_reject(request)
+                reason_to_reject = self._check_reasons_to_reject(request, context)
                 if reason_to_reject is not None:
                 if reason_to_reject is not None:
                     yield reason_to_reject
                     yield reason_to_reject
                     return
                     return
 
 
-                self.current_followers[request.endpoint] = request
+                self.current_followers[context.remote_id] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
 
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
@@ -270,12 +277,12 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 self.was_accepted_to_group.is_set()
                 self.was_accepted_to_group.is_set()
                 or not self.assembled_group.done()
                 or not self.assembled_group.done()
                 or self.assembled_group.cancelled()
                 or self.assembled_group.cancelled()
-                or request.endpoint not in self.assembled_group.result()
+                or context.remote_id not in self.assembled_group.result()
             ):
             ):
                 if self.current_leader is not None:
                 if self.current_leader is not None:
                     # outcome 3: found by a leader with higher priority, send our followers to him
                     # outcome 3: found by a leader with higher priority, send our followers to him
                     yield averaging_pb2.MessageFromLeader(
                     yield averaging_pb2.MessageFromLeader(
-                        code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader
+                        code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader.to_bytes()
                     )
                     )
                     return
                     return
                 else:
                 else:
@@ -286,7 +293,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             yield averaging_pb2.MessageFromLeader(
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE,
                 code=averaging_pb2.BEGIN_ALLREDUCE,
                 group_id=group_info.group_id,
                 group_id=group_info.group_id,
-                ordered_group_endpoints=group_info.endpoints,
+                ordered_peer_ids=[item.to_bytes() for item in group_info.peer_ids],
                 gathered=group_info.gathered,
                 gathered=group_info.gathered,
             )
             )
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
@@ -296,11 +303,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
 
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
-            self.current_followers.pop(request.endpoint, None)
+            self.current_followers.pop(context.remote_id, None)
             self.follower_was_discarded.set()
             self.follower_was_discarded.set()
 
 
     def _check_reasons_to_reject(
     def _check_reasons_to_reject(
-        self, request: averaging_pb2.JoinRequest
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> Optional[averaging_pb2.MessageFromLeader]:
     ) -> Optional[averaging_pb2.MessageFromLeader]:
         """:returns: if accepted, return None, otherwise return a reason for rejection"""
         """:returns: if accepted, return None, otherwise return a reason for rejection"""
         if not self.is_looking_for_group or self.assembled_group.done():
         if not self.is_looking_for_group or self.assembled_group.done():
@@ -312,24 +319,25 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             or len(request.schema_hash) == 0
             or len(request.schema_hash) == 0
             or not isinstance(request.expiration, DHTExpiration)
             or not isinstance(request.expiration, DHTExpiration)
             or not isfinite(request.expiration)
             or not isfinite(request.expiration)
-            or not isinstance(request.endpoint, Endpoint)
-            or len(request.endpoint) == 0
             or self.client_mode
             or self.client_mode
+            or not isinstance(request.group_key, GroupKey)
         ):
         ):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
 
 
         elif request.schema_hash != self.schema_hash:
         elif request.schema_hash != self.schema_hash:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
+        elif request.group_key != self.group_key_manager.current_key:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_GROUP_KEY)
         elif self.potential_leaders.declared_group_key is None:
         elif self.potential_leaders.declared_group_key is None:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
         elif self.current_leader is not None:
         elif self.current_leader is not None:
             return averaging_pb2.MessageFromLeader(
             return averaging_pb2.MessageFromLeader(
-                code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
-            )  # note: this suggested leader is currently ignored
-        elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
-            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
+                code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader.to_bytes()
+            )
+        elif context.remote_id == self.peer_id or context.remote_id in self.current_followers:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_PEER_ID)
         elif len(self.current_followers) + 1 >= self.target_group_size:
         elif len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
         else:
         else:
@@ -339,34 +347,35 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         """Form up all current followers into a group and gather metadata"""
         """Form up all current followers into a group and gather metadata"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
         assert not self.assembled_group.done()
         assert not self.assembled_group.done()
-        group_id = DHTID.generate().to_bytes()  # note: both groupd_id and the order of endpoints must be random
-        ordered_group_endpoints = list(self.current_followers)
-        ordered_group_endpoints.append(self.endpoint)
-        random.shuffle(ordered_group_endpoints)
+        group_id = DHTID.generate().to_bytes()  # note: both groupd_id and the order of peer_ids must be random
+        ordered_peer_ids = list(self.current_followers)
+        ordered_peer_ids.append(self.peer_id)
+        random.shuffle(ordered_peer_ids)
 
 
         gathered = tuple(
         gathered = tuple(
-            self.data_for_gather if endpoint == self.endpoint else self.current_followers[endpoint].gather
-            for endpoint in ordered_group_endpoints
+            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            for peer_id in ordered_peer_ids
         )
         )
 
 
-        logger.debug(f"{self.endpoint} - assembled group of {len(ordered_group_endpoints)} peers.")
-        group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), gathered)
+        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers.")
+        group_info = GroupInfo(group_id, tuple(ordered_peer_ids), gathered)
         await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
         await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
         self.assembled_group.set_result(group_info)
         self.assembled_group.set_result(group_info)
         return group_info
         return group_info
 
 
-    async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
+    async def follower_assemble_group(self, leader: PeerID, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
         """Form a group from using peers and metadata provided by our leader"""
         """Form a group from using peers and metadata provided by our leader"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert not self.assembled_group.done()
         assert not self.assembled_group.done()
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
 
 
-        group_id, ordered_group_endpoints = msg.group_id, msg.ordered_group_endpoints
-        assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
-        assert len(ordered_group_endpoints) == len(msg.gathered)
+        group_id = msg.group_id
+        ordered_peer_ids = [PeerID(item) for item in msg.ordered_peer_ids]
+        assert self.peer_id in ordered_peer_ids, "Leader sent us group_peer_ids that does not contain us!"
+        assert len(ordered_peer_ids) == len(msg.gathered)
 
 
-        logger.debug(f"{self.endpoint} - follower assembled group with leader {leader}.")
-        group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), tuple(msg.gathered))
+        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}.")
+        group_info = GroupInfo(group_id, tuple(ordered_peer_ids), tuple(msg.gathered))
         await self.group_key_manager.update_key_on_group_assembled(group_info)
         await self.group_key_manager.update_key_on_group_assembled(group_info)
         self.assembled_group.set_result(group_info)
         self.assembled_group.set_result(group_info)
         return group_info
         return group_info
@@ -380,13 +389,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 class PotentialLeaders:
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
     """An utility class that searches for averagers that could become our leaders"""
 
 
-    def __init__(self, endpoint: Endpoint, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.endpoint, self.averaging_expiration = endpoint, averaging_expiration
+    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
-        self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
-        self.past_attempts: Set[Tuple[Endpoint, DHTExpiration]] = set()
+        self.leader_queue = TimedStorage[PeerID, DHTExpiration]()
+        self.past_attempts: Set[Tuple[PeerID, DHTExpiration]] = set()
         self.declared_expiration_time = float("inf")
         self.declared_expiration_time = float("inf")
         self.declared_group_key: Optional[GroupKey] = None
         self.declared_group_key: Optional[GroupKey] = None
         self.max_assured_time = float("-inf")
         self.max_assured_time = float("-inf")
@@ -433,7 +442,7 @@ class PotentialLeaders:
             else:
             else:
                 self.running.clear()
                 self.running.clear()
 
 
-    async def pop_next_leader(self) -> Endpoint:
+    async def pop_next_leader(self) -> PeerID:
         """Remove and return the next most suitable leader or throw an exception if reached timeout"""
         """Remove and return the next most suitable leader or throw an exception if reached timeout"""
         assert self.running.is_set(), "Not running search at the moment"
         assert self.running.is_set(), "Not running search at the moment"
         while True:
         while True:
@@ -442,9 +451,9 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
                 self.update_triggered.set()
 
 
-            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader.to_bytes()) > (
                 self.declared_expiration_time,
                 self.declared_expiration_time,
-                self.endpoint,
+                self.peer_id.to_bytes(),
             ):
             ):
                 await asyncio.wait(
                 await asyncio.wait(
                     {self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
                     {self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
@@ -479,7 +488,7 @@ class PotentialLeaders:
 
 
                 self.leader_queue.clear()
                 self.leader_queue.clear()
                 for peer, peer_expiration_time in new_peers:
                 for peer, peer_expiration_time in new_peers:
-                    if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
+                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
                         continue
                         continue
                     self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
                     self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
                     self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
                     self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
@@ -495,7 +504,7 @@ class PotentialLeaders:
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
             return  # note: this is a compatibility layer for python3.7
             return  # note: this is a compatibility layer for python3.7
         except Exception as e:
         except Exception as e:
-            logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             raise
             raise
 
 
     async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
     async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
@@ -508,21 +517,21 @@ class PotentialLeaders:
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()
                     self.declared_expiration.set()
-                    await key_manager.declare_averager(group_key, self.endpoint, expiration_time=new_expiration_time)
+                    await key_manager.declare_averager(group_key, self.peer_id, expiration_time=new_expiration_time)
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
                         await key_manager.update_key_on_not_enough_peers()
             except (concurrent.futures.CancelledError, asyncio.CancelledError):
             except (concurrent.futures.CancelledError, asyncio.CancelledError):
                 pass  # note: this is a compatibility layer for python3.7
                 pass  # note: this is a compatibility layer for python3.7
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
             finally:
                 if self.declared_group_key is not None:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     self.declared_group_key, self.declared_expiration_time = None, float("inf")
                     self.declared_group_key, self.declared_expiration_time = None, float("inf")
-                    self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float("-inf")
+                    self.leader_queue, self.max_assured_time = TimedStorage[PeerID, DHTExpiration](), float("-inf")
                     await key_manager.declare_averager(
                     await key_manager.declare_averager(
-                        prev_declared_key, self.endpoint, prev_expiration_time, looking_for_group=False
+                        prev_declared_key, self.peer_id, prev_expiration_time, looking_for_group=False
                     )
                     )
 
 
 
 

+ 6 - 7
hivemind/averaging/partition.py

@@ -2,19 +2,18 @@
 Auxiliary data structures for AllReduceRunner
 Auxiliary data structures for AllReduceRunner
 """
 """
 import asyncio
 import asyncio
-from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
 from collections import deque
 from collections import deque
+from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar, Union
 
 
-import torch
 import numpy as np
 import numpy as np
+import torch
 
 
 from hivemind.proto.runtime_pb2 import CompressionType, Tensor
 from hivemind.proto.runtime_pb2 import CompressionType, Tensor
-from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
 from hivemind.utils.asyncio import amap_in_executor
 from hivemind.utils.asyncio import amap_in_executor
-
+from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor
 
 
 T = TypeVar("T")
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 20
+DEFAULT_PART_SIZE_BYTES = 2 ** 19
 
 
 
 
 class TensorPartContainer:
 class TensorPartContainer:
@@ -32,8 +31,8 @@ class TensorPartContainer:
         self,
         self,
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
         peer_fractions: Sequence[float],
-        compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
-        part_size_bytes: int = 2 ** 20,
+        compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE,
+        part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         prefetch: int = 1,
         prefetch: int = 1,
     ):
     ):
         if not isinstance(compression_type, Sequence):
         if not isinstance(compression_type, Sequence):

+ 3 - 3
hivemind/averaging/training.py

@@ -2,13 +2,13 @@
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from contextlib import nullcontext
 from itertools import chain
 from itertools import chain
-from threading import Lock, Event
-from typing import Sequence, Dict, Iterator, Optional
+from threading import Event, Lock
+from typing import Dict, Iterator, Optional, Sequence
 
 
 import torch
 import torch
 
 
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging import DecentralizedAverager
-from hivemind.utils import nested_flatten, nested_pack, get_logger
+from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 55 - 9
hivemind/dht/__init__.py

@@ -23,9 +23,10 @@ from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, Type
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-from hivemind.dht.node import DHTID, DHTNode
-from hivemind.dht.routing import DHTKey, DHTValue, Subkey
+from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
+from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
+from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -42,13 +43,14 @@ class DHT(mp.Process):
     :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
-    :param max_workers: declare_experts and get_experts will use up to this many parallel workers
+    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
       (but no more than one per key)
       (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
     :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
     :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
       The validators will be combined using the CompositeValidator class. It merges them when possible
       The validators will be combined using the CompositeValidator class. It merges them when possible
       (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
       (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
+    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
     :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
     :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
     """
     """
 
 
@@ -60,9 +62,10 @@ class DHT(mp.Process):
         *,
         *,
         start: bool,
         start: bool,
         daemon: bool = True,
         daemon: bool = True,
-        max_workers: Optional[int] = None,
+        num_workers: int = DEFAULT_NUM_WORKERS,
         record_validators: Iterable[RecordValidatorBase] = (),
         record_validators: Iterable[RecordValidatorBase] = (),
         shutdown_timeout: float = 3,
         shutdown_timeout: float = 3,
+        await_ready: bool = True,
         **kwargs,
         **kwargs,
     ):
     ):
         self._parent_pid = os.getpid()
         self._parent_pid = os.getpid()
@@ -78,15 +81,21 @@ class DHT(mp.Process):
             raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
             raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
         self.initial_peers = initial_peers
         self.initial_peers = initial_peers
         self.kwargs = kwargs
         self.kwargs = kwargs
-        self.max_workers = max_workers
+        self.num_workers = num_workers
 
 
         self._record_validator = CompositeValidator(record_validators)
         self._record_validator = CompositeValidator(record_validators)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
         self.shutdown_timeout = shutdown_timeout
         self.shutdown_timeout = shutdown_timeout
         self.ready = mp.Event()
         self.ready = mp.Event()
         self.daemon = daemon
         self.daemon = daemon
+
+        # These values will be fetched from the child process when requested
+        self._peer_id = None
+        self._client_mode = None
+        self._p2p_replica = None
+
         if start:
         if start:
-            self.run_in_background(await_ready=True)
+            self.run_in_background(await_ready=await_ready)
 
 
     def run(self) -> None:
     def run(self) -> None:
         """Serve DHT forever. This function will not return until DHT node is shut down"""
         """Serve DHT forever. This function will not return until DHT node is shut down"""
@@ -97,7 +106,7 @@ class DHT(mp.Process):
             async def _run():
             async def _run():
                 self._node = await DHTNode.create(
                 self._node = await DHTNode.create(
                     initial_peers=self.initial_peers,
                     initial_peers=self.initial_peers,
-                    num_workers=self.max_workers or 1,
+                    num_workers=self.num_workers,
                     record_validator=self._record_validator,
                     record_validator=self._record_validator,
                     **self.kwargs,
                     **self.kwargs,
                 )
                 )
@@ -251,9 +260,30 @@ class DHT(mp.Process):
 
 
         self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
         self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
 
 
-    async def _add_validators(self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
+    @staticmethod
+    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
         node.protocol.record_validator.extend(record_validators)
         node.protocol.record_validator.extend(record_validators)
 
 
+    @property
+    def peer_id(self) -> PeerID:
+        if self._peer_id is None:
+            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+        return self._peer_id
+
+    @staticmethod
+    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
+        return node.peer_id
+
+    @property
+    def client_mode(self) -> bool:
+        if self._client_mode is None:
+            self._client_mode = self.run_coroutine(DHT._get_client_mode)
+        return self._client_mode
+
+    @staticmethod
+    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
+        return node.protocol.client_mode
+
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
         """
         Get multiaddrs of the current DHT node that should be accessible by other peers.
         Get multiaddrs of the current DHT node that should be accessible by other peers.
@@ -263,9 +293,25 @@ class DHT(mp.Process):
 
 
         return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
         return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
 
 
-    async def _get_visible_maddrs(self, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
+    @staticmethod
+    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
         return await node.get_visible_maddrs(latest=latest)
         return await node.get_visible_maddrs(latest=latest)
 
 
+    async def replicate_p2p(self) -> P2P:
+        """
+        Get a replica of a P2P instance used in the DHT process internally.
+        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
+        """
+
+        if self._p2p_replica is None:
+            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
+            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
+        return self._p2p_replica
+
+    @staticmethod
+    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
+        return node.p2p.daemon_listen_maddr
+
     def __del__(self):
     def __del__(self):
         if self._parent_pid == os.getpid() and self.is_alive():
         if self._parent_pid == os.getpid() and self.is_alive():
             self.shutdown()
             self.shutdown()

+ 0 - 1
hivemind/dht/crypto.py

@@ -6,7 +6,6 @@ from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 from hivemind.utils import MSGPackSerializer, get_logger
 from hivemind.utils import MSGPackSerializer, get_logger
 from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
 from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 

+ 10 - 6
hivemind/dht/node.py

@@ -2,8 +2,9 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import dataclasses
 import dataclasses
+import os
 import random
 import random
-from collections import defaultdict, Counter
+from collections import Counter, defaultdict
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from functools import partial
 from functools import partial
 from typing import (
 from typing import (
@@ -27,17 +28,20 @@ from sortedcontainers import SortedSet
 
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.routing import DHTID, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, BinaryDHTValue, DHTKey, DHTValue, Subkey, get_dht_time
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
 from hivemind.dht.traverse import traverse_dht
 from hivemind.p2p import P2P, PeerID
 from hivemind.p2p import P2P, PeerID
-from hivemind.utils import MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils import MSGPackSerializer, SerializerBase, get_logger
 from hivemind.utils.auth import AuthorizerBase
 from hivemind.utils.auth import AuthorizerBase
 from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithExpiration
 from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithExpiration
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+DEFAULT_NUM_WORKERS = int(os.getenv("HIVEMIND_DHT_NUM_WORKERS", 4))
+
+
 class DHTNode:
 class DHTNode:
     """
     """
     Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
     Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
@@ -110,7 +114,7 @@ class DHTNode:
         cache_refresh_before_expiry: float = 5,
         cache_refresh_before_expiry: float = 5,
         cache_on_store: bool = True,
         cache_on_store: bool = True,
         reuse_get_requests: bool = True,
         reuse_get_requests: bool = True,
-        num_workers: int = 1,
+        num_workers: int = DEFAULT_NUM_WORKERS,
         chunk_size: int = 16,
         chunk_size: int = 16,
         blacklist_time: float = 5.0,
         blacklist_time: float = 5.0,
         backoff_rate: float = 2.0,
         backoff_rate: float = 2.0,
@@ -154,7 +158,7 @@ class DHTNode:
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param validate: if True, use initial peers to validate that this node is accessible and synchronized
         :param validate: if True, use initial peers to validate that this node is accessible and synchronized
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
-        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citzen"
+        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citizen"
           if True, this node will refuse any incoming requests, effectively being only a client
           if True, this node will refuse any incoming requests, effectively being only a client
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response
@@ -207,7 +211,7 @@ class DHTNode:
             record_validator,
             record_validator,
             authorizer,
             authorizer,
         )
         )
-        self.peer_id = p2p.id
+        self.peer_id = p2p.peer_id
 
 
         if initial_peers:
         if initial_peers:
             initial_peers = {PeerID.from_base58(Multiaddr(item)["p2p"]) for item in initial_peers}
             initial_peers = {PeerID.from_base58(Multiaddr(item)["p2p"]) for item in initial_peers}

+ 8 - 8
hivemind/dht/protocol.py

@@ -2,20 +2,20 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
-from typing import Optional, List, Tuple, Dict, Sequence, Union, Collection
+from typing import Collection, Dict, List, Optional, Sequence, Tuple, Union
 
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
-from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, BinaryDHTValue, RoutingTable, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
 from hivemind.proto import dht_pb2
 from hivemind.proto import dht_pb2
-from hivemind.utils import get_logger, MSGPackSerializer
-from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
+from hivemind.utils import MSGPackSerializer, get_logger
+from hivemind.utils.auth import AuthorizerBase, AuthRole, AuthRPCWrapper
 from hivemind.utils.timed_storage import (
 from hivemind.utils.timed_storage import (
-    DHTExpiration,
-    get_dht_time,
     MAX_DHT_TIME_DISCREPANCY_SECONDS,
     MAX_DHT_TIME_DISCREPANCY_SECONDS,
+    DHTExpiration,
     ValueWithExpiration,
     ValueWithExpiration,
+    get_dht_time,
 )
 )
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -296,7 +296,7 @@ class DHTProtocol(ServicerBase):
                 nearest = dict(
                 nearest = dict(
                     zip(
                     zip(
                         map(DHTID.from_bytes, result.nearest_node_ids),
                         map(DHTID.from_bytes, result.nearest_node_ids),
-                        map(PeerID.from_base58, result.nearest_peer_ids),
+                        map(PeerID, result.nearest_peer_ids),
                     )
                     )
                 )
                 )
 
 
@@ -359,7 +359,7 @@ class DHTProtocol(ServicerBase):
                 key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)
                 key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)
             ):
             ):
                 item.nearest_node_ids.append(node_id.to_bytes())
                 item.nearest_node_ids.append(node_id.to_bytes())
-                item.nearest_peer_ids.append(peer_id.to_base58())
+                item.nearest_peer_ids.append(peer_id.to_bytes())
             response.results.append(item)
             response.results.append(item)
         return response
         return response
 
 

+ 2 - 1
hivemind/dht/routing.py

@@ -7,7 +7,8 @@ import os
 import random
 import random
 from collections.abc import Iterable
 from collections.abc import Iterable
 from itertools import chain
 from itertools import chain
-from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
+
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
 from hivemind.utils import MSGPackSerializer, get_dht_time
 from hivemind.utils import MSGPackSerializer, get_dht_time
 
 

+ 1 - 1
hivemind/dht/storage.py

@@ -4,7 +4,7 @@ from typing import Optional, Union
 
 
 from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey
 from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
-from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage, DHTExpiration
+from hivemind.utils.timed_storage import DHTExpiration, KeyType, TimedStorage, ValueType
 
 
 
 
 @MSGPackSerializer.ext_serializable(0x50)
 @MSGPackSerializer.ext_serializable(0x50)

+ 1 - 1
hivemind/dht/traverse.py

@@ -2,7 +2,7 @@
 import asyncio
 import asyncio
 import heapq
 import heapq
 from collections import Counter
 from collections import Counter
-from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
+from typing import Any, Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple
 
 
 from hivemind.dht.routing import DHTID
 from hivemind.dht.routing import DHTID
 
 

+ 2 - 2
hivemind/hivemind_cli/run_server.py

@@ -4,11 +4,11 @@ from pathlib import Path
 import configargparse
 import configargparse
 import torch
 import torch
 
 
-from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.moe.server import Server
 from hivemind.moe.server import Server
+from hivemind.moe.server.layers import schedule_name_to_scheduler
+from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from hivemind.moe.server.layers import schedule_name_to_scheduler
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,2 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, register_expert_class, get_experts, declare_experts
+from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class

+ 11 - 11
hivemind/moe/client/beam_search.py

@@ -2,22 +2,22 @@ import asyncio
 import heapq
 import heapq
 from collections import deque
 from collections import deque
 from functools import partial
 from functools import partial
-from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
+from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 
-from hivemind.dht import DHT, DHTNode, DHTExpiration
+from hivemind.dht import DHT, DHTExpiration, DHTNode
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
-    ExpertUID,
-    ExpertPrefix,
     FLAT_EXPERT,
     FLAT_EXPERT,
-    UidEndpoint,
-    Score,
-    Coordinate,
     PREFIX_PATTERN,
     PREFIX_PATTERN,
     UID_DELIMITER,
     UID_DELIMITER,
+    Coordinate,
+    ExpertPrefix,
+    ExpertUID,
+    Score,
+    UidEndpoint,
     is_valid_prefix,
     is_valid_prefix,
 )
 )
-from hivemind.utils import get_logger, get_dht_time, MPFuture
+from hivemind.utils import MPFuture, get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -125,7 +125,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
     ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
     ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
-        num_workers = num_workers or dht.max_workers or beam_size
+        num_workers = num_workers or dht.num_workers or beam_size
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         unattempted_indices: List[Coordinate] = sorted(
         unattempted_indices: List[Coordinate] = sorted(
             range(len(scores)), key=scores.__getitem__
             range(len(scores)), key=scores.__getitem__
@@ -206,7 +206,7 @@ class MoEBeamSearcher:
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
         grid_size = grid_size or float("inf")
         grid_size = grid_size or float("inf")
-        num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
+        num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
         successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
         for prefix, found in dht_responses.items():
         for prefix, found in dht_responses.items():
@@ -270,7 +270,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
     ) -> List[RemoteExpert]:
     ) -> List[RemoteExpert]:
-        num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
+        num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(

+ 3 - 3
hivemind/moe/client/expert.py

@@ -1,13 +1,13 @@
 import pickle
 import pickle
-from typing import Tuple, Optional, Any, Dict
+from typing import Any, Dict, Optional, Tuple
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import ChannelCache
 from hivemind.utils.grpc import ChannelCache
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert

+ 6 - 6
hivemind/moe/client/moe.py

@@ -1,8 +1,8 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import time
 import time
-from queue import Queue, Empty
-from typing import Tuple, List, Optional, Dict, Any
+from queue import Empty, Queue
+from typing import Any, Dict, List, Optional, Tuple
 
 
 import grpc
 import grpc
 import torch
 import torch
@@ -11,11 +11,11 @@ from torch.autograd.function import once_differentiable
 
 
 import hivemind
 import hivemind
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import RemoteExpert, DUMMY, _get_expert_stub
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten, nested_map
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import nested_flatten, nested_map, nested_pack
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 3 - 3
hivemind/moe/client/switch_moe.py

@@ -1,14 +1,14 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Tuple, List
+from typing import List, Tuple
 
 
 import grpc
 import grpc
 import torch
 import torch
 
 
-from hivemind.moe.client.expert import RemoteExpert, DUMMY
+from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten
+from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 12 - 7
hivemind/moe/server/__init__.py

@@ -5,24 +5,29 @@ import multiprocessing.synchronize
 import threading
 import threading
 from contextlib import contextmanager
 from contextlib import contextmanager
 from functools import partial
 from functools import partial
-from typing import Dict, List, Optional, Tuple
 from pathlib import Path
 from pathlib import Path
+from typing import Dict, List, Optional, Tuple
 
 
 import torch
 import torch
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind
 import hivemind
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
-from hivemind.moe.server.checkpoints import CheckpointSaver, load_experts, is_directory
+from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.layers import name_to_block, name_to_input, register_expert_class
-from hivemind.moe.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
+from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
+from hivemind.moe.server.layers import (
+    add_custom_models_from_file,
+    name_to_block,
+    name_to_input,
+    register_expert_class,
+    schedule_name_to_scheduler,
+)
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
-from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger, BatchTensorDescriptor
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -63,7 +68,7 @@ class Server(threading.Thread):
         super().__init__()
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         if get_port(listen_on) is None:
         if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=find_open_port())
+            listen_on = replace_port(listen_on, new_port=get_free_port())
         self.listen_on, self.port = listen_on, get_port(listen_on)
         self.listen_on, self.port = listen_on, get_port(listen_on)
 
 
         self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
         self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]

+ 3 - 3
hivemind/moe/server/connection_handler.py

@@ -6,11 +6,11 @@ from typing import Dict
 import grpc
 import grpc
 import torch
 import torch
 
 
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.utils import get_logger, Endpoint, nested_flatten
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import Endpoint, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 7 - 7
hivemind/moe/server/dht_handler.py

@@ -1,16 +1,16 @@
 import threading
 import threading
 from functools import partial
 from functools import partial
-from typing import Sequence, Dict, List, Tuple, Optional
+from typing import Dict, List, Optional, Sequence, Tuple
 
 
-from hivemind.dht import DHT, DHTNode, DHTExpiration, DHTValue
+from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
-    ExpertUID,
-    ExpertPrefix,
     FLAT_EXPERT,
     FLAT_EXPERT,
-    Coordinate,
     UID_DELIMITER,
     UID_DELIMITER,
     UID_PATTERN,
     UID_PATTERN,
+    Coordinate,
+    ExpertPrefix,
+    ExpertUID,
     is_valid_uid,
     is_valid_uid,
     split_uid,
     split_uid,
 )
 )
@@ -56,7 +56,7 @@ def declare_experts(
 async def _declare_experts(
 async def _declare_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
     dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
 ) -> Dict[ExpertUID, bool]:
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
     for uid in uids:
@@ -89,7 +89,7 @@ async def _get_experts(
 ) -> List[Optional[RemoteExpert]]:
 ) -> List[Optional[RemoteExpert]]:
     if expiration_time is None:
     if expiration_time is None:
         expiration_time = get_dht_time()
         expiration_time = get_dht_time()
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)

+ 3 - 3
hivemind/moe/server/expert_backend.py

@@ -1,12 +1,12 @@
-from typing import Dict, Sequence, Any, Tuple, Union, Callable
+from typing import Any, Callable, Dict, Sequence, Tuple, Union
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.moe.server.task_pool import TaskPool
-from hivemind.utils.tensor_descr import BatchTensorDescriptor, DUMMY_BATCH_SIZE
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from hivemind.utils.nested import nested_flatten, nested_pack, nested_compare, nested_map
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
+from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 1
hivemind/moe/server/expert_uid.py

@@ -1,6 +1,6 @@
 import random
 import random
 import re
 import re
-from typing import NamedTuple, Union, Tuple, Optional, List
+from typing import List, NamedTuple, Optional, Tuple, Union
 
 
 import hivemind
 import hivemind
 from hivemind.dht import DHT
 from hivemind.dht import DHT

+ 1 - 1
hivemind/moe/server/layers/custom_experts.py

@@ -1,5 +1,5 @@
-import os
 import importlib
 import importlib
+import os
 from typing import Callable, Type
 from typing import Callable, Type
 
 
 import torch
 import torch

+ 1 - 1
hivemind/moe/server/runtime.py

@@ -4,7 +4,7 @@ import threading
 from collections import defaultdict
 from collections import defaultdict
 from itertools import chain
 from itertools import chain
 from queue import SimpleQueue
 from queue import SimpleQueue
-from selectors import DefaultSelector, EVENT_READ
+from selectors import EVENT_READ, DefaultSelector
 from statistics import mean
 from statistics import mean
 from time import time
 from time import time
 from typing import Dict, NamedTuple, Optional
 from typing import Dict, NamedTuple, Optional

+ 3 - 3
hivemind/moe/server/task_pool.py

@@ -10,12 +10,12 @@ from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from collections import namedtuple
 from concurrent.futures import Future
 from concurrent.futures import Future
 from queue import Empty
 from queue import Empty
-from typing import List, Tuple, Dict, Any, Generator
+from typing import Any, Dict, Generator, List, Tuple
 
 
 import torch
 import torch
 
 
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
-from hivemind.utils.mpfuture import MPFuture, InvalidStateError
+from hivemind.utils.mpfuture import InvalidStateError, MPFuture
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
 Task = namedtuple("Task", ("future", "args"))
@@ -100,7 +100,7 @@ class TaskPool(TaskPoolBase):
 
 
     def submit_task(self, *args: torch.Tensor) -> Future:
     def submit_task(self, *args: torch.Tensor) -> Future:
         """Add task to this pool's queue, return Future for its output"""
         """Add task to this pool's queue, return Future for its output"""
-        task = Task(MPFuture(synchronize=False), args)
+        task = Task(MPFuture(), args)
         if self.get_task_size(task) > self.max_batch_size:
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)
             task.future.set_exception(exc)

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.simple import DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
+from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 
 import torch.optim
 import torch.optim
 
 
-from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind import TrainingAverager
 from hivemind import TrainingAverager
+from hivemind.optim.collaborative import CollaborativeOptimizer
 
 
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 4 - 4
hivemind/optim/collaborative.py

@@ -2,8 +2,8 @@ from __future__ import annotations
 
 
 import logging
 import logging
 from dataclasses import dataclass
 from dataclasses import dataclass
-from threading import Thread, Lock, Event
-from typing import Dict, Optional, Iterator
+from threading import Event, Lock, Thread
+from typing import Dict, Iterator, Optional
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
@@ -42,7 +42,7 @@ class CollaborationState:
 
 
 
 
 class TrainingState(BaseModel):
 class TrainingState(BaseModel):
-    endpoint: Endpoint
+    peer_id: bytes
     step: conint(ge=0, strict=True)
     step: conint(ge=0, strict=True)
     samples_accumulated: conint(ge=0, strict=True)
     samples_accumulated: conint(ge=0, strict=True)
     samples_per_second: confloat(ge=0.0, strict=True)
     samples_per_second: confloat(ge=0.0, strict=True)
@@ -354,7 +354,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             with self.lock_local_progress:
             with self.lock_local_progress:
                 current_time = get_dht_time()
                 current_time = get_dht_time()
                 local_state_info = TrainingState(
                 local_state_info = TrainingState(
-                    endpoint=self.averager.endpoint,
+                    peer_id=self.averager.peer_id.to_bytes(),
                     step=self.local_step,
                     step=self.local_step,
                     samples_accumulated=self.local_samples_accumulated,
                     samples_accumulated=self.local_samples_accumulated,
                     samples_per_second=self.performance_ema.samples_per_second,
                     samples_per_second=self.performance_ema.samples_per_second,

+ 3 - 3
hivemind/optim/simple.py

@@ -1,13 +1,13 @@
 import time
 import time
-from threading import Thread, Lock, Event
+from threading import Event, Lock, Thread
 from typing import Optional, Sequence, Tuple
 from typing import Optional, Sequence, Tuple
 
 
 import torch
 import torch
 
 
-from hivemind.dht import DHT
 from hivemind.averaging import TrainingAverager
 from hivemind.averaging import TrainingAverager
+from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.utils import get_logger, get_dht_time
+from hivemind.utils import get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1,3 +1,3 @@
-from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
+from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
 from hivemind.p2p.servicer import ServicerBase, StubBase
 from hivemind.p2p.servicer import ServicerBase, StubBase

+ 59 - 88
hivemind/p2p/p2p_daemon.py

@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
 from importlib.resources import path
 from importlib.resources import path
-from subprocess import Popen
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
 
 
 from google.protobuf.message import Message
 from google.protobuf.message import Message
@@ -16,7 +15,7 @@ import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.proto.p2pd_pb2 import RPCError
-from hivemind.utils.asyncio import aiter
+from hivemind.utils.asyncio import aiter, asingle
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -45,7 +44,7 @@ class P2P:
       - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
       - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
 
 
     To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
     To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
-    using the recipient's unique `P2P.id` and the name of the corresponding handler.
+    using the recipient's unique `P2P.peer_id` and the name of the corresponding handler.
     """
     """
 
 
     HEADER_LEN = 8
     HEADER_LEN = 8
@@ -66,11 +65,11 @@ class P2P:
     _UNIX_SOCKET_PREFIX = "/unix/tmp/hivemind-"
     _UNIX_SOCKET_PREFIX = "/unix/tmp/hivemind-"
 
 
     def __init__(self):
     def __init__(self):
-        self.id = None
+        self.peer_id = None
         self._child = None
         self._child = None
         self._alive = False
         self._alive = False
+        self._reader_task = None
         self._listen_task = None
         self._listen_task = None
-        self._server_stopped = asyncio.Event()
 
 
     @classmethod
     @classmethod
     async def create(
     async def create(
@@ -91,9 +90,7 @@ class P2P:
         use_relay_discovery: bool = False,
         use_relay_discovery: bool = False,
         use_auto_relay: bool = False,
         use_auto_relay: bool = False,
         relay_hop_limit: int = 0,
         relay_hop_limit: int = 0,
-        quiet: bool = True,
-        ping_n_attempts: int = 5,
-        ping_delay: float = 0.4,
+        startup_timeout: float = 15,
     ) -> "P2P":
     ) -> "P2P":
         """
         """
         Start a new p2pd process and connect to it.
         Start a new p2pd process and connect to it.
@@ -114,10 +111,7 @@ class P2P:
         :param use_relay_discovery: enables passive discovery for relay
         :param use_relay_discovery: enables passive discovery for relay
         :param use_auto_relay: enables autorelay
         :param use_auto_relay: enables autorelay
         :param relay_hop_limit: sets the hop limit for hop relays
         :param relay_hop_limit: sets the hop limit for hop relays
-        :param quiet: make the daemon process quiet
-        :param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
-        :param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
-          (in particular, wait for ``ping_delay`` seconds before the first attempt)
+        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :return: a wrapper for the p2p daemon
         :return: a wrapper for the p2p daemon
         """
         """
 
 
@@ -158,37 +152,26 @@ class P2P:
             autoRelay=use_auto_relay,
             autoRelay=use_auto_relay,
             relayHopLimit=relay_hop_limit,
             relayHopLimit=relay_hop_limit,
             b=need_bootstrap,
             b=need_bootstrap,
-            q=quiet,
             **process_kwargs,
             **process_kwargs,
         )
         )
 
 
-        self._child = Popen(args=proc_args, encoding="utf8")
+        self._child = await asyncio.subprocess.create_subprocess_exec(
+            *proc_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
+        )
         self._alive = True
         self._alive = True
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
 
 
-        await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)
+        ready = asyncio.Future()
+        self._reader_task = asyncio.create_task(self._read_outputs(ready))
+        try:
+            await asyncio.wait_for(ready, startup_timeout)
+        except asyncio.TimeoutError:
+            await self.shutdown()
+            raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
 
+        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        await self._ping_daemon()
         return self
         return self
 
 
-    async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
-        for try_number in range(ping_n_attempts):
-            await asyncio.sleep(ping_delay * (2 ** try_number))
-
-            if self._child.poll() is not None:  # Process died
-                break
-
-            try:
-                await self._ping_daemon()
-                break
-            except Exception as e:
-                if try_number == ping_n_attempts - 1:
-                    logger.exception("Failed to ping p2pd that has just started")
-                    await self.shutdown()
-                    raise
-
-        if self._child.returncode is not None:
-            raise RuntimeError(f"The p2p daemon has died with return code {self._child.returncode}")
-
     @classmethod
     @classmethod
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """
         """
@@ -213,8 +196,8 @@ class P2P:
         return self
         return self
 
 
     async def _ping_daemon(self) -> None:
     async def _ping_daemon(self) -> None:
-        self.id, self._visible_maddrs = await self._client.identify()
-        logger.debug(f"Launched p2pd with id = {self.id}, host multiaddrs = {self._visible_maddrs}")
+        self.peer_id, self._visible_maddrs = await self._client.identify()
+        logger.debug(f"Launched p2pd with peer id = {self.peer_id}, host multiaddrs = {self._visible_maddrs}")
 
 
     async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
     async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
         """
@@ -227,9 +210,9 @@ class P2P:
             _, self._visible_maddrs = await self._client.identify()
             _, self._visible_maddrs = await self._client.identify()
 
 
         if not self._visible_maddrs:
         if not self._visible_maddrs:
-            raise ValueError(f"No multiaddrs found for peer {self.id}")
+            raise ValueError(f"No multiaddrs found for peer {self.peer_id}")
 
 
-        p2p_maddr = Multiaddr(f"/p2p/{self.id.to_base58()}")
+        p2p_maddr = Multiaddr(f"/p2p/{self.peer_id.to_base58()}")
         return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
         return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
 
 
     async def list_peers(self) -> List[PeerInfo]:
     async def list_peers(self) -> List[PeerInfo]:
@@ -308,15 +291,12 @@ class P2P:
           they will not be received while the prefetch buffer is full.
           they will not be received while the prefetch buffer is full.
         """
         """
 
 
-        if self._listen_task is None:
-            self._start_listening()
-
         async def _handle_stream(
         async def _handle_stream(
             stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
             stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
         ) -> None:
         ) -> None:
             context = P2PContext(
             context = P2PContext(
                 handle_name=name,
                 handle_name=name,
-                local_id=self.id,
+                local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
                 remote_id=stream_info.peer_id,
             )
             )
             requests = asyncio.Queue(max_prefetch)
             requests = asyncio.Queue(max_prefetch)
@@ -334,7 +314,9 @@ class P2P:
                         await P2P.send_protobuf(response, writer)
                         await P2P.send_protobuf(response, writer)
                 except Exception as e:
                 except Exception as e:
                     logger.warning("Exception while processing stream and sending responses:", exc_info=True)
                     logger.warning("Exception while processing stream and sending responses:", exc_info=True)
-                    await P2P.send_protobuf(RPCError(message=str(e)), writer)
+                    # Sometimes `e` is a connection error, so we won't be able to report the error to the caller
+                    with suppress(Exception):
+                        await P2P.send_protobuf(RPCError(message=str(e)), writer)
 
 
             with closing(writer):
             with closing(writer):
                 processing_task = asyncio.create_task(_process_stream())
                 processing_task = asyncio.create_task(_process_stream())
@@ -358,12 +340,12 @@ class P2P:
                 finally:
                 finally:
                     processing_task.cancel()
                     processing_task.cancel()
 
 
-        await self._client.stream_handler(name, _handle_stream)
+        await self.add_binary_stream_handler(name, _handle_stream)
 
 
     async def _iterate_protobuf_stream_handler(
     async def _iterate_protobuf_stream_handler(
         self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Message
         self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Message
     ) -> TOutputStream:
     ) -> TOutputStream:
-        _, reader, writer = await self._client.stream_open(peer_id, (name,))
+        _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
 
         async def _write_to_stream() -> None:
         async def _write_to_stream() -> None:
             async for request in requests:
             async for request in requests:
@@ -409,15 +391,7 @@ class P2P:
             return
             return
 
 
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
-            if stream_input:
-                input = requests
-            else:
-                count = 0
-                async for input in requests:
-                    count += 1
-                if count != 1:
-                    raise ValueError(f"Got {count} requests for handler {name} instead of one")
-
+            input = requests if stream_input else await asingle(requests)
             output = handler(input, context)
             output = handler(input, context)
 
 
             if isinstance(output, AsyncIterableABC):
             if isinstance(output, AsyncIterableABC):
@@ -448,7 +422,7 @@ class P2P:
             input_serialized = input_protobuf_type.FromString(request)
             input_serialized = input_protobuf_type.FromString(request)
             context = P2PContext(
             context = P2PContext(
                 handle_name=handle_name,
                 handle_name=handle_name,
-                local_id=self.id,
+                local_id=self.peer_id,
                 remote_id=remote_id,
                 remote_id=remote_id,
             )
             )
 
 
@@ -468,14 +442,9 @@ class P2P:
         if not isinstance(input, AsyncIterableABC):
         if not isinstance(input, AsyncIterableABC):
             return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
             return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
 
 
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
-
-        count = 0
-        async for response in responses:
-            count += 1
-        if count != 1:
-            raise ValueError(f"Got {count} responses from handler {name} instead of one")
-        return response
+        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+        return await asingle(responses)
 
 
     async def _call_unary_protobuf_handler(
     async def _call_unary_protobuf_handler(
         self,
         self,
@@ -501,20 +470,10 @@ class P2P:
     def _start_listening(self) -> None:
     def _start_listening(self) -> None:
         async def listen() -> None:
         async def listen() -> None:
             async with self._client.listen():
             async with self._client.listen():
-                await self._server_stopped.wait()
+                await asyncio.Future()  # Wait until this task will be cancelled in _terminate()
 
 
         self._listen_task = asyncio.create_task(listen())
         self._listen_task = asyncio.create_task(listen())
 
 
-    async def _stop_listening(self) -> None:
-        if self._listen_task is not None:
-            self._server_stopped.set()
-            self._listen_task.cancel()
-            try:
-                await self._listen_task
-            except asyncio.CancelledError:
-                self._listen_task = None
-                self._server_stopped.clear()
-
     async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
     async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
         if self._listen_task is None:
         if self._listen_task is None:
             self._start_listening()
             self._start_listening()
@@ -533,22 +492,20 @@ class P2P:
         return self._alive
         return self._alive
 
 
     async def shutdown(self) -> None:
     async def shutdown(self) -> None:
-        await self._stop_listening()
-        await asyncio.get_event_loop().run_in_executor(None, self._terminate)
+        self._terminate()
+        if self._child is not None:
+            await self._child.wait()
 
 
     def _terminate(self) -> None:
     def _terminate(self) -> None:
-        self._alive = False
-
-        if self._client.control._write_task is not None:
-            self._client.control._write_task.cancel()
-
-        if self._client.control._read_task is not None:
-            self._client.control._read_task.cancel()
+        if self._listen_task is not None:
+            self._listen_task.cancel()
+        if self._reader_task is not None:
+            self._reader_task.cancel()
 
 
-        if self._child is not None and self._child.poll() is None:
+        self._alive = False
+        if self._child is not None and self._child.returncode is None:
             self._child.terminate()
             self._child.terminate()
-            self._child.wait()
-            logger.debug(f"Terminated p2pd with id = {self.id}")
+            logger.debug(f"Terminated p2pd with id = {self.peer_id}")
 
 
             with suppress(FileNotFoundError):
             with suppress(FileNotFoundError):
                 os.remove(self._daemon_listen_maddr["unix"])
                 os.remove(self._daemon_listen_maddr["unix"])
@@ -575,6 +532,20 @@ class P2P:
     def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
     def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
         return ",".join(str(addr) for addr in maddrs)
         return ",".join(str(addr) for addr in maddrs)
 
 
+    async def _read_outputs(self, ready: asyncio.Future) -> None:
+        last_line = None
+        while True:
+            line = await self._child.stdout.readline()
+            if not line:  # Stream closed
+                break
+            last_line = line.rstrip().decode(errors="ignore")
+
+            if last_line.startswith("Peer ID:"):
+                ready.set_result(None)
+
+        if not ready.done():
+            ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
+
 
 
-class P2PInterruptedError(Exception):
+class P2PDaemonError(RuntimeError):
     pass
     pass

+ 54 - 27
hivemind/p2p/servicer.py

@@ -1,6 +1,7 @@
 import asyncio
 import asyncio
+import inspect
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Any, AsyncIterator, Optional, Tuple, get_type_hints
+from typing import Any, AsyncIterator, List, Optional, Tuple, Type, get_type_hints
 
 
 from hivemind.p2p.p2p_daemon import P2P
 from hivemind.p2p.p2p_daemon import P2P
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
@@ -9,7 +10,6 @@ from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 @dataclass
 @dataclass
 class RPCHandler:
 class RPCHandler:
     method_name: str
     method_name: str
-    handle_name: str
     request_type: type
     request_type: type
     response_type: type
     response_type: type
     stream_input: bool
     stream_input: bool
@@ -24,9 +24,10 @@ class StubBase:
     adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
     adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
     """
     """
 
 
-    def __init__(self, p2p: P2P, peer: PeerID):
+    def __init__(self, p2p: P2P, peer: PeerID, namespace: Optional[str]):
         self._p2p = p2p
         self._p2p = p2p
         self._peer = peer
         self._peer = peer
+        self._namespace = namespace
 
 
 
 
 class ServicerBase:
 class ServicerBase:
@@ -41,39 +42,49 @@ class ServicerBase:
       to calls to the remote peer.
       to calls to the remote peer.
     """
     """
 
 
-    def __init__(self):
-        class_name = self.__class__.__name__
+    _rpc_handlers: Optional[List[RPCHandler]] = None
+    _stub_type: Optional[Type[StubBase]] = None
 
 
-        self._rpc_handlers = []
-        for method_name, method in self.__class__.__dict__.items():
-            if method_name.startswith("rpc_") and callable(method):
-                handle_name = f"{class_name}.{method_name}"
+    @classmethod
+    def _collect_rpc_handlers(cls) -> None:
+        if cls._rpc_handlers is not None:
+            return
 
 
+        cls._rpc_handlers = []
+        for method_name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
+            if method_name.startswith("rpc_"):
+                spec = inspect.getfullargspec(method)
+                if len(spec.args) < 3:
+                    raise ValueError(
+                        f"{method_name} is expected to at least three positional arguments "
+                        f"(self, request: TInputProtobuf | AsyncIterator[TInputProtobuf], context: P2PContext)"
+                    )
+                request_arg = spec.args[1]
                 hints = get_type_hints(method)
                 hints = get_type_hints(method)
                 try:
                 try:
-                    request_type = hints["request"]
+                    request_type = hints[request_arg]
                     response_type = hints["return"]
                     response_type = hints["return"]
                 except KeyError:
                 except KeyError:
                     raise ValueError(
                     raise ValueError(
-                        f"{handle_name} is expected to have type annotations "
+                        f"{method_name} is expected to have type annotations "
                         f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
                         f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
-                        f"for the `request` parameter and the return value"
+                        f"for the `{request_arg}` parameter and the return value"
                     )
                     )
-                request_type, stream_input = self._strip_iterator_hint(request_type)
-                response_type, stream_output = self._strip_iterator_hint(response_type)
+                request_type, stream_input = cls._strip_iterator_hint(request_type)
+                response_type, stream_output = cls._strip_iterator_hint(response_type)
 
 
-                self._rpc_handlers.append(
-                    RPCHandler(method_name, handle_name, request_type, response_type, stream_input, stream_output)
+                cls._rpc_handlers.append(
+                    RPCHandler(method_name, request_type, response_type, stream_input, stream_output)
                 )
                 )
 
 
-        self._stub_type = type(
-            f"{class_name}Stub",
+        cls._stub_type = type(
+            f"{cls.__name__}Stub",
             (StubBase,),
             (StubBase,),
-            {handler.method_name: self._make_rpc_caller(handler) for handler in self._rpc_handlers},
+            {handler.method_name: cls._make_rpc_caller(handler) for handler in cls._rpc_handlers},
         )
         )
 
 
-    @staticmethod
-    def _make_rpc_caller(handler: RPCHandler):
+    @classmethod
+    def _make_rpc_caller(cls, handler: RPCHandler):
         input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
         input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
 
 
         # This method will be added to a new Stub type (a subclass of StubBase)
         # This method will be added to a new Stub type (a subclass of StubBase)
@@ -87,7 +98,7 @@ class ServicerBase:
 
 
                 return self._p2p.iterate_protobuf_handler(
                 return self._p2p.iterate_protobuf_handler(
                     self._peer,
                     self._peer,
-                    handler.handle_name,
+                    cls._get_handle_name(self._namespace, handler.method_name),
                     input,
                     input,
                     handler.response_type,
                     handler.response_type,
                 )
                 )
@@ -98,26 +109,42 @@ class ServicerBase:
                 self: StubBase, input: input_type, timeout: Optional[float] = None
                 self: StubBase, input: input_type, timeout: Optional[float] = None
             ) -> handler.response_type:
             ) -> handler.response_type:
                 return await asyncio.wait_for(
                 return await asyncio.wait_for(
-                    self._p2p.call_protobuf_handler(self._peer, handler.handle_name, input, handler.response_type),
+                    self._p2p.call_protobuf_handler(
+                        self._peer,
+                        cls._get_handle_name(self._namespace, handler.method_name),
+                        input,
+                        handler.response_type,
+                    ),
                     timeout=timeout,
                     timeout=timeout,
                 )
                 )
 
 
         caller.__name__ = handler.method_name
         caller.__name__ = handler.method_name
         return caller
         return caller
 
 
-    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
+    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None) -> None:
+        self._collect_rpc_handlers()
+
         servicer = self if wrapper is None else wrapper
         servicer = self if wrapper is None else wrapper
         for handler in self._rpc_handlers:
         for handler in self._rpc_handlers:
             await p2p.add_protobuf_handler(
             await p2p.add_protobuf_handler(
-                handler.handle_name,
+                self._get_handle_name(namespace, handler.method_name),
                 getattr(servicer, handler.method_name),
                 getattr(servicer, handler.method_name),
                 handler.request_type,
                 handler.request_type,
                 stream_input=handler.stream_input,
                 stream_input=handler.stream_input,
                 stream_output=handler.stream_output,
                 stream_output=handler.stream_output,
             )
             )
 
 
-    def get_stub(self, p2p: P2P, peer: PeerID) -> StubBase:
-        return self._stub_type(p2p, peer)
+    @classmethod
+    def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
+        cls._collect_rpc_handlers()
+        return cls._stub_type(p2p, peer, namespace)
+
+    @classmethod
+    def _get_handle_name(cls, namespace: Optional[str], method_name: str) -> str:
+        handle_name = f"{cls.__name__}.{method_name}"
+        if namespace is not None:
+            handle_name = f"{namespace}::{handle_name}"
+        return handle_name
 
 
     @staticmethod
     @staticmethod
     def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:
     def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:

+ 8 - 14
hivemind/proto/averaging.proto

@@ -2,13 +2,6 @@ syntax = "proto3";
 import "runtime.proto";
 import "runtime.proto";
 
 
 
 
-// Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
-service DecentralizedAveraging {
-  rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
-  rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData);  // send local part => get average part
-  rpc rpc_download_state(DownloadRequest) returns (stream DownloadData);
-}
-
 enum MessageCode {
 enum MessageCode {
   NO_CODE = 0;               // Default value that should not be used explicitly
   NO_CODE = 0;               // Default value that should not be used explicitly
   REQUEST_JOIN = 1;          // "Dear maybe leader, will you have me in your group as a follower?"
   REQUEST_JOIN = 1;          // "Dear maybe leader, will you have me in your group as a follower?"
@@ -21,35 +14,36 @@ enum MessageCode {
   BAD_EXPIRATION_TIME = 8;   // "I will not accept you. I cannot guarantee that we begin before you expire."
   BAD_EXPIRATION_TIME = 8;   // "I will not accept you. I cannot guarantee that we begin before you expire."
   BAD_SCHEMA_HASH = 9;       // "I will not accept you. I am not averaging the samy type of tensors as you."
   BAD_SCHEMA_HASH = 9;       // "I will not accept you. I am not averaging the samy type of tensors as you."
   BAD_GROUP_ID = 10;         // "I will not accept your request, your group id does not match with any groups i'm in."
   BAD_GROUP_ID = 10;         // "I will not accept your request, your group id does not match with any groups i'm in."
-  DUPLICATE_ENDPOINT = 11;   // "I will not accept you, i already have exactly the same endpoint in my current group."
+  DUPLICATE_PEER_ID = 11;    // "I will not accept you, i already have exactly the same peer id in my current group."
   GROUP_IS_FULL = 12;        // "I will not accept you, my group already contains too many peers."
   GROUP_IS_FULL = 12;        // "I will not accept you, my group already contains too many peers."
   NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost."
   NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost."
   PROTOCOL_VIOLATION = 14;   // "You did something so unspeakable that i don't have a special code for that."
   PROTOCOL_VIOLATION = 14;   // "You did something so unspeakable that i don't have a special code for that."
   INTERNAL_ERROR = 15;       // "I messed up, we will have to stop allreduce because of that."
   INTERNAL_ERROR = 15;       // "I messed up, we will have to stop allreduce because of that."
   CANCELLED = 16;            // "[from peer during allreduce] I no longer want to participate in AllReduce."
   CANCELLED = 16;            // "[from peer during allreduce] I no longer want to participate in AllReduce."
   GROUP_DISBANDED = 17;      // "[from leader] The group is closed. Go find another group."
   GROUP_DISBANDED = 17;      // "[from leader] The group is closed. Go find another group."
+  BAD_GROUP_KEY = 18;        // "I will not accept you. My current group key differs (maybe you used my older key)."
 }
 }
 
 
 message JoinRequest {
 message JoinRequest {
-  string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
   bool client_mode = 5;         // if True, the incoming averager is a client with no capacity for averaging
   bool client_mode = 5;         // if True, the incoming averager is a client with no capacity for averaging
+  string group_key = 6;         // group key identifying an All-Reduce bucket, e.g my_averager.0b011011101
 }
 }
 
 
 message MessageFromLeader {
 message MessageFromLeader {
   MessageCode code = 1;
   MessageCode code = 1;
-  bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
-  string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
-  repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
-  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endpoints
+  bytes group_id = 2;           // a unique identifier of this group, only valid until allreduce is finished/failed
+  bytes suggested_leader = 3;   // if peer is already in a group, it'll provide us with a peer id of its leader
+  repeated bytes ordered_peer_ids = 4;  // a sequence of peers, each responsible for one shard during averaging
+  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their peer ids
 }
 }
 
 
 message AveragingData {
 message AveragingData {
   MessageCode code = 1;     // in case of a protocol violation, this will be the error message
   MessageCode code = 1;     // in case of a protocol violation, this will be the error message
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
-  string endpoint = 3;      // sender's rpc endpoint, used for coordination
+  bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
   bytes metadata = 5;       // reserved user-extendable metadata
   bytes metadata = 5;       // reserved user-extendable metadata
 }
 }

+ 2 - 2
hivemind/proto/dht.proto

@@ -65,8 +65,8 @@ message FindResult {
   double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
   double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
 
 
   // two aligned arrays: DHTIDs and PeerIDs for nearest peers (sorted by XOR distance)
   // two aligned arrays: DHTIDs and PeerIDs for nearest peers (sorted by XOR distance)
-  repeated bytes nearest_node_ids = 4;      // DHTIDs of the nearest peers serialized with node_id.to_bytes()
-  repeated string nearest_peer_ids = 5;     // Base58-serialized libp2p PeerIDs of the nearest peers
+  repeated bytes nearest_node_ids = 4;  // DHTIDs of the nearest peers serialized with node_id.to_bytes()
+  repeated bytes nearest_peer_ids = 5;  // libp2p PeerIDs of the nearest peers
 }
 }
 
 
 message FindResponse {
 message FindResponse {

+ 3 - 3
hivemind/utils/__init__.py

@@ -1,11 +1,11 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.asyncio import *
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.networking import *
-from hivemind.utils.serializer import SerializerBase, MSGPackSerializer
-from hivemind.utils.tensor_descr import TensorDescriptor, BatchTensorDescriptor
+from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *

+ 15 - 4
hivemind/utils/asyncio.py

@@ -1,12 +1,11 @@
-from concurrent.futures import ThreadPoolExecutor
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable, Tuple, Optional, Callable
 import asyncio
 import asyncio
+from concurrent.futures import ThreadPoolExecutor
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 
 
 import uvloop
 import uvloop
 
 
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-
 T = TypeVar("T")
 T = TypeVar("T")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -59,6 +58,18 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]
         index += 1
         index += 1
 
 
 
 
+async def asingle(aiter: AsyncIterable[T]) -> T:
+    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`."""
+    count = 0
+    async for item in aiter:
+        count += 1
+        if count == 2:
+            raise ValueError("asingle() expected an iterable with exactly one item, but got two or more items")
+    if count == 0:
+        raise ValueError("asingle() expected an iterable with exactly one item, but got an empty iterable")
+    return item
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
     try:
         await awaitable
         await awaitable
@@ -73,7 +84,7 @@ async def amap_in_executor(
     func: Callable[..., T],
     func: Callable[..., T],
     *iterables: AsyncIterable,
     *iterables: AsyncIterable,
     max_prefetch: Optional[int] = None,
     max_prefetch: Optional[int] = None,
-    executor: Optional[ThreadPoolExecutor] = None
+    executor: Optional[ThreadPoolExecutor] = None,
 ) -> AsyncIterator[T]:
 ) -> AsyncIterator[T]:
     """iterate from an async iterable in a background thread, yield results to async iterable"""
     """iterate from an async iterable in a background thread, yield results to async iterable"""
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()

+ 2 - 3
hivemind/utils/compression.py

@@ -1,15 +1,14 @@
 import os
 import os
+import warnings
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from typing import Tuple, Sequence, Optional
+from typing import Optional, Sequence, Tuple
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
-import warnings
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
-
 FP32_EPS = 1e-06
 FP32_EPS = 1e-06
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT16 = 2
 NUM_BYTES_FLOAT16 = 2

+ 2 - 2
hivemind/utils/grpc.py

@@ -6,14 +6,14 @@ from __future__ import annotations
 
 
 import os
 import os
 import threading
 import threading
-from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable
+from typing import Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
 
 
 import grpc
 import grpc
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.utils.networking import Endpoint
 from hivemind.utils.networking import Endpoint
-from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
+from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 137 - 161
hivemind/utils/mpfuture.py

@@ -2,18 +2,19 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import concurrent.futures._base as base
 import concurrent.futures._base as base
-from contextlib import nullcontext, suppress
 import multiprocessing as mp
 import multiprocessing as mp
 import multiprocessing.connection
 import multiprocessing.connection
 import os
 import os
 import threading
 import threading
 import uuid
 import uuid
-from weakref import ref
+from contextlib import nullcontext
 from enum import Enum, auto
 from enum import Enum, auto
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
+from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
+from weakref import ref
 
 
-from hivemind.utils.logging import get_logger
+import torch  # used for py3.7-compatible shared memory
 
 
+from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -33,13 +34,38 @@ except ImportError:
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
 
 
 
 
-class MessageType(Enum):
+class SharedBytes:
+    """
+    A process-wide object that allocates large chunks of shared memory and partitions it into individual bytes.
+
+    Note: this process is only responsible for bulk allocation, it does not manage/free unused bytes.
+    The chunks are deallocated by the garbage collector,
+    when it detects that all processes no longer use any bytes from this chunk.
+    """
+
+    _lock = mp.Lock()
+    _pid: Optional[PID] = None
+    _buffer: Optional[torch.Tensor] = None
+    _index: int = 0
+
+    @classmethod
+    def next(cls) -> torch.Tensor:
+        """Create another shared byte value, represented as a scalar uint8 tensor"""
+        with cls._lock:
+            if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
+                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                cls._pid = os.getpid()
+                cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
+                cls._index = 0
+
+            cls._index += 1
+            return cls._buffer[cls._index - 1]
+
+
+class UpdateType(Enum):
     RESULT = auto()
     RESULT = auto()
     EXCEPTION = auto()
     EXCEPTION = auto()
-    RUNNING = auto()
     CANCEL = auto()
     CANCEL = auto()
-    STATE_REQUEST = auto()
-    STATE_RESPONSE = auto()
 
 
 
 
 class MPFuture(base.Future, Generic[ResultType]):
 class MPFuture(base.Future, Generic[ResultType]):
@@ -48,12 +74,9 @@ class MPFuture(base.Future, Generic[ResultType]):
     Any process can access future status and set the result / exception and check for state.
     Any process can access future status and set the result / exception and check for state.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
 
 
-    :param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
-      Setting synchronize=False results in slightly better performance of done or set_running_or_notify_cancel
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
-    :param loop: if specified, overrides default asyncio event loop for the purpose of awaiting MPFuture
 
 
     :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
     :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
      More specifically, there are two known limitations:
      More specifically, there are two known limitations:
@@ -64,26 +87,30 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
-    _process_wide_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
-    _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
-    _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
+    _active_futures: Optional[Dict[UID, "ref[MPFuture]"]] = None  # non-done futures originated from this process
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
 
-    SOFT_UPDATE_TIMEOUT = 0.5  # seconds spent awaiting status update before warning is printed
-    HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
-
-    def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
-        super().__init__()
-        self.synchronize = synchronize
+    def __init__(self, *, use_lock: bool = True):
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
+        self._shared_state_code = SharedBytes.next()
+        self._state_cache: Dict[State, State] = {}
+        # mapping from global to cached local future used that makes updates immediately
+        # available on setter side; dictionary-based cache works because future can visit any state at most once
+
+        base.Future.__init__(self)  # parent init is deferred because it uses self._shared_state_code
         self._state, self._result, self._exception = base.PENDING, None, None
         self._state, self._result, self._exception = base.PENDING, None, None
         self._use_lock = use_lock
         self._use_lock = use_lock
 
 
-        self._initialize_backend_if_necessary()
+        if self._origin_pid != MPFuture._active_pid:
+            with MPFuture._initialization_lock:
+                if self._origin_pid != MPFuture._active_pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    self._initialize_mpfuture_backend()
         assert self._uid not in MPFuture._active_futures
         assert self._uid not in MPFuture._active_futures
         MPFuture._active_futures[self._uid] = ref(self)
         MPFuture._active_futures[self._uid] = ref(self)
-        self._sender_pipe = MPFuture._process_wide_pipe
+        self._sender_pipe = MPFuture._global_sender_pipe
 
 
         try:
         try:
             self._loop = asyncio.get_event_loop()
             self._loop = asyncio.get_event_loop()
@@ -91,97 +118,83 @@ class MPFuture(base.Future, Generic[ResultType]):
         except RuntimeError:
         except RuntimeError:
             self._loop, self._aio_event = None, None
             self._loop, self._aio_event = None, None
 
 
-    def _set_event_if_necessary(self):
-        if self._aio_event is None or self._aio_event.is_set():
-            return
+    @property
+    def _state(self) -> State:
+        shared_state = ALL_STATES[self._shared_state_code.item()]
+        return self._state_cache.get(shared_state, shared_state)
+
+    @_state.setter
+    def _state(self, new_state: State):
+        self._shared_state_code[...] = ALL_STATES.index(new_state)
+        if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
+            self._set_event_threadsafe()
+
+    def _set_event_threadsafe(self):
         try:
         try:
-            loop = asyncio.get_running_loop()
+            running_loop = asyncio.get_running_loop()
         except RuntimeError:
         except RuntimeError:
-            loop = None
+            running_loop = None
 
 
         async def _event_setter():
         async def _event_setter():
             self._aio_event.set()
             self._aio_event.set()
 
 
-        if self._loop.is_running() and loop == self.get_loop():
+        if self._loop.is_running() and running_loop == self._loop:
             asyncio.create_task(_event_setter())
             asyncio.create_task(_event_setter())
-        elif self._loop.is_running() and loop != self.get_loop():
+        elif self._loop.is_running() and running_loop != self._loop:
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
         else:
         else:
             self._loop.run_until_complete(_event_setter())
             self._loop.run_until_complete(_event_setter())
 
 
     @classmethod
     @classmethod
-    def _initialize_backend_if_necessary(cls):
+    def _initialize_mpfuture_backend(cls):
         pid = os.getpid()
         pid = os.getpid()
-        if MPFuture._active_pid != pid:
-            with MPFuture._initialization_lock:
-                if MPFuture._active_pid != pid:
-                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
-                    logger.debug(f"Initializing MPFuture backend for pid {pid}")
-                    receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
-                    cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
-                    cls._pipe_waiter_thread = threading.Thread(
-                        target=cls._process_updates_in_background,
-                        args=[receiver_pipe],
-                        name=f"{__name__}.BACKEND",
-                        daemon=True,
-                    )
-                    cls._pipe_waiter_thread.start()
+        logger.debug(f"Initializing MPFuture backend for pid {pid}")
+
+        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
+        cls._active_pid, cls._active_futures = pid, {}
+        cls._pipe_waiter_thread = threading.Thread(
+            target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
+        )
+        cls._pipe_waiter_thread.start()
+
+    @staticmethod
+    def reset_backend():
+        """Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
+        MPFuture._active_pid = None
+        MPFuture._initialization_lock = mp.Lock()
+        MPFuture._update_lock = mp.Lock()
+        SharedBytes._lock = mp.Lock()
 
 
     @classmethod
     @classmethod
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
         pid = os.getpid()
         pid = os.getpid()
         while True:
         while True:
             try:
             try:
-                uid, msg_type, payload = receiver_pipe.recv()
+                if cls._pipe_waiter_thread is not threading.current_thread():
+                    break  # backend was reset, a new background thread has started
+
+                uid, update_type, payload = receiver_pipe.recv()
                 future = None
                 future = None
-                future_ref = cls._active_futures.get(uid)
+                future_ref = cls._active_futures.pop(uid, None)
                 if future_ref is not None:
                 if future_ref is not None:
                     future = future_ref()
                     future = future_ref()
 
 
-                if msg_type == MessageType.STATE_REQUEST:
-                    future_state = None if future is None else future.__getstate__()
-                    use_lock, return_pipe = payload
-                    with MPFuture._update_lock if use_lock else nullcontext():
-                        return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
-
-                elif msg_type == MessageType.STATE_RESPONSE:
-                    future, state_updated_event = cls._status_requests.get(uid, (None, None))
-                    if future is None:
-                        logger.debug("Received a state update for a future that does not await status update.")
-                    else:
-                        if payload is not None:
-                            future.__setstate__(payload)
-                        else:
-                            base.Future.cancel(future)
-                        state_updated_event.set()
-
-                elif future is None:
-                    logger.debug(
-                        f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
-                    )
-                elif msg_type == MessageType.RESULT:
+                if future is None:
+                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
+                elif update_type == UpdateType.RESULT:
                     future.set_result(payload)
                     future.set_result(payload)
-                elif msg_type == MessageType.EXCEPTION:
+                elif update_type == UpdateType.EXCEPTION:
                     future.set_exception(payload)
                     future.set_exception(payload)
-                elif msg_type == MessageType.RUNNING:
-                    try:
-                        future.set_running_or_notify_cancel()
-                    except (InvalidStateError, RuntimeError) as e:
-                        logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
-                elif msg_type == MessageType.CANCEL:
+                elif update_type == UpdateType.CANCEL:
                     future.cancel()
                     future.cancel()
                 else:
                 else:
-                    raise RuntimeError(f"Received unexpected update type {msg_type}")
-
-                if future is None or future.done():
-                    cls._active_futures.pop(uid, None)
-
+                    raise RuntimeError(f"Received unexpected update type {update_type}")
             except (BrokenPipeError, EOFError, ConnectionError):
             except (BrokenPipeError, EOFError, ConnectionError):
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
             except Exception as e:
             except Exception as e:
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
 
-    def _send_update(self, update_type: MessageType, payload: Any = None):
+    def _send_update(self, update_type: UpdateType, payload: Any = None):
         """This method sends result, exception or cancel to the MPFuture origin."""
         """This method sends result, exception or cancel to the MPFuture origin."""
         try:
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
             with MPFuture._update_lock if self._use_lock else nullcontext():
@@ -189,110 +202,76 @@ class MPFuture(base.Future, Generic[ResultType]):
         except (ConnectionError, BrokenPipeError, EOFError) as e:
         except (ConnectionError, BrokenPipeError, EOFError) as e:
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
 
 
-    def _synchronize_if_necessary(self):
-        if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
-            return
-
-        self._initialize_backend_if_necessary()
-
-        status_updated = threading.Event()
-        _, existing_status_event = self._status_requests.setdefault(self._uid, (self, status_updated))
-        # this line checks if another thread is synchronizing concurrently, assuming that setdefault to be atomic
-
-        if existing_status_event != status_updated:
-            existing_status_event.wait(MPFuture.HARD_UPDATE_TIMEOUT)
-            return
-
-        # otherwise create a new request for synchronization
-
-        try:
-            with MPFuture._update_lock if self._use_lock else nullcontext():
-                payload = (self._use_lock, self._process_wide_pipe)
-                self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
-            status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
-            if not status_updated.is_set():
-                logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
-                status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
-                if not status_updated.is_set() and not self.cancel():
-                    with suppress(InvalidStateError, RuntimeError):
-                        self.set_exception(
-                            TimeoutError(
-                                f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
-                                f"MPFuture is cancelled"
-                            )
-                        )
-                    status_updated.set()  # this triggers any concurrent _synchronize_if_necessary calls to finish
-        except (ConnectionError, BrokenPipeError, EOFError) as e:
-            logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
-            if not self.cancel():
-                with suppress(InvalidStateError, RuntimeError):
-                    self.set_exception(e)
-        finally:
-            self._status_requests.pop(self._uid, None)
-
     def set_result(self, result: ResultType):
     def set_result(self, result: ResultType):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_result(result)
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
         else:
         else:
-            self._send_update(MessageType.RESULT, result)
-        super().set_result(result)
+            self._state_cache[self._state], self._result = base.FINISHED, result
+            self._send_update(UpdateType.RESULT, result)
 
 
     def set_exception(self, exception: Optional[BaseException]):
     def set_exception(self, exception: Optional[BaseException]):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_exception(exception)
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
         else:
         else:
-            self._send_update(MessageType.EXCEPTION, exception)
-        super().set_exception(exception)
+            self._state_cache[self._state], self._exception = base.FINISHED, exception
+            self._send_update(UpdateType.EXCEPTION, exception)
 
 
     def cancel(self) -> bool:
     def cancel(self) -> bool:
-        if self._state in [base.RUNNING, base.FINISHED]:
-            return False
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+            return super().cancel()
+        elif self._state in [base.RUNNING, base.FINISHED]:
+            return False
         else:
         else:
-            self._send_update(MessageType.CANCEL)
-        return super().cancel()
+            self._state_cache[self._state] = base.CANCELLED
+            self._send_update(UpdateType.CANCEL)
+            return True
 
 
     def set_running_or_notify_cancel(self):
     def set_running_or_notify_cancel(self):
-        """if synchronize is set to False, this future will ignore any state changes from origin"""
-        self._synchronize_if_necessary()
-        try:
-            is_running = super().set_running_or_notify_cancel()
-            if is_running and os.getpid() != self._origin_pid:
-                self._send_update(MessageType.RUNNING)
-            return is_running
-        except RuntimeError as e:
-            raise InvalidStateError(str(e))
+        if self._state == base.PENDING:
+            self._state = base.RUNNING
+            return True
+        elif self._state == base.CANCELLED:
+            return False
+        else:
+            raise InvalidStateError(
+                f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
+            )
 
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
     def result(self, timeout: Optional[float] = None) -> ResultType:
         if self._state not in TERMINAL_STATES:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await result")
                 raise RuntimeError("Only the process that created MPFuture can await result")
-        return super().result(timeout)
+            return super().result(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        elif self._exception:
+            raise self._exception
+        else:
+            return self._result
 
 
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
         if self._state not in TERMINAL_STATES:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await exception")
                 raise RuntimeError("Only the process that created MPFuture can await exception")
-        return super().exception(timeout)
+            return super().exception(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        return self._exception
 
 
     def done(self) -> bool:
     def done(self) -> bool:
-        self._synchronize_if_necessary()
         return self._state in TERMINAL_STATES
         return self._state in TERMINAL_STATES
 
 
     def running(self):
     def running(self):
-        self._synchronize_if_necessary()
         return self._state == base.RUNNING
         return self._state == base.RUNNING
 
 
     def cancelled(self):
     def cancelled(self):
-        self._synchronize_if_necessary()
         return self._state == base.CANCELLED
         return self._state == base.CANCELLED
 
 
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
@@ -300,9 +279,6 @@ class MPFuture(base.Future, Generic[ResultType]):
             raise RuntimeError("Only the process that created MPFuture can set callbacks")
             raise RuntimeError("Only the process that created MPFuture can set callbacks")
         return super().add_done_callback(callback)
         return super().add_done_callback(callback)
 
 
-    def get_loop(self) -> Optional[asyncio.BaseEventLoop]:
-        return self._loop
-
     def __await__(self):
     def __await__(self):
         if not self._aio_event:
         if not self._aio_event:
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
@@ -320,9 +296,8 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     def __getstate__(self):
     def __getstate__(self):
         return dict(
         return dict(
-            synchronize=self.synchronize,
             _sender_pipe=self._sender_pipe,
             _sender_pipe=self._sender_pipe,
-            _state=self._state,
+            _shared_state_code=self._shared_state_code,
             _origin_pid=self._origin_pid,
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _uid=self._uid,
             _use_lock=self._use_lock,
             _use_lock=self._use_lock,
@@ -331,12 +306,13 @@ class MPFuture(base.Future, Generic[ResultType]):
         )
         )
 
 
     def __setstate__(self, state):
     def __setstate__(self, state):
-        self.synchronize = state["synchronize"]
         self._sender_pipe = state["_sender_pipe"]
         self._sender_pipe = state["_sender_pipe"]
-        self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
+        self._shared_state_code = state["_shared_state_code"]
+        self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]
         self._use_lock = state["_use_lock"]
 
 
         self._waiters, self._done_callbacks = [], []
         self._waiters, self._done_callbacks = [], []
         self._condition = threading.Condition()
         self._condition = threading.Condition()
         self._aio_event, self._loop = None, None
         self._aio_event, self._loop = None, None
+        self._state_cache = {}

+ 1 - 2
hivemind/utils/networking.py

@@ -5,7 +5,6 @@ from typing import Optional, Sequence
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-
 Hostname, Port = str, int  # flavour types
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = "127.0.0.1"
 LOCALHOST = "127.0.0.1"
@@ -31,7 +30,7 @@ def strip_port(endpoint: Endpoint) -> Hostname:
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
 
 
 
 
-def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
     """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
     try:
     try:
         with closing(socket.socket(*params)) as sock:
         with closing(socket.socket(*params)) as sock:

+ 1 - 1
hivemind/utils/serializer.py

@@ -1,6 +1,6 @@
 """ A unified interface for several common serialization methods """
 """ A unified interface for several common serialization methods """
-from typing import Dict, Any
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
+from typing import Any, Dict
 
 
 import msgpack
 import msgpack
 
 

+ 1 - 1
hivemind/utils/tensor_descr.py

@@ -1,5 +1,5 @@
 import warnings
 import warnings
-from dataclasses import dataclass, asdict
+from dataclasses import asdict, dataclass
 
 
 import torch
 import torch
 
 

+ 2 - 1
hivemind/utils/timed_storage.py

@@ -1,10 +1,11 @@
 """ A dictionary-like storage that stores items until a specified expiration time or up to a limited size """
 """ A dictionary-like storage that stores items until a specified expiration time or up to a limited size """
 from __future__ import annotations
 from __future__ import annotations
+
 import heapq
 import heapq
 import time
 import time
 from contextlib import contextmanager
 from contextlib import contextmanager
-from typing import TypeVar, Generic, Optional, Dict, List, Iterator, Tuple
 from dataclasses import dataclass
 from dataclasses import dataclass
+from typing import Dict, Generic, Iterator, List, Optional, Tuple, TypeVar
 
 
 KeyType = TypeVar("KeyType")
 KeyType = TypeVar("KeyType")
 ValueType = TypeVar("ValueType")
 ValueType = TypeVar("ValueType")

+ 7 - 0
pyproject.toml

@@ -1,3 +1,10 @@
 [tool.black]
 [tool.black]
 line-length = 119
 line-length = 119
 required-version = "21.6b0"
 required-version = "21.6b0"
+
+[tool.isort]
+profile = "black"
+line_length = 119
+combine_as_imports = true
+combine_star = true
+known_local_folder = ["arguments", "test_utils", "tests", "utils"]

+ 1 - 1
requirements-dev.txt

@@ -2,8 +2,8 @@ pytest
 pytest-forked
 pytest-forked
 pytest-asyncio
 pytest-asyncio
 pytest-cov
 pytest-cov
-codecov
 tqdm
 tqdm
 scikit-learn
 scikit-learn
 black==21.6b0
 black==21.6b0
+isort
 psutil
 psutil

+ 22 - 2
tests/conftest.py

@@ -1,15 +1,33 @@
+import asyncio
 import gc
 import gc
+import multiprocessing as mp
 from contextlib import suppress
 from contextlib import suppress
 
 
 import psutil
 import psutil
 import pytest
 import pytest
 
 
-from hivemind.utils import get_logger
-
+from hivemind.utils.logging import get_logger
+from hivemind.utils.mpfuture import MPFuture, SharedBytes
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+@pytest.fixture
+def event_loop():
+    """
+    This overrides the ``event_loop`` fixture from pytest-asyncio
+    (e.g. to make it compatible with ``asyncio.subprocess``).
+
+    This fixture is identical to the original one but does not call ``loop.close()`` in the end.
+    Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops).
+    However, finalizers of objects created in the current test may reference the current loop and fail if it is closed.
+    For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer
+    fails if the loop is closed, but works if the loop is only stopped).
+    """
+
+    yield asyncio.get_event_loop()
+
+
 @pytest.fixture(autouse=True, scope="session")
 @pytest.fixture(autouse=True, scope="session")
 def cleanup_children():
 def cleanup_children():
     yield
     yield
@@ -26,3 +44,5 @@ def cleanup_children():
         for child in children:
         for child in children:
             with suppress(psutil.NoSuchProcess):
             with suppress(psutil.NoSuchProcess):
                 child.kill()
                 child.kill()
+
+    MPFuture.reset_backend()

+ 28 - 47
tests/test_allreduce.py

@@ -3,16 +3,15 @@ import random
 import time
 import time
 from typing import Sequence
 from typing import Sequence
 
 
-import grpc
 import pytest
 import pytest
 import torch
 import torch
 
 
-from hivemind import aenumerate, Endpoint
+from hivemind import aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
-from hivemind.proto import averaging_pb2_grpc
+from hivemind.p2p import P2P, StubBase
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import deserialize_torch_tensor, ChannelCache
+from hivemind.utils import deserialize_torch_tensor
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -152,46 +151,34 @@ async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float
             assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
             assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
 
 
 
 
-class AllreduceRunnerForTesting(AllReduceRunner):
-    """a version of AllReduceRunner that was monkey-patched to accept custom endpoint names"""
-
-    def __init__(self, *args, peer_endpoints, **kwargs):
-        self.__peer_endpoints = peer_endpoints
-        super().__init__(*args, **kwargs)
-
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(
-            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True
-        )
-
-
 NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
-    "peer_modes, averaging_weights, peer_fractions",
+    "peer_modes, averaging_weights, peer_fractions, part_size_bytes",
     [
     [
-        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1)),
-        ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1)),
-        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0)),
-        ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0)),
-        ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4)),
-        ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1)),
-        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0)),
-        ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4)),
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1), 2 ** 20),
+        ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1), 2 ** 20),
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
+        ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
+        ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4), 2 ** 20),
+        ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1), 2 ** 20),
+        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 2 ** 20),
+        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 256),
+        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 19),
+        ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4), 2 ** 20),
     ],
     ],
 )
 )
-@pytest.mark.parametrize(
-    "part_size_bytes",
-    [2 ** 20, 256, 19],
-)
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
     """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
     """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
 
 
-    peers = "alice", "bob", "carol", "colab"
+    p2ps = [await P2P.create()]
+    visible_maddrs = await p2ps[0].get_visible_maddrs()
+    p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
 
 
+    peers = [instance.peer_id for instance in p2ps]
     tensors_by_peer = {
     tensors_by_peer = {
         peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
         peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
         for i, peer in enumerate(peers)
         for i, peer in enumerate(peers)
@@ -199,28 +186,22 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
 
 
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
 
-    servers = []
     allreduce_protocols = []
     allreduce_protocols = []
-    peer_endpoints = {}
-
-    for peer in peers:
-        server = grpc.aio.server()
-        allreduce_protocol = AllreduceRunnerForTesting(
+    for p2p in p2ps:
+        allreduce_protocol = AllReduceRunner(
+            p2p=p2p,
+            servicer_type=AllReduceRunner,
+            prefix=None,
             group_id=group_id,
             group_id=group_id,
-            endpoint=peer,
-            tensors=[x.clone() for x in tensors_by_peer[peer]],
-            ordered_group_endpoints=peers,
+            tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
+            ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             peer_fractions=peer_fractions,
             modes=peer_modes,
             modes=peer_modes,
             weights=averaging_weights,
             weights=averaging_weights,
-            peer_endpoints=peer_endpoints,
             part_size_bytes=part_size_bytes,
             part_size_bytes=part_size_bytes,
         )
         )
-        averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(allreduce_protocol, server)
-        peer_endpoints[peer] = f"127.0.0.1:{server.add_insecure_port('127.0.0.1:*')}"
+        await allreduce_protocol.add_p2p_handlers(p2p)
         allreduce_protocols.append(allreduce_protocol)
         allreduce_protocols.append(allreduce_protocol)
-        servers.append(server)
-        await server.start()
 
 
     async def _run_allreduce_inplace(allreduce: AllReduceRunner):
     async def _run_allreduce_inplace(allreduce: AllReduceRunner):
         async for tensor_index, tensor_delta in aenumerate(allreduce):
         async for tensor_index, tensor_delta in aenumerate(allreduce):
@@ -244,5 +225,5 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
         assert len(output_tensors) == len(targets_for_peer)
         assert len(output_tensors) == len(targets_for_peer)
         assert all(torch.allclose(our, ref, atol=1e-6, rtol=0) for our, ref in zip(output_tensors, targets_for_peer))
         assert all(torch.allclose(our, ref, atol=1e-6, rtol=0) for our, ref in zip(output_tensors, targets_for_peer))
 
 
-    for server in servers:
-        await server.stop(grace=1)
+    for instance in p2ps:
+        await instance.shutdown()

+ 1 - 2
tests/test_auth.py

@@ -5,11 +5,10 @@ import pytest
 
 
 from hivemind.proto import dht_pb2
 from hivemind.proto import dht_pb2
 from hivemind.proto.auth_pb2 import AccessToken
 from hivemind.proto.auth_pb2 import AccessToken
-from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
+from hivemind.utils.auth import AuthRole, AuthRPCWrapper, TokenAuthorizerBase
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 

+ 89 - 82
tests/test_averaging.py

@@ -1,4 +1,5 @@
 import random
 import random
+import time
 
 
 import numpy as np
 import numpy as np
 import pytest
 import pytest
@@ -9,46 +10,51 @@ import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
+from test_utils.dht_swarms import launch_dht_instances
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_key_manager():
 async def test_key_manager():
+    dht = hivemind.DHT(start=True)
     key_manager = GroupKeyManager(
     key_manager = GroupKeyManager(
-        hivemind.DHT(start=True),
-        endpoint="localhvost",
+        dht,
         prefix="test_averaging",
         prefix="test_averaging",
         initial_group_bits="10110",
         initial_group_bits="10110",
         target_group_size=2,
         target_group_size=2,
     )
     )
+    alice = dht.peer_id
+    bob = PeerID(b"bob")
 
 
     t = hivemind.get_dht_time()
     t = hivemind.get_dht_time()
     key = key_manager.current_key
     key = key_manager.current_key
-    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 60)
-    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61)
+    await key_manager.declare_averager(key, alice, expiration_time=t + 60)
+    await key_manager.declare_averager(key, bob, expiration_time=t + 61)
 
 
     q1 = await key_manager.get_averagers(key, only_active=True)
     q1 = await key_manager.get_averagers(key, only_active=True)
 
 
-    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 66)
+    await key_manager.declare_averager(key, alice, expiration_time=t + 66)
     q2 = await key_manager.get_averagers(key, only_active=True)
     q2 = await key_manager.get_averagers(key, only_active=True)
 
 
-    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61, looking_for_group=False)
+    await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False)
     q3 = await key_manager.get_averagers(key, only_active=True)
     q3 = await key_manager.get_averagers(key, only_active=True)
     q4 = await key_manager.get_averagers(key, only_active=False)
     q4 = await key_manager.get_averagers(key, only_active=False)
 
 
     q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
     q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
 
 
-    assert len(q1) == 2 and ("localhvost", t + 60) in q1 and ("localhvost2", t + 61) in q1
-    assert len(q2) == 2 and ("localhvost", t + 66) in q2 and ("localhvost2", t + 61) in q2
-    assert len(q3) == 1 and ("localhvost", t + 66) in q3
-    assert len(q4) == 2 and ("localhvost", t + 66) in q4 and ("localhvost2", t + 61) in q2
+    assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1
+    assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2
+    assert len(q3) == 1 and (alice, t + 66) in q3
+    assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2
     assert len(q5) == 0
     assert len(q5) == 0
 
 
+    dht.shutdown()
 
 
-def _test_allreduce_once(n_clients, n_aux):
-    dht = hivemind.DHT(start=True)
 
 
+def _test_allreduce_once(n_clients, n_aux):
     n_peers = 4
     n_peers = 4
     modes = (
     modes = (
         [AveragingMode.CLIENT] * n_clients
         [AveragingMode.CLIENT] * n_clients
@@ -69,6 +75,7 @@ def _test_allreduce_once(n_clients, n_aux):
         for i in range(len(tensors1))
         for i in range(len(tensors1))
     ]
     ]
 
 
+    dht_instances = launch_dht_instances(len(peer_tensors))
     averagers = [
     averagers = [
         hivemind.averaging.DecentralizedAverager(
         hivemind.averaging.DecentralizedAverager(
             tensors,
             tensors,
@@ -77,11 +84,10 @@ def _test_allreduce_once(n_clients, n_aux):
             averaging_expiration=15,
             averaging_expiration=15,
             prefix="mygroup",
             prefix="mygroup",
             client_mode=mode == AveragingMode.CLIENT,
             client_mode=mode == AveragingMode.CLIENT,
-            listen_on="127.0.0.1:*",
             auxiliary=mode == AveragingMode.AUX,
             auxiliary=mode == AveragingMode.AUX,
             start=True,
             start=True,
         )
         )
-        for tensors, mode in zip(peer_tensors, modes)
+        for tensors, dht, mode in zip(peer_tensors, dht_instances, modes)
     ]
     ]
 
 
     futures = []
     futures = []
@@ -90,7 +96,7 @@ def _test_allreduce_once(n_clients, n_aux):
     for future in futures:
     for future in futures:
         result = future.result()
         result = future.result()
         for averager in averagers:
         for averager in averagers:
-            assert averager.endpoint in result
+            assert averager.peer_id in result
 
 
     for averager in averagers:
     for averager in averagers:
         if averager.mode != AveragingMode.AUX:
         if averager.mode != AveragingMode.AUX:
@@ -98,9 +104,8 @@ def _test_allreduce_once(n_clients, n_aux):
                 for ref, our in zip(reference, averaged_tensors):
                 for ref, our in zip(reference, averaged_tensors):
                     assert torch.allclose(ref, our, atol=1e-6)
                     assert torch.allclose(ref, our, atol=1e-6)
 
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -118,8 +123,6 @@ def test_allreduce_once_edge_cases(n_clients, n_aux):
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
-    dht = hivemind.DHT(start=True)
-
     n_peers = 4
     n_peers = 4
     client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
     client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
     random.shuffle(client_modes)
     random.shuffle(client_modes)
@@ -128,6 +131,8 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
+
+    dht_instances = launch_dht_instances(4)
     averagers = [
     averagers = [
         hivemind.averaging.DecentralizedAverager(
         hivemind.averaging.DecentralizedAverager(
             tensors,
             tensors,
@@ -136,11 +141,11 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
             averaging_expiration=15,
             averaging_expiration=15,
             prefix="mygroup",
             prefix="mygroup",
             client_mode=client_mode,
             client_mode=client_mode,
-            listen_on="127.0.0.1:*",
             start=True,
             start=True,
         )
         )
-        for tensors, client_mode in zip([tensors1, tensors2, tensors3, tensors4], client_modes)
+        for tensors, dht, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dht_instances, client_modes)
     ]
     ]
+
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
     reference = [
     reference = [
         (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
         (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
@@ -159,15 +164,13 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
             for ref, our in zip(reference, averaged_tensors):
             for ref, our in zip(reference, averaged_tensors):
                 assert torch.allclose(ref, our, atol=1e-6)
                 assert torch.allclose(ref, our, atol=1e-6)
 
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allreduce_compression():
 def test_allreduce_compression():
     """this test ensures that compression works correctly when multiple tensors have different compression types"""
     """this test ensures that compression works correctly when multiple tensors have different compression types"""
-    dht = hivemind.DHT(start=True)
 
 
     tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
     tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
     tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
     tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
@@ -176,9 +179,10 @@ def test_allreduce_compression():
     FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
     FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
 
 
     for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
     for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        dht_instances = launch_dht_instances(2)
         averager1 = hivemind.averaging.DecentralizedAverager(
         averager1 = hivemind.averaging.DecentralizedAverager(
             [x.clone() for x in tensors1],
             [x.clone() for x in tensors1],
-            dht=dht,
+            dht=dht_instances[0],
             compression_type=compression_type_pair,
             compression_type=compression_type_pair,
             client_mode=True,
             client_mode=True,
             target_group_size=2,
             target_group_size=2,
@@ -187,11 +191,10 @@ def test_allreduce_compression():
         )
         )
         averager2 = hivemind.averaging.DecentralizedAverager(
         averager2 = hivemind.averaging.DecentralizedAverager(
             [x.clone() for x in tensors2],
             [x.clone() for x in tensors2],
-            dht=dht,
+            dht=dht_instances[1],
             compression_type=compression_type_pair,
             compression_type=compression_type_pair,
             target_group_size=2,
             target_group_size=2,
             prefix="mygroup",
             prefix="mygroup",
-            listen_on="127.0.0.1:*",
             start=True,
             start=True,
         )
         )
 
 
@@ -201,6 +204,9 @@ def test_allreduce_compression():
         with averager1.get_tensors() as averaged_tensors:
         with averager1.get_tensors() as averaged_tensors:
             results[compression_type_pair] = averaged_tensors
             results[compression_type_pair] = averaged_tensors
 
 
+        for instance in [averager1, averager2] + dht_instances:
+            instance.shutdown()
+
     assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
     assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
     assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
     assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
     assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
     assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
@@ -231,7 +237,7 @@ def compute_mean_std(averagers, unbiased=True):
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allreduce_grid():
 def test_allreduce_grid():
-    dht = hivemind.DHT(start=True)
+    dht_instances = launch_dht_instances(8)
     averagers = [
     averagers = [
         hivemind.averaging.DecentralizedAverager(
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             averaged_tensors=[torch.randn(3)],
@@ -239,10 +245,9 @@ def test_allreduce_grid():
             target_group_size=2,
             target_group_size=2,
             prefix="mygroup",
             prefix="mygroup",
             initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
             initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
-            listen_on="127.0.0.1:*",
             start=True,
             start=True,
         )
         )
-        for i in range(8)
+        for i, dht in enumerate(dht_instances)
     ]
     ]
 
 
     [means0], [stds0] = compute_mean_std(averagers)
     [means0], [stds0] = compute_mean_std(averagers)
@@ -262,48 +267,41 @@ def test_allreduce_grid():
         else:
         else:
             assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
             assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
 
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_allgather():
-    dht = hivemind.DHT(start=True)
+def test_allgather(n_averagers=8, target_group_size=4):
+    dht_instances = launch_dht_instances(n_averagers)
     averagers = [
     averagers = [
         hivemind.averaging.DecentralizedAverager(
         hivemind.averaging.DecentralizedAverager(
             [torch.ones(1)],
             [torch.ones(1)],
             dht=dht,
             dht=dht,
-            target_group_size=4,
+            target_group_size=target_group_size,
             averaging_expiration=15,
             averaging_expiration=15,
             prefix="mygroup",
             prefix="mygroup",
             initial_group_bits="000",
             initial_group_bits="000",
-            listen_on="127.0.0.1:*",
             start=True,
             start=True,
         )
         )
-        for _ in range(8)
+        for dht in dht_instances
     ]
     ]
 
 
     futures = []
     futures = []
     for i, averager in enumerate(averagers):
     for i, averager in enumerate(averagers):
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
 
 
-    assert len(set(repr(sorted(future.result())) for future in futures)) == 2
-
     reference_metadata = {
     reference_metadata = {
-        averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
+        averager.peer_id: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
     }
     }
     for future in futures:
     for future in futures:
         gathered = future.result()
         gathered = future.result()
+        assert len(gathered) == target_group_size
+        for peer_id in gathered:
+            assert gathered[peer_id] == reference_metadata[peer_id]
 
 
-        assert len(gathered) == 4
-
-        for endpoint in gathered:
-            assert gathered[endpoint] == reference_metadata[endpoint]
-
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 
 
 def get_cost(vector_size, partitions, bandwidths):
 def get_cost(vector_size, partitions, bandwidths):
@@ -351,7 +349,7 @@ def test_load_balancing():
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_too_few_peers():
 def test_too_few_peers():
-    dht = hivemind.DHT(start=True)
+    dht_instances = launch_dht_instances(4)
     averagers = [
     averagers = [
         hivemind.averaging.DecentralizedAverager(
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             averaged_tensors=[torch.randn(3)],
@@ -361,23 +359,25 @@ def test_too_few_peers():
             request_timeout=0.5,
             request_timeout=0.5,
             prefix="mygroup",
             prefix="mygroup",
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
-            listen_on="127.0.0.1:*",
             start=True,
             start=True,
         )
         )
-        for i in range(4)
+        for i, dht in enumerate(dht_instances)
     ]
     ]
     step_futures = [averager.step(wait=False) for averager in averagers]
     step_futures = [averager.step(wait=False) for averager in averagers]
     for future in step_futures:
     for future in step_futures:
         assert len(future.result()) == 2
         assert len(future.result()) == 2
 
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 
 
+@pytest.mark.skip(
+    reason="The current implementation of elasticity (multi-stage averaging when num_peers > ~3 * target_group_size) "
+    "is incorrect (TODO @justheuristic)"
+)
 @pytest.mark.forked
 @pytest.mark.forked
 def test_overcrowded(num_peers=16):
 def test_overcrowded(num_peers=16):
-    dht = hivemind.DHT(start=True)
+    dht_instances = launch_dht_instances(num_peers)
     averagers = [
     averagers = [
         hivemind.averaging.DecentralizedAverager(
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             averaged_tensors=[torch.randn(3)],
@@ -387,18 +387,16 @@ def test_overcrowded(num_peers=16):
             request_timeout=0.5,
             request_timeout=0.5,
             prefix="mygroup",
             prefix="mygroup",
             initial_group_bits="",
             initial_group_bits="",
-            listen_on="127.0.0.1:*",
             start=True,
             start=True,
         )
         )
-        for _ in range(num_peers)
+        for dht in dht_instances
     ]
     ]
-    for t in range(5):
+    for _ in range(5):
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
 
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -417,27 +415,22 @@ def test_load_state_from_peers():
             num_calls += 1
             num_calls += 1
             return super_metadata, super_tensors
             return super_metadata, super_tensors
 
 
-    dht_root = hivemind.DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-    dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
+    dht_instances = launch_dht_instances(2)
     averager1 = TestAverager(
     averager1 = TestAverager(
         [torch.randn(3), torch.rand(5)],
         [torch.randn(3), torch.rand(5)],
-        dht=dht1,
+        dht=dht_instances[0],
         start=True,
         start=True,
         prefix="demo-run",
         prefix="demo-run",
         target_group_size=2,
         target_group_size=2,
-        listen_on="127.0.0.1:*",
     )
     )
 
 
-    dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
-    dht2.get("demo-run.all_averagers")
+    dht_instances[1].get("demo-run.all_averagers")
     averager2 = TestAverager(
     averager2 = TestAverager(
         [torch.randn(3), torch.rand(5)],
         [torch.randn(3), torch.rand(5)],
-        dht=dht2,
+        dht=dht_instances[1],
         start=True,
         start=True,
         prefix="demo-run",
         prefix="demo-run",
         target_group_size=2,
         target_group_size=2,
-        listen_on="127.0.0.1:*",
     )
     )
 
 
     assert num_calls == 0
     assert num_calls == 0
@@ -463,12 +456,19 @@ def test_load_state_from_peers():
     assert num_calls == 3
     assert num_calls == 3
     assert got_metadata == super_metadata
     assert got_metadata == super_metadata
 
 
+    for instance in [averager1, averager2] + dht_instances:
+        instance.shutdown()
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_getset_bits():
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
     dht = hivemind.DHT(start=True)
     averager = hivemind.averaging.DecentralizedAverager(
     averager = hivemind.averaging.DecentralizedAverager(
-        [torch.randn(3)], dht=dht, start=True, prefix="test_prefix", target_group_size=2, listen_on="127.0.0.1:*"
+        [torch.randn(3)],
+        dht=dht,
+        start=True,
+        prefix="test_prefix",
+        target_group_size=2,
     )
     )
     averager.set_group_bits("00101011101010")
     averager.set_group_bits("00101011101010")
     assert averager.get_group_bits() == "00101011101010"
     assert averager.get_group_bits() == "00101011101010"
@@ -478,11 +478,9 @@ def test_getset_bits():
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)
     torch.manual_seed(42)
 
 
-    dht = hivemind.DHT(start=True)
+    dht_instances = launch_dht_instances(2)
     common_kwargs = {
     common_kwargs = {
-        "dht": dht,
         "start": True,
         "start": True,
-        "listen_on": "127.0.0.1:*",
         "prefix": "demo-run",
         "prefix": "demo-run",
         "target_group_size": 2,
         "target_group_size": 2,
     }
     }
@@ -490,13 +488,23 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     x1 = torch.randn(n_dims, requires_grad=True)
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
     opt1 = torch.optim.Adam([x1], lr=0.05)
     averager1 = hivemind.averaging.TrainingAverager(
     averager1 = hivemind.averaging.TrainingAverager(
-        opt1, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
+        opt1,
+        average_gradients=True,
+        average_parameters=True,
+        average_opt_statistics=["exp_avg_sq"],
+        dht=dht_instances[0],
+        **common_kwargs
     )
     )
 
 
     x2 = torch.randn(n_dims, requires_grad=True)
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
     opt2 = torch.optim.Adam([x2], lr=0.05)
     averager2 = hivemind.averaging.TrainingAverager(
     averager2 = hivemind.averaging.TrainingAverager(
-        opt2, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
+        opt2,
+        average_gradients=True,
+        average_parameters=True,
+        average_opt_statistics=["exp_avg_sq"],
+        dht=dht_instances[1],
+        **common_kwargs
     )
     )
     a = torch.ones(n_dims)
     a = torch.ones(n_dims)
 
 
@@ -526,6 +534,5 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
         assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
 
 
-    averager1.shutdown()
-    averager2.shutdown()
-    dht.shutdown()
+    for instance in [averager1, averager2] + dht_instances:
+        instance.shutdown()

+ 4 - 4
tests/test_dht.py

@@ -7,12 +7,12 @@ from multiaddr import Multiaddr
 
 
 import hivemind
 import hivemind
 
 
+from test_utils.dht_swarms import launch_dht_instances
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_get_store(n_peers=10):
 def test_get_store(n_peers=10):
-    peers = [hivemind.DHT(start=True)]
-    initial_peers = peers[0].get_visible_maddrs()
-    peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
+    peers = launch_dht_instances(n_peers)
 
 
     node1, node2 = random.sample(peers, 2)
     node1, node2 = random.sample(peers, 2)
     assert node1.store("key1", "value1", expiration_time=hivemind.get_dht_time() + 30)
     assert node1.store("key1", "value1", expiration_time=hivemind.get_dht_time() + 30)
@@ -104,5 +104,5 @@ async def test_dht_get_visible_maddrs():
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
     dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
 
 
-    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.id}")]
+    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()
     dht.shutdown()

+ 2 - 2
tests/test_dht_crypto.py

@@ -1,15 +1,15 @@
 import dataclasses
 import dataclasses
-import pickle
 import multiprocessing as mp
 import multiprocessing as mp
+import pickle
 
 
 import pytest
 import pytest
 
 
 import hivemind
 import hivemind
-from hivemind.utils.timed_storage import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.node import DHTNode
 from hivemind.dht.node import DHTNode
 from hivemind.dht.validation import DHTRecord
 from hivemind.dht.validation import DHTRecord
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.timed_storage import get_dht_time
 
 
 
 
 def test_rsa_signature_validator():
 def test_rsa_signature_validator():

+ 2 - 2
tests/test_dht_experts.py

@@ -6,11 +6,11 @@ import numpy as np
 import pytest
 import pytest
 
 
 import hivemind
 import hivemind
-from hivemind.dht import DHTNode
 from hivemind import LOCALHOST
 from hivemind import LOCALHOST
+from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.moe.server import declare_experts, get_experts
-from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
+from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 24 - 11
tests/test_dht_node.py

@@ -17,8 +17,8 @@ from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.p2p import P2P, PeerID
 from hivemind.p2p import P2P, PeerID
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from test_utils.dht_swarms import launch_swarm_in_separate_processes, launch_star_shaped_swarm
 
 
+from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -44,7 +44,7 @@ def run_protocol_listener(
     for peer_id in maddrs_to_peer_ids(initial_peers):
     for peer_id in maddrs_to_peer_ids(initial_peers):
         loop.run_until_complete(protocol.call_ping(peer_id))
         loop.run_until_complete(protocol.call_ping(peer_id))
 
 
-    maddr_conn.send((p2p.id, visible_maddrs))
+    maddr_conn.send((p2p.peer_id, visible_maddrs))
 
 
     async def shutdown():
     async def shutdown():
         await p2p.shutdown()
         await p2p.shutdown()
@@ -194,16 +194,27 @@ def test_empty_table():
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_dht_node():
+def test_dht_node(
+    n_peers: int = 20, n_sequential_peers: int = 5, parallel_rpc: int = 10, bucket_size: int = 5, num_replicas: int = 3
+):
     # step A: create a swarm of 50 dht nodes in separate processes
     # step A: create a swarm of 50 dht nodes in separate processes
     #         (first 5 created sequentially, others created in parallel)
     #         (first 5 created sequentially, others created in parallel)
-    processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(n_peers=50, n_sequential_peers=5)
+
+    processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(
+        n_peers=n_peers, n_sequential_peers=n_sequential_peers, bucket_size=bucket_size, num_replicas=num_replicas
+    )
 
 
     # step B: run 51-st node in this process
     # step B: run 51-st node in this process
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
     initial_peers = random.choice(swarm_maddrs)
     initial_peers = random.choice(swarm_maddrs)
     me = loop.run_until_complete(
     me = loop.run_until_complete(
-        DHTNode.create(initial_peers=initial_peers, parallel_rpc=10, cache_refresh_before_expiry=False)
+        DHTNode.create(
+            initial_peers=initial_peers,
+            parallel_rpc=parallel_rpc,
+            bucket_size=bucket_size,
+            num_replicas=num_replicas,
+            cache_refresh_before_expiry=False,
+        )
     )
     )
 
 
     # test 1: find self
     # test 1: find self
@@ -223,7 +234,7 @@ def test_dht_node():
     jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
     jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
     all_node_ids = list(dht.values())
     all_node_ids = list(dht.values())
 
 
-    for _ in range(10):
+    for _ in range(20):
         query_id = DHTID.generate()
         query_id = DHTID.generate()
         k_nearest = random.randint(1, 10)
         k_nearest = random.randint(1, 10)
         exclude_self = random.random() > 0.5
         exclude_self = random.random() > 0.5
@@ -249,10 +260,10 @@ def test_dht_node():
         jaccard_denominator += k_nearest
         jaccard_denominator += k_nearest
 
 
     accuracy = accuracy_numerator / accuracy_denominator
     accuracy = accuracy_numerator / accuracy_denominator
-    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 98-100%
+    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 90-100%
     jaccard_index = jaccard_numerator / jaccard_denominator
     jaccard_index = jaccard_numerator / jaccard_denominator
     logger.debug(f"Jaccard index (intersection over union): {jaccard_index}")  # should be 95-100%
     logger.debug(f"Jaccard index (intersection over union): {jaccard_index}")  # should be 95-100%
-    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+    assert accuracy >= 0.8, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
     assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
     assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
 
 
     # test 4: find all nodes
     # test 4: find all nodes
@@ -275,7 +286,10 @@ def test_dht_node():
     initial_peers = random.choice(swarm_maddrs)
     initial_peers = random.choice(swarm_maddrs)
     that_guy = loop.run_until_complete(
     that_guy = loop.run_until_complete(
         DHTNode.create(
         DHTNode.create(
-            initial_peers=initial_peers, parallel_rpc=10, cache_refresh_before_expiry=False, cache_locally=False
+            initial_peers=initial_peers,
+            parallel_rpc=parallel_rpc,
+            cache_refresh_before_expiry=False,
+            cache_locally=False,
         )
         )
     )
     )
 
 
@@ -310,10 +324,9 @@ def test_dht_node():
     assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
     assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
     assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
     assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
     assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
     assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
-    loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
 
 
     for node in [that_guy, me]:
     for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key))
+        value, time = loop.run_until_complete(node.get(upper_key, latest=True))
         assert isinstance(value, dict) and time == now + 50, (value, time)
         assert isinstance(value, dict) and time == now + 50, (value, time)
         assert value[subkey1] == (123, now + 10)
         assert value[subkey1] == (123, now + 10)
         assert value[subkey2] == (567, now + 30)
         assert value[subkey2] == (567, now + 30)

+ 2 - 2
tests/test_dht_storage.py

@@ -1,8 +1,8 @@
 import time
 import time
 
 
-from hivemind.utils.timed_storage import get_dht_time
-from hivemind.dht.storage import DHTLocalStorage, DHTID, DictionaryDHTValue
+from hivemind.dht.storage import DHTID, DHTLocalStorage, DictionaryDHTValue
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.timed_storage import get_dht_time
 
 
 
 
 def test_store():
 def test_store():

+ 1 - 1
tests/test_dht_validation.py

@@ -9,7 +9,7 @@ from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID
 from hivemind.dht.routing import DHTID
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.dht.validation import DHTRecord, CompositeValidator
+from hivemind.dht.validation import CompositeValidator, DHTRecord
 
 
 
 
 class SchemaA(BaseModel):
 class SchemaA(BaseModel):

+ 1 - 1
tests/test_expert_backend.py

@@ -6,7 +6,7 @@ import torch
 from torch.nn import Linear
 from torch.nn import Linear
 
 
 from hivemind import BatchTensorDescriptor, ExpertBackend
 from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.moe.server.checkpoints import store_experts, load_experts
+from hivemind.moe.server.checkpoints import load_experts, store_experts
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 
 EXPERT_WEIGHT_UPDATES = 3
 EXPERT_WEIGHT_UPDATES = 3

+ 1 - 2
tests/test_moe.py

@@ -4,9 +4,8 @@ import pytest
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-from hivemind.moe.server import background_server, declare_experts
 from hivemind.moe.client.expert import DUMMY
 from hivemind.moe.client.expert import DUMMY
-from hivemind.moe.server import layers
+from hivemind.moe.server import background_server, declare_experts, layers
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 25 - 13
tests/test_p2p_daemon.py

@@ -9,8 +9,9 @@ import numpy as np
 import pytest
 import pytest
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-from hivemind.p2p import P2P, P2PHandlerError
+from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.proto import dht_pb2
 from hivemind.proto import dht_pb2
+from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
 
 
 
 
@@ -33,6 +34,17 @@ async def test_daemon_killed_on_del():
     assert not is_process_running(child_pid)
     assert not is_process_running(child_pid)
 
 
 
 
+@pytest.mark.asyncio
+async def test_startup_error_message():
+    with pytest.raises(P2PDaemonError, match=r"failed to connect to bootstrap peers"):
+        await P2P.create(
+            initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
+        )
+
+    with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
+        await P2P.create(startup_timeout=0.1)  # Test that startup_timeout works
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "host_maddrs",
     "host_maddrs",
     [
     [
@@ -92,7 +104,7 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
         except asyncio.CancelledError:
         except asyncio.CancelledError:
             nonlocal handler_cancelled
             nonlocal handler_cancelled
             handler_cancelled = True
             handler_cancelled = True
-        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
+        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
 
 
     server_pid = server_primary._child.pid
     server_pid = server_primary._child.pid
     await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
     await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
@@ -104,12 +116,12 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
     assert is_process_running(client_pid)
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
-    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
+    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
 
 
     if should_cancel:
     if should_cancel:
         call_task = asyncio.create_task(
         call_task = asyncio.create_task(
-            client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+            client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
         )
         )
         await asyncio.sleep(0.25)
         await asyncio.sleep(0.25)
 
 
@@ -119,7 +131,7 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
         assert handler_cancelled
         assert handler_cancelled
     else:
     else:
         actual_response = await client.call_protobuf_handler(
         actual_response = await client.call_protobuf_handler(
-            server.id, handle_name, ping_request, dht_pb2.PingResponse
+            server.peer_id, handle_name, ping_request, dht_pb2.PingResponse
         )
         )
         assert actual_response == expected_response
         assert actual_response == expected_response
         assert not handler_cancelled
         assert not handler_cancelled
@@ -147,10 +159,10 @@ async def test_call_protobuf_handler_error(handle_name="handle"):
     assert is_process_running(client_pid)
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
 
 
     with pytest.raises(P2PHandlerError) as excinfo:
     with pytest.raises(P2PHandlerError) as excinfo:
-        await client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        await client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
     assert "boom" in str(excinfo.value)
     assert "boom" in str(excinfo.value)
 
 
     await server.shutdown()
     await server.shutdown()
@@ -196,7 +208,7 @@ async def test_call_peer_single_process():
 
 
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     await validate_square_stream(reader, writer)
     await validate_square_stream(reader, writer)
 
 
     await server.shutdown()
     await server.shutdown()
@@ -213,7 +225,7 @@ async def run_server(handler_name, server_side, response_received):
 
 
     await server.add_binary_stream_handler(handler_name, handle_square_stream)
     await server.add_binary_stream_handler(handler_name, handle_square_stream)
 
 
-    server_side.send(server.id)
+    server_side.send(server.peer_id)
     server_side.send(await server.get_visible_maddrs())
     server_side.send(await server.get_visible_maddrs())
     while response_received.value == 0:
     while response_received.value == 0:
         await asyncio.sleep(0.5)
         await asyncio.sleep(0.5)
@@ -281,7 +293,7 @@ async def test_error_closes_connection():
 
 
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     with closing(writer):
     with closing(writer):
         await P2P.send_raw_data(b"raise_error", writer)
         await P2P.send_raw_data(b"raise_error", writer)
         with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
         with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
@@ -290,7 +302,7 @@ async def test_error_closes_connection():
     # Despite the handler raised an exception, the server did not crash and ready for next requests
     # Despite the handler raised an exception, the server did not crash and ready for next requests
     assert is_process_running(server_pid)
     assert is_process_running(server_pid)
 
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     with closing(writer):
     with closing(writer):
         await P2P.send_raw_data(b"behave_normally", writer)
         await P2P.send_raw_data(b"behave_normally", writer)
         assert await P2P.receive_raw_data(reader) == b"okay"
         assert await P2P.receive_raw_data(reader) == b"okay"
@@ -309,7 +321,7 @@ async def test_handlers_on_different_replicas():
             await P2P.send_raw_data(key, writer)
             await P2P.send_raw_data(key, writer)
 
 
     server_primary = await P2P.create()
     server_primary = await P2P.create()
-    server_id = server_primary.id
+    server_id = server_primary.peer_id
     await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
     await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
 
 
     server_replica1 = await replicate_if_needed(server_primary, True)
     server_replica1 = await replicate_if_needed(server_primary, True)

+ 2 - 1
tests/test_p2p_daemon_bindings.py

@@ -17,7 +17,8 @@ from hivemind.p2p.p2p_daemon_bindings.utils import (
     write_unsigned_varint,
     write_unsigned_varint,
 )
 )
 from hivemind.proto import p2pd_pb2 as p2pd_pb
 from hivemind.proto import p2pd_pb2 as p2pd_pb
-from test_utils.p2p_daemon import make_p2pd_pair_ip4, connect_safe
+
+from test_utils.p2p_daemon import connect_safe, make_p2pd_pair_ip4
 
 
 
 
 def test_raise_if_failed_raises():
 def test_raise_if_failed_raises():

+ 17 - 13
tests/test_p2p_servicer.py

@@ -19,13 +19,13 @@ async def server_client():
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_unary_unary(server_client):
 async def test_unary_unary(server_client):
     class ExampleServicer(ServicerBase):
     class ExampleServicer(ServicerBase):
-        async def rpc_square(self, request: test_pb2.TestRequest, _: P2PContext) -> test_pb2.TestResponse:
+        async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
             return test_pb2.TestResponse(number=request.number ** 2)
             return test_pb2.TestResponse(number=request.number ** 2)
 
 
     server, client = server_client
     server, client = server_client
     servicer = ExampleServicer()
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = servicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
 
     assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
     assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
 
 
@@ -33,16 +33,18 @@ async def test_unary_unary(server_client):
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_stream_unary(server_client):
 async def test_stream_unary(server_client):
     class ExampleServicer(ServicerBase):
     class ExampleServicer(ServicerBase):
-        async def rpc_sum(self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext) -> test_pb2.TestResponse:
+        async def rpc_sum(
+            self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
+        ) -> test_pb2.TestResponse:
             result = 0
             result = 0
-            async for item in request:
+            async for item in stream:
                 result += item.number
                 result += item.number
             return test_pb2.TestResponse(number=result)
             return test_pb2.TestResponse(number=result)
 
 
     server, client = server_client
     server, client = server_client
     servicer = ExampleServicer()
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = servicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
         for i in range(10):
@@ -55,7 +57,7 @@ async def test_stream_unary(server_client):
 async def test_unary_stream(server_client):
 async def test_unary_stream(server_client):
     class ExampleServicer(ServicerBase):
     class ExampleServicer(ServicerBase):
         async def rpc_count(
         async def rpc_count(
-            self, request: test_pb2.TestRequest, _: P2PContext
+            self, request: test_pb2.TestRequest, _context: P2PContext
         ) -> AsyncIterator[test_pb2.TestResponse]:
         ) -> AsyncIterator[test_pb2.TestResponse]:
             for i in range(request.number):
             for i in range(request.number):
                 yield test_pb2.TestResponse(number=i)
                 yield test_pb2.TestResponse(number=i)
@@ -63,7 +65,7 @@ async def test_unary_stream(server_client):
     server, client = server_client
     server, client = server_client
     servicer = ExampleServicer()
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = servicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
 
     i = 0
     i = 0
     async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
     async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
@@ -76,16 +78,16 @@ async def test_unary_stream(server_client):
 async def test_stream_stream(server_client):
 async def test_stream_stream(server_client):
     class ExampleServicer(ServicerBase):
     class ExampleServicer(ServicerBase):
         async def rpc_powers(
         async def rpc_powers(
-            self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext
+            self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
         ) -> AsyncIterator[test_pb2.TestResponse]:
         ) -> AsyncIterator[test_pb2.TestResponse]:
-            async for item in request:
+            async for item in stream:
                 yield test_pb2.TestResponse(number=item.number ** 2)
                 yield test_pb2.TestResponse(number=item.number ** 2)
                 yield test_pb2.TestResponse(number=item.number ** 3)
                 yield test_pb2.TestResponse(number=item.number ** 3)
 
 
     server, client = server_client
     server, client = server_client
     servicer = ExampleServicer()
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = servicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
         for i in range(10):
@@ -109,7 +111,9 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
     handler_cancelled = False
     handler_cancelled = False
 
 
     class ExampleServicer(ServicerBase):
     class ExampleServicer(ServicerBase):
-        async def rpc_wait(self, request: test_pb2.TestRequest, _: P2PContext) -> AsyncIterator[test_pb2.TestResponse]:
+        async def rpc_wait(
+            self, request: test_pb2.TestRequest, _context: P2PContext
+        ) -> AsyncIterator[test_pb2.TestResponse]:
             try:
             try:
                 yield test_pb2.TestResponse(number=request.number + 1)
                 yield test_pb2.TestResponse(number=request.number + 1)
                 await asyncio.sleep(2)
                 await asyncio.sleep(2)
@@ -124,7 +128,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
 
 
     if cancel_reason == "close_connection":
     if cancel_reason == "close_connection":
-        _, reader, writer = await client.call_binary_stream_handler(server.id, "ExampleServicer.rpc_wait")
+        _, reader, writer = await client.call_binary_stream_handler(server.peer_id, "ExampleServicer.rpc_wait")
         await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
         await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
         await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
         await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
 
 
@@ -134,7 +138,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
 
 
         writer.close()
         writer.close()
     elif cancel_reason == "close_generator":
     elif cancel_reason == "close_generator":
-        stub = servicer.get_stub(client, server.id)
+        stub = ExampleServicer.get_stub(client, server.peer_id)
         iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
         iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
 
 
         assert await iter.__anext__() == test_pb2.TestResponse(number=11)
         assert await iter.__anext__() == test_pb2.TestResponse(number=11)

+ 2 - 2
tests/test_routing.py

@@ -1,10 +1,10 @@
-import random
 import heapq
 import heapq
 import operator
 import operator
+import random
 from itertools import chain, zip_longest
 from itertools import chain, zip_longest
 
 
 from hivemind import LOCALHOST
 from hivemind import LOCALHOST
-from hivemind.dht.routing import RoutingTable, DHTID
+from hivemind.dht.routing import DHTID, RoutingTable
 
 
 
 
 def test_ids_basic():
 def test_ids_basic():

+ 3 - 2
tests/test_training.py

@@ -10,7 +10,7 @@ from sklearn.datasets import load_digits
 from hivemind import DHT
 from hivemind import DHT
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
-from hivemind.optim import DecentralizedSGD, DecentralizedAdam
+from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -169,6 +169,7 @@ def test_decentralized_optimizer_step():
     assert torch.allclose(param1, torch.full_like(param1, reference))
     assert torch.allclose(param1, torch.full_like(param1, reference))
 
 
 
 
+@pytest.mark.skip(reason="Skipped until a more stable averager implementation is ready (TODO @justheuristic)")
 @pytest.mark.forked
 @pytest.mark.forked
 def test_decentralized_optimizer_averaging():
 def test_decentralized_optimizer_averaging():
     dht_root = DHT(start=True)
     dht_root = DHT(start=True)
@@ -200,7 +201,7 @@ def test_decentralized_optimizer_averaging():
     (param1.sum() + param2.sum()).backward()
     (param1.sum() + param2.sum()).backward()
 
 
     for _ in range(100):
     for _ in range(100):
-        time.sleep(0.01)
+        time.sleep(0.1)
         opt1.step()
         opt1.step()
         opt2.step()
         opt2.step()
         opt1.zero_grad()
         opt1.zero_grad()

+ 8 - 11
tests/test_util_modules.py

@@ -12,9 +12,9 @@ import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import MSGPackSerializer, ValueWithExpiration, HeapEntry, DHTExpiration
-from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
+from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 from hivemind.utils.mpfuture import InvalidStateError
 
 
 
 
@@ -256,8 +256,8 @@ def test_mpfuture_done_callback():
 
 
     assert future1.done() and not future1.cancelled()
     assert future1.done() and not future1.cancelled()
     assert future2.done() and future2.cancelled()
     assert future2.done() and future2.cancelled()
-    events[0].wait(1)
-    events[1].wait(1)
+    for i in 0, 1, 4:
+        events[i].wait(1)
     assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
     assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
     assert not events[3].is_set()
     assert not events[3].is_set()
 
 
@@ -266,15 +266,14 @@ def test_mpfuture_done_callback():
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-@pytest.mark.parametrize("synchronize", [True, False])
-def test_many_futures(synchronize: bool):
+def test_many_futures():
     evt = mp.Event()
     evt = mp.Event()
     receiver, sender = mp.Pipe()
     receiver, sender = mp.Pipe()
-    main_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(1000)]
+    main_futures = [hivemind.MPFuture() for _ in range(1000)]
     assert len(hivemind.MPFuture._active_futures) == 1000
     assert len(hivemind.MPFuture._active_futures) == 1000
 
 
     def _run_peer():
     def _run_peer():
-        fork_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(500)]
+        fork_futures = [hivemind.MPFuture() for _ in range(500)]
         assert len(hivemind.MPFuture._active_futures) == 500
         assert len(hivemind.MPFuture._active_futures) == 500
 
 
         for i, future in enumerate(random.sample(main_futures, 300)):
         for i, future in enumerate(random.sample(main_futures, 300)):
@@ -299,8 +298,6 @@ def test_many_futures(synchronize: bool):
     p.start()
     p.start()
 
 
     some_fork_futures = receiver.recv()
     some_fork_futures = receiver.recv()
-
-    time.sleep(0.5)  # wait for active futures to synchronize
     assert len(hivemind.MPFuture._active_futures) == 700
     assert len(hivemind.MPFuture._active_futures) == 700
 
 
     for future in some_fork_futures:
     for future in some_fork_futures:

+ 21 - 9
tests/test_utils/dht_swarms.py

@@ -7,17 +7,18 @@ from typing import Dict, List, Tuple
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
+from hivemind.dht import DHT
 from hivemind.dht.node import DHTID, DHTNode
 from hivemind.dht.node import DHTID, DHTNode
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
 
 
 
 
-def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
+def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue, **kwargs):
     if asyncio.get_event_loop().is_running():
     if asyncio.get_event_loop().is_running():
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
         asyncio.set_event_loop(asyncio.new_event_loop())
         asyncio.set_event_loop(asyncio.new_event_loop())
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
 
 
-    node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10))
+    node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, **kwargs))
     maddrs = loop.run_until_complete(node.get_visible_maddrs())
     maddrs = loop.run_until_complete(node.get_visible_maddrs())
 
 
     info_queue.put((node.node_id, node.peer_id, maddrs))
     info_queue.put((node.node_id, node.peer_id, maddrs))
@@ -31,7 +32,7 @@ def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
 
 
 
 
 def launch_swarm_in_separate_processes(
 def launch_swarm_in_separate_processes(
-    n_peers: int, n_sequential_peers: int
+    n_peers: int, n_sequential_peers: int, **kwargs
 ) -> Tuple[List[mp.Process], Dict[PeerID, DHTID], List[List[Multiaddr]]]:
 ) -> Tuple[List[mp.Process], Dict[PeerID, DHTID], List[List[Multiaddr]]]:
     assert (
     assert (
         n_sequential_peers < n_peers
         n_sequential_peers < n_peers
@@ -47,19 +48,19 @@ def launch_swarm_in_separate_processes(
     for _ in range(n_sequential_peers):
     for _ in range(n_sequential_peers):
         initial_peers = random.choice(swarm_maddrs) if swarm_maddrs else []
         initial_peers = random.choice(swarm_maddrs) if swarm_maddrs else []
 
 
-        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
+        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), kwargs=kwargs, daemon=True)
         proc.start()
         proc.start()
         processes.append(proc)
         processes.append(proc)
 
 
-        node_id, peer_endpoint, peer_maddrs = info_queue.get()
-        dht[peer_endpoint] = node_id
+        node_id, peer_id, peer_maddrs = info_queue.get()
+        dht[peer_id] = node_id
         swarm_maddrs.append(peer_maddrs)
         swarm_maddrs.append(peer_maddrs)
 
 
     def collect_info():
     def collect_info():
         while True:
         while True:
-            node_id, peer_endpoint, peer_maddrs = info_queue.get()
+            node_id, peer_id, peer_maddrs = info_queue.get()
             with info_lock:
             with info_lock:
-                dht[peer_endpoint] = node_id
+                dht[peer_id] = node_id
                 swarm_maddrs.append(peer_maddrs)
                 swarm_maddrs.append(peer_maddrs)
 
 
                 if len(dht) == n_peers:
                 if len(dht) == n_peers:
@@ -72,7 +73,7 @@ def launch_swarm_in_separate_processes(
         with info_lock:
         with info_lock:
             initial_peers = random.choice(swarm_maddrs)
             initial_peers = random.choice(swarm_maddrs)
 
 
-        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
+        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), kwargs=kwargs, daemon=True)
         proc.start()
         proc.start()
         processes.append(proc)
         processes.append(proc)
 
 
@@ -86,3 +87,14 @@ async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
     initial_peers = await nodes[0].get_visible_maddrs()
     initial_peers = await nodes[0].get_visible_maddrs()
     nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs) for _ in range(n_peers - 1)])
     nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs) for _ in range(n_peers - 1)])
     return nodes
     return nodes
+
+
+def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
+    dhts = [DHT(start=True, **kwargs)]
+    initial_peers = dhts[0].get_visible_maddrs()
+
+    dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
+    for instance in dhts[1:]:
+        instance.ready.wait()
+
+    return dhts

+ 5 - 6
tests/test_utils/p2p_daemon.py

@@ -6,14 +6,13 @@ import time
 import uuid
 import uuid
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 from typing import NamedTuple
 from typing import NamedTuple
-from pkg_resources import resource_filename
 
 
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
+from pkg_resources import resource_filename
 
 
-from hivemind import find_open_port
+from hivemind import get_free_port
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 
 
-
 TIMEOUT_DURATION = 30  # seconds
 TIMEOUT_DURATION = 30  # seconds
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")
 
 
@@ -58,7 +57,7 @@ class Daemon:
 
 
     def _run(self):
     def _run(self):
         cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
         cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
-        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
+        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{get_free_port()}"]
         if self.enable_connmgr:
         if self.enable_connmgr:
             cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
             cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
         if self.enable_dht:
         if self.enable_dht:
@@ -130,8 +129,8 @@ async def make_p2pd_pair_unix(enable_control, enable_connmgr, enable_dht, enable
 
 
 @asynccontextmanager
 @asynccontextmanager
 async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
 async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
-    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
-    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{get_free_port()}")
+    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{get_free_port()}")
     async with _make_p2pd_pair(
     async with _make_p2pd_pair(
         control_maddr=control_maddr,
         control_maddr=control_maddr,
         listen_maddr=listen_maddr,
         listen_maddr=listen_maddr,