Ver código fonte

Merge branch 'master' into unary-handlers

Denis Mazur 4 anos atrás
pai
commit
b058e6e6f8
84 arquivos alterados com 1043 adições e 985 exclusões
  1. 21 0
      .github/workflows/check-style.yml
  2. 0 13
      .github/workflows/check_style.yml
  3. 3 3
      .github/workflows/run-tests.yml
  4. 3 2
      CONTRIBUTING.md
  5. 32 27
      README.md
  6. 2 3
      benchmarks/benchmark_averaging.py
  7. 1 2
      benchmarks/benchmark_tensor_compression.py
  8. 2 2
      benchmarks/benchmark_throughput.py
  9. 1 2
      docs/conf.py
  10. 1 1
      docs/user/dht.md
  11. 1 5
      examples/albert/arguments.py
  12. 3 3
      examples/albert/run_trainer.py
  13. 2 2
      examples/albert/run_training_monitor.py
  14. 0 1
      examples/albert/utils.py
  15. 5 5
      hivemind/__init__.py
  16. 88 82
      hivemind/averaging/allreduce.py
  17. 46 99
      hivemind/averaging/averager.py
  18. 6 6
      hivemind/averaging/group_info.py
  19. 25 20
      hivemind/averaging/key_manager.py
  20. 2 1
      hivemind/averaging/load_balancing.py
  21. 91 82
      hivemind/averaging/matchmaking.py
  22. 6 7
      hivemind/averaging/partition.py
  23. 3 3
      hivemind/averaging/training.py
  24. 55 9
      hivemind/dht/__init__.py
  25. 0 1
      hivemind/dht/crypto.py
  26. 10 6
      hivemind/dht/node.py
  27. 8 8
      hivemind/dht/protocol.py
  28. 2 1
      hivemind/dht/routing.py
  29. 1 1
      hivemind/dht/storage.py
  30. 1 1
      hivemind/dht/traverse.py
  31. 2 2
      hivemind/hivemind_cli/run_server.py
  32. 1 1
      hivemind/moe/__init__.py
  33. 11 11
      hivemind/moe/client/beam_search.py
  34. 3 3
      hivemind/moe/client/expert.py
  35. 6 6
      hivemind/moe/client/moe.py
  36. 3 3
      hivemind/moe/client/switch_moe.py
  37. 12 7
      hivemind/moe/server/__init__.py
  38. 3 3
      hivemind/moe/server/connection_handler.py
  39. 7 7
      hivemind/moe/server/dht_handler.py
  40. 3 3
      hivemind/moe/server/expert_backend.py
  41. 1 1
      hivemind/moe/server/expert_uid.py
  42. 1 1
      hivemind/moe/server/layers/custom_experts.py
  43. 1 1
      hivemind/moe/server/runtime.py
  44. 3 3
      hivemind/moe/server/task_pool.py
  45. 1 1
      hivemind/optim/__init__.py
  46. 1 1
      hivemind/optim/adaptive.py
  47. 4 4
      hivemind/optim/collaborative.py
  48. 3 3
      hivemind/optim/simple.py
  49. 1 1
      hivemind/p2p/__init__.py
  50. 59 88
      hivemind/p2p/p2p_daemon.py
  51. 54 27
      hivemind/p2p/servicer.py
  52. 8 14
      hivemind/proto/averaging.proto
  53. 2 2
      hivemind/proto/dht.proto
  54. 3 3
      hivemind/utils/__init__.py
  55. 15 4
      hivemind/utils/asyncio.py
  56. 2 3
      hivemind/utils/compression.py
  57. 2 2
      hivemind/utils/grpc.py
  58. 137 161
      hivemind/utils/mpfuture.py
  59. 1 2
      hivemind/utils/networking.py
  60. 1 1
      hivemind/utils/serializer.py
  61. 1 1
      hivemind/utils/tensor_descr.py
  62. 2 1
      hivemind/utils/timed_storage.py
  63. 7 0
      pyproject.toml
  64. 1 1
      requirements-dev.txt
  65. 22 2
      tests/conftest.py
  66. 28 47
      tests/test_allreduce.py
  67. 1 2
      tests/test_auth.py
  68. 89 82
      tests/test_averaging.py
  69. 4 4
      tests/test_dht.py
  70. 2 2
      tests/test_dht_crypto.py
  71. 2 2
      tests/test_dht_experts.py
  72. 24 11
      tests/test_dht_node.py
  73. 2 2
      tests/test_dht_storage.py
  74. 1 1
      tests/test_dht_validation.py
  75. 1 1
      tests/test_expert_backend.py
  76. 1 2
      tests/test_moe.py
  77. 25 13
      tests/test_p2p_daemon.py
  78. 2 1
      tests/test_p2p_daemon_bindings.py
  79. 17 13
      tests/test_p2p_servicer.py
  80. 2 2
      tests/test_routing.py
  81. 3 2
      tests/test_training.py
  82. 8 11
      tests/test_util_modules.py
  83. 21 9
      tests/test_utils/dht_swarms.py
  84. 5 6
      tests/test_utils/p2p_daemon.py

+ 21 - 0
.github/workflows/check-style.yml

@@ -0,0 +1,21 @@
+name: Check style
+
+on: [ push ]
+
+jobs:
+  black:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: psf/black@stable
+        with:
+          options: "--check --diff"
+          version: "21.6b0"
+  isort:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: actions/setup-python@v2
+        with:
+          python-version: 3.8
+      - uses: isort/isort-action@master

+ 0 - 13
.github/workflows/check_style.yml

@@ -1,13 +0,0 @@
-name: Check style
-
-on: [ push ]
-
-jobs:
-  black:
-    runs-on: ubuntu-latest
-    steps:
-      - uses: actions/checkout@v2
-      - uses: psf/black@stable
-        with:
-          options: "--check"
-          version: "21.6b0"

+ 3 - 3
.github/workflows/run-tests.yml

@@ -33,7 +33,7 @@ jobs:
       - name: Test
         run: |
           cd tests
-          pytest --durations=0 --durations-min=1.0
+          pytest --durations=0 --durations-min=1.0 -v
 
   build_and_test_p2pd:
     runs-on: ubuntu-latest
@@ -60,7 +60,7 @@ jobs:
       - name: Test
         run: |
           cd tests
-          pytest -k "p2p" 
+          pytest -k "p2p" -v
 
   codecov_in_develop_mode:
 
@@ -87,6 +87,6 @@ jobs:
           pip install -e .
       - name: Test
         run: |
-          pytest --cov=hivemind tests
+          pytest --cov=hivemind -v tests
       - name: Upload coverage to Codecov
         uses: codecov/codecov-action@v1

+ 3 - 2
CONTRIBUTING.md

@@ -34,10 +34,11 @@ with the following rules:
 
 ## Code style
 
-* We use [black](https://github.com/psf/black) for code formatting. Before submitting a PR, make sure to install and
-  run `black .` in the root of the repository.
 * The code must follow [PEP8](https://www.python.org/dev/peps/pep-0008/) unless absolutely necessary. Also, each line
   cannot be longer than 119 characters.
+* We use [black](https://github.com/psf/black) for code formatting and [isort](https://github.com/PyCQA/isort) for 
+  import sorting. Before submitting a PR, make sure to install and run `black .` and `isort .` in the root of the
+  repository.
 * We highly encourage the use of [typing](https://docs.python.org/3/library/typing.html) where applicable.
 * Use `get_logger` from `hivemind.utils.logging` to log any information instead of `print`ing directly to standard
   output/error streams.

+ 32 - 27
README.md

@@ -2,7 +2,7 @@
 
 [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest)
 [![PyPI version](https://img.shields.io/pypi/v/hivemind.svg)](https://pypi.org/project/hivemind/)
-[![Discord](https://img.shields.io/static/v1?style=default&label=Discord&logo=discord&message=join)](https://discord.gg/xC7ucM8j)
+[![Discord](https://img.shields.io/static/v1?style=default&label=Discord&logo=discord&message=join)](https://discord.gg/uGugx9zYvN)
 [![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions)
 ![Codecov](https://img.shields.io/codecov/c/github/learning-at-home/hivemind)
 [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
@@ -18,21 +18,21 @@ large model on hundreds of computers from different universities, companies, and
   network.
 * Fault-tolerant backpropagation: forward and backward passes succeed even if some nodes are unresponsive or take too
   long to respond.
-* Decentralized parameter averaging: iteratively aggregate updates from multiple
-  workers without the need to synchronize across the entire network ([paper](https://arxiv.org/abs/2103.03239)).
+* Decentralized parameter averaging: iteratively aggregate updates from multiple workers without the need to
+  synchronize across the entire network ([paper](https://arxiv.org/abs/2103.03239)).
 * Train neural networks of arbitrary size: parts of their layers are distributed across the participants with the
-  decentralized mixture-of-experts ([paper](https://arxiv.org/abs/2002.04013)).
+  Decentralized Mixture-of-Experts ([paper](https://arxiv.org/abs/2002.04013)).
 
 To learn more about the ideas behind this library, see https://learning-at-home.github.io or read
 the [NeurIPS 2020 paper](https://arxiv.org/abs/2002.04013).
 
 ## Installation
 
-Before installing, make sure that your environment has Python 3.7+ 
-and [PyTorch](https://pytorch.org/get-started/locally/#start-locally) 1.6.0 or newer.
-You can install them either natively or with [Anaconda](https://www.anaconda.com/products/individual).
+Before installing, make sure that your environment has Python 3.7+
+and [PyTorch](https://pytorch.org/get-started/locally/#start-locally) 1.6.0 or newer. They can be installed either
+natively or with [Anaconda](https://www.anaconda.com/products/individual).
 
-You can install [the latest release](https://pypi.org/project/hivemind) with pip or build hivemind from source.
+You can get [the latest release](https://pypi.org/project/hivemind) with pip or build hivemind from source.
 
 ### With pip
 
@@ -62,24 +62,29 @@ Before running the compilation, please ensure that your machine has a recent ver
 of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
 
 ### System requirements
-- __Linux__ is the default OS for which hivemind is developed and tested. We recommend Ubuntu 18.04+ (64-bit),
-  but other 64-bit distros should work as well. Legacy 32-bit is not recommended.
-- __macOS 10.x__ mostly works but requires building hivemind from source, and some edge cases may fail.
-  To ensure full compatibility, we recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
-- __Windows 10+ (experimental)__ can run hivemind using [WSL](https://docs.microsoft.com/ru-ru/windows/wsl/install-win10).
-  You can configure WSL to use GPU following [this guide](https://docs.nvidia.com/cuda/wsl-user-guide/index.html) by NVIDIA.
-  After the CUDA toolkit is installed you can simply follow the instructions above to install with pip or from source.
+
+- __Linux__ is the default OS for which hivemind is developed and tested. We recommend Ubuntu 18.04+ (64-bit), but
+  other 64-bit distros should work as well. Legacy 32-bit is not recommended.
+- __macOS 10.x__ mostly works but requires building hivemind from source, and some edge cases may fail. To ensure full
+  compatibility, we recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
+- __Windows 10+ (experimental)__ can run hivemind
+  using [WSL](https://docs.microsoft.com/ru-ru/windows/wsl/install-win10). You can configure WSL to use GPU by
+  following sections 1–3 of [this guide](https://docs.nvidia.com/cuda/wsl-user-guide/index.html) by NVIDIA. After
+  that, you can simply follow the instructions above to install with pip or from source.
 
 ## Documentation
 
-* The [quickstart tutorial](https://learning-at-home.readthedocs.io/en/latest/user/quickstart.html) walks through installation
-  and a training a simple neural network with several peers.  
+* The [quickstart tutorial](https://learning-at-home.readthedocs.io/en/latest/user/quickstart.html) walks through
+  installation and a training a simple neural network with several peers.
 * [examples/albert](https://github.com/learning-at-home/hivemind/tree/master/examples/albert) contains the starter kit
   and instructions for training a Transformer masked language model collaboratively.
-* API reference and additional tutorials are available at [learning-at-home.readthedocs.io](https://learning-at-home.readthedocs.io)
+* The [Mixture-of-Experts tutorial](https://learning-at-home.readthedocs.io/en/latest/user/moe.html)
+  covers the usage of Decentralized Mixture-of-Experts layers.
+* API reference and additional tutorials are available
+  at [learning-at-home.readthedocs.io](https://learning-at-home.readthedocs.io)
 
-If you have any questions about installing and using hivemind, you can ask them in 
-[our Discord chat](https://discord.gg/xC7ucM8j) or file an [issue](https://github.com/learning-at-home/hivemind/issues).
+If you have any questions about installing and using hivemind, you can ask them in
+[our Discord chat](https://discord.gg/uGugx9zYvN) or file an [issue](https://github.com/learning-at-home/hivemind/issues).
 
 ## Contributing
 
@@ -88,9 +93,8 @@ documentation improvements to entirely new features, is equally appreciated.
 
 If you want to contribute to hivemind but don't know where to start, take a look at the
 unresolved [issues](https://github.com/learning-at-home/hivemind/issues). Open a new issue or
-join [our chat room](https://discord.gg/xC7ucM8j) in case you want to discuss new functionality or
-report a possible bug. Bug fixes are always welcome, but new features should be preferably discussed with maintainers
-beforehand.
+join [our chat room](https://discord.gg/xC7ucM8j) in case you want to discuss new functionality or report a possible
+bug. Bug fixes are always welcome, but new features should be preferably discussed with maintainers beforehand.
 
 If you want to start contributing to the source code of hivemind, please see
 the [contributing guidelines](https://github.com/learning-at-home/hivemind/blob/master/CONTRIBUTING.md) first. To learn
@@ -99,7 +103,7 @@ our [guide](https://learning-at-home.readthedocs.io/en/latest/user/contributing.
 
 ## Citation
 
-If you found hivemind or its underlying algorithms useful for your research, please cite the relevant papers:
+If you found hivemind or its underlying algorithms useful for your research, please cite the following source:
 
 ```
 @misc{hivemind,
@@ -111,7 +115,8 @@ If you found hivemind or its underlying algorithms useful for your research, ple
 ```
 
 Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired the creation of this library
-(prototype implementation of hivemind available at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
+(prototype implementation of hivemind available
+at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)):
 
 ```
 @inproceedings{ryabinin2020crowdsourced,
@@ -171,5 +176,5 @@ Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired t
 
 </details>
 
-We also maintain a list of [related projects and
-acknowledgements](https://learning-at-home.readthedocs.io/en/latest/user/acknowledgements.html).
+We also maintain a list
+of [related projects and acknowledgements](https://learning-at-home.readthedocs.io/en/latest/user/acknowledgements.html).

+ 2 - 3
benchmarks/benchmark_averaging.py

@@ -57,11 +57,10 @@ def benchmark_averaging(
         dht = hivemind.DHT(initial_peers=initial_peers, start=True)
         initial_bits = bin(index % num_groups)[2:].rjust(nbits, "0")
         averager = hivemind.averaging.DecentralizedAverager(
-            peer_tensors[i],
+            peer_tensors[index],
             dht,
             prefix="my_tensor",
             initial_group_bits=initial_bits,
-            listen_on=f"{LOCALHOST}:*",
             compression_type=runtime_pb2.CompressionType.FLOAT16,
             target_group_size=target_group_size,
             averaging_expiration=averaging_expiration,
@@ -71,7 +70,7 @@ def benchmark_averaging(
         processes.update({dht, averager})
 
         logger.info(
-            f"Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}"
+            f"Averager {index}: started with peer id {averager.peer_id}, group_bits: {averager.get_group_bits()}"
         )
         for step in range(num_rounds):
             try:

+ 1 - 2
benchmarks/benchmark_tensor_compression.py

@@ -4,10 +4,9 @@ import time
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
-
 logger = get_logger(__name__)
 
 

+ 2 - 2
benchmarks/benchmark_throughput.py

@@ -7,7 +7,7 @@ import time
 import torch
 
 import hivemind
-from hivemind import find_open_port
+from hivemind import get_free_port
 from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
@@ -66,7 +66,7 @@ def benchmark_throughput(
         or torch.device(device) == torch.device("cpu")
     )
     assert expert_cls in layers.name_to_block
-    port = port or find_open_port()
+    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()

+ 1 - 2
docs/conf.py

@@ -17,9 +17,8 @@
 # sys.path.insert(0, os.path.abspath('.'))
 import sys
 
-from recommonmark.transform import AutoStructify
 from recommonmark.parser import CommonMarkParser
-
+from recommonmark.transform import AutoStructify
 
 # -- Project information -----------------------------------------------------
 src_path = "../hivemind"

+ 1 - 1
docs/user/dht.md

@@ -18,7 +18,7 @@ dht2 = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
 ```
 
 Note that `initial_peers` contains the address of the first DHT node.
-This implies that the resulting node will have shared key-value with the first node, __as well as any other
+This implies that the new node will share the key-value data with the first node, __as well as any other
 nodes connected to it.__ When the two nodes are connected, subsequent peers can use any one of them (or both)
 as `initial_peers` to connect to the shared "dictionary".
 

+ 1 - 5
examples/albert/arguments.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass, field
-from typing import Optional, List
+from typing import List, Optional
 
 from transformers import TrainingArguments
 
@@ -45,10 +45,6 @@ class AveragerArguments:
     averaging_timeout: float = field(
         default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
     )
-    listen_on: str = field(
-        default="[::]:*",
-        metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"},
-    )
     min_refresh_period: float = field(
         default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
     )

+ 3 - 3
examples/albert/run_trainer.py

@@ -11,8 +11,8 @@ import transformers
 from datasets import load_from_disk
 from torch.utils.data import DataLoader
 from torch_optimizer import Lamb
-from transformers import set_seed, HfArgumentParser, TrainingArguments, DataCollatorForLanguageModeling
-from transformers.models.albert import AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining
+from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
+from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
@@ -21,7 +21,7 @@ import hivemind
 from hivemind.utils.compression import CompressionType
 
 import utils
-from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
+from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 
 logger = logging.getLogger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

+ 2 - 2
examples/albert/run_training_monitor.py

@@ -10,13 +10,13 @@ import requests
 import torch
 import wandb
 from torch_optimizer import Lamb
-from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
+from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 import hivemind
 from hivemind.utils.compression import CompressionType
 
 import utils
-from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
+from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 
 logger = logging.getLogger(__name__)
 

+ 0 - 1
examples/albert/utils.py

@@ -9,7 +9,6 @@ from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
 from hivemind.utils.logging import get_logger
 
-
 logger = get_logger(__name__)
 
 

+ 5 - 5
hivemind/__init__.py

@@ -2,21 +2,21 @@ from hivemind.averaging import DecentralizedAverager, TrainingAverager
 from hivemind.dht import DHT
 from hivemind.moe import (
     ExpertBackend,
-    Server,
-    register_expert_class,
     RemoteExpert,
     RemoteMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,
+    Server,
+    register_expert_class,
 )
 from hivemind.optim import (
     CollaborativeAdaptiveOptimizer,
-    DecentralizedOptimizerBase,
     CollaborativeOptimizer,
+    DecentralizedAdam,
     DecentralizedOptimizer,
+    DecentralizedOptimizerBase,
     DecentralizedSGD,
-    DecentralizedAdam,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 
-__version__ = "0.9.10"
+__version__ = "1.0.0.dev0"

+ 88 - 82
hivemind/averaging/allreduce.py

@@ -1,15 +1,15 @@
 import asyncio
-from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
 from enum import Enum
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 
-import grpc
 import torch
 
-from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
-from hivemind.utils import Endpoint, get_logger, ChannelCache
-from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.proto import averaging_pb2_grpc, averaging_pb2
+from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
+from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
+from hivemind.proto import averaging_pb2
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 # flavour types
 GroupID = bytes
@@ -22,19 +22,27 @@ class AveragingMode(Enum):
     AUX = 2
 
 
-class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
+class AllReduceRunner(ServicerBase):
     """
-    An internal class that runs butterfly AllReduce in a predefined group of averagers
+    An internal class that runs butterfly AllReduce in a predefined group of averagers.
+
+    This class inherits hivemind.p2p.ServicerBase, so it can be used as an RPCServicer for testing purposes without
+    creating a full DecentralizedAverager.
 
     :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
+    :param p2p: a hivemind.p2p.P2P instance used for communication with other peers
+    :param servicer_type: a hivemind.p2p.ServicerBase subclass whose RPC signatures are used
+      when requesting other peers. Typically, it is DecentralizedAverager, its derivative,
+      or AllReduceRunner itself (for testing purposes).
+    :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys)
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
-    :param endpoint: your endpoint, must be included in ordered_group_endpoints
-    :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    :param peer_id: your peer_id, must be included in ordered_peer_ids
+    :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
-    :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
+    :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
@@ -43,73 +51,83 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     def __init__(
         self,
         *,
+        p2p: P2P,
+        servicer_type: Type[ServicerBase],
+        prefix: Optional[str],
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
-        endpoint: Endpoint,
-        ordered_group_endpoints: Sequence[Endpoint],
+        ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
-        gathered: Optional[Dict[Endpoint, Any]] = None,
+        gathered: Optional[Dict[PeerID, Any]] = None,
         **kwargs,
     ):
-        assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+        self._p2p = p2p
+        self.peer_id = p2p.peer_id
+        assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
+
+        if not issubclass(servicer_type, ServicerBase):
+            raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
+        self._servicer_type = servicer_type
+        self._prefix = prefix
+
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
+        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
         for mode, frac, weight in zip(modes, peer_fractions, weights):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
-        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
         self._future = asyncio.Future()
 
-        self.sender_endpoints, self.sender_weights = [], []
-        for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
+        self.sender_peer_ids, self.sender_weights = [], []
+        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
             if mode != AveragingMode.AUX:
-                self.sender_endpoints.append(endpoint)
+                self.sender_peer_ids.append(peer_id)
                 self.sender_weights.append(weight)
 
-        endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
+        peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
-        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
+        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
-            len(self.sender_endpoints),
+            len(self.sender_peer_ids),
             self.sender_weights,
         )
 
     def __repr__(self):
-        return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
+        return f"{self.__class__.__name__}({self.peer_id}, group_size={self.group_size})"
 
     def __aiter__(self):
         return self.run()
 
-    def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.ordered_group_endpoints
+    def __contains__(self, peer_id: PeerID):
+        return peer_id in self.ordered_peer_ids
 
     @property
     def group_size(self):
-        return len(self.ordered_group_endpoints)
+        return len(self.ordered_peer_ids)
 
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+    def _get_peer_stub(self, peer: PeerID) -> StubBase:
+        return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
         try:
-            if len(self.sender_endpoints) == 0:
+            if len(self.sender_peer_ids) == 0:
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
                 self.finalize()
 
-            elif self.endpoint in self.sender_endpoints:
-                for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
+            elif self.peer_id in self.sender_peer_ids:
+                for peer_id, parts in zip(self.ordered_peer_ids, self.tensor_part_container.num_parts_by_peer):
                     if parts != 0:
-                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
+                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(peer_id)))
 
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
@@ -125,57 +143,45 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 task.cancel()
             raise
 
-    async def _communicate_with_peer(self, peer_endpoint: Endpoint):
+    async def _communicate_with_peer(self, peer_id: PeerID):
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
-        peer_index = self.ordered_group_endpoints.index(peer_endpoint)
-        if peer_endpoint == self.endpoint:
-            sender_index = self.sender_endpoints.index(peer_endpoint)
+        peer_index = self.ordered_peer_ids.index(peer_id)
+        if peer_id == self.peer_id:
+            sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
                 averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
             loop = asyncio.get_event_loop()
-            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-            write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))
-
-            try:
-                code = None
-                async for part_index, msg in aenumerate(stream):
-                    if code is None:
-                        code = msg.code
-                    averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
-                    self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
-                await write_task
-
-                if code != averaging_pb2.AVERAGED_PART:
-                    raise AllreduceException(
-                        f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
-                        f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                        f", allreduce failed"
-                    )
-            finally:
-                if not write_task.done():
-                    write_task.cancel()
-
-    async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
+            code = None
+            stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
+            async for part_index, msg in aenumerate(stream):
+                if code is None:
+                    code = msg.code
+                averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
+                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
+
+            if code != averaging_pb2.AVERAGED_PART:
+                raise AllreduceException(
+                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
+                    f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
+                    f", allreduce failed"
+                )
+
+    async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
         first_part = await anext(parts_aiter)
-        await stream.write(
-            averaging_pb2.AveragingData(
-                code=averaging_pb2.PART_FOR_AVERAGING,
-                group_id=self.group_id,
-                endpoint=self.endpoint,
-                tensor_part=first_part,
-            )
+        yield averaging_pb2.AveragingData(
+            code=averaging_pb2.PART_FOR_AVERAGING,
+            group_id=self.group_id,
+            tensor_part=first_part,
         )
         async for part in parts_aiter:
-            await stream.write(averaging_pb2.AveragingData(tensor_part=part))
-
-        await stream.done_writing()
+            yield averaging_pb2.AveragingData(tensor_part=part)
 
     async def rpc_aggregate_part(
-        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         request: averaging_pb2.AveragingData = await anext(stream)
@@ -186,7 +192,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
-                sender_index = self.sender_endpoints.index(request.endpoint)
+                sender_index = self.sender_peer_ids.index(context.remote_id)
                 async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
                     yield msg
 
@@ -195,8 +201,8 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
-            logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
-            self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
+            self.finalize(exception=AllreduceException(f"peer {context.remote_id} sent {error_code}."))
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
     def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
@@ -223,10 +229,10 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             )
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
-    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
-        await stream.done_writing()
+    async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
+        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
+        # In case of reporting the error, we expect the response stream to contain exactly one item
+        await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
@@ -239,9 +245,9 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             else:
                 code = averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
-                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
-                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_endpoint, code)))
+            for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
+                if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
+                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
 
         if not self._future.done():
             if cancel:

+ 46 - 99
hivemind/averaging/averager.py

@@ -8,42 +8,36 @@ import ctypes
 import multiprocessing as mp
 import os
 import threading
-import uuid
 import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
-from ipaddress import ip_address
-from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 
-import grpc
 import numpy as np
 import torch
-from grpc._cython.cygrpc import InternalError
 
-from hivemind.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
+from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.dht import DHT, DHTID
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor
-from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
-from hivemind.utils.networking import choose_ip_address, strip_port, Hostname
+from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.proto import averaging_pb2, runtime_pb2
+from hivemind.utils import MPFuture, TensorDescriptor, get_logger
+from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
-from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
+from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 
 # flavour types
-StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 GatheredData = Any
 logger = get_logger(__name__)
 
 
-class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
+class DecentralizedAverager(mp.Process, ServicerBase):
     """
-
     Parameter averaging service. A trainer can run this service in background to periodically average his parameters
     with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
     group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
@@ -67,14 +61,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
-    :param client_mode: if False (default), this averager will accept incoming requests from other peers
-            if True, the averager will only join existing groups where at least one peer has client_mode=False
-    :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
-    :param announced_host: visible IP address the averager will announce for external connections from other peers.
-          If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
-    :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
-          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
-    :param kwargs: extra parameters forwarded to grpc.aio.server
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+          if True, the averager will only join existing groups where at least one peer has client_mode=False.
+          By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param auxiliary: if this flag is specified, averager.step will only assist others without sending
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
@@ -96,7 +85,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
-    _server: grpc.aio.Server
     serializer = MSGPackSerializer
 
     def __init__(
@@ -119,13 +107,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         min_vector_size: int = 0,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
-        client_mode: bool = False,
-        listen_on: Endpoint = "0.0.0.0:*",
+        client_mode: Optional[bool] = None,
         daemon: bool = True,
-        announced_host: Optional[str] = None,
-        channel_options: Sequence[Tuple[str, Any]] = (),
         shutdown_timeout: float = 5,
-        **kwargs,
     ):
         assert "." not in prefix, "group prefix must be a string without trailing '.'"
         assert bandwidth is None or (
@@ -138,7 +122,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         super().__init__()
         self.dht = dht
-        self.client_mode, self.listen_on, self.kwargs = client_mode, listen_on, kwargs
+        self.prefix = prefix
+
+        if client_mode is None:
+            client_mode = dht.client_mode
+        self.client_mode = client_mode
+
         self._parent_pid = os.getpid()
         if self.client_mode:
             self.mode = AveragingMode.CLIENT
@@ -146,11 +135,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self.mode = AveragingMode.AUX
         else:
             self.mode = AveragingMode.NODE
-
-        if announced_host is None:
-            announced_host = self._choose_announced_host()
-        self.announced_host = announced_host
-        self.channel_options = channel_options
         self.daemon = daemon
 
         self._averaged_tensors = tuple(averaged_tensors)
@@ -165,6 +149,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.bandwidth = bandwidth
 
         self.matchmaking_kwargs = dict(
+            servicer_type=type(self),
             prefix=prefix,
             initial_group_bits=initial_group_bits,
             target_group_size=target_group_size,
@@ -179,17 +164,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
-        self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
 
-        self._averager_endpoint: Optional[Endpoint] = None
-        if self.client_mode:
-            self._averager_endpoint = f"client::{uuid.uuid4()}"
-
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
@@ -201,22 +181,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if start:
             self.run_in_background(await_ready=True)
 
-    def _choose_announced_host(self) -> Hostname:
-        announced_host = strip_port(self.listen_on).strip("[]")  # Stripping square brackets for IPv6
-        if ip_address(announced_host) not in [ip_address("0.0.0.0"), ip_address("::")]:
-            return announced_host
-
-        maddrs = self.dht.get_visible_maddrs()
-        announced_host = choose_ip_address(maddrs)
-        logger.info(
-            f"Choosing IP {announced_host} as endpoint for DecentralizedAverager " f"from visible multiaddrs {maddrs}"
-        )
-        return announced_host
-
-    @property
-    def port(self) -> Optional[Port]:
-        return self._port.value if self._port.value != 0 else None
-
     @property
     def allow_state_sharing(self) -> bool:
         """if set to True, other peers can download this peer's state"""
@@ -230,15 +194,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self._allow_state_sharing.value = value
 
     @property
-    def endpoint(self) -> Optional[Endpoint]:
-        if self._averager_endpoint is None and not self.client_mode:
-            assert self.port is not None, "Averager is not running yet"
-            self._averager_endpoint = f"{self.announced_host}:{self.port}"
-            logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
-        return self._averager_endpoint
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}({self.endpoint})"
+    def peer_id(self) -> PeerID:
+        return self.dht.peer_id
 
     def run(self):
         """
@@ -257,20 +214,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
             async def _run():
-                grpc.aio.init_grpc_aio()
-
+                self._p2p = await self.dht.replicate_p2p()
                 if not self.client_mode:
-                    self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-                    averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
-                    found_port = self._server.add_insecure_port(self.listen_on)
-                    assert found_port != 0, f"Failed to listen to {self.listen_on}"
-                    self._port.value = found_port
-                    await self._server.start()
+                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
                 else:
                     logger.debug(f"The averager is running in client mode.")
 
                 self._matchmaking = Matchmaking(
-                    self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
+                    self._p2p,
+                    self.schema_hash,
+                    self.dht,
+                    client_mode=self.client_mode,
+                    **self.matchmaking_kwargs,
                 )
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
@@ -313,8 +268,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         remaining_tasks = set()
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
-        if not self.client_mode:
-            remaining_tasks.add(self._server.stop(timeout))
         await asyncio.gather(*remaining_tasks)
 
     def __del__(self):
@@ -328,7 +281,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         timeout: Optional[float] = None,
         allow_retries: bool = True,
         wait: bool = True,
-    ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
@@ -394,11 +347,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     MatchmakingException,
                     AssertionError,
                     StopAsyncIteration,
-                    InternalError,
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
-                    grpc.RpcError,
-                    grpc.aio.AioRpcError,
+                    P2PHandlerError,
                 ) as e:
                     time_elapsed = get_dht_time() - start_time
                     if not allow_retries or (timeout is not None and timeout < time_elapsed):
@@ -424,7 +375,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
             weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
 
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
@@ -437,10 +388,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
             async with self.get_tensors_async() as local_tensors:
                 allreduce = AllReduceRunner(
+                    p2p=self._p2p,
+                    servicer_type=type(self),
+                    prefix=self.prefix,
                     group_id=group_info.group_id,
                     tensors=local_tensors,
-                    endpoint=self.endpoint,
-                    ordered_group_endpoints=group_info.endpoints,
+                    ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     weights=weights,
                     gathered=user_gathered,
@@ -453,7 +406,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     # actually run all-reduce
                     averaging_outputs = [output async for output in allreduce]
 
-                    if modes[group_info.endpoints.index(self.endpoint)] != AveragingMode.AUX:
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                         assert len(local_tensors) == len(self._averaged_tensors)
                         for tensor, update in zip(local_tensors, averaging_outputs):
                             tensor.add_(update, alpha=self._averaging_alpha)
@@ -496,14 +449,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             self.lock_averaged_tensors.release()
 
     async def rpc_join_group(
-        self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         async for response in self._matchmaking.rpc_join_group(request, context):
             yield response
 
     async def rpc_aggregate_part(
-        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a groupmate sends us a part of his tensor; we should average it with other peers and return the result"""
         request = await anext(stream)
@@ -528,7 +481,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     asyncio.wait_for(
                         self.dht.store(
                             download_key,
-                            subkey=self.endpoint,
+                            subkey=self.peer_id.to_bytes(),
                             value=self.last_updated,
                             expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
                             return_future=True,
@@ -539,7 +492,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             await asyncio.sleep(self._matchmaking.averaging_expiration)
 
     async def rpc_download_state(
-        self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
+        self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
     ) -> AsyncIterator[averaging_pb2.DownloadData]:
         """
         Get the up-to-date trainer state from a peer.
@@ -594,8 +547,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority = {
-                peer: float(info.value)
-                for peer, info in peer_priority.items()
+                PeerID(peer_id): float(info.value)
+                for peer_id, info in peer_priority.items()
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
             }
 
@@ -606,13 +559,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
             metadata = None
             for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
-                if peer != self.endpoint:
+                if peer != self.peer_id:
                     logger.info(f"Downloading parameters from peer {peer}")
-                    stream = None
                     try:
-                        stub = ChannelCache.get_stub(
-                            peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True, options=self.channel_options
-                        )
+                        stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         async for message in stream:
@@ -636,9 +586,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                         return
                     except BaseException as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
-                    finally:
-                        if stream is not None:
-                            await stream.code()
 
         finally:
             if not future.done():

+ 6 - 6
hivemind/averaging/group_info.py

@@ -1,7 +1,7 @@
 from dataclasses import dataclass
 from typing import Tuple
 
-from hivemind.utils import Endpoint
+from hivemind.p2p import PeerID
 
 
 @dataclass(frozen=True)
@@ -9,12 +9,12 @@ class GroupInfo:
     """A group of peers assembled through decentralized matchmaking"""
 
     group_id: bytes  # random unique bytestring that describes the current group, generated by group leader
-    endpoints: Tuple[Endpoint, ...]  # an ordered sequence of endpoints of each groupmate
-    gathered: Tuple[bytes, ...]  # binary metadata gathered from all peers by leader, same order as endpoints
+    peer_ids: Tuple[PeerID, ...]  # an ordered sequence of peer_ids of each groupmate
+    gathered: Tuple[bytes, ...]  # binary metadata gathered from all peers by leader, same order as peer_ids
 
     @property
     def group_size(self):
-        return len(self.endpoints)
+        return len(self.peer_ids)
 
-    def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.endpoints
+    def __contains__(self, peer_id: PeerID):
+        return peer_id in self.peer_ids

+ 25 - 20
hivemind/averaging/key_manager.py

@@ -1,13 +1,14 @@
 import asyncio
-import re
 import random
-from typing import Optional, List, Tuple
+import re
+from typing import List, Optional, Tuple
 
 import numpy as np
 
-from hivemind.dht import DHT
 from hivemind.averaging.group_info import GroupInfo
-from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
+from hivemind.dht import DHT
+from hivemind.p2p import PeerID
+from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
 
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
@@ -29,7 +30,6 @@ class GroupKeyManager:
     def __init__(
         self,
         dht: DHT,
-        endpoint: Endpoint,
         prefix: str,
         initial_group_bits: Optional[str],
         target_group_size: int,
@@ -43,7 +43,8 @@ class GroupKeyManager:
             search_result = dht.get(f"{prefix}.0b", latest=True)
             initial_group_nbits = self.get_suggested_nbits(search_result) or 0
             initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
-        self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
+        self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
+        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.excessive_size = excessive_size or target_group_size * 3
@@ -55,13 +56,13 @@ class GroupKeyManager:
         return f"{self.prefix}.0b{self.group_bits}"
 
     async def declare_averager(
-        self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float, looking_for_group: bool = True
+        self, group_key: GroupKey, peer_id: PeerID, expiration_time: float, looking_for_group: bool = True
     ) -> bool:
         """
         Add (or remove) the averager to a given allreduce bucket
 
         :param group_key: allreduce group key, e.g. my_averager.0b011011101
-        :param endpoint: averager public endpoint for incoming requests
+        :param peer_id: averager public peer_id for incoming requests
         :param expiration_time: intent to run allreduce before this timestamp
         :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
           If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
@@ -72,20 +73,20 @@ class GroupKeyManager:
         expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
         return await self.dht.store(
             key=group_key,
-            subkey=endpoint,
+            subkey=peer_id.to_bytes(),
             value=looking_for_group,
             expiration_time=expiration_time,
             return_future=True,
         )
 
-    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[Endpoint, DHTExpiration]]:
+    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[PeerID, DHTExpiration]]:
         """
         Find and return averagers that were declared with a given all-reduce key
 
         :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
         :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
             if False, return all averagers under a given group_key regardless of value
-        :return: endpoints and expirations of every matching averager
+        :return: peer_ids and expirations of every matching averager
         """
         assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
         result = await self.dht.get(group_key, latest=True, return_future=True)
@@ -93,11 +94,15 @@ class GroupKeyManager:
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
         averagers = [
-            (key, entry.expiration_time)
-            for key, entry in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)
+            (PeerID(key), looking_for_group.expiration_time)
+            for key, looking_for_group in result.value.items()
+            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
         ]
-        num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
+        num_active_averagers = sum(
+            1
+            for key, looking_for_group in result.value.items()
+            if key != self.RESERVED_KEY_FOR_NBITS and looking_for_group.value
+        )
 
         suggested_nbits = self.get_suggested_nbits(result)
         if (
@@ -106,10 +111,10 @@ class GroupKeyManager:
             and suggested_nbits != self.suggested_nbits
         ):
             self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
+            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
         elif num_active_averagers >= self.excessive_size:
             self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
         return averagers
 
     async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
@@ -136,12 +141,12 @@ class GroupKeyManager:
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
-        index = group_info.endpoints.index(self.endpoint)
+        index = group_info.peer_ids.index(self.peer_id)
         generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
         nbits = int(np.ceil(np.log2(self.target_group_size)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
-        logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
+        logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
         if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
             asyncio.create_task(self.notify_stragglers())
@@ -156,7 +161,7 @@ class GroupKeyManager:
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
         prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
         if self.group_bits != prev_nbits:
-            logger.warning(f"{self.endpoint} - switching to {len(self.group_bits)}-bit keys")
+            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
         self.suggested_nbits = None
 
     async def notify_stragglers(self):

+ 2 - 1
hivemind/averaging/load_balancing.py

@@ -1,4 +1,5 @@
-from typing import Sequence, Optional, Tuple
+from typing import Optional, Sequence, Tuple
+
 import numpy as np
 import scipy.optimize
 

+ 91 - 82
hivemind/averaging/matchmaking.py

@@ -2,27 +2,25 @@
 
 from __future__ import annotations
 
+import asyncio
+import concurrent.futures
 import contextlib
 import random
 from math import isfinite
-from typing import Optional, AsyncIterator, Set, Tuple, Dict
-import concurrent.futures
-import asyncio
-
-import grpc
-import grpc._cython.cygrpc
+from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
 from hivemind.averaging.group_info import GroupInfo
-from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
+from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
-from hivemind.utils import get_logger, Endpoint, timed_storage, TimedStorage, get_dht_time
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc
-from hivemind.utils.grpc import ChannelCache
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.proto import averaging_pb2
+from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
+from hivemind.utils.asyncio import anext
 
 logger = get_logger(__name__)
 
 
-class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
+class Matchmaking:
     f"""
     An internal class that is used to form groups of averages for running allreduce
     See DecentralizedAverager docstring for the detailed description of all parameters
@@ -37,10 +35,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
     def __init__(
         self,
-        endpoint: Endpoint,
+        p2p: P2P,
         schema_hash: bytes,
         dht: DHT,
         *,
+        servicer_type: Type[ServicerBase],
         prefix: str,
         target_group_size: int,
         min_group_size: int,
@@ -57,8 +56,16 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             )
 
         super().__init__()
-        self.endpoint, self.schema_hash = endpoint, schema_hash
-        self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
+        self._p2p = p2p
+
+        if not issubclass(servicer_type, ServicerBase):
+            raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
+        self._servicer_type = servicer_type
+        self._prefix = prefix
+
+        self.peer_id = p2p.peer_id
+        self.schema_hash = schema_hash
+        self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.client_mode = client_mode
@@ -69,9 +76,9 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         self.was_accepted_to_group = asyncio.Event()
         self.assembled_group = asyncio.Future()
 
-        self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
-        self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(endpoint, averaging_expiration, target_group_size)
+        self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
+        self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
+        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
         self.data_for_gather: Optional[bytes] = None
 
     @property
@@ -87,7 +94,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 lfg_status += f" leading {len(self.current_followers)} followers,"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         return (
-            f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}"
+            f"{self.__class__.__name__}(peer_id={self.peer_id}, schema={schema_hash_repr}, {lfg_status}"
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
 
@@ -160,7 +167,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                         self.assembled_group.set_exception(e)
                     raise e
 
-    async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
         """
         :param leader: request this peer to be your leader for allreduce
         :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
@@ -169,23 +176,24 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
           The originally specified leader can disband group and redirect us to a different leader
         """
         assert self.is_looking_for_group and self.current_leader is None
-        call: Optional[grpc.aio.UnaryStreamCall] = None
+        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
         try:
             async with self.lock_request_join_group:
-                leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
-                call = leader_stub.rpc_join_group(
+                leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
+
+                stream = leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
-                        endpoint=self.endpoint,
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,
                         client_mode=self.client_mode,
                         gather=self.data_for_gather,
+                        group_key=self.group_key_manager.current_key,
                     )
-                )
-                message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
+                ).__aiter__()
+                message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
-                    logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
+                    logger.debug(f"{self.peer_id} - joining the group of {leader}; waiting for peers")
                     self.current_leader = leader
                     self.was_accepted_to_group.set()
                     if len(self.current_followers) > 0:
@@ -193,56 +201,55 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
             if message.code != averaging_pb2.ACCEPTED:
                 code = averaging_pb2.MessageCode.Name(message.code)
-                logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
+                logger.debug(f"{self.peer_id} - requested {leader} to be my leader, but got rejected with {code}")
                 return None
 
             async with self.potential_leaders.pause_search():
                 time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
-                message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)
+                message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                     async with self.lock_request_join_group:
                         return await self.follower_assemble_group(leader, message)
 
             if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
-                if message.suggested_leader and message.suggested_leader != self.endpoint:
-                    logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
-                    self.current_leader = None
-                    call.cancel()
-                    return await self.request_join_group(message.suggested_leader, expiration_time)
-                else:
-                    logger.debug(f"{self} - leader disbanded group")
-                    return None
+                if message.suggested_leader:
+                    suggested_leader = PeerID(message.suggested_leader)
+                    if suggested_leader != self.peer_id:
+                        logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
+                        self.current_leader = None
+                        await stream.aclose()
+                        return await self.request_join_group(suggested_leader, expiration_time)
+                logger.debug(f"{self} - leader disbanded group")
+                return None
 
             logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
             return None
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
-            if call is not None:
-                call.cancel()
             return None
-        except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
+        except (P2PHandlerError, StopAsyncIteration) as e:
             logger.error(f"{self} - failed to request potential leader {leader}: {e}")
             return None
 
         finally:
             self.was_accepted_to_group.clear()
             self.current_leader = None
-            if call is not None:
-                await call.code()
+            if stream is not None:
+                await stream.aclose()
 
     async def rpc_join_group(
-        self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         try:
             async with self.lock_request_join_group:
-                reason_to_reject = self._check_reasons_to_reject(request)
+                reason_to_reject = self._check_reasons_to_reject(request, context)
                 if reason_to_reject is not None:
                     yield reason_to_reject
                     return
 
-                self.current_followers[request.endpoint] = request
+                self.current_followers[context.remote_id] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
@@ -270,12 +277,12 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 self.was_accepted_to_group.is_set()
                 or not self.assembled_group.done()
                 or self.assembled_group.cancelled()
-                or request.endpoint not in self.assembled_group.result()
+                or context.remote_id not in self.assembled_group.result()
             ):
                 if self.current_leader is not None:
                     # outcome 3: found by a leader with higher priority, send our followers to him
                     yield averaging_pb2.MessageFromLeader(
-                        code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader
+                        code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader.to_bytes()
                     )
                     return
                 else:
@@ -286,7 +293,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE,
                 group_id=group_info.group_id,
-                ordered_group_endpoints=group_info.endpoints,
+                ordered_peer_ids=[item.to_bytes() for item in group_info.peer_ids],
                 gathered=group_info.gathered,
             )
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
@@ -296,11 +303,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
-            self.current_followers.pop(request.endpoint, None)
+            self.current_followers.pop(context.remote_id, None)
             self.follower_was_discarded.set()
 
     def _check_reasons_to_reject(
-        self, request: averaging_pb2.JoinRequest
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> Optional[averaging_pb2.MessageFromLeader]:
         """:returns: if accepted, return None, otherwise return a reason for rejection"""
         if not self.is_looking_for_group or self.assembled_group.done():
@@ -312,24 +319,25 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             or len(request.schema_hash) == 0
             or not isinstance(request.expiration, DHTExpiration)
             or not isfinite(request.expiration)
-            or not isinstance(request.endpoint, Endpoint)
-            or len(request.endpoint) == 0
             or self.client_mode
+            or not isinstance(request.group_key, GroupKey)
         ):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
 
         elif request.schema_hash != self.schema_hash:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
+        elif request.group_key != self.group_key_manager.current_key:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_GROUP_KEY)
         elif self.potential_leaders.declared_group_key is None:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
         elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
         elif self.current_leader is not None:
             return averaging_pb2.MessageFromLeader(
-                code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
-            )  # note: this suggested leader is currently ignored
-        elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
-            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
+                code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader.to_bytes()
+            )
+        elif context.remote_id == self.peer_id or context.remote_id in self.current_followers:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_PEER_ID)
         elif len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
         else:
@@ -339,34 +347,35 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         """Form up all current followers into a group and gather metadata"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
         assert not self.assembled_group.done()
-        group_id = DHTID.generate().to_bytes()  # note: both groupd_id and the order of endpoints must be random
-        ordered_group_endpoints = list(self.current_followers)
-        ordered_group_endpoints.append(self.endpoint)
-        random.shuffle(ordered_group_endpoints)
+        group_id = DHTID.generate().to_bytes()  # note: both groupd_id and the order of peer_ids must be random
+        ordered_peer_ids = list(self.current_followers)
+        ordered_peer_ids.append(self.peer_id)
+        random.shuffle(ordered_peer_ids)
 
         gathered = tuple(
-            self.data_for_gather if endpoint == self.endpoint else self.current_followers[endpoint].gather
-            for endpoint in ordered_group_endpoints
+            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            for peer_id in ordered_peer_ids
         )
 
-        logger.debug(f"{self.endpoint} - assembled group of {len(ordered_group_endpoints)} peers.")
-        group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), gathered)
+        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers.")
+        group_info = GroupInfo(group_id, tuple(ordered_peer_ids), gathered)
         await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
         self.assembled_group.set_result(group_info)
         return group_info
 
-    async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
+    async def follower_assemble_group(self, leader: PeerID, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
         """Form a group from using peers and metadata provided by our leader"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert not self.assembled_group.done()
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
 
-        group_id, ordered_group_endpoints = msg.group_id, msg.ordered_group_endpoints
-        assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
-        assert len(ordered_group_endpoints) == len(msg.gathered)
+        group_id = msg.group_id
+        ordered_peer_ids = [PeerID(item) for item in msg.ordered_peer_ids]
+        assert self.peer_id in ordered_peer_ids, "Leader sent us group_peer_ids that does not contain us!"
+        assert len(ordered_peer_ids) == len(msg.gathered)
 
-        logger.debug(f"{self.endpoint} - follower assembled group with leader {leader}.")
-        group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), tuple(msg.gathered))
+        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}.")
+        group_info = GroupInfo(group_id, tuple(ordered_peer_ids), tuple(msg.gathered))
         await self.group_key_manager.update_key_on_group_assembled(group_info)
         self.assembled_group.set_result(group_info)
         return group_info
@@ -380,13 +389,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
 
-    def __init__(self, endpoint: Endpoint, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.endpoint, self.averaging_expiration = endpoint, averaging_expiration
+    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
-        self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
-        self.past_attempts: Set[Tuple[Endpoint, DHTExpiration]] = set()
+        self.leader_queue = TimedStorage[PeerID, DHTExpiration]()
+        self.past_attempts: Set[Tuple[PeerID, DHTExpiration]] = set()
         self.declared_expiration_time = float("inf")
         self.declared_group_key: Optional[GroupKey] = None
         self.max_assured_time = float("-inf")
@@ -433,7 +442,7 @@ class PotentialLeaders:
             else:
                 self.running.clear()
 
-    async def pop_next_leader(self) -> Endpoint:
+    async def pop_next_leader(self) -> PeerID:
         """Remove and return the next most suitable leader or throw an exception if reached timeout"""
         assert self.running.is_set(), "Not running search at the moment"
         while True:
@@ -442,9 +451,9 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
 
-            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader.to_bytes()) > (
                 self.declared_expiration_time,
-                self.endpoint,
+                self.peer_id.to_bytes(),
             ):
                 await asyncio.wait(
                     {self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
@@ -479,7 +488,7 @@ class PotentialLeaders:
 
                 self.leader_queue.clear()
                 for peer, peer_expiration_time in new_peers:
-                    if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
+                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
                         continue
                     self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
                     self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
@@ -495,7 +504,7 @@ class PotentialLeaders:
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
             return  # note: this is a compatibility layer for python3.7
         except Exception as e:
-            logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             raise
 
     async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
@@ -508,21 +517,21 @@ class PotentialLeaders:
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()
-                    await key_manager.declare_averager(group_key, self.endpoint, expiration_time=new_expiration_time)
+                    await key_manager.declare_averager(group_key, self.peer_id, expiration_time=new_expiration_time)
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
             except (concurrent.futures.CancelledError, asyncio.CancelledError):
                 pass  # note: this is a compatibility layer for python3.7
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     self.declared_group_key, self.declared_expiration_time = None, float("inf")
-                    self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float("-inf")
+                    self.leader_queue, self.max_assured_time = TimedStorage[PeerID, DHTExpiration](), float("-inf")
                     await key_manager.declare_averager(
-                        prev_declared_key, self.endpoint, prev_expiration_time, looking_for_group=False
+                        prev_declared_key, self.peer_id, prev_expiration_time, looking_for_group=False
                     )
 
 

+ 6 - 7
hivemind/averaging/partition.py

@@ -2,19 +2,18 @@
 Auxiliary data structures for AllReduceRunner
 """
 import asyncio
-from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
 from collections import deque
+from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar, Union
 
-import torch
 import numpy as np
+import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType, Tensor
-from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
 from hivemind.utils.asyncio import amap_in_executor
-
+from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor
 
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 20
+DEFAULT_PART_SIZE_BYTES = 2 ** 19
 
 
 class TensorPartContainer:
@@ -32,8 +31,8 @@ class TensorPartContainer:
         self,
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
-        compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
-        part_size_bytes: int = 2 ** 20,
+        compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE,
+        part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         prefetch: int = 1,
     ):
         if not isinstance(compression_type, Sequence):

+ 3 - 3
hivemind/averaging/training.py

@@ -2,13 +2,13 @@
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from itertools import chain
-from threading import Lock, Event
-from typing import Sequence, Dict, Iterator, Optional
+from threading import Event, Lock
+from typing import Dict, Iterator, Optional, Sequence
 
 import torch
 
 from hivemind.averaging import DecentralizedAverager
-from hivemind.utils import nested_flatten, nested_pack, get_logger
+from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
 

+ 55 - 9
hivemind/dht/__init__.py

@@ -23,9 +23,10 @@ from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, Type
 
 from multiaddr import Multiaddr
 
-from hivemind.dht.node import DHTID, DHTNode
-from hivemind.dht.routing import DHTKey, DHTValue, Subkey
+from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
+from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
+from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 
 logger = get_logger(__name__)
@@ -42,13 +43,14 @@ class DHT(mp.Process):
     :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
-    :param max_workers: declare_experts and get_experts will use up to this many parallel workers
+    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
       (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
     :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
       The validators will be combined using the CompositeValidator class. It merges them when possible
       (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
+    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
     :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
     """
 
@@ -60,9 +62,10 @@ class DHT(mp.Process):
         *,
         start: bool,
         daemon: bool = True,
-        max_workers: Optional[int] = None,
+        num_workers: int = DEFAULT_NUM_WORKERS,
         record_validators: Iterable[RecordValidatorBase] = (),
         shutdown_timeout: float = 3,
+        await_ready: bool = True,
         **kwargs,
     ):
         self._parent_pid = os.getpid()
@@ -78,15 +81,21 @@ class DHT(mp.Process):
             raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
         self.initial_peers = initial_peers
         self.kwargs = kwargs
-        self.max_workers = max_workers
+        self.num_workers = num_workers
 
         self._record_validator = CompositeValidator(record_validators)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
         self.shutdown_timeout = shutdown_timeout
         self.ready = mp.Event()
         self.daemon = daemon
+
+        # These values will be fetched from the child process when requested
+        self._peer_id = None
+        self._client_mode = None
+        self._p2p_replica = None
+
         if start:
-            self.run_in_background(await_ready=True)
+            self.run_in_background(await_ready=await_ready)
 
     def run(self) -> None:
         """Serve DHT forever. This function will not return until DHT node is shut down"""
@@ -97,7 +106,7 @@ class DHT(mp.Process):
             async def _run():
                 self._node = await DHTNode.create(
                     initial_peers=self.initial_peers,
-                    num_workers=self.max_workers or 1,
+                    num_workers=self.num_workers,
                     record_validator=self._record_validator,
                     **self.kwargs,
                 )
@@ -251,9 +260,30 @@ class DHT(mp.Process):
 
         self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
 
-    async def _add_validators(self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
+    @staticmethod
+    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
         node.protocol.record_validator.extend(record_validators)
 
+    @property
+    def peer_id(self) -> PeerID:
+        if self._peer_id is None:
+            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+        return self._peer_id
+
+    @staticmethod
+    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
+        return node.peer_id
+
+    @property
+    def client_mode(self) -> bool:
+        if self._client_mode is None:
+            self._client_mode = self.run_coroutine(DHT._get_client_mode)
+        return self._client_mode
+
+    @staticmethod
+    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
+        return node.protocol.client_mode
+
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
         Get multiaddrs of the current DHT node that should be accessible by other peers.
@@ -263,9 +293,25 @@ class DHT(mp.Process):
 
         return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
 
-    async def _get_visible_maddrs(self, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
+    @staticmethod
+    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
         return await node.get_visible_maddrs(latest=latest)
 
+    async def replicate_p2p(self) -> P2P:
+        """
+        Get a replica of a P2P instance used in the DHT process internally.
+        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
+        """
+
+        if self._p2p_replica is None:
+            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
+            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
+        return self._p2p_replica
+
+    @staticmethod
+    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
+        return node.p2p.daemon_listen_maddr
+
     def __del__(self):
         if self._parent_pid == os.getpid() and self.is_alive():
             self.shutdown()

+ 0 - 1
hivemind/dht/crypto.py

@@ -6,7 +6,6 @@ from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 from hivemind.utils import MSGPackSerializer, get_logger
 from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
 
-
 logger = get_logger(__name__)
 
 

+ 10 - 6
hivemind/dht/node.py

@@ -2,8 +2,9 @@ from __future__ import annotations
 
 import asyncio
 import dataclasses
+import os
 import random
-from collections import defaultdict, Counter
+from collections import Counter, defaultdict
 from dataclasses import dataclass, field
 from functools import partial
 from typing import (
@@ -27,17 +28,20 @@ from sortedcontainers import SortedSet
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.routing import DHTID, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, BinaryDHTValue, DHTKey, DHTValue, Subkey, get_dht_time
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
 from hivemind.p2p import P2P, PeerID
-from hivemind.utils import MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils import MSGPackSerializer, SerializerBase, get_logger
 from hivemind.utils.auth import AuthorizerBase
 from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithExpiration
 
 logger = get_logger(__name__)
 
 
+DEFAULT_NUM_WORKERS = int(os.getenv("HIVEMIND_DHT_NUM_WORKERS", 4))
+
+
 class DHTNode:
     """
     Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
@@ -110,7 +114,7 @@ class DHTNode:
         cache_refresh_before_expiry: float = 5,
         cache_on_store: bool = True,
         reuse_get_requests: bool = True,
-        num_workers: int = 1,
+        num_workers: int = DEFAULT_NUM_WORKERS,
         chunk_size: int = 16,
         blacklist_time: float = 5.0,
         backoff_rate: float = 2.0,
@@ -154,7 +158,7 @@ class DHTNode:
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param validate: if True, use initial peers to validate that this node is accessible and synchronized
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
-        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citzen"
+        :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citizen"
           if True, this node will refuse any incoming requests, effectively being only a client
         :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
         :param authorizer: instance of AuthorizerBase used for signing and validating requests and response
@@ -207,7 +211,7 @@ class DHTNode:
             record_validator,
             authorizer,
         )
-        self.peer_id = p2p.id
+        self.peer_id = p2p.peer_id
 
         if initial_peers:
             initial_peers = {PeerID.from_base58(Multiaddr(item)["p2p"]) for item in initial_peers}

+ 8 - 8
hivemind/dht/protocol.py

@@ -2,20 +2,20 @@
 from __future__ import annotations
 
 import asyncio
-from typing import Optional, List, Tuple, Dict, Sequence, Union, Collection
+from typing import Collection, Dict, List, Optional, Sequence, Tuple, Union
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
-from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, BinaryDHTValue, RoutingTable, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
 from hivemind.proto import dht_pb2
-from hivemind.utils import get_logger, MSGPackSerializer
-from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
+from hivemind.utils import MSGPackSerializer, get_logger
+from hivemind.utils.auth import AuthorizerBase, AuthRole, AuthRPCWrapper
 from hivemind.utils.timed_storage import (
-    DHTExpiration,
-    get_dht_time,
     MAX_DHT_TIME_DISCREPANCY_SECONDS,
+    DHTExpiration,
     ValueWithExpiration,
+    get_dht_time,
 )
 
 logger = get_logger(__name__)
@@ -296,7 +296,7 @@ class DHTProtocol(ServicerBase):
                 nearest = dict(
                     zip(
                         map(DHTID.from_bytes, result.nearest_node_ids),
-                        map(PeerID.from_base58, result.nearest_peer_ids),
+                        map(PeerID, result.nearest_peer_ids),
                     )
                 )
 
@@ -359,7 +359,7 @@ class DHTProtocol(ServicerBase):
                 key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)
             ):
                 item.nearest_node_ids.append(node_id.to_bytes())
-                item.nearest_peer_ids.append(peer_id.to_base58())
+                item.nearest_peer_ids.append(peer_id.to_bytes())
             response.results.append(item)
         return response
 

+ 2 - 1
hivemind/dht/routing.py

@@ -7,7 +7,8 @@ import os
 import random
 from collections.abc import Iterable
 from itertools import chain
-from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
+
 from hivemind.p2p import PeerID
 from hivemind.utils import MSGPackSerializer, get_dht_time
 

+ 1 - 1
hivemind/dht/storage.py

@@ -4,7 +4,7 @@ from typing import Optional, Union
 
 from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey
 from hivemind.utils.serializer import MSGPackSerializer
-from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage, DHTExpiration
+from hivemind.utils.timed_storage import DHTExpiration, KeyType, TimedStorage, ValueType
 
 
 @MSGPackSerializer.ext_serializable(0x50)

+ 1 - 1
hivemind/dht/traverse.py

@@ -2,7 +2,7 @@
 import asyncio
 import heapq
 from collections import Counter
-from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
+from typing import Any, Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple
 
 from hivemind.dht.routing import DHTID
 

+ 2 - 2
hivemind/hivemind_cli/run_server.py

@@ -4,11 +4,11 @@ from pathlib import Path
 import configargparse
 import torch
 
-from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.moe.server import Server
+from hivemind.moe.server.layers import schedule_name_to_scheduler
+from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
-from hivemind.moe.server.layers import schedule_name_to_scheduler
 
 logger = get_logger(__name__)
 

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,2 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, register_expert_class, get_experts, declare_experts
+from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class

+ 11 - 11
hivemind/moe/client/beam_search.py

@@ -2,22 +2,22 @@ import asyncio
 import heapq
 from collections import deque
 from functools import partial
-from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
+from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
-from hivemind.dht import DHT, DHTNode, DHTExpiration
+from hivemind.dht import DHT, DHTExpiration, DHTNode
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import (
-    ExpertUID,
-    ExpertPrefix,
     FLAT_EXPERT,
-    UidEndpoint,
-    Score,
-    Coordinate,
     PREFIX_PATTERN,
     UID_DELIMITER,
+    Coordinate,
+    ExpertPrefix,
+    ExpertUID,
+    Score,
+    UidEndpoint,
     is_valid_prefix,
 )
-from hivemind.utils import get_logger, get_dht_time, MPFuture
+from hivemind.utils import MPFuture, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 
@@ -125,7 +125,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
     ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
-        num_workers = num_workers or dht.max_workers or beam_size
+        num_workers = num_workers or dht.num_workers or beam_size
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
         unattempted_indices: List[Coordinate] = sorted(
             range(len(scores)), key=scores.__getitem__
@@ -206,7 +206,7 @@ class MoEBeamSearcher:
         num_workers: Optional[int] = None,
     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
         grid_size = grid_size or float("inf")
-        num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
+        num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
         for prefix, found in dht_responses.items():
@@ -270,7 +270,7 @@ class MoEBeamSearcher:
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
     ) -> List[RemoteExpert]:
-        num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
+        num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(

+ 3 - 3
hivemind/moe/client/expert.py

@@ -1,13 +1,13 @@
 import pickle
-from typing import Tuple, Optional, Any, Dict
+from typing import Any, Dict, Optional, Tuple
 
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import ChannelCache
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert

+ 6 - 6
hivemind/moe/client/moe.py

@@ -1,8 +1,8 @@
 from __future__ import annotations
 
 import time
-from queue import Queue, Empty
-from typing import Tuple, List, Optional, Dict, Any
+from queue import Empty, Queue
+from typing import Any, Dict, List, Optional, Tuple
 
 import grpc
 import torch
@@ -11,11 +11,11 @@ from torch.autograd.function import once_differentiable
 
 import hivemind
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import RemoteExpert, DUMMY, _get_expert_stub
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten, nested_map
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import nested_flatten, nested_map, nested_pack
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 3 - 3
hivemind/moe/client/switch_moe.py

@@ -1,14 +1,14 @@
 from __future__ import annotations
 
-from typing import Tuple, List
+from typing import List, Tuple
 
 import grpc
 import torch
 
-from hivemind.moe.client.expert import RemoteExpert, DUMMY
+from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten
+from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 12 - 7
hivemind/moe/server/__init__.py

@@ -5,24 +5,29 @@ import multiprocessing.synchronize
 import threading
 from contextlib import contextmanager
 from functools import partial
-from typing import Dict, List, Optional, Tuple
 from pathlib import Path
+from typing import Dict, List, Optional, Tuple
 
 import torch
 from multiaddr import Multiaddr
 
 import hivemind
 from hivemind.dht import DHT
-from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
-from hivemind.moe.server.checkpoints import CheckpointSaver, load_experts, is_directory
+from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.layers import name_to_block, name_to_input, register_expert_class
-from hivemind.moe.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
+from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
+from hivemind.moe.server.layers import (
+    add_custom_models_from_file,
+    name_to_block,
+    name_to_input,
+    register_expert_class,
+    schedule_name_to_scheduler,
+)
 from hivemind.moe.server.runtime import Runtime
-from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger, BatchTensorDescriptor
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 
 logger = get_logger(__name__)
 
@@ -63,7 +68,7 @@ class Server(threading.Thread):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=find_open_port())
+            listen_on = replace_port(listen_on, new_port=get_free_port())
         self.listen_on, self.port = listen_on, get_port(listen_on)
 
         self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]

+ 3 - 3
hivemind/moe/server/connection_handler.py

@@ -6,11 +6,11 @@ from typing import Dict
 import grpc
 import torch
 
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.utils import get_logger, Endpoint, nested_flatten
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import Endpoint, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 logger = get_logger(__name__)

+ 7 - 7
hivemind/moe/server/dht_handler.py

@@ -1,16 +1,16 @@
 import threading
 from functools import partial
-from typing import Sequence, Dict, List, Tuple, Optional
+from typing import Dict, List, Optional, Sequence, Tuple
 
-from hivemind.dht import DHT, DHTNode, DHTExpiration, DHTValue
+from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import (
-    ExpertUID,
-    ExpertPrefix,
     FLAT_EXPERT,
-    Coordinate,
     UID_DELIMITER,
     UID_PATTERN,
+    Coordinate,
+    ExpertPrefix,
+    ExpertUID,
     is_valid_uid,
     split_uid,
 )
@@ -56,7 +56,7 @@ def declare_experts(
 async def _declare_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
@@ -89,7 +89,7 @@ async def _get_experts(
 ) -> List[Optional[RemoteExpert]]:
     if expiration_time is None:
         expiration_time = get_dht_time()
-    num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)

+ 3 - 3
hivemind/moe/server/expert_backend.py

@@ -1,12 +1,12 @@
-from typing import Dict, Sequence, Any, Tuple, Union, Callable
+from typing import Any, Callable, Dict, Sequence, Tuple, Union
 
 import torch
 from torch import nn
 
 from hivemind.moe.server.task_pool import TaskPool
-from hivemind.utils.tensor_descr import BatchTensorDescriptor, DUMMY_BATCH_SIZE
 from hivemind.utils.logging import get_logger
-from hivemind.utils.nested import nested_flatten, nested_pack, nested_compare, nested_map
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
+from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 logger = get_logger(__name__)
 

+ 1 - 1
hivemind/moe/server/expert_uid.py

@@ -1,6 +1,6 @@
 import random
 import re
-from typing import NamedTuple, Union, Tuple, Optional, List
+from typing import List, NamedTuple, Optional, Tuple, Union
 
 import hivemind
 from hivemind.dht import DHT

+ 1 - 1
hivemind/moe/server/layers/custom_experts.py

@@ -1,5 +1,5 @@
-import os
 import importlib
+import os
 from typing import Callable, Type
 
 import torch

+ 1 - 1
hivemind/moe/server/runtime.py

@@ -4,7 +4,7 @@ import threading
 from collections import defaultdict
 from itertools import chain
 from queue import SimpleQueue
-from selectors import DefaultSelector, EVENT_READ
+from selectors import EVENT_READ, DefaultSelector
 from statistics import mean
 from time import time
 from typing import Dict, NamedTuple, Optional

+ 3 - 3
hivemind/moe/server/task_pool.py

@@ -10,12 +10,12 @@ from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from concurrent.futures import Future
 from queue import Empty
-from typing import List, Tuple, Dict, Any, Generator
+from typing import Any, Dict, Generator, List, Tuple
 
 import torch
 
 from hivemind.utils import get_logger
-from hivemind.utils.mpfuture import MPFuture, InvalidStateError
+from hivemind.utils.mpfuture import InvalidStateError, MPFuture
 
 logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
@@ -100,7 +100,7 @@ class TaskPool(TaskPoolBase):
 
     def submit_task(self, *args: torch.Tensor) -> Future:
         """Add task to this pool's queue, return Future for its output"""
-        task = Task(MPFuture(synchronize=False), args)
+        task = Task(MPFuture(), args)
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.simple import DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
+from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 import torch.optim
 
-from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind import TrainingAverager
+from hivemind.optim.collaborative import CollaborativeOptimizer
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 4 - 4
hivemind/optim/collaborative.py

@@ -2,8 +2,8 @@ from __future__ import annotations
 
 import logging
 from dataclasses import dataclass
-from threading import Thread, Lock, Event
-from typing import Dict, Optional, Iterator
+from threading import Event, Lock, Thread
+from typing import Dict, Iterator, Optional
 
 import numpy as np
 import torch
@@ -42,7 +42,7 @@ class CollaborationState:
 
 
 class TrainingState(BaseModel):
-    endpoint: Endpoint
+    peer_id: bytes
     step: conint(ge=0, strict=True)
     samples_accumulated: conint(ge=0, strict=True)
     samples_per_second: confloat(ge=0.0, strict=True)
@@ -354,7 +354,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             with self.lock_local_progress:
                 current_time = get_dht_time()
                 local_state_info = TrainingState(
-                    endpoint=self.averager.endpoint,
+                    peer_id=self.averager.peer_id.to_bytes(),
                     step=self.local_step,
                     samples_accumulated=self.local_samples_accumulated,
                     samples_per_second=self.performance_ema.samples_per_second,

+ 3 - 3
hivemind/optim/simple.py

@@ -1,13 +1,13 @@
 import time
-from threading import Thread, Lock, Event
+from threading import Event, Lock, Thread
 from typing import Optional, Sequence, Tuple
 
 import torch
 
-from hivemind.dht import DHT
 from hivemind.averaging import TrainingAverager
+from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.utils import get_logger, get_dht_time
+from hivemind.utils import get_dht_time, get_logger
 
 logger = get_logger(__name__)
 

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1,3 +1,3 @@
-from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
+from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
 from hivemind.p2p.servicer import ServicerBase, StubBase

+ 59 - 88
hivemind/p2p/p2p_daemon.py

@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from importlib.resources import path
-from subprocess import Popen
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
 
 from google.protobuf.message import Message
@@ -16,7 +15,7 @@ import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
 from hivemind.proto.p2pd_pb2 import RPCError
-from hivemind.utils.asyncio import aiter
+from hivemind.utils.asyncio import aiter, asingle
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -45,7 +44,7 @@ class P2P:
       - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
 
     To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
-    using the recipient's unique `P2P.id` and the name of the corresponding handler.
+    using the recipient's unique `P2P.peer_id` and the name of the corresponding handler.
     """
 
     HEADER_LEN = 8
@@ -66,11 +65,11 @@ class P2P:
     _UNIX_SOCKET_PREFIX = "/unix/tmp/hivemind-"
 
     def __init__(self):
-        self.id = None
+        self.peer_id = None
         self._child = None
         self._alive = False
+        self._reader_task = None
         self._listen_task = None
-        self._server_stopped = asyncio.Event()
 
     @classmethod
     async def create(
@@ -91,9 +90,7 @@ class P2P:
         use_relay_discovery: bool = False,
         use_auto_relay: bool = False,
         relay_hop_limit: int = 0,
-        quiet: bool = True,
-        ping_n_attempts: int = 5,
-        ping_delay: float = 0.4,
+        startup_timeout: float = 15,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
@@ -114,10 +111,7 @@ class P2P:
         :param use_relay_discovery: enables passive discovery for relay
         :param use_auto_relay: enables autorelay
         :param relay_hop_limit: sets the hop limit for hop relays
-        :param quiet: make the daemon process quiet
-        :param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
-        :param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
-          (in particular, wait for ``ping_delay`` seconds before the first attempt)
+        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :return: a wrapper for the p2p daemon
         """
 
@@ -158,37 +152,26 @@ class P2P:
             autoRelay=use_auto_relay,
             relayHopLimit=relay_hop_limit,
             b=need_bootstrap,
-            q=quiet,
             **process_kwargs,
         )
 
-        self._child = Popen(args=proc_args, encoding="utf8")
+        self._child = await asyncio.subprocess.create_subprocess_exec(
+            *proc_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
+        )
         self._alive = True
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
 
-        await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)
+        ready = asyncio.Future()
+        self._reader_task = asyncio.create_task(self._read_outputs(ready))
+        try:
+            await asyncio.wait_for(ready, startup_timeout)
+        except asyncio.TimeoutError:
+            await self.shutdown()
+            raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
+        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        await self._ping_daemon()
         return self
 
-    async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
-        for try_number in range(ping_n_attempts):
-            await asyncio.sleep(ping_delay * (2 ** try_number))
-
-            if self._child.poll() is not None:  # Process died
-                break
-
-            try:
-                await self._ping_daemon()
-                break
-            except Exception as e:
-                if try_number == ping_n_attempts - 1:
-                    logger.exception("Failed to ping p2pd that has just started")
-                    await self.shutdown()
-                    raise
-
-        if self._child.returncode is not None:
-            raise RuntimeError(f"The p2p daemon has died with return code {self._child.returncode}")
-
     @classmethod
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """
@@ -213,8 +196,8 @@ class P2P:
         return self
 
     async def _ping_daemon(self) -> None:
-        self.id, self._visible_maddrs = await self._client.identify()
-        logger.debug(f"Launched p2pd with id = {self.id}, host multiaddrs = {self._visible_maddrs}")
+        self.peer_id, self._visible_maddrs = await self._client.identify()
+        logger.debug(f"Launched p2pd with peer id = {self.peer_id}, host multiaddrs = {self._visible_maddrs}")
 
     async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
@@ -227,9 +210,9 @@ class P2P:
             _, self._visible_maddrs = await self._client.identify()
 
         if not self._visible_maddrs:
-            raise ValueError(f"No multiaddrs found for peer {self.id}")
+            raise ValueError(f"No multiaddrs found for peer {self.peer_id}")
 
-        p2p_maddr = Multiaddr(f"/p2p/{self.id.to_base58()}")
+        p2p_maddr = Multiaddr(f"/p2p/{self.peer_id.to_base58()}")
         return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
 
     async def list_peers(self) -> List[PeerInfo]:
@@ -308,15 +291,12 @@ class P2P:
           they will not be received while the prefetch buffer is full.
         """
 
-        if self._listen_task is None:
-            self._start_listening()
-
         async def _handle_stream(
             stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
         ) -> None:
             context = P2PContext(
                 handle_name=name,
-                local_id=self.id,
+                local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
             )
             requests = asyncio.Queue(max_prefetch)
@@ -334,7 +314,9 @@ class P2P:
                         await P2P.send_protobuf(response, writer)
                 except Exception as e:
                     logger.warning("Exception while processing stream and sending responses:", exc_info=True)
-                    await P2P.send_protobuf(RPCError(message=str(e)), writer)
+                    # Sometimes `e` is a connection error, so we won't be able to report the error to the caller
+                    with suppress(Exception):
+                        await P2P.send_protobuf(RPCError(message=str(e)), writer)
 
             with closing(writer):
                 processing_task = asyncio.create_task(_process_stream())
@@ -358,12 +340,12 @@ class P2P:
                 finally:
                     processing_task.cancel()
 
-        await self._client.stream_handler(name, _handle_stream)
+        await self.add_binary_stream_handler(name, _handle_stream)
 
     async def _iterate_protobuf_stream_handler(
         self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Message
     ) -> TOutputStream:
-        _, reader, writer = await self._client.stream_open(peer_id, (name,))
+        _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
         async def _write_to_stream() -> None:
             async for request in requests:
@@ -409,15 +391,7 @@ class P2P:
             return
 
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
-            if stream_input:
-                input = requests
-            else:
-                count = 0
-                async for input in requests:
-                    count += 1
-                if count != 1:
-                    raise ValueError(f"Got {count} requests for handler {name} instead of one")
-
+            input = requests if stream_input else await asingle(requests)
             output = handler(input, context)
 
             if isinstance(output, AsyncIterableABC):
@@ -448,7 +422,7 @@ class P2P:
             input_serialized = input_protobuf_type.FromString(request)
             context = P2PContext(
                 handle_name=handle_name,
-                local_id=self.id,
+                local_id=self.peer_id,
                 remote_id=remote_id,
             )
 
@@ -468,14 +442,9 @@ class P2P:
         if not isinstance(input, AsyncIterableABC):
             return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
 
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
-
-        count = 0
-        async for response in responses:
-            count += 1
-        if count != 1:
-            raise ValueError(f"Got {count} responses from handler {name} instead of one")
-        return response
+        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+        return await asingle(responses)
 
     async def _call_unary_protobuf_handler(
         self,
@@ -501,20 +470,10 @@ class P2P:
     def _start_listening(self) -> None:
         async def listen() -> None:
             async with self._client.listen():
-                await self._server_stopped.wait()
+                await asyncio.Future()  # Wait until this task will be cancelled in _terminate()
 
         self._listen_task = asyncio.create_task(listen())
 
-    async def _stop_listening(self) -> None:
-        if self._listen_task is not None:
-            self._server_stopped.set()
-            self._listen_task.cancel()
-            try:
-                await self._listen_task
-            except asyncio.CancelledError:
-                self._listen_task = None
-                self._server_stopped.clear()
-
     async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
         if self._listen_task is None:
             self._start_listening()
@@ -533,22 +492,20 @@ class P2P:
         return self._alive
 
     async def shutdown(self) -> None:
-        await self._stop_listening()
-        await asyncio.get_event_loop().run_in_executor(None, self._terminate)
+        self._terminate()
+        if self._child is not None:
+            await self._child.wait()
 
     def _terminate(self) -> None:
-        self._alive = False
-
-        if self._client.control._write_task is not None:
-            self._client.control._write_task.cancel()
-
-        if self._client.control._read_task is not None:
-            self._client.control._read_task.cancel()
+        if self._listen_task is not None:
+            self._listen_task.cancel()
+        if self._reader_task is not None:
+            self._reader_task.cancel()
 
-        if self._child is not None and self._child.poll() is None:
+        self._alive = False
+        if self._child is not None and self._child.returncode is None:
             self._child.terminate()
-            self._child.wait()
-            logger.debug(f"Terminated p2pd with id = {self.id}")
+            logger.debug(f"Terminated p2pd with id = {self.peer_id}")
 
             with suppress(FileNotFoundError):
                 os.remove(self._daemon_listen_maddr["unix"])
@@ -575,6 +532,20 @@ class P2P:
     def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
         return ",".join(str(addr) for addr in maddrs)
 
+    async def _read_outputs(self, ready: asyncio.Future) -> None:
+        last_line = None
+        while True:
+            line = await self._child.stdout.readline()
+            if not line:  # Stream closed
+                break
+            last_line = line.rstrip().decode(errors="ignore")
+
+            if last_line.startswith("Peer ID:"):
+                ready.set_result(None)
+
+        if not ready.done():
+            ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
+
 
-class P2PInterruptedError(Exception):
+class P2PDaemonError(RuntimeError):
     pass

+ 54 - 27
hivemind/p2p/servicer.py

@@ -1,6 +1,7 @@
 import asyncio
+import inspect
 from dataclasses import dataclass
-from typing import Any, AsyncIterator, Optional, Tuple, get_type_hints
+from typing import Any, AsyncIterator, List, Optional, Tuple, Type, get_type_hints
 
 from hivemind.p2p.p2p_daemon import P2P
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
@@ -9,7 +10,6 @@ from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 @dataclass
 class RPCHandler:
     method_name: str
-    handle_name: str
     request_type: type
     response_type: type
     stream_input: bool
@@ -24,9 +24,10 @@ class StubBase:
     adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
     """
 
-    def __init__(self, p2p: P2P, peer: PeerID):
+    def __init__(self, p2p: P2P, peer: PeerID, namespace: Optional[str]):
         self._p2p = p2p
         self._peer = peer
+        self._namespace = namespace
 
 
 class ServicerBase:
@@ -41,39 +42,49 @@ class ServicerBase:
       to calls to the remote peer.
     """
 
-    def __init__(self):
-        class_name = self.__class__.__name__
+    _rpc_handlers: Optional[List[RPCHandler]] = None
+    _stub_type: Optional[Type[StubBase]] = None
 
-        self._rpc_handlers = []
-        for method_name, method in self.__class__.__dict__.items():
-            if method_name.startswith("rpc_") and callable(method):
-                handle_name = f"{class_name}.{method_name}"
+    @classmethod
+    def _collect_rpc_handlers(cls) -> None:
+        if cls._rpc_handlers is not None:
+            return
 
+        cls._rpc_handlers = []
+        for method_name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
+            if method_name.startswith("rpc_"):
+                spec = inspect.getfullargspec(method)
+                if len(spec.args) < 3:
+                    raise ValueError(
+                        f"{method_name} is expected to at least three positional arguments "
+                        f"(self, request: TInputProtobuf | AsyncIterator[TInputProtobuf], context: P2PContext)"
+                    )
+                request_arg = spec.args[1]
                 hints = get_type_hints(method)
                 try:
-                    request_type = hints["request"]
+                    request_type = hints[request_arg]
                     response_type = hints["return"]
                 except KeyError:
                     raise ValueError(
-                        f"{handle_name} is expected to have type annotations "
+                        f"{method_name} is expected to have type annotations "
                         f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
-                        f"for the `request` parameter and the return value"
+                        f"for the `{request_arg}` parameter and the return value"
                     )
-                request_type, stream_input = self._strip_iterator_hint(request_type)
-                response_type, stream_output = self._strip_iterator_hint(response_type)
+                request_type, stream_input = cls._strip_iterator_hint(request_type)
+                response_type, stream_output = cls._strip_iterator_hint(response_type)
 
-                self._rpc_handlers.append(
-                    RPCHandler(method_name, handle_name, request_type, response_type, stream_input, stream_output)
+                cls._rpc_handlers.append(
+                    RPCHandler(method_name, request_type, response_type, stream_input, stream_output)
                 )
 
-        self._stub_type = type(
-            f"{class_name}Stub",
+        cls._stub_type = type(
+            f"{cls.__name__}Stub",
             (StubBase,),
-            {handler.method_name: self._make_rpc_caller(handler) for handler in self._rpc_handlers},
+            {handler.method_name: cls._make_rpc_caller(handler) for handler in cls._rpc_handlers},
         )
 
-    @staticmethod
-    def _make_rpc_caller(handler: RPCHandler):
+    @classmethod
+    def _make_rpc_caller(cls, handler: RPCHandler):
         input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
 
         # This method will be added to a new Stub type (a subclass of StubBase)
@@ -87,7 +98,7 @@ class ServicerBase:
 
                 return self._p2p.iterate_protobuf_handler(
                     self._peer,
-                    handler.handle_name,
+                    cls._get_handle_name(self._namespace, handler.method_name),
                     input,
                     handler.response_type,
                 )
@@ -98,26 +109,42 @@ class ServicerBase:
                 self: StubBase, input: input_type, timeout: Optional[float] = None
             ) -> handler.response_type:
                 return await asyncio.wait_for(
-                    self._p2p.call_protobuf_handler(self._peer, handler.handle_name, input, handler.response_type),
+                    self._p2p.call_protobuf_handler(
+                        self._peer,
+                        cls._get_handle_name(self._namespace, handler.method_name),
+                        input,
+                        handler.response_type,
+                    ),
                     timeout=timeout,
                 )
 
         caller.__name__ = handler.method_name
         return caller
 
-    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
+    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None) -> None:
+        self._collect_rpc_handlers()
+
         servicer = self if wrapper is None else wrapper
         for handler in self._rpc_handlers:
             await p2p.add_protobuf_handler(
-                handler.handle_name,
+                self._get_handle_name(namespace, handler.method_name),
                 getattr(servicer, handler.method_name),
                 handler.request_type,
                 stream_input=handler.stream_input,
                 stream_output=handler.stream_output,
             )
 
-    def get_stub(self, p2p: P2P, peer: PeerID) -> StubBase:
-        return self._stub_type(p2p, peer)
+    @classmethod
+    def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
+        cls._collect_rpc_handlers()
+        return cls._stub_type(p2p, peer, namespace)
+
+    @classmethod
+    def _get_handle_name(cls, namespace: Optional[str], method_name: str) -> str:
+        handle_name = f"{cls.__name__}.{method_name}"
+        if namespace is not None:
+            handle_name = f"{namespace}::{handle_name}"
+        return handle_name
 
     @staticmethod
     def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:

+ 8 - 14
hivemind/proto/averaging.proto

@@ -2,13 +2,6 @@ syntax = "proto3";
 import "runtime.proto";
 
 
-// Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
-service DecentralizedAveraging {
-  rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
-  rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData);  // send local part => get average part
-  rpc rpc_download_state(DownloadRequest) returns (stream DownloadData);
-}
-
 enum MessageCode {
   NO_CODE = 0;               // Default value that should not be used explicitly
   REQUEST_JOIN = 1;          // "Dear maybe leader, will you have me in your group as a follower?"
@@ -21,35 +14,36 @@ enum MessageCode {
   BAD_EXPIRATION_TIME = 8;   // "I will not accept you. I cannot guarantee that we begin before you expire."
   BAD_SCHEMA_HASH = 9;       // "I will not accept you. I am not averaging the samy type of tensors as you."
   BAD_GROUP_ID = 10;         // "I will not accept your request, your group id does not match with any groups i'm in."
-  DUPLICATE_ENDPOINT = 11;   // "I will not accept you, i already have exactly the same endpoint in my current group."
+  DUPLICATE_PEER_ID = 11;    // "I will not accept you, i already have exactly the same peer id in my current group."
   GROUP_IS_FULL = 12;        // "I will not accept you, my group already contains too many peers."
   NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost."
   PROTOCOL_VIOLATION = 14;   // "You did something so unspeakable that i don't have a special code for that."
   INTERNAL_ERROR = 15;       // "I messed up, we will have to stop allreduce because of that."
   CANCELLED = 16;            // "[from peer during allreduce] I no longer want to participate in AllReduce."
   GROUP_DISBANDED = 17;      // "[from leader] The group is closed. Go find another group."
+  BAD_GROUP_KEY = 18;        // "I will not accept you. My current group key differs (maybe you used my older key)."
 }
 
 message JoinRequest {
-  string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
   bool client_mode = 5;         // if True, the incoming averager is a client with no capacity for averaging
+  string group_key = 6;         // group key identifying an All-Reduce bucket, e.g my_averager.0b011011101
 }
 
 message MessageFromLeader {
   MessageCode code = 1;
-  bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
-  string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
-  repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
-  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endpoints
+  bytes group_id = 2;           // a unique identifier of this group, only valid until allreduce is finished/failed
+  bytes suggested_leader = 3;   // if peer is already in a group, it'll provide us with a peer id of its leader
+  repeated bytes ordered_peer_ids = 4;  // a sequence of peers, each responsible for one shard during averaging
+  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their peer ids
 }
 
 message AveragingData {
   MessageCode code = 1;     // in case of a protocol violation, this will be the error message
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
-  string endpoint = 3;      // sender's rpc endpoint, used for coordination
+  bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
   bytes metadata = 5;       // reserved user-extendable metadata
 }

+ 2 - 2
hivemind/proto/dht.proto

@@ -65,8 +65,8 @@ message FindResult {
   double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
 
   // two aligned arrays: DHTIDs and PeerIDs for nearest peers (sorted by XOR distance)
-  repeated bytes nearest_node_ids = 4;      // DHTIDs of the nearest peers serialized with node_id.to_bytes()
-  repeated string nearest_peer_ids = 5;     // Base58-serialized libp2p PeerIDs of the nearest peers
+  repeated bytes nearest_node_ids = 4;  // DHTIDs of the nearest peers serialized with node_id.to_bytes()
+  repeated bytes nearest_peer_ids = 5;  // libp2p PeerIDs of the nearest peers
 }
 
 message FindResponse {

+ 3 - 3
hivemind/utils/__init__.py

@@ -1,11 +1,11 @@
 from hivemind.utils.asyncio import *
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
-from hivemind.utils.serializer import SerializerBase, MSGPackSerializer
-from hivemind.utils.tensor_descr import TensorDescriptor, BatchTensorDescriptor
+from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 15 - 4
hivemind/utils/asyncio.py

@@ -1,12 +1,11 @@
-from concurrent.futures import ThreadPoolExecutor
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable, Tuple, Optional, Callable
 import asyncio
+from concurrent.futures import ThreadPoolExecutor
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 
 import uvloop
 
 from hivemind.utils.logging import get_logger
 
-
 T = TypeVar("T")
 logger = get_logger(__name__)
 
@@ -59,6 +58,18 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]
         index += 1
 
 
+async def asingle(aiter: AsyncIterable[T]) -> T:
+    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`."""
+    count = 0
+    async for item in aiter:
+        count += 1
+        if count == 2:
+            raise ValueError("asingle() expected an iterable with exactly one item, but got two or more items")
+    if count == 0:
+        raise ValueError("asingle() expected an iterable with exactly one item, but got an empty iterable")
+    return item
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable
@@ -73,7 +84,7 @@ async def amap_in_executor(
     func: Callable[..., T],
     *iterables: AsyncIterable,
     max_prefetch: Optional[int] = None,
-    executor: Optional[ThreadPoolExecutor] = None
+    executor: Optional[ThreadPoolExecutor] = None,
 ) -> AsyncIterator[T]:
     """iterate from an async iterable in a background thread, yield results to async iterable"""
     loop = asyncio.get_event_loop()

+ 2 - 3
hivemind/utils/compression.py

@@ -1,15 +1,14 @@
 import os
+import warnings
 from concurrent.futures import ThreadPoolExecutor
-from typing import Tuple, Sequence, Optional
+from typing import Optional, Sequence, Tuple
 
 import numpy as np
 import torch
-import warnings
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
 
-
 FP32_EPS = 1e-06
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT16 = 2

+ 2 - 2
hivemind/utils/grpc.py

@@ -6,14 +6,14 @@ from __future__ import annotations
 
 import os
 import threading
-from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable
+from typing import Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
 
 import grpc
 
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.networking import Endpoint
-from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
+from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
 
 logger = get_logger(__name__)
 

+ 137 - 161
hivemind/utils/mpfuture.py

@@ -2,18 +2,19 @@ from __future__ import annotations
 
 import asyncio
 import concurrent.futures._base as base
-from contextlib import nullcontext, suppress
 import multiprocessing as mp
 import multiprocessing.connection
 import os
 import threading
 import uuid
-from weakref import ref
+from contextlib import nullcontext
 from enum import Enum, auto
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
+from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
+from weakref import ref
 
-from hivemind.utils.logging import get_logger
+import torch  # used for py3.7-compatible shared memory
 
+from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
 
@@ -33,13 +34,38 @@ except ImportError:
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
 
 
-class MessageType(Enum):
+class SharedBytes:
+    """
+    A process-wide object that allocates large chunks of shared memory and partitions it into individual bytes.
+
+    Note: this process is only responsible for bulk allocation, it does not manage/free unused bytes.
+    The chunks are deallocated by the garbage collector,
+    when it detects that all processes no longer use any bytes from this chunk.
+    """
+
+    _lock = mp.Lock()
+    _pid: Optional[PID] = None
+    _buffer: Optional[torch.Tensor] = None
+    _index: int = 0
+
+    @classmethod
+    def next(cls) -> torch.Tensor:
+        """Create another shared byte value, represented as a scalar uint8 tensor"""
+        with cls._lock:
+            if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
+                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                cls._pid = os.getpid()
+                cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
+                cls._index = 0
+
+            cls._index += 1
+            return cls._buffer[cls._index - 1]
+
+
+class UpdateType(Enum):
     RESULT = auto()
     EXCEPTION = auto()
-    RUNNING = auto()
     CANCEL = auto()
-    STATE_REQUEST = auto()
-    STATE_RESPONSE = auto()
 
 
 class MPFuture(base.Future, Generic[ResultType]):
@@ -48,12 +74,9 @@ class MPFuture(base.Future, Generic[ResultType]):
     Any process can access future status and set the result / exception and check for state.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
 
-    :param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
-      Setting synchronize=False results in slightly better performance of done or set_running_or_notify_cancel
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
-    :param loop: if specified, overrides default asyncio event loop for the purpose of awaiting MPFuture
 
     :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
      More specifically, there are two known limitations:
@@ -64,26 +87,30 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
-    _process_wide_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
-    _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
-    _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
+    _active_futures: Optional[Dict[UID, "ref[MPFuture]"]] = None  # non-done futures originated from this process
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
-    SOFT_UPDATE_TIMEOUT = 0.5  # seconds spent awaiting status update before warning is printed
-    HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
-
-    def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
-        super().__init__()
-        self.synchronize = synchronize
+    def __init__(self, *, use_lock: bool = True):
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
+        self._shared_state_code = SharedBytes.next()
+        self._state_cache: Dict[State, State] = {}
+        # mapping from global to cached local future used that makes updates immediately
+        # available on setter side; dictionary-based cache works because future can visit any state at most once
+
+        base.Future.__init__(self)  # parent init is deferred because it uses self._shared_state_code
         self._state, self._result, self._exception = base.PENDING, None, None
         self._use_lock = use_lock
 
-        self._initialize_backend_if_necessary()
+        if self._origin_pid != MPFuture._active_pid:
+            with MPFuture._initialization_lock:
+                if self._origin_pid != MPFuture._active_pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    self._initialize_mpfuture_backend()
         assert self._uid not in MPFuture._active_futures
         MPFuture._active_futures[self._uid] = ref(self)
-        self._sender_pipe = MPFuture._process_wide_pipe
+        self._sender_pipe = MPFuture._global_sender_pipe
 
         try:
             self._loop = asyncio.get_event_loop()
@@ -91,97 +118,83 @@ class MPFuture(base.Future, Generic[ResultType]):
         except RuntimeError:
             self._loop, self._aio_event = None, None
 
-    def _set_event_if_necessary(self):
-        if self._aio_event is None or self._aio_event.is_set():
-            return
+    @property
+    def _state(self) -> State:
+        shared_state = ALL_STATES[self._shared_state_code.item()]
+        return self._state_cache.get(shared_state, shared_state)
+
+    @_state.setter
+    def _state(self, new_state: State):
+        self._shared_state_code[...] = ALL_STATES.index(new_state)
+        if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
+            self._set_event_threadsafe()
+
+    def _set_event_threadsafe(self):
         try:
-            loop = asyncio.get_running_loop()
+            running_loop = asyncio.get_running_loop()
         except RuntimeError:
-            loop = None
+            running_loop = None
 
         async def _event_setter():
             self._aio_event.set()
 
-        if self._loop.is_running() and loop == self.get_loop():
+        if self._loop.is_running() and running_loop == self._loop:
             asyncio.create_task(_event_setter())
-        elif self._loop.is_running() and loop != self.get_loop():
+        elif self._loop.is_running() and running_loop != self._loop:
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
         else:
             self._loop.run_until_complete(_event_setter())
 
     @classmethod
-    def _initialize_backend_if_necessary(cls):
+    def _initialize_mpfuture_backend(cls):
         pid = os.getpid()
-        if MPFuture._active_pid != pid:
-            with MPFuture._initialization_lock:
-                if MPFuture._active_pid != pid:
-                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
-                    logger.debug(f"Initializing MPFuture backend for pid {pid}")
-                    receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
-                    cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
-                    cls._pipe_waiter_thread = threading.Thread(
-                        target=cls._process_updates_in_background,
-                        args=[receiver_pipe],
-                        name=f"{__name__}.BACKEND",
-                        daemon=True,
-                    )
-                    cls._pipe_waiter_thread.start()
+        logger.debug(f"Initializing MPFuture backend for pid {pid}")
+
+        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
+        cls._active_pid, cls._active_futures = pid, {}
+        cls._pipe_waiter_thread = threading.Thread(
+            target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
+        )
+        cls._pipe_waiter_thread.start()
+
+    @staticmethod
+    def reset_backend():
+        """Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
+        MPFuture._active_pid = None
+        MPFuture._initialization_lock = mp.Lock()
+        MPFuture._update_lock = mp.Lock()
+        SharedBytes._lock = mp.Lock()
 
     @classmethod
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
         pid = os.getpid()
         while True:
             try:
-                uid, msg_type, payload = receiver_pipe.recv()
+                if cls._pipe_waiter_thread is not threading.current_thread():
+                    break  # backend was reset, a new background thread has started
+
+                uid, update_type, payload = receiver_pipe.recv()
                 future = None
-                future_ref = cls._active_futures.get(uid)
+                future_ref = cls._active_futures.pop(uid, None)
                 if future_ref is not None:
                     future = future_ref()
 
-                if msg_type == MessageType.STATE_REQUEST:
-                    future_state = None if future is None else future.__getstate__()
-                    use_lock, return_pipe = payload
-                    with MPFuture._update_lock if use_lock else nullcontext():
-                        return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
-
-                elif msg_type == MessageType.STATE_RESPONSE:
-                    future, state_updated_event = cls._status_requests.get(uid, (None, None))
-                    if future is None:
-                        logger.debug("Received a state update for a future that does not await status update.")
-                    else:
-                        if payload is not None:
-                            future.__setstate__(payload)
-                        else:
-                            base.Future.cancel(future)
-                        state_updated_event.set()
-
-                elif future is None:
-                    logger.debug(
-                        f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
-                    )
-                elif msg_type == MessageType.RESULT:
+                if future is None:
+                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
+                elif update_type == UpdateType.RESULT:
                     future.set_result(payload)
-                elif msg_type == MessageType.EXCEPTION:
+                elif update_type == UpdateType.EXCEPTION:
                     future.set_exception(payload)
-                elif msg_type == MessageType.RUNNING:
-                    try:
-                        future.set_running_or_notify_cancel()
-                    except (InvalidStateError, RuntimeError) as e:
-                        logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
-                elif msg_type == MessageType.CANCEL:
+                elif update_type == UpdateType.CANCEL:
                     future.cancel()
                 else:
-                    raise RuntimeError(f"Received unexpected update type {msg_type}")
-
-                if future is None or future.done():
-                    cls._active_futures.pop(uid, None)
-
+                    raise RuntimeError(f"Received unexpected update type {update_type}")
             except (BrokenPipeError, EOFError, ConnectionError):
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
             except Exception as e:
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
-    def _send_update(self, update_type: MessageType, payload: Any = None):
+    def _send_update(self, update_type: UpdateType, payload: Any = None):
         """This method sends result, exception or cancel to the MPFuture origin."""
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
@@ -189,110 +202,76 @@ class MPFuture(base.Future, Generic[ResultType]):
         except (ConnectionError, BrokenPipeError, EOFError) as e:
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
 
-    def _synchronize_if_necessary(self):
-        if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
-            return
-
-        self._initialize_backend_if_necessary()
-
-        status_updated = threading.Event()
-        _, existing_status_event = self._status_requests.setdefault(self._uid, (self, status_updated))
-        # this line checks if another thread is synchronizing concurrently, assuming that setdefault to be atomic
-
-        if existing_status_event != status_updated:
-            existing_status_event.wait(MPFuture.HARD_UPDATE_TIMEOUT)
-            return
-
-        # otherwise create a new request for synchronization
-
-        try:
-            with MPFuture._update_lock if self._use_lock else nullcontext():
-                payload = (self._use_lock, self._process_wide_pipe)
-                self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
-            status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
-            if not status_updated.is_set():
-                logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
-                status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
-                if not status_updated.is_set() and not self.cancel():
-                    with suppress(InvalidStateError, RuntimeError):
-                        self.set_exception(
-                            TimeoutError(
-                                f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
-                                f"MPFuture is cancelled"
-                            )
-                        )
-                    status_updated.set()  # this triggers any concurrent _synchronize_if_necessary calls to finish
-        except (ConnectionError, BrokenPipeError, EOFError) as e:
-            logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
-            if not self.cancel():
-                with suppress(InvalidStateError, RuntimeError):
-                    self.set_exception(e)
-        finally:
-            self._status_requests.pop(self._uid, None)
-
     def set_result(self, result: ResultType):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_result(result)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
         else:
-            self._send_update(MessageType.RESULT, result)
-        super().set_result(result)
+            self._state_cache[self._state], self._result = base.FINISHED, result
+            self._send_update(UpdateType.RESULT, result)
 
     def set_exception(self, exception: Optional[BaseException]):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_exception(exception)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
         else:
-            self._send_update(MessageType.EXCEPTION, exception)
-        super().set_exception(exception)
+            self._state_cache[self._state], self._exception = base.FINISHED, exception
+            self._send_update(UpdateType.EXCEPTION, exception)
 
     def cancel(self) -> bool:
-        if self._state in [base.RUNNING, base.FINISHED]:
-            return False
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+            return super().cancel()
+        elif self._state in [base.RUNNING, base.FINISHED]:
+            return False
         else:
-            self._send_update(MessageType.CANCEL)
-        return super().cancel()
+            self._state_cache[self._state] = base.CANCELLED
+            self._send_update(UpdateType.CANCEL)
+            return True
 
     def set_running_or_notify_cancel(self):
-        """if synchronize is set to False, this future will ignore any state changes from origin"""
-        self._synchronize_if_necessary()
-        try:
-            is_running = super().set_running_or_notify_cancel()
-            if is_running and os.getpid() != self._origin_pid:
-                self._send_update(MessageType.RUNNING)
-            return is_running
-        except RuntimeError as e:
-            raise InvalidStateError(str(e))
+        if self._state == base.PENDING:
+            self._state = base.RUNNING
+            return True
+        elif self._state == base.CANCELLED:
+            return False
+        else:
+            raise InvalidStateError(
+                f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
+            )
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await result")
-        return super().result(timeout)
+            return super().result(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        elif self._exception:
+            raise self._exception
+        else:
+            return self._result
 
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await exception")
-        return super().exception(timeout)
+            return super().exception(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        return self._exception
 
     def done(self) -> bool:
-        self._synchronize_if_necessary()
         return self._state in TERMINAL_STATES
 
     def running(self):
-        self._synchronize_if_necessary()
         return self._state == base.RUNNING
 
     def cancelled(self):
-        self._synchronize_if_necessary()
         return self._state == base.CANCELLED
 
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
@@ -300,9 +279,6 @@ class MPFuture(base.Future, Generic[ResultType]):
             raise RuntimeError("Only the process that created MPFuture can set callbacks")
         return super().add_done_callback(callback)
 
-    def get_loop(self) -> Optional[asyncio.BaseEventLoop]:
-        return self._loop
-
     def __await__(self):
         if not self._aio_event:
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
@@ -320,9 +296,8 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     def __getstate__(self):
         return dict(
-            synchronize=self.synchronize,
             _sender_pipe=self._sender_pipe,
-            _state=self._state,
+            _shared_state_code=self._shared_state_code,
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _use_lock=self._use_lock,
@@ -331,12 +306,13 @@ class MPFuture(base.Future, Generic[ResultType]):
         )
 
     def __setstate__(self, state):
-        self.synchronize = state["synchronize"]
         self._sender_pipe = state["_sender_pipe"]
-        self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
+        self._shared_state_code = state["_shared_state_code"]
+        self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]
 
         self._waiters, self._done_callbacks = [], []
         self._condition = threading.Condition()
         self._aio_event, self._loop = None, None
+        self._state_cache = {}

+ 1 - 2
hivemind/utils/networking.py

@@ -5,7 +5,6 @@ from typing import Optional, Sequence
 
 from multiaddr import Multiaddr
 
-
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = "127.0.0.1"
@@ -31,7 +30,7 @@ def strip_port(endpoint: Endpoint) -> Hostname:
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
 
 
-def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
     try:
         with closing(socket.socket(*params)) as sock:

+ 1 - 1
hivemind/utils/serializer.py

@@ -1,6 +1,6 @@
 """ A unified interface for several common serialization methods """
-from typing import Dict, Any
 from abc import ABC, abstractmethod
+from typing import Any, Dict
 
 import msgpack
 

+ 1 - 1
hivemind/utils/tensor_descr.py

@@ -1,5 +1,5 @@
 import warnings
-from dataclasses import dataclass, asdict
+from dataclasses import asdict, dataclass
 
 import torch
 

+ 2 - 1
hivemind/utils/timed_storage.py

@@ -1,10 +1,11 @@
 """ A dictionary-like storage that stores items until a specified expiration time or up to a limited size """
 from __future__ import annotations
+
 import heapq
 import time
 from contextlib import contextmanager
-from typing import TypeVar, Generic, Optional, Dict, List, Iterator, Tuple
 from dataclasses import dataclass
+from typing import Dict, Generic, Iterator, List, Optional, Tuple, TypeVar
 
 KeyType = TypeVar("KeyType")
 ValueType = TypeVar("ValueType")

+ 7 - 0
pyproject.toml

@@ -1,3 +1,10 @@
 [tool.black]
 line-length = 119
 required-version = "21.6b0"
+
+[tool.isort]
+profile = "black"
+line_length = 119
+combine_as_imports = true
+combine_star = true
+known_local_folder = ["arguments", "test_utils", "tests", "utils"]

+ 1 - 1
requirements-dev.txt

@@ -2,8 +2,8 @@ pytest
 pytest-forked
 pytest-asyncio
 pytest-cov
-codecov
 tqdm
 scikit-learn
 black==21.6b0
+isort
 psutil

+ 22 - 2
tests/conftest.py

@@ -1,15 +1,33 @@
+import asyncio
 import gc
+import multiprocessing as mp
 from contextlib import suppress
 
 import psutil
 import pytest
 
-from hivemind.utils import get_logger
-
+from hivemind.utils.logging import get_logger
+from hivemind.utils.mpfuture import MPFuture, SharedBytes
 
 logger = get_logger(__name__)
 
 
+@pytest.fixture
+def event_loop():
+    """
+    This overrides the ``event_loop`` fixture from pytest-asyncio
+    (e.g. to make it compatible with ``asyncio.subprocess``).
+
+    This fixture is identical to the original one but does not call ``loop.close()`` in the end.
+    Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops).
+    However, finalizers of objects created in the current test may reference the current loop and fail if it is closed.
+    For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer
+    fails if the loop is closed, but works if the loop is only stopped).
+    """
+
+    yield asyncio.get_event_loop()
+
+
 @pytest.fixture(autouse=True, scope="session")
 def cleanup_children():
     yield
@@ -26,3 +44,5 @@ def cleanup_children():
         for child in children:
             with suppress(psutil.NoSuchProcess):
                 child.kill()
+
+    MPFuture.reset_backend()

+ 28 - 47
tests/test_allreduce.py

@@ -3,16 +3,15 @@ import random
 import time
 from typing import Sequence
 
-import grpc
 import pytest
 import torch
 
-from hivemind import aenumerate, Endpoint
+from hivemind import aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
-from hivemind.proto import averaging_pb2_grpc
+from hivemind.p2p import P2P, StubBase
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import deserialize_torch_tensor, ChannelCache
+from hivemind.utils import deserialize_torch_tensor
 
 
 @pytest.mark.forked
@@ -152,46 +151,34 @@ async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float
             assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
 
 
-class AllreduceRunnerForTesting(AllReduceRunner):
-    """a version of AllReduceRunner that was monkey-patched to accept custom endpoint names"""
-
-    def __init__(self, *args, peer_endpoints, **kwargs):
-        self.__peer_endpoints = peer_endpoints
-        super().__init__(*args, **kwargs)
-
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(
-            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True
-        )
-
-
 NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 
 
 @pytest.mark.parametrize(
-    "peer_modes, averaging_weights, peer_fractions",
+    "peer_modes, averaging_weights, peer_fractions, part_size_bytes",
     [
-        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1)),
-        ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1)),
-        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0)),
-        ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0)),
-        ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4)),
-        ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1)),
-        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0)),
-        ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4)),
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1), 2 ** 20),
+        ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1), 2 ** 20),
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
+        ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
+        ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4), 2 ** 20),
+        ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1), 2 ** 20),
+        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 2 ** 20),
+        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 256),
+        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 19),
+        ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4), 2 ** 20),
     ],
 )
-@pytest.mark.parametrize(
-    "part_size_bytes",
-    [2 ** 20, 256, 19],
-)
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
     """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
 
-    peers = "alice", "bob", "carol", "colab"
+    p2ps = [await P2P.create()]
+    visible_maddrs = await p2ps[0].get_visible_maddrs()
+    p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
 
+    peers = [instance.peer_id for instance in p2ps]
     tensors_by_peer = {
         peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
         for i, peer in enumerate(peers)
@@ -199,28 +186,22 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
 
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
-    servers = []
     allreduce_protocols = []
-    peer_endpoints = {}
-
-    for peer in peers:
-        server = grpc.aio.server()
-        allreduce_protocol = AllreduceRunnerForTesting(
+    for p2p in p2ps:
+        allreduce_protocol = AllReduceRunner(
+            p2p=p2p,
+            servicer_type=AllReduceRunner,
+            prefix=None,
             group_id=group_id,
-            endpoint=peer,
-            tensors=[x.clone() for x in tensors_by_peer[peer]],
-            ordered_group_endpoints=peers,
+            tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
+            ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             modes=peer_modes,
             weights=averaging_weights,
-            peer_endpoints=peer_endpoints,
             part_size_bytes=part_size_bytes,
         )
-        averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(allreduce_protocol, server)
-        peer_endpoints[peer] = f"127.0.0.1:{server.add_insecure_port('127.0.0.1:*')}"
+        await allreduce_protocol.add_p2p_handlers(p2p)
         allreduce_protocols.append(allreduce_protocol)
-        servers.append(server)
-        await server.start()
 
     async def _run_allreduce_inplace(allreduce: AllReduceRunner):
         async for tensor_index, tensor_delta in aenumerate(allreduce):
@@ -244,5 +225,5 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
         assert len(output_tensors) == len(targets_for_peer)
         assert all(torch.allclose(our, ref, atol=1e-6, rtol=0) for our, ref in zip(output_tensors, targets_for_peer))
 
-    for server in servers:
-        await server.stop(grace=1)
+    for instance in p2ps:
+        await instance.shutdown()

+ 1 - 2
tests/test_auth.py

@@ -5,11 +5,10 @@ import pytest
 
 from hivemind.proto import dht_pb2
 from hivemind.proto.auth_pb2 import AccessToken
-from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
+from hivemind.utils.auth import AuthRole, AuthRPCWrapper, TokenAuthorizerBase
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger
 
-
 logger = get_logger(__name__)
 
 

+ 89 - 82
tests/test_averaging.py

@@ -1,4 +1,5 @@
 import random
+import time
 
 import numpy as np
 import pytest
@@ -9,46 +10,51 @@ import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 
+from test_utils.dht_swarms import launch_dht_instances
+
 
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_key_manager():
+    dht = hivemind.DHT(start=True)
     key_manager = GroupKeyManager(
-        hivemind.DHT(start=True),
-        endpoint="localhvost",
+        dht,
         prefix="test_averaging",
         initial_group_bits="10110",
         target_group_size=2,
     )
+    alice = dht.peer_id
+    bob = PeerID(b"bob")
 
     t = hivemind.get_dht_time()
     key = key_manager.current_key
-    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 60)
-    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61)
+    await key_manager.declare_averager(key, alice, expiration_time=t + 60)
+    await key_manager.declare_averager(key, bob, expiration_time=t + 61)
 
     q1 = await key_manager.get_averagers(key, only_active=True)
 
-    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 66)
+    await key_manager.declare_averager(key, alice, expiration_time=t + 66)
     q2 = await key_manager.get_averagers(key, only_active=True)
 
-    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61, looking_for_group=False)
+    await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False)
     q3 = await key_manager.get_averagers(key, only_active=True)
     q4 = await key_manager.get_averagers(key, only_active=False)
 
     q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
 
-    assert len(q1) == 2 and ("localhvost", t + 60) in q1 and ("localhvost2", t + 61) in q1
-    assert len(q2) == 2 and ("localhvost", t + 66) in q2 and ("localhvost2", t + 61) in q2
-    assert len(q3) == 1 and ("localhvost", t + 66) in q3
-    assert len(q4) == 2 and ("localhvost", t + 66) in q4 and ("localhvost2", t + 61) in q2
+    assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1
+    assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2
+    assert len(q3) == 1 and (alice, t + 66) in q3
+    assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2
     assert len(q5) == 0
 
+    dht.shutdown()
 
-def _test_allreduce_once(n_clients, n_aux):
-    dht = hivemind.DHT(start=True)
 
+def _test_allreduce_once(n_clients, n_aux):
     n_peers = 4
     modes = (
         [AveragingMode.CLIENT] * n_clients
@@ -69,6 +75,7 @@ def _test_allreduce_once(n_clients, n_aux):
         for i in range(len(tensors1))
     ]
 
+    dht_instances = launch_dht_instances(len(peer_tensors))
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             tensors,
@@ -77,11 +84,10 @@ def _test_allreduce_once(n_clients, n_aux):
             averaging_expiration=15,
             prefix="mygroup",
             client_mode=mode == AveragingMode.CLIENT,
-            listen_on="127.0.0.1:*",
             auxiliary=mode == AveragingMode.AUX,
             start=True,
         )
-        for tensors, mode in zip(peer_tensors, modes)
+        for tensors, dht, mode in zip(peer_tensors, dht_instances, modes)
     ]
 
     futures = []
@@ -90,7 +96,7 @@ def _test_allreduce_once(n_clients, n_aux):
     for future in futures:
         result = future.result()
         for averager in averagers:
-            assert averager.endpoint in result
+            assert averager.peer_id in result
 
     for averager in averagers:
         if averager.mode != AveragingMode.AUX:
@@ -98,9 +104,8 @@ def _test_allreduce_once(n_clients, n_aux):
                 for ref, our in zip(reference, averaged_tensors):
                     assert torch.allclose(ref, our, atol=1e-6)
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 @pytest.mark.forked
@@ -118,8 +123,6 @@ def test_allreduce_once_edge_cases(n_clients, n_aux):
 
 @pytest.mark.forked
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
-    dht = hivemind.DHT(start=True)
-
     n_peers = 4
     client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
     random.shuffle(client_modes)
@@ -128,6 +131,8 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
+
+    dht_instances = launch_dht_instances(4)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             tensors,
@@ -136,11 +141,11 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
             averaging_expiration=15,
             prefix="mygroup",
             client_mode=client_mode,
-            listen_on="127.0.0.1:*",
             start=True,
         )
-        for tensors, client_mode in zip([tensors1, tensors2, tensors3, tensors4], client_modes)
+        for tensors, dht, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dht_instances, client_modes)
     ]
+
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
     reference = [
         (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
@@ -159,15 +164,13 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
             for ref, our in zip(reference, averaged_tensors):
                 assert torch.allclose(ref, our, atol=1e-6)
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 @pytest.mark.forked
 def test_allreduce_compression():
     """this test ensures that compression works correctly when multiple tensors have different compression types"""
-    dht = hivemind.DHT(start=True)
 
     tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
     tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
@@ -176,9 +179,10 @@ def test_allreduce_compression():
     FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
 
     for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        dht_instances = launch_dht_instances(2)
         averager1 = hivemind.averaging.DecentralizedAverager(
             [x.clone() for x in tensors1],
-            dht=dht,
+            dht=dht_instances[0],
             compression_type=compression_type_pair,
             client_mode=True,
             target_group_size=2,
@@ -187,11 +191,10 @@ def test_allreduce_compression():
         )
         averager2 = hivemind.averaging.DecentralizedAverager(
             [x.clone() for x in tensors2],
-            dht=dht,
+            dht=dht_instances[1],
             compression_type=compression_type_pair,
             target_group_size=2,
             prefix="mygroup",
-            listen_on="127.0.0.1:*",
             start=True,
         )
 
@@ -201,6 +204,9 @@ def test_allreduce_compression():
         with averager1.get_tensors() as averaged_tensors:
             results[compression_type_pair] = averaged_tensors
 
+        for instance in [averager1, averager2] + dht_instances:
+            instance.shutdown()
+
     assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
     assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
     assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
@@ -231,7 +237,7 @@ def compute_mean_std(averagers, unbiased=True):
 
 @pytest.mark.forked
 def test_allreduce_grid():
-    dht = hivemind.DHT(start=True)
+    dht_instances = launch_dht_instances(8)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
@@ -239,10 +245,9 @@ def test_allreduce_grid():
             target_group_size=2,
             prefix="mygroup",
             initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
-            listen_on="127.0.0.1:*",
             start=True,
         )
-        for i in range(8)
+        for i, dht in enumerate(dht_instances)
     ]
 
     [means0], [stds0] = compute_mean_std(averagers)
@@ -262,48 +267,41 @@ def test_allreduce_grid():
         else:
             assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 @pytest.mark.forked
-def test_allgather():
-    dht = hivemind.DHT(start=True)
+def test_allgather(n_averagers=8, target_group_size=4):
+    dht_instances = launch_dht_instances(n_averagers)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             [torch.ones(1)],
             dht=dht,
-            target_group_size=4,
+            target_group_size=target_group_size,
             averaging_expiration=15,
             prefix="mygroup",
             initial_group_bits="000",
-            listen_on="127.0.0.1:*",
             start=True,
         )
-        for _ in range(8)
+        for dht in dht_instances
     ]
 
     futures = []
     for i, averager in enumerate(averagers):
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
 
-    assert len(set(repr(sorted(future.result())) for future in futures)) == 2
-
     reference_metadata = {
-        averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
+        averager.peer_id: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
     }
     for future in futures:
         gathered = future.result()
+        assert len(gathered) == target_group_size
+        for peer_id in gathered:
+            assert gathered[peer_id] == reference_metadata[peer_id]
 
-        assert len(gathered) == 4
-
-        for endpoint in gathered:
-            assert gathered[endpoint] == reference_metadata[endpoint]
-
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 def get_cost(vector_size, partitions, bandwidths):
@@ -351,7 +349,7 @@ def test_load_balancing():
 
 @pytest.mark.forked
 def test_too_few_peers():
-    dht = hivemind.DHT(start=True)
+    dht_instances = launch_dht_instances(4)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
@@ -361,23 +359,25 @@ def test_too_few_peers():
             request_timeout=0.5,
             prefix="mygroup",
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
-            listen_on="127.0.0.1:*",
             start=True,
         )
-        for i in range(4)
+        for i, dht in enumerate(dht_instances)
     ]
     step_futures = [averager.step(wait=False) for averager in averagers]
     for future in step_futures:
         assert len(future.result()) == 2
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
+@pytest.mark.skip(
+    reason="The current implementation of elasticity (multi-stage averaging when num_peers > ~3 * target_group_size) "
+    "is incorrect (TODO @justheuristic)"
+)
 @pytest.mark.forked
 def test_overcrowded(num_peers=16):
-    dht = hivemind.DHT(start=True)
+    dht_instances = launch_dht_instances(num_peers)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
@@ -387,18 +387,16 @@ def test_overcrowded(num_peers=16):
             request_timeout=0.5,
             prefix="mygroup",
             initial_group_bits="",
-            listen_on="127.0.0.1:*",
             start=True,
         )
-        for _ in range(num_peers)
+        for dht in dht_instances
     ]
-    for t in range(5):
+    for _ in range(5):
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for process in averagers + dht_instances:
+        process.shutdown()
 
 
 @pytest.mark.forked
@@ -417,27 +415,22 @@ def test_load_state_from_peers():
             num_calls += 1
             return super_metadata, super_tensors
 
-    dht_root = hivemind.DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-    dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
+    dht_instances = launch_dht_instances(2)
     averager1 = TestAverager(
         [torch.randn(3), torch.rand(5)],
-        dht=dht1,
+        dht=dht_instances[0],
         start=True,
         prefix="demo-run",
         target_group_size=2,
-        listen_on="127.0.0.1:*",
     )
 
-    dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
-    dht2.get("demo-run.all_averagers")
+    dht_instances[1].get("demo-run.all_averagers")
     averager2 = TestAverager(
         [torch.randn(3), torch.rand(5)],
-        dht=dht2,
+        dht=dht_instances[1],
         start=True,
         prefix="demo-run",
         target_group_size=2,
-        listen_on="127.0.0.1:*",
     )
 
     assert num_calls == 0
@@ -463,12 +456,19 @@ def test_load_state_from_peers():
     assert num_calls == 3
     assert got_metadata == super_metadata
 
+    for instance in [averager1, averager2] + dht_instances:
+        instance.shutdown()
+
 
 @pytest.mark.forked
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
     averager = hivemind.averaging.DecentralizedAverager(
-        [torch.randn(3)], dht=dht, start=True, prefix="test_prefix", target_group_size=2, listen_on="127.0.0.1:*"
+        [torch.randn(3)],
+        dht=dht,
+        start=True,
+        prefix="test_prefix",
+        target_group_size=2,
     )
     averager.set_group_bits("00101011101010")
     assert averager.get_group_bits() == "00101011101010"
@@ -478,11 +478,9 @@ def test_getset_bits():
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)
 
-    dht = hivemind.DHT(start=True)
+    dht_instances = launch_dht_instances(2)
     common_kwargs = {
-        "dht": dht,
         "start": True,
-        "listen_on": "127.0.0.1:*",
         "prefix": "demo-run",
         "target_group_size": 2,
     }
@@ -490,13 +488,23 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
     averager1 = hivemind.averaging.TrainingAverager(
-        opt1, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
+        opt1,
+        average_gradients=True,
+        average_parameters=True,
+        average_opt_statistics=["exp_avg_sq"],
+        dht=dht_instances[0],
+        **common_kwargs
     )
 
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
     averager2 = hivemind.averaging.TrainingAverager(
-        opt2, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
+        opt2,
+        average_gradients=True,
+        average_parameters=True,
+        average_opt_statistics=["exp_avg_sq"],
+        dht=dht_instances[1],
+        **common_kwargs
     )
     a = torch.ones(n_dims)
 
@@ -526,6 +534,5 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
         assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
 
-    averager1.shutdown()
-    averager2.shutdown()
-    dht.shutdown()
+    for instance in [averager1, averager2] + dht_instances:
+        instance.shutdown()

+ 4 - 4
tests/test_dht.py

@@ -7,12 +7,12 @@ from multiaddr import Multiaddr
 
 import hivemind
 
+from test_utils.dht_swarms import launch_dht_instances
+
 
 @pytest.mark.forked
 def test_get_store(n_peers=10):
-    peers = [hivemind.DHT(start=True)]
-    initial_peers = peers[0].get_visible_maddrs()
-    peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
+    peers = launch_dht_instances(n_peers)
 
     node1, node2 = random.sample(peers, 2)
     assert node1.store("key1", "value1", expiration_time=hivemind.get_dht_time() + 30)
@@ -104,5 +104,5 @@ async def test_dht_get_visible_maddrs():
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
 
-    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.id}")]
+    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()

+ 2 - 2
tests/test_dht_crypto.py

@@ -1,15 +1,15 @@
 import dataclasses
-import pickle
 import multiprocessing as mp
+import pickle
 
 import pytest
 
 import hivemind
-from hivemind.utils.timed_storage import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.node import DHTNode
 from hivemind.dht.validation import DHTRecord
 from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.timed_storage import get_dht_time
 
 
 def test_rsa_signature_validator():

+ 2 - 2
tests/test_dht_experts.py

@@ -6,11 +6,11 @@ import numpy as np
 import pytest
 
 import hivemind
-from hivemind.dht import DHTNode
 from hivemind import LOCALHOST
+from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.server import declare_experts, get_experts
-from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
+from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
 
 
 @pytest.mark.forked

+ 24 - 11
tests/test_dht_node.py

@@ -17,8 +17,8 @@ from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.p2p import P2P, PeerID
 from hivemind.utils.logging import get_logger
-from test_utils.dht_swarms import launch_swarm_in_separate_processes, launch_star_shaped_swarm
 
+from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
 
 logger = get_logger(__name__)
 
@@ -44,7 +44,7 @@ def run_protocol_listener(
     for peer_id in maddrs_to_peer_ids(initial_peers):
         loop.run_until_complete(protocol.call_ping(peer_id))
 
-    maddr_conn.send((p2p.id, visible_maddrs))
+    maddr_conn.send((p2p.peer_id, visible_maddrs))
 
     async def shutdown():
         await p2p.shutdown()
@@ -194,16 +194,27 @@ def test_empty_table():
 
 
 @pytest.mark.forked
-def test_dht_node():
+def test_dht_node(
+    n_peers: int = 20, n_sequential_peers: int = 5, parallel_rpc: int = 10, bucket_size: int = 5, num_replicas: int = 3
+):
     # step A: create a swarm of 50 dht nodes in separate processes
     #         (first 5 created sequentially, others created in parallel)
-    processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(n_peers=50, n_sequential_peers=5)
+
+    processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(
+        n_peers=n_peers, n_sequential_peers=n_sequential_peers, bucket_size=bucket_size, num_replicas=num_replicas
+    )
 
     # step B: run 51-st node in this process
     loop = asyncio.get_event_loop()
     initial_peers = random.choice(swarm_maddrs)
     me = loop.run_until_complete(
-        DHTNode.create(initial_peers=initial_peers, parallel_rpc=10, cache_refresh_before_expiry=False)
+        DHTNode.create(
+            initial_peers=initial_peers,
+            parallel_rpc=parallel_rpc,
+            bucket_size=bucket_size,
+            num_replicas=num_replicas,
+            cache_refresh_before_expiry=False,
+        )
     )
 
     # test 1: find self
@@ -223,7 +234,7 @@ def test_dht_node():
     jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
     all_node_ids = list(dht.values())
 
-    for _ in range(10):
+    for _ in range(20):
         query_id = DHTID.generate()
         k_nearest = random.randint(1, 10)
         exclude_self = random.random() > 0.5
@@ -249,10 +260,10 @@ def test_dht_node():
         jaccard_denominator += k_nearest
 
     accuracy = accuracy_numerator / accuracy_denominator
-    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 98-100%
+    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 90-100%
     jaccard_index = jaccard_numerator / jaccard_denominator
     logger.debug(f"Jaccard index (intersection over union): {jaccard_index}")  # should be 95-100%
-    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+    assert accuracy >= 0.8, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
     assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
 
     # test 4: find all nodes
@@ -275,7 +286,10 @@ def test_dht_node():
     initial_peers = random.choice(swarm_maddrs)
     that_guy = loop.run_until_complete(
         DHTNode.create(
-            initial_peers=initial_peers, parallel_rpc=10, cache_refresh_before_expiry=False, cache_locally=False
+            initial_peers=initial_peers,
+            parallel_rpc=parallel_rpc,
+            cache_refresh_before_expiry=False,
+            cache_locally=False,
         )
     )
 
@@ -310,10 +324,9 @@ def test_dht_node():
     assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
     assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
     assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
-    loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
 
     for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key))
+        value, time = loop.run_until_complete(node.get(upper_key, latest=True))
         assert isinstance(value, dict) and time == now + 50, (value, time)
         assert value[subkey1] == (123, now + 10)
         assert value[subkey2] == (567, now + 30)

+ 2 - 2
tests/test_dht_storage.py

@@ -1,8 +1,8 @@
 import time
 
-from hivemind.utils.timed_storage import get_dht_time
-from hivemind.dht.storage import DHTLocalStorage, DHTID, DictionaryDHTValue
+from hivemind.dht.storage import DHTID, DHTLocalStorage, DictionaryDHTValue
 from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.timed_storage import get_dht_time
 
 
 def test_store():

+ 1 - 1
tests/test_dht_validation.py

@@ -9,7 +9,7 @@ from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.dht.validation import DHTRecord, CompositeValidator
+from hivemind.dht.validation import CompositeValidator, DHTRecord
 
 
 class SchemaA(BaseModel):

+ 1 - 1
tests/test_expert_backend.py

@@ -6,7 +6,7 @@ import torch
 from torch.nn import Linear
 
 from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.moe.server.checkpoints import store_experts, load_experts
+from hivemind.moe.server.checkpoints import load_experts, store_experts
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 EXPERT_WEIGHT_UPDATES = 3

+ 1 - 2
tests/test_moe.py

@@ -4,9 +4,8 @@ import pytest
 import torch
 
 import hivemind
-from hivemind.moe.server import background_server, declare_experts
 from hivemind.moe.client.expert import DUMMY
-from hivemind.moe.server import layers
+from hivemind.moe.server import background_server, declare_experts, layers
 
 
 @pytest.mark.forked

+ 25 - 13
tests/test_p2p_daemon.py

@@ -9,8 +9,9 @@ import numpy as np
 import pytest
 from multiaddr import Multiaddr
 
-from hivemind.p2p import P2P, P2PHandlerError
+from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.proto import dht_pb2
+from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 
 
@@ -33,6 +34,17 @@ async def test_daemon_killed_on_del():
     assert not is_process_running(child_pid)
 
 
+@pytest.mark.asyncio
+async def test_startup_error_message():
+    with pytest.raises(P2PDaemonError, match=r"failed to connect to bootstrap peers"):
+        await P2P.create(
+            initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
+        )
+
+    with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
+        await P2P.create(startup_timeout=0.1)  # Test that startup_timeout works
+
+
 @pytest.mark.parametrize(
     "host_maddrs",
     [
@@ -92,7 +104,7 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
         except asyncio.CancelledError:
             nonlocal handler_cancelled
             handler_cancelled = True
-        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
+        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
 
     server_pid = server_primary._child.pid
     await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
@@ -104,12 +116,12 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
 
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
-    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
+    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
 
     if should_cancel:
         call_task = asyncio.create_task(
-            client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+            client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
         )
         await asyncio.sleep(0.25)
 
@@ -119,7 +131,7 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
         assert handler_cancelled
     else:
         actual_response = await client.call_protobuf_handler(
-            server.id, handle_name, ping_request, dht_pb2.PingResponse
+            server.peer_id, handle_name, ping_request, dht_pb2.PingResponse
         )
         assert actual_response == expected_response
         assert not handler_cancelled
@@ -147,10 +159,10 @@ async def test_call_protobuf_handler_error(handle_name="handle"):
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
 
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
 
     with pytest.raises(P2PHandlerError) as excinfo:
-        await client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        await client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
     assert "boom" in str(excinfo.value)
 
     await server.shutdown()
@@ -196,7 +208,7 @@ async def test_call_peer_single_process():
 
     await client.wait_for_at_least_n_peers(1)
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     await validate_square_stream(reader, writer)
 
     await server.shutdown()
@@ -213,7 +225,7 @@ async def run_server(handler_name, server_side, response_received):
 
     await server.add_binary_stream_handler(handler_name, handle_square_stream)
 
-    server_side.send(server.id)
+    server_side.send(server.peer_id)
     server_side.send(await server.get_visible_maddrs())
     while response_received.value == 0:
         await asyncio.sleep(0.5)
@@ -281,7 +293,7 @@ async def test_error_closes_connection():
 
     await client.wait_for_at_least_n_peers(1)
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     with closing(writer):
         await P2P.send_raw_data(b"raise_error", writer)
         with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
@@ -290,7 +302,7 @@ async def test_error_closes_connection():
     # Despite the handler raised an exception, the server did not crash and ready for next requests
     assert is_process_running(server_pid)
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     with closing(writer):
         await P2P.send_raw_data(b"behave_normally", writer)
         assert await P2P.receive_raw_data(reader) == b"okay"
@@ -309,7 +321,7 @@ async def test_handlers_on_different_replicas():
             await P2P.send_raw_data(key, writer)
 
     server_primary = await P2P.create()
-    server_id = server_primary.id
+    server_id = server_primary.peer_id
     await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
 
     server_replica1 = await replicate_if_needed(server_primary, True)

+ 2 - 1
tests/test_p2p_daemon_bindings.py

@@ -17,7 +17,8 @@ from hivemind.p2p.p2p_daemon_bindings.utils import (
     write_unsigned_varint,
 )
 from hivemind.proto import p2pd_pb2 as p2pd_pb
-from test_utils.p2p_daemon import make_p2pd_pair_ip4, connect_safe
+
+from test_utils.p2p_daemon import connect_safe, make_p2pd_pair_ip4
 
 
 def test_raise_if_failed_raises():

+ 17 - 13
tests/test_p2p_servicer.py

@@ -19,13 +19,13 @@ async def server_client():
 @pytest.mark.asyncio
 async def test_unary_unary(server_client):
     class ExampleServicer(ServicerBase):
-        async def rpc_square(self, request: test_pb2.TestRequest, _: P2PContext) -> test_pb2.TestResponse:
+        async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
             return test_pb2.TestResponse(number=request.number ** 2)
 
     server, client = server_client
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
-    stub = servicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
     assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
 
@@ -33,16 +33,18 @@ async def test_unary_unary(server_client):
 @pytest.mark.asyncio
 async def test_stream_unary(server_client):
     class ExampleServicer(ServicerBase):
-        async def rpc_sum(self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext) -> test_pb2.TestResponse:
+        async def rpc_sum(
+            self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
+        ) -> test_pb2.TestResponse:
             result = 0
-            async for item in request:
+            async for item in stream:
                 result += item.number
             return test_pb2.TestResponse(number=result)
 
     server, client = server_client
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
-    stub = servicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
@@ -55,7 +57,7 @@ async def test_stream_unary(server_client):
 async def test_unary_stream(server_client):
     class ExampleServicer(ServicerBase):
         async def rpc_count(
-            self, request: test_pb2.TestRequest, _: P2PContext
+            self, request: test_pb2.TestRequest, _context: P2PContext
         ) -> AsyncIterator[test_pb2.TestResponse]:
             for i in range(request.number):
                 yield test_pb2.TestResponse(number=i)
@@ -63,7 +65,7 @@ async def test_unary_stream(server_client):
     server, client = server_client
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
-    stub = servicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
     i = 0
     async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
@@ -76,16 +78,16 @@ async def test_unary_stream(server_client):
 async def test_stream_stream(server_client):
     class ExampleServicer(ServicerBase):
         async def rpc_powers(
-            self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext
+            self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
         ) -> AsyncIterator[test_pb2.TestResponse]:
-            async for item in request:
+            async for item in stream:
                 yield test_pb2.TestResponse(number=item.number ** 2)
                 yield test_pb2.TestResponse(number=item.number ** 3)
 
     server, client = server_client
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
-    stub = servicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
@@ -109,7 +111,9 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
     handler_cancelled = False
 
     class ExampleServicer(ServicerBase):
-        async def rpc_wait(self, request: test_pb2.TestRequest, _: P2PContext) -> AsyncIterator[test_pb2.TestResponse]:
+        async def rpc_wait(
+            self, request: test_pb2.TestRequest, _context: P2PContext
+        ) -> AsyncIterator[test_pb2.TestResponse]:
             try:
                 yield test_pb2.TestResponse(number=request.number + 1)
                 await asyncio.sleep(2)
@@ -124,7 +128,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
     await servicer.add_p2p_handlers(server)
 
     if cancel_reason == "close_connection":
-        _, reader, writer = await client.call_binary_stream_handler(server.id, "ExampleServicer.rpc_wait")
+        _, reader, writer = await client.call_binary_stream_handler(server.peer_id, "ExampleServicer.rpc_wait")
         await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
         await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
 
@@ -134,7 +138,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
 
         writer.close()
     elif cancel_reason == "close_generator":
-        stub = servicer.get_stub(client, server.id)
+        stub = ExampleServicer.get_stub(client, server.peer_id)
         iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
 
         assert await iter.__anext__() == test_pb2.TestResponse(number=11)

+ 2 - 2
tests/test_routing.py

@@ -1,10 +1,10 @@
-import random
 import heapq
 import operator
+import random
 from itertools import chain, zip_longest
 
 from hivemind import LOCALHOST
-from hivemind.dht.routing import RoutingTable, DHTID
+from hivemind.dht.routing import DHTID, RoutingTable
 
 
 def test_ids_basic():

+ 3 - 2
tests/test_training.py

@@ -10,7 +10,7 @@ from sklearn.datasets import load_digits
 from hivemind import DHT
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import background_server
-from hivemind.optim import DecentralizedSGD, DecentralizedAdam
+from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
 @pytest.mark.forked
@@ -169,6 +169,7 @@ def test_decentralized_optimizer_step():
     assert torch.allclose(param1, torch.full_like(param1, reference))
 
 
+@pytest.mark.skip(reason="Skipped until a more stable averager implementation is ready (TODO @justheuristic)")
 @pytest.mark.forked
 def test_decentralized_optimizer_averaging():
     dht_root = DHT(start=True)
@@ -200,7 +201,7 @@ def test_decentralized_optimizer_averaging():
     (param1.sum() + param2.sum()).backward()
 
     for _ in range(100):
-        time.sleep(0.01)
+        time.sleep(0.1)
         opt1.step()
         opt2.step()
         opt1.zero_grad()

+ 8 - 11
tests/test_util_modules.py

@@ -12,9 +12,9 @@ import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import MSGPackSerializer, ValueWithExpiration, HeapEntry, DHTExpiration
-from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
+from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 
 
@@ -256,8 +256,8 @@ def test_mpfuture_done_callback():
 
     assert future1.done() and not future1.cancelled()
     assert future2.done() and future2.cancelled()
-    events[0].wait(1)
-    events[1].wait(1)
+    for i in 0, 1, 4:
+        events[i].wait(1)
     assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
     assert not events[3].is_set()
 
@@ -266,15 +266,14 @@ def test_mpfuture_done_callback():
 
 
 @pytest.mark.forked
-@pytest.mark.parametrize("synchronize", [True, False])
-def test_many_futures(synchronize: bool):
+def test_many_futures():
     evt = mp.Event()
     receiver, sender = mp.Pipe()
-    main_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(1000)]
+    main_futures = [hivemind.MPFuture() for _ in range(1000)]
     assert len(hivemind.MPFuture._active_futures) == 1000
 
     def _run_peer():
-        fork_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(500)]
+        fork_futures = [hivemind.MPFuture() for _ in range(500)]
         assert len(hivemind.MPFuture._active_futures) == 500
 
         for i, future in enumerate(random.sample(main_futures, 300)):
@@ -299,8 +298,6 @@ def test_many_futures(synchronize: bool):
     p.start()
 
     some_fork_futures = receiver.recv()
-
-    time.sleep(0.5)  # wait for active futures to synchronize
     assert len(hivemind.MPFuture._active_futures) == 700
 
     for future in some_fork_futures:

+ 21 - 9
tests/test_utils/dht_swarms.py

@@ -7,17 +7,18 @@ from typing import Dict, List, Tuple
 
 from multiaddr import Multiaddr
 
+from hivemind.dht import DHT
 from hivemind.dht.node import DHTID, DHTNode
 from hivemind.p2p import PeerID
 
 
-def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
+def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue, **kwargs):
     if asyncio.get_event_loop().is_running():
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
         asyncio.set_event_loop(asyncio.new_event_loop())
     loop = asyncio.get_event_loop()
 
-    node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10))
+    node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, **kwargs))
     maddrs = loop.run_until_complete(node.get_visible_maddrs())
 
     info_queue.put((node.node_id, node.peer_id, maddrs))
@@ -31,7 +32,7 @@ def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
 
 
 def launch_swarm_in_separate_processes(
-    n_peers: int, n_sequential_peers: int
+    n_peers: int, n_sequential_peers: int, **kwargs
 ) -> Tuple[List[mp.Process], Dict[PeerID, DHTID], List[List[Multiaddr]]]:
     assert (
         n_sequential_peers < n_peers
@@ -47,19 +48,19 @@ def launch_swarm_in_separate_processes(
     for _ in range(n_sequential_peers):
         initial_peers = random.choice(swarm_maddrs) if swarm_maddrs else []
 
-        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
+        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), kwargs=kwargs, daemon=True)
         proc.start()
         processes.append(proc)
 
-        node_id, peer_endpoint, peer_maddrs = info_queue.get()
-        dht[peer_endpoint] = node_id
+        node_id, peer_id, peer_maddrs = info_queue.get()
+        dht[peer_id] = node_id
         swarm_maddrs.append(peer_maddrs)
 
     def collect_info():
         while True:
-            node_id, peer_endpoint, peer_maddrs = info_queue.get()
+            node_id, peer_id, peer_maddrs = info_queue.get()
             with info_lock:
-                dht[peer_endpoint] = node_id
+                dht[peer_id] = node_id
                 swarm_maddrs.append(peer_maddrs)
 
                 if len(dht) == n_peers:
@@ -72,7 +73,7 @@ def launch_swarm_in_separate_processes(
         with info_lock:
             initial_peers = random.choice(swarm_maddrs)
 
-        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
+        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), kwargs=kwargs, daemon=True)
         proc.start()
         processes.append(proc)
 
@@ -86,3 +87,14 @@ async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
     initial_peers = await nodes[0].get_visible_maddrs()
     nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs) for _ in range(n_peers - 1)])
     return nodes
+
+
+def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
+    dhts = [DHT(start=True, **kwargs)]
+    initial_peers = dhts[0].get_visible_maddrs()
+
+    dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
+    for instance in dhts[1:]:
+        instance.ready.wait()
+
+    return dhts

+ 5 - 6
tests/test_utils/p2p_daemon.py

@@ -6,14 +6,13 @@ import time
 import uuid
 from contextlib import asynccontextmanager
 from typing import NamedTuple
-from pkg_resources import resource_filename
 
 from multiaddr import Multiaddr, protocols
+from pkg_resources import resource_filename
 
-from hivemind import find_open_port
+from hivemind import get_free_port
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 
-
 TIMEOUT_DURATION = 30  # seconds
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")
 
@@ -58,7 +57,7 @@ class Daemon:
 
     def _run(self):
         cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
-        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
+        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{get_free_port()}"]
         if self.enable_connmgr:
             cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
         if self.enable_dht:
@@ -130,8 +129,8 @@ async def make_p2pd_pair_unix(enable_control, enable_connmgr, enable_dht, enable
 
 @asynccontextmanager
 async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
-    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
-    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{get_free_port()}")
+    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{get_free_port()}")
     async with _make_p2pd_pair(
         control_maddr=control_maddr,
         listen_maddr=listen_maddr,