瀏覽代碼

Convert DHT to libp2p backend (#296)

This PR changes DHT to operate over the p2p daemon (instead of gRPC) using libp2p PeerIDs and Multiaddrs (instead of raw IP:port endpoints).

Co-authored-by: Ilya Kobelev <ilya.kobellev@gmail.com>
Alexander Borzunov 4 年之前
父節點
當前提交
0be1512c74

+ 4 - 4
benchmarks/benchmark_averaging.py

@@ -34,7 +34,9 @@ def sample_tensors(hid_size, num_layers):
 def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
                         averaging_expiration: float, request_timeout: float, round_timeout: float,
                         hid_size: int, num_layers: int, spawn_dtime: float):
-    dht_root = hivemind.DHT(listen_on=f'{LOCALHOST}:*', start=True)
+    dht_root = hivemind.DHT(start=True)
+    initial_peers = dht_root.get_visible_maddrs()
+
     num_groups = 2 ** int(round(math.log2(num_peers / target_group_size)))
     nbits = int(round(math.log2(num_groups)))
     peer_tensors = [sample_tensors(hid_size, num_layers)
@@ -45,9 +47,7 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
 
     def run_averager(index):
         nonlocal successful_steps, total_steps, lock_stats
-        dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
-                           initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
-                           start=True)
+        dht = hivemind.DHT(initial_peers=initial_peers, start=True)
         initial_bits = bin(index % num_groups)[2:].rjust(nbits, '0')
         averager = hivemind.averaging.DecentralizedAverager(
             peer_tensors[i], dht, prefix='my_tensor', initial_group_bits=initial_bits, listen_on=f"{LOCALHOST}:*",

+ 3 - 3
benchmarks/benchmark_dht.py

@@ -23,9 +23,9 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     logger.info("Creating peers...")
     peers = []
     for _ in trange(num_peers):
-        neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
-        peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout,
-                            listen_on=f'0.0.0.0:*')
+        neighbors = sum([peer.get_visible_maddrs()
+                         for peer in random.sample(peers, min(initial_peers, len(peers)))], [])
+        peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
         peers.append(peer)
 
     store_peer, get_peer = peers[-2:]

+ 40 - 18
examples/albert/README.md

@@ -12,14 +12,12 @@ This tutorial will walk you through the steps to set up collaborative training w
 ## Running an experiment
 - Run the first DHT peer to welcome trainers and record training statistics (e.g. loss, performance):
    - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics; If you're unfamiliar with Weights & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
-   - Run `python run_training_monitor.py --dht_listen_on '[::]:*' --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
+   - Run `python run_training_monitor.py --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
    - `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.` due to naming conventions.
    - `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the same project name.
-   - This peer will run a DHT node on a certain IP/port (`Running DHT root at ...`). You will need this address for next steps
 ```
-+ python run_training_monitor.py --dht_listen_on '[::]:*' --experiment_prefix my-albert-v1 --wandb_project Demo-run
-[2021/06/17 16:26:35.931][WARN][root.<module>:140] No address specified. Attempting to infer address from DNS.
-[2021/06/17 16:26:36.083][INFO][root.<module>:149] Running DHT root at 193.106.95.184:38319
+$ python run_training_monitor.py --experiment_prefix my-albert-v1 --wandb_project Demo-run
+[2021/06/17 16:26:36.083][INFO][root.log_visible_maddrs:42] Running a DHT peer. To connect other peers to this one, use --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
 wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
 wandb: Tracking run with wandb version 0.10.32
 wandb: Syncing run dry-mountain-2
@@ -30,21 +28,37 @@ wandb: Run `wandb offline` to turn off syncing.
 [2021/04/19 02:26:41.064][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
 [2021/04/19 02:26:44.068][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
 ...
-[2021/04/19 02:37:37.246][INFO][root.<module>:74] 11.05164
-[2021/04/19 02:39:37.441][INFO][root.<module>:74] 11.03771
-[2021/04/19 02:40:37.541][INFO][root.<module>:74] 11.02886
+[2021/04/19 02:37:37.246][INFO][__main__.<module>:194] Step #1  loss = 11.05164
+[2021/04/19 02:39:37.441][INFO][__main__.<module>:194] Step #2  loss = 11.03771
+[2021/04/19 02:40:37.541][INFO][__main__.<module>:194] Step #3  loss = 11.02886
 ```
 
-- To join a collaboration with a GPU trainer, 
+- To join a collaboration with a GPU trainer,
   - install the same dependencies (minus the `wandb` and `whatsmyip`), download the data and unpack it to the experiment folder,
   - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
   - run:
-```shell
-python run_trainer.py \
- --experiment_prefix SAME_AS_IN_RUN_TRAINING_MONITOR --initial_peers ONE_OR_MORE_PEERS --seed 42 \
- --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
-```
-Here, `ONE_OR_MORE_PEERS` stands for either your coordinator endpoint (e.g. `123.123.123.123:1337`), an endpoint of any pre-existing trainer or multiple endpoints for stability. See tips & tricks section below for more information on setting up collaborative training.
+    ```bash
+    python run_trainer.py \
+    --experiment_prefix SAME_AS_IN_RUN_TRAINING_MONITOR --initial_peers ONE_OR_MORE_PEERS --seed 42 \
+    --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
+    ```
+
+    Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing trainers)
+    collected from the first lines of their terminal output. For the example above, the multiaddresses would be:
+    ```
+    --initial_peers /ip4/8.8.8.8/tcp/1337/p2p/XXXX /ip4/8.8.8.8/udp/31337/quic/p2p/XXXX
+    ```
+
+    __Note:__ a [multiaddress](https://docs.libp2p.io/concepts/addressing/) is a format for encoding multiple layers of addressing information
+    that supports a number of different protocols. In hivemind, we typically operate with multiaddresses
+    that contain a [libp2p](https://libp2p.io/) peer ID (e.g. `/p2p/XXXX`) together with the information about how to reach it
+    (e.g. the IPv4 address and TCP port `/ip4/8.8.8.8/tcp/31337` or
+    the information about a relay used for [NAT traversal](https://docs.libp2p.io/concepts/nat/)).
+
+    You may need to change the IP address to a publicly visible one if some of the initial peers are located behind NAT.
+    If you have any trouble doing this, consider the ["Using IPFS"](#using-ipfs) section.
+
+See the ["Tips and tricks"](#tips-and-tricks) section for more information on setting up collaborative training.
 
 As the peer begins training, it will periodically report training logs in the following form:
 ```
@@ -61,7 +75,10 @@ As the peer begins training, it will periodically report training logs in the fo
 __Sanity check:__ a healthy peer will periodically report `Averaged tensors successfully with [N > 1]` peers.
 
 For convenience, you can view (and share!) the learning curves of your collaborative experiments in wandb:
-![image](https://user-images.githubusercontent.com/3491902/115177859-bed5e100-a0d8-11eb-82bc-55d1b12d335d.png)
+
+<p align="center">
+  <img src="https://user-images.githubusercontent.com/3491902/115177859-bed5e100-a0d8-11eb-82bc-55d1b12d335d.png">
+</p>
 
 
 ## Tips and tricks
@@ -70,7 +87,7 @@ Finally, we provide best practices for running collaborative experiments of diff
 
 ### Hosting the data
 For small experiments (3-16 peers, <1GB data), you can use a free-tier file hosting that has a convenient way to [download with curl/wget](https://superuser.com/questions/470664/how-to-download-dropbox-files-using-wget-command). However, these services are not meant for high load and could ban you for generating too much traffic. If you want to scale up, you could either use an S3-like storage from [any](https://aws.amazon.com/s3/) [cloud](https://cloud.google.com/storage) [provider](https://cloud.google.com/storage) or host the data [yourself]((https://gist.github.com/willurd/5720255)). Large data files (>5GB) will take long to download; we recommend splitting them into chunks and implementing a custom dataloader that can load chunks on the fly. Finally, the most _comme il faut_ solution to sharing large datasets is to use [academic torrents](https://academictorrents.com/).
- 
+
 ### run_training_monitor.py
 This peer exists solely to welcome other peers onto the DHT and track learning progress. It requires neither GPU nor high bandwidth, the only prerequisite is that coordinator should have high uptime. If no high uptime server is available, one can also run multiple coordinators on different servers and list all of them as `--initial_peers`. The system will stay up as long as at least one coordinator is available. For short- to mid-term experiments you can host coordinator on a [free-tier VM](https://www.quora.com/Are-there-any-free-online-virtual-machines).
 
@@ -84,7 +101,7 @@ There are awesome services like [Google Colab](https://colab.research.google.com
   - you can create starter notebooks to make it more convenient for collaborators to join your training run ([example](https://colab.research.google.com/gist/yhn112/e858cb841c73879d8ef98a84e03b43e7/collaborative-training-v0-10.ipynb)). Ideally, joining collaboration should take at most a couple of clicks.
 
 Here's an example of a full trainer script for Google Colab:
-```
+```bash
 !pip install transformers datasets sentencepiece torch_optimizer==0.1.0
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !curl -L YOUR_HOSTED_DATA | tar xzf -     # example: https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103.tar.gz
@@ -94,3 +111,8 @@ Here's an example of a full trainer script for Google Colab:
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \
  --experiment_prefix EXPERIMENT_NAME_HERE --seed 42
 ```
+
+### Using IPFS
+If the initial peers for your experiment are located behind NAT and/or you have any trouble with figuring out their public IP addresses and ports, you can set up hivemind to use the [IPFS](https://ipfs.io) network to find the route to your peers automatically. To do this, you should specify the `--use_ipfs` option on all peers (both training monitors and trainers) you are starting.
+
+After that, it is enough to provide only a [libp2p](https://libp2p.io/) peer ID (e.g. `/p2p/XXXX`) for each initial peer. No other information (like IP addresses or TCP/UDP ports) is required.

+ 20 - 9
examples/albert/arguments.py

@@ -1,5 +1,5 @@
-from typing import Optional, List
 from dataclasses import dataclass, field
+from typing import Optional, List
 
 from transformers import TrainingArguments
 
@@ -11,11 +11,26 @@ class BaseTrainingArguments:
     )
     initial_peers: List[str] = field(
         default_factory=list,
-        metadata={"help": "One or more peers (comma-separated) that will welcome you into the collaboration"}
+        metadata={"help":
+            "Multiaddrs of the peers that will welcome you into the existing collaboration. "
+            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"}
     )
-    dht_listen_on: str = field(
-        default="[::]:*",
-        metadata={"help": "Network interface used for incoming DHT communication. Default: all ipv6"}
+    use_ipfs: bool = field(
+        default=False,
+        metadata={"help":
+            "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of the multiaddrs "
+            "for the initial_peers (no need to specify a particular IPv4/IPv6 host and port)"}
+    )
+    host_maddrs: List[str] = field(
+        default_factory=lambda: ['/ip4/0.0.0.0/tcp/0', '/ip4/0.0.0.0/udp/0/quic'],
+        metadata={"help":
+            "Multiaddrs to listen for external connections from other p2p instances. "
+            "Defaults to all IPv4 interfaces with TCP and QUIC (over UDP) protocols: "
+            "/ip4/0.0.0.0/tcp/0 /ip4/0.0.0.0/udp/0/quic"}
+    )
+    announce_maddrs: List[str] = field(
+        default_factory=list,
+        metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"}
     )
 
 
@@ -97,10 +112,6 @@ class CollaborationArguments(AveragerArguments, CollaborativeOptimizerArguments,
         default=600,
         metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )
-    endpoint: Optional[str] = field(
-        default=None,
-        metadata={"help": "This node's IP for inbound connections, used when running from behind a proxy"}
-    )
 
 
 @dataclass

+ 11 - 8
examples/albert/run_trainer.py

@@ -18,8 +18,8 @@ from transformers.trainer import Trainer
 from torch_optimizer import Lamb
 
 import hivemind
+import utils
 from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments
-import metrics_utils
 
 
 logger = logging.getLogger(__name__)
@@ -130,7 +130,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
                 self.total_samples_processed += self.samples
                 samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
-                statistics = metrics_utils.LocalMetrics(
+                statistics = utils.LocalMetrics(
                     step=self.collaborative_optimizer.local_step,
                     samples_per_second=samples_per_second,
                     samples_accumulated=self.samples,
@@ -219,13 +219,16 @@ def main():
 
     opt, scheduler = get_optimizer_and_scheduler(training_args, model)
 
-    validators, local_public_key = metrics_utils.make_validators(
+    validators, local_public_key = utils.make_validators(
         collaboration_args_dict['experiment_prefix'])
-    dht = hivemind.DHT(
-        start=True, initial_peers=collaboration_args_dict.pop('initial_peers'),
-        listen=not collaboration_args_dict['client_mode'],
-        listen_on=collaboration_args_dict.pop('dht_listen_on'),
-        endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
+    dht = hivemind.DHT(start=True,
+                       initial_peers=collaboration_args_dict.pop('initial_peers'),
+                       listen=not collaboration_args_dict['client_mode'],
+                       record_validators=validators,
+                       use_ipfs=collaboration_args_dict['use_ipfs'],
+                       host_maddrs=collaboration_args_dict.pop('host_maddrs'),
+                       announce_maddrs=collaboration_args_dict.pop('announce_maddrs'))
+    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args_dict.pop('use_ipfs'))
 
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
     if torch.cuda.device_count() != 0:

+ 28 - 23
examples/albert/run_training_monitor.py

@@ -1,23 +1,24 @@
 #!/usr/bin/env python
 
-from dataclasses import dataclass, field, asdict
-import subprocess
+import logging
 import time
+from dataclasses import asdict, dataclass, field
+from ipaddress import ip_address
 from typing import Optional
 
 import torch
+import wandb
 from torch_optimizer import Lamb
 from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
-import wandb
-from whatsmyip.providers import GoogleDnsProvider
 from whatsmyip.ip import get_ip
+from whatsmyip.providers import GoogleDnsProvider
 
-from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
 import hivemind
-from hivemind.utils.logging import get_logger
-import metrics_utils
+import utils
+from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
+
 
-logger = get_logger(__name__)
+logger = logging.getLogger(__name__)
 
 
 @dataclass
@@ -27,10 +28,10 @@ class CoordinatorArguments(BaseTrainingArguments):
     new workers still can join the collaboration via alive initial peers' addresses.
     Specify initial_peers argument for that purpose
     """
-    address: Optional[str] = field(
-        default=None,
-        metadata={"help": "This machine's network address. Use public IP for global experiments, "
-                          "local address for private runs"}
+    use_google_dns: bool = field(
+        default=False,
+        metadata={"help":
+            "Use Google DNS to determine the public IP address of this machine (and add it to --announce_maddrs)"}
     )
     refresh_period: float = field(
         default=30,
@@ -141,17 +142,21 @@ if __name__ == '__main__':
     parser = HfArgumentParser((CoordinatorArguments, CollaborativeOptimizerArguments, AveragerArguments))
     coordinator_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
 
-    if coordinator_args.address is None:
-        logger.warning("No address specified. Attempting to infer address from DNS.")
-        coordinator_args.address = get_ip(GoogleDnsProvider)
+    if coordinator_args.use_google_dns:
+        address = get_ip(GoogleDnsProvider)
+        logger.info(f"Received public IP address of this machine from Google DNS: {address}")
+        version = ip_address(address).version
+        coordinator_args.announce_maddrs += [f'/ip{version}/{address}/tcp/0', f'/ip{version}/{address}/udp/0/quic']
 
     experiment_prefix = coordinator_args.experiment_prefix
-    validators, local_public_key = metrics_utils.make_validators(experiment_prefix)
-    dht = hivemind.DHT(start=True, listen_on=coordinator_args.dht_listen_on,
-                       endpoint=f"{coordinator_args.address}:*", initial_peers=coordinator_args.initial_peers,
-                       record_validators=validators)
-
-    logger.info(f"Running DHT root at {coordinator_args.address}:{dht.port}")
+    validators, local_public_key = utils.make_validators(experiment_prefix)
+    dht = hivemind.DHT(start=True,
+                       initial_peers=coordinator_args.initial_peers,
+                       record_validators=validators,
+                       use_ipfs=coordinator_args.use_ipfs,
+                       host_maddrs=coordinator_args.host_maddrs,
+                       announce_maddrs=coordinator_args.announce_maddrs)
+    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=coordinator_args.use_ipfs)
 
     if coordinator_args.wandb_project is not None:
         wandb.init(project=coordinator_args.wandb_project)
@@ -164,7 +169,7 @@ if __name__ == '__main__':
         metrics_dict = dht.get(experiment_prefix + '_metrics', latest=True)
         if metrics_dict is not None:
             metrics_dict = metrics_dict.value
-            metrics = [metrics_utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
+            metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
                        for peer in metrics_dict]
             latest_step = max(item.step for item in metrics)
             if latest_step != current_step:
@@ -186,6 +191,7 @@ if __name__ == '__main__':
                     num_samples += item.samples_accumulated
                     sum_mini_steps += item.mini_steps
                 current_loss = sum_loss / sum_mini_steps
+                logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
 
                 if coordinator_args.wandb_project is not None:
                     wandb.log({
@@ -200,6 +206,5 @@ if __name__ == '__main__':
                         checkpoint_handler.save_state(current_step)
                         if checkpoint_handler.is_time_to_upload():
                             checkpoint_handler.upload_checkpoint(current_loss)
-                    logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
         logger.debug("Peer is still alive...")
         time.sleep(coordinator_args.refresh_period)

+ 23 - 1
examples/albert/metrics_utils.py → examples/albert/utils.py

@@ -1,9 +1,15 @@
 from typing import Dict, List, Tuple
 
+from multiaddr import Multiaddr
+from pydantic import BaseModel, StrictFloat, confloat, conint
+
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
-from pydantic import BaseModel, StrictFloat, confloat, conint
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
 
 
 class LocalMetrics(BaseModel):
@@ -23,3 +29,19 @@ def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase],
     validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix),
                   signature_validator]
     return validators, signature_validator.local_public_key
+
+
+class TextStyle:
+    BOLD = '\033[1m'
+    BLUE = '\033[34m'
+    RESET = '\033[0m'
+
+
+def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
+    if only_p2p:
+        unique_addrs = {addr['p2p'] for addr in visible_maddrs}
+        initial_peers_str = ' '.join(f'/p2p/{addr}' for addr in unique_addrs)
+    else:
+        initial_peers_str = ' '.join(str(addr) for addr in visible_maddrs)
+    logger.info(f"Running a DHT peer. To connect other peers to this one, use "
+                f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}")

+ 23 - 3
hivemind/averaging/averager.py

@@ -12,6 +12,7 @@ import uuid
 import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
+from ipaddress import ip_address
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
 import grpc
@@ -30,6 +31,7 @@ from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescripto
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
+from hivemind.utils.networking import choose_ip_address, strip_port
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
 
@@ -68,6 +70,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
             if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
     :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+    :param announced_host: visible IP address the averager will announce for external connections from other peers.
+          If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
     :param kwargs: extra parameters forwarded to grpc.aio.server
@@ -102,7 +106,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
-                 channel_options: Optional[Sequence[Tuple[str, Any]]] = None,
+                 announced_host: Optional[str] = None,
+                 channel_options: Sequence[Tuple[str, Any]] = (),
                  shutdown_timeout: float = 5, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
@@ -122,6 +127,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         else:
             self.mode = AveragingMode.NODE
 
+        if announced_host is None:
+            announced_host = self._choose_announced_host()
+        self.announced_host = announced_host
         self.channel_options = channel_options
         self.daemon = daemon
 
@@ -163,6 +171,17 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if start:
             self.run_in_background(await_ready=True)
 
+    def _choose_announced_host(self) -> Hostname:
+        announced_host = strip_port(self.listen_on).strip('[]')  # Stripping square brackets for IPv6
+        if ip_address(announced_host) not in [ip_address('0.0.0.0'), ip_address('::')]:
+            return announced_host
+
+        maddrs = self.dht.get_visible_maddrs()
+        announced_host = choose_ip_address(maddrs)
+        logger.info(f'Choosing IP {announced_host} as endpoint for DecentralizedAverager '
+                    f'from visible multiaddrs {maddrs}')
+        return announced_host
+
     @property
     def port(self) -> Optional[Port]:
         return self._port.value if self._port.value != 0 else None
@@ -183,7 +202,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     def endpoint(self) -> Optional[Endpoint]:
         if self.listen and self._averager_endpoint is None:
             assert self.port is not None, "Averager is not running yet"
-            self._averager_endpoint = f"{self.dht.get_visible_address()}:{self.port}"
+            self._averager_endpoint = f"{self.announced_host}:{self.port}"
             logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
         return self._averager_endpoint
 
@@ -499,7 +518,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     logger.info(f"Downloading parameters from peer {peer}")
                     stream = None
                     try:
-                        stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+                        stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True,
+                                                     options=self.channel_options)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         async for message in stream:

+ 39 - 75
hivemind/dht/__init__.py

@@ -15,18 +15,19 @@ The code is organized as follows:
 from __future__ import annotations
 
 import asyncio
-import ctypes
 import multiprocessing as mp
 import os
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
-from typing import Iterable, Optional, Sequence, Union, Callable, Awaitable, TypeVar
+from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
 
-from hivemind.dht.node import DHTNode, DHTID
-from hivemind.dht.routing import DHTValue, DHTKey, Subkey
+from multiaddr import Multiaddr
+
+from hivemind.dht.node import DHTID, DHTNode
+from hivemind.dht.routing import DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
-from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, DHTExpiration
-from hivemind.utils.networking import Hostname, Endpoint, strip_port
+from hivemind.p2p import P2P
+from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 
 logger = get_logger(__name__)
 
@@ -39,27 +40,39 @@ class DHT(mp.Process):
     * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
     * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
 
-    :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
-    :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
+    :param p2p: instance of hivemind.p2p.P2P that will be used for communication.
+      If None, DHTNode will create and manage its own P2P instance with given initial_peers and
+      parameters from ``kwargs``
+    :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
-        (but no more than one per key)
+      (but no more than one per key)
     :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
+    :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
+      The validators will be combined using the CompositeValidator class. It merges them when possible
+      (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
-    :param kwargs: any other params will be forwarded to DHTNode upon creation
+    :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
     """
     _node: DHTNode
 
-    def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
-                 daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 record_validators: Iterable[RecordValidatorBase] = (), shutdown_timeout: float = 3, **kwargs):
+    def __init__(self, p2p: Optional[P2P] = None,
+                 initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
+                 *, start: bool, daemon: bool = True, max_workers: Optional[int] = None,
+                 record_validators: Iterable[RecordValidatorBase] = (),
+                 shutdown_timeout: float = 3, **kwargs):
         super().__init__()
-        assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
-        self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
-        self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
+
+        self.p2p = p2p
+        if not (initial_peers is None or (isinstance(initial_peers, Sequence) and
+                                          all(isinstance(item, (Multiaddr, str)) for item in initial_peers))):
+            raise TypeError('initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]')
+        self.initial_peers = initial_peers
+        self.kwargs = kwargs
+        self.max_workers = max_workers
+
         self._record_validator = CompositeValidator(record_validators)
-        self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
         self.shutdown_timeout = shutdown_timeout
         self.ready = mp.Event()
@@ -74,11 +87,9 @@ class DHT(mp.Process):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
             async def _run():
                 self._node = await DHTNode.create(
-                    initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
+                    p2p=self.p2p, initial_peers=self.initial_peers,
                     num_workers=self.max_workers or 1, record_validator=self._record_validator,
                     **self.kwargs)
-                if self._node.port is not None:
-                    self._port.value = self._node.port
                 self.ready.set()
 
                 while True:
@@ -108,16 +119,10 @@ class DHT(mp.Process):
             if self.is_alive():
                 logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
                 self.terminate()
-        else:
-            logger.warning("DHT shutdown has no effect: dht process is already not alive")
 
     async def _shutdown(self):
         await self._node.shutdown()
 
-    @property
-    def port(self) -> Optional[int]:
-        return self._port.value if self._port.value != 0 else None
-
     def get(self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
             ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
         """
@@ -216,58 +221,17 @@ class DHT(mp.Process):
             self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
         node.protocol.record_validator.extend(record_validators)
 
-    def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
+    def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
-        Get this machine's visible address by requesting other peers or using pre-specified network addresses.
-        If no parameters are specified, this function will check for manual endpoint; if unavailable, ask 1 random peer.
+        Get multiaddrs of the current DHT node that should be accessible by other peers.
 
-        :param num_peers: if specified, ask multiple peers and check that they perceive the same endpoint
-        :param peers: if specified, ask these exact peers instead of choosing random known peers
-        :note: if this node has no known peers in routing table, one must specify :peers: manually
+        :param latest: ask the P2P daemon to refresh the visible multiaddrs
         """
-        assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
-        assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
-        future = MPFuture()
-        self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=future)))
-        return future.result()
-
-    async def _get_visible_address(self, num_peers: Optional[int], peers: Sequence[Endpoint],
-                                   future: Optional[MPFuture]):
-        if not peers and (num_peers or not self._node.protocol.node_info.endpoint):
-            # if we can't resolve the endpoint locally, ask one random peer
-            peers_and_endpoints = self._node.protocol.routing_table.get_nearest_neighbors(
-                DHTID.generate(), num_peers or 1, exclude=self._node.node_id)
-            peers = tuple(endpoint for node_id, endpoint in peers_and_endpoints)
-
-        chosen_address = None
-        if peers:
-            possible_endpoints: Sequence[Optional[Endpoint]] = await asyncio.gather(*(
-                self._node.protocol.get_outgoing_request_endpoint(peer) for peer in peers))
-
-            for endpoint in possible_endpoints:
-                if endpoint is None:
-                    continue
-                address = strip_port(endpoint)
-                if chosen_address is not None and address != chosen_address:
-                    logger.warning("At least two peers returned different visible addresses for this node:"
-                                   f"{address} and {chosen_address} (keeping the former one)")
-                else:
-                    chosen_address = address
-
-            if chosen_address is None:
-                logger.warning(f"None of the selected peers responded with an address ({peers})")
-
-        if self._node.protocol.node_info.endpoint:
-            address = strip_port(self._node.protocol.node_info.endpoint)
-            if chosen_address is not None and address != chosen_address:
-                logger.warning(f"Node was manually given endpoint {address} , but other peers report {chosen_address}")
-            chosen_address = chosen_address or address
-
-        if chosen_address:
-            future.set_result(chosen_address)
-        else:
-            future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
-                                            f" Please ensure the node is connected or specify peers=... manually."))
+
+        return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
+
+    async def _get_visible_maddrs(self, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
+        return await node.get_visible_maddrs(latest=latest)
 
     def __del__(self):
         if self._parent_pid == os.getpid() and self.is_alive():

+ 53 - 22
hivemind/dht/node.py

@@ -6,8 +6,10 @@ import random
 from collections import defaultdict, Counter
 from dataclasses import dataclass, field
 from functools import partial
-from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
+from typing import (Any, Awaitable, Callable, Collection, DefaultDict, Dict, List, Optional, Sequence, Set, Tuple,
+                    Type, Union)
 
+from multiaddr import Multiaddr
 from sortedcontainers import SortedSet
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
@@ -15,8 +17,10 @@ from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase, DHTExpiration
-from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration
+from hivemind.p2p import P2P, PeerID as Endpoint
+from hivemind.utils import MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils.auth import AuthorizerBase
+from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithExpiration
 
 logger = get_logger(__name__)
 
@@ -66,7 +70,7 @@ class DHTNode:
 
     """
     # fmt:off
-    node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
+    node_id: DHTID; is_alive: bool; endpoint: Endpoint; num_replicas: int; num_workers: int; protocol: DHTProtocol
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
@@ -75,18 +79,26 @@ class DHTNode:
 
     @classmethod
     async def create(
-            cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
+            cls,
+            p2p: Optional[P2P] = None,
+            node_id: Optional[DHTID] = None,
+            initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
             bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 3, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
             cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
             blacklist_time: float = 5.0, backoff_rate: float = 2.0,
-            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", endpoint: Optional[Endpoint] = None,
+            listen: bool = True,
             record_validator: Optional[RecordValidatorBase] = None,
+            authorizer: Optional[AuthorizerBase] = None,
             validate: bool = True, strict: bool = True, **kwargs) -> DHTNode:
         """
-        :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
-        :param initial_peers: connects to these peers to populate routing table, defaults to no peers
+        :param p2p: instance of hivemind.p2p.P2P that will be used for communication.
+          If None, DHTNode will create and manage its own P2P instance with given initial_peers and
+          parameters from ``kwargs``
+        :param node_id: current node's DHTID for hivemind.dht, determines which keys it will store locally,
+          defaults to random id
+        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
         :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
           either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
           Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
@@ -115,11 +127,11 @@ class DHTNode:
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
           if False, this node will refuse any incoming request, effectively being only a "client"
-        :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
-        :param endpoint: if specified, this is peer's preferred public endpoint. Otherwise let peers infer endpoint
-        :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
-          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
-        :param kwargs: extra parameters used in grpc.aio.server
+        :param record_validator: instance of RecordValidatorBase used for signing and validating stored records
+        :param authorizer: instance of AuthorizerBase used for signing and validating requests and response
+          for a given authorization protocol
+        :param kwargs: extra parameters for an internally created instance of hivemind.p2p.P2P.
+          Should be empty if the P2P instance is provided in the constructor
         """
         self = cls(_initialized_with_create=True)
         self.node_id = node_id if node_id is not None else DHTID.generate()
@@ -138,12 +150,28 @@ class DHTNode:
         self.cache_refresh_evt = asyncio.Event()
         self.cache_refresh_task = None
 
-        self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
-                                                 parallel_rpc, cache_size, listen, listen_on, endpoint,
-                                                 record_validator, **kwargs)
-        self.port = self.protocol.port
+        if p2p is None:
+            if not kwargs.get('use_ipfs'):
+                kwargs['initial_peers'] = initial_peers
+            p2p = await P2P.create(**kwargs)
+            self._should_shutdown_p2p = True
+        else:
+            if kwargs:
+                raise ValueError(
+                    f'**kwargs in DHTNode.create() should be empty if hivemind.p2p.P2P instance is provided'
+                    f'in the constructor. Got kwargs = {kwargs} instead. '
+                    f'You may have a typo in a DHTNode.create() parameter name')
+            self._should_shutdown_p2p = False
+        self.p2p = p2p
+
+        self.protocol = await DHTProtocol.create(
+            p2p, self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
+            parallel_rpc, cache_size, listen, record_validator, authorizer)
+        self.endpoint = p2p.id
 
         if initial_peers:
+            initial_peers = {Endpoint.from_base58(Multiaddr(item)['p2p']) for item in initial_peers}
+
             # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             start_time = get_dht_time()
@@ -182,11 +210,11 @@ class DHTNode:
         assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
         super().__init__()
 
-    async def shutdown(self, timeout=None):
+    async def shutdown(self):
         """ Process existing requests, close all connections and stop the server """
         self.is_alive = False
-        if self.protocol.server:
-            await self.protocol.shutdown(timeout)
+        if self._should_shutdown_p2p:
+            await self.p2p.shutdown()
 
     async def find_nearest_nodes(
             self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
@@ -234,7 +262,7 @@ class DHTNode:
         for query, nearest_nodes in nearest_nodes_per_query.items():
             if not exclude_self:
                 nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
-                node_to_endpoint[self.node_id] = f"{LOCALHOST}:{self.port}"
+                node_to_endpoint[self.node_id] = self.endpoint
             nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
         return nearest_nodes_with_endpoints
 
@@ -626,6 +654,9 @@ class DHTNode:
 
             await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))
 
+    async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
+        return await self.protocol.p2p.get_visible_maddrs(latest=latest)
+
 
 @dataclass(init=True, repr=True, frozen=False, order=False)
 class _SearchState:
@@ -636,7 +667,7 @@ class _SearchState:
     expiration_time: Optional[DHTExpiration] = None  # best expiration time so far
     source_node_id: Optional[DHTID] = None  # node that gave us the value
     future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
-    serializer: type(SerializerBase) = MSGPackSerializer
+    serializer: Type[SerializerBase] = MSGPackSerializer
     record_validator: Optional[RecordValidatorBase] = None
 
     def add_candidate(self, candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],

+ 46 - 77
hivemind/dht/protocol.py

@@ -1,4 +1,4 @@
-""" RPC protocol that provides nodes a way to communicate with each other. Based on gRPC.AIO. """
+""" RPC protocol that provides nodes a way to communicate with each other """
 from __future__ import annotations
 
 import asyncio
@@ -9,8 +9,9 @@ import grpc
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
-from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
-from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, GRPC_KEEPALIVE_OPTIONS
+from hivemind.p2p import P2P, P2PContext, PeerID as Endpoint, Servicer
+from hivemind.proto import dht_pb2
+from hivemind.utils import get_logger, MSGPackSerializer
 from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
 from hivemind.utils.timed_storage import DHTExpiration, get_dht_time, MAX_DHT_TIME_DISCREPANCY_SECONDS, \
     ValueWithExpiration
@@ -18,10 +19,10 @@ from hivemind.utils.timed_storage import DHTExpiration, get_dht_time, MAX_DHT_TI
 logger = get_logger(__name__)
 
 
-class DHTProtocol(dht_grpc.DHTServicer):
+class DHTProtocol(Servicer):
     # fmt:off
-    node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
-    channel_options: Tuple[Tuple[str, Any]]; server: grpc.aio.Server
+    p2p: P2P
+    node_id: DHTID; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     record_validator: Optional[RecordValidatorBase]
     # fmt:on
@@ -31,12 +32,10 @@ class DHTProtocol(dht_grpc.DHTServicer):
 
     @classmethod
     async def create(
-            cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
+            cls, p2p: P2P, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
             parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
-            listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
-            record_validator: Optional[RecordValidatorBase] = None,
-            authorizer: Optional[AuthorizerBase] = None,
-            channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
+            listen=True, record_validator: Optional[RecordValidatorBase] = None,
+            authorizer: Optional[AuthorizerBase] = None) -> DHTProtocol:
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
         As a side-effect, DHTProtocol also maintains a routing table as described in
@@ -50,33 +49,23 @@ class DHTProtocol(dht_grpc.DHTServicer):
          Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
         """
         self = cls(_initialized_with_create=True)
+        self.p2p = p2p
         self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
-        self.wait_timeout, self.channel_options = wait_timeout, tuple(channel_options)
+        self.wait_timeout = wait_timeout
         self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
+        self.listen = listen
         self.record_validator = record_validator
         self.authorizer = authorizer
 
-        if listen:  # set up server to process incoming rpc requests
-            grpc.aio.init_grpc_aio()
-            self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-            servicer = AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer)
-            dht_grpc.add_DHTServicer_to_server(servicer, self.server)
-
-            self.port = self.server.add_insecure_port(listen_on)
-            assert self.port != 0, f"Failed to listen to {listen_on}"
-            if endpoint is not None and endpoint.endswith('*'):
-                endpoint = replace_port(endpoint, self.port)
-            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=self.port,
-                                              endpoint=endpoint or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value)
-            await self.server.start()
-        else:  # not listening to incoming requests, client-only mode
+        if listen:
+            await self.add_p2p_handlers(self.p2p, AuthRPCWrapper(self, AuthRole.SERVICER, self.authorizer))
+
+            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes())
+        else:  # client-only mode
             # note: use empty node_info so peers won't add you to their routing tables
-            self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
-            if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
-                logger.warning(f"DHTProtocol has no server (due to listen=False), listen_on"
-                               f"and kwargs have no effect (unused kwargs: {kwargs})")
+            self.node_info = dht_pb2.NodeInfo()
         return self
 
     def __init__(self, *, _initialized_with_create=False):
@@ -84,22 +73,15 @@ class DHTProtocol(dht_grpc.DHTServicer):
         assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
         super().__init__()
 
-    async def shutdown(self, timeout=None):
-        """ Process existing requests, close all connections and stop the server """
-        if self.server:
-            await self.server.stop(timeout)
-        else:
-            logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
-
-    def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
-        """ get a DHTStub that sends requests to a given peer """
-        stub = ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
+    def get_stub(self, peer: Endpoint) -> AuthRPCWrapper:
+        """ get a stub that sends requests to a given peer """
+        stub = super().get_stub(self.p2p, peer)
         return AuthRPCWrapper(stub, AuthRole.CLIENT, self.authorizer, service_public_key=None)
 
     async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
         """
         Get peer's node id and add him to the routing table. If peer doesn't respond, return None
-        :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+        :param peer: peer ID to ping
         :param validate: if True, validates that node's endpoint is available
         :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
         :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
@@ -110,18 +92,18 @@ class DHTProtocol(dht_grpc.DHTServicer):
             async with self.rpc_semaphore:
                 ping_request = dht_pb2.PingRequest(peer=self.node_info, validate=validate)
                 time_requested = get_dht_time()
-                response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
+                response = await self.get_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
                 time_responded = get_dht_time()
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to ping {peer}", exc_info=True)
             response = None
         responded = bool(response and response.peer and response.peer.node_id)
 
         if responded and validate:
             try:
-                if self.server is not None and not response.available:
-                    raise ValidationError(f"Peer {peer} couldn't access this node at {response.sender_endpoint} . "
-                                          f"Make sure that this port is open for incoming requests.")
+                if self.listen and not response.available:
+                    raise ValidationError(f"Peer {peer} can't access this node. "
+                                          f"Probably, libp2p has failed to bypass the firewall")
 
                 if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
                     if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \
@@ -138,32 +120,18 @@ class DHTProtocol(dht_grpc.DHTServicer):
         asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
         return peer_id
 
-    async def get_outgoing_request_endpoint(self, peer: Endpoint) -> Optional[Endpoint]:
-        """ ask this peer how it perceives this node's outgoing request address """
-        try:
-            async with self.rpc_semaphore:
-                ping_request = dht_pb2.PingRequest(peer=None, validate=False)
-                response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
-                if response.sender_endpoint != dht_pb2.PingResponse.sender_endpoint.DESCRIPTOR.default_value:
-                    return response.sender_endpoint
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
-
-    async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext):
+    async def rpc_ping(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
         """ Some node wants us to add it to our routing table. """
-        response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(),
+
+        response = dht_pb2.PingResponse(peer=self.node_info,
                                         dht_time=get_dht_time(), available=False)
 
-        if request.peer and request.peer.node_id and request.peer.rpc_port:
+        if request.peer and request.peer.node_id:
             sender_id = DHTID.from_bytes(request.peer.node_id)
-            if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
-                sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
-            else:
-                sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port)
+            sender_endpoint = context.remote_id
 
-            response.sender_endpoint = sender_endpoint
             if request.validate:
-                response.available = await self.call_ping(response.sender_endpoint, validate=False) == sender_id
+                response.available = await self.call_ping(sender_endpoint, validate=False) == sender_id
 
             asyncio.create_task(self.update_routing_table(sender_id, sender_endpoint,
                                                           responded=response.available or not request.validate))
@@ -215,17 +183,17 @@ class DHTProtocol(dht_grpc.DHTServicer):
                                              expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
         try:
             async with self.rpc_semaphore:
-                response = await self._get_dht_stub(peer).rpc_store(store_request, timeout=self.wait_timeout)
+                response = await self.get_stub(peer).rpc_store(store_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
             return response.store_ok
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to store at {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to store at {peer}", exc_info=True)
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             return None
 
-    async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
+    async def rpc_store(self, request: dht_pb2.StoreRequest, context: P2PContext) -> dht_pb2.StoreResponse:
         """ Some node wants us to store this (key, value) pair """
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
@@ -274,7 +242,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
         try:
             async with self.rpc_semaphore:
-                response = await self._get_dht_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
+                response = await self.get_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
@@ -283,7 +251,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
             output = {}  # unpack data depending on its type
             for key, result in zip(keys, response.results):
                 key_bytes = DHTID.to_bytes(key)
-                nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
+                nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids),
+                                   map(Endpoint.from_base58, result.nearest_endpoints)))
 
                 if result.type == dht_pb2.NOT_FOUND:
                     output[key] = None, nearest
@@ -305,11 +274,11 @@ class DHTProtocol(dht_grpc.DHTServicer):
                     logger.error(f"Unknown result type: {result.type}")
 
             return output
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to find at {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to find at {peer}", exc_info=True)
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
-    async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
+    async def rpc_find(self, request: dht_pb2.FindRequest, context: P2PContext) -> dht_pb2.FindResponse:
         """
         Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
         Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
@@ -337,7 +306,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             for node_id, endpoint in self.routing_table.get_nearest_neighbors(
                     key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
                 item.nearest_node_ids.append(node_id.to_bytes())
-                item.nearest_endpoints.append(endpoint)
+                item.nearest_endpoints.append(endpoint.to_base58())
             response.results.append(item)
         return response
 

+ 2 - 2
hivemind/dht/routing.py

@@ -8,8 +8,8 @@ import random
 from collections.abc import Iterable
 from itertools import chain
 from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
-
-from hivemind.utils import Endpoint, MSGPackSerializer, get_dht_time
+from hivemind.p2p import PeerID as Endpoint
+from hivemind.utils import MSGPackSerializer, get_dht_time
 
 DHTKey, Subkey, DHTValue, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, bytes, bytes
 

+ 1 - 2
hivemind/hivemind_cli/run_server.py

@@ -48,8 +48,7 @@ def main():
 
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
-                        help='one or more peers that can welcome you to the dht, e.g. 1.2.3.4:1337 192.132.231.4:4321')
-    parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
+                        help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
     parser.add_argument('--increase_file_limit', action='store_true',
                         help='On *nix, this will increase the max number of processes '
                              'a server can spawn before hitting "Too many open files"; Use at your own risk.')

+ 11 - 17
hivemind/moe/server/__init__.py

@@ -5,10 +5,11 @@ import multiprocessing.synchronize
 import threading
 from contextlib import contextmanager
 from functools import partial
-from typing import Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
 from pathlib import Path
 
 import torch
+from multiaddr import Multiaddr
 
 import hivemind
 from hivemind.dht import DHT
@@ -75,7 +76,7 @@ class Server(threading.Thread):
     def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
                num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1,
-               max_batch_size=4096, device=None, no_dht=False, initial_peers=(), dht_port=None,
+               max_batch_size=4096, device=None, no_dht=False, initial_peers=(),
                checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE,
                stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server:
         """
@@ -99,11 +100,7 @@ class Server(threading.Thread):
         :param clip_grad_norm: maximum gradient norm used for clipping
 
         :param no_dht: if specified, the server will not be attached to a dht
-        :param initial_peers: a list of peers that will introduce this node to the dht,\
-           e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
-
-        :param dht_port:  DHT node will listen on this port, default = find open port
-           You can then use this node as initial peer for subsequent servers.
+        :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
 
         :param checkpoint_dir: directory to save and load expert checkpoints
 
@@ -121,9 +118,8 @@ class Server(threading.Thread):
         if no_dht:
             dht = None
         else:
-            dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
-            dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
-            logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
+            dht = hivemind.DHT(initial_peers=initial_peers, start=True)
+            logger.info(f"Running DHT node on {dht.get_visible_maddrs()}, initial peers = {initial_peers}")
 
         assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
                 (num_experts is not None and expert_uids is None)), \
@@ -266,13 +262,14 @@ class Server(threading.Thread):
 
 
 @contextmanager
-def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
+def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, List[Multiaddr]]:
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     try:
         runner.start()
-        # once the server is ready, runner will send us either (False, exception) or (True, (server_port, dht_port))
+        # once the server is ready, runner will send us
+        # either (False, exception) or (True, (server.listen_on, dht_maddrs))
         start_ok, data = pipe.recv()
         if start_ok:
             yield data
@@ -296,11 +293,8 @@ def _server_runner(pipe, *args, **kwargs):
         return
 
     try:
-        if server.dht is not None:
-            dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
-        else:
-            dht_listen_on = None
-        pipe.send((True, (server.listen_on, dht_listen_on)))
+        dht_maddrs = server.dht.get_visible_maddrs() if server.dht is not None else None
+        pipe.send((True, (server.listen_on, dht_maddrs)))
         pipe.recv()  # wait for shutdown signal
 
     finally:

+ 1 - 0
hivemind/p2p/__init__.py

@@ -1,2 +1,3 @@
 from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
+from hivemind.p2p.servicer import Servicer

+ 78 - 58
hivemind/p2p/p2p_daemon.py

@@ -1,9 +1,11 @@
 import asyncio
+import os
 import secrets
+from contextlib import suppress
 from dataclasses import dataclass
 from importlib.resources import path
 from subprocess import Popen
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
 
 import google.protobuf
 from multiaddr import Multiaddr
@@ -32,9 +34,10 @@ class P2PContext(object):
 class P2P:
     """
     This class is responsible for establishing peer-to-peer connections through NAT and/or firewalls.
-    It creates and manages a libp2p daemon in a background process, then terminates it when P2P is shut down.
-    In order to communicate, a P2P instance should either use one or more bootstrap_peers that will connect it
-    to the rest of the swarm or use the public IPFS network (https://ipfs.io).
+    It creates and manages a libp2p daemon (https://libp2p.io) in a background process,
+    then terminates it when P2P is shut down. In order to communicate, a P2P instance should
+    either use one or more initial_peers that will connect it to the rest of the swarm or
+    use the public IPFS network (https://ipfs.io).
 
     For incoming connections, P2P instances add RPC handlers that may be accessed by other peers:
       - `P2P.add_unary_handler` accepts a protobuf message and returns another protobuf
@@ -58,6 +61,7 @@ class P2P:
         'public': {'forceReachabilityPublic': 1},
         'private': {'forceReachabilityPrivate': 1},
     }
+    _UNIX_SOCKET_PREFIX = '/unix/tmp/hivemind-'
 
     def __init__(self):
         self.id = None
@@ -67,17 +71,25 @@ class P2P:
         self._server_stopped = asyncio.Event()
 
     @classmethod
-    async def create(cls, *args, quic: bool = False, tls: bool = True, conn_manager: bool = True,
+    async def create(cls,
+                     initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
+                     use_ipfs: bool = False,
+                     host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ('/ip4/127.0.0.1/tcp/0',),
+                     announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
+                     quic: bool = True, tls: bool = True, conn_manager: bool = True,
                      dht_mode: str = 'dht_server', force_reachability: Optional[str] = None,
                      nat_port_map: bool = True, auto_nat: bool = True,
-                     bootstrap_peers: Optional[List[Multiaddr]] = None, use_ipfs: bool = False,
-                     host_maddrs: Optional[List[Multiaddr]] = None,
                      use_relay: bool = True, use_relay_hop: bool = False,
                      use_relay_discovery: bool = False, use_auto_relay: bool = False, relay_hop_limit: int = 0,
                      quiet: bool = True,
-                     ping_n_retries: int = 3, ping_retry_delay: float = 0.4, **kwargs) -> 'P2P':
+                     ping_n_attempts: int = 5, ping_delay: float = 0.4) -> 'P2P':
         """
         Start a new p2pd process and connect to it.
+        :param initial_peers: List of bootstrap peers
+        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
+        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param announce_maddrs: Visible multiaddrs that the peer will announce
+          for external connections from other p2p instances
         :param quic: Enables the QUIC transport
         :param tls: Enables TLS1.3 channel security protocol
         :param conn_manager: Enables the Connection Manager
@@ -85,59 +97,58 @@ class P2P:
         :param force_reachability: Force reachability mode (public/private)
         :param nat_port_map: Enables NAT port mapping
         :param auto_nat: Enables the AutoNAT service
-        :param bootstrap: Connects to bootstrap peers and bootstraps the dht if enabled
-        :param bootstrap_peers: List of bootstrap peers
-        :param use_ipfs: Bootstrap to IPFS (works only if bootstrap=True and bootstrap_peers=None)
-        :param host_maddrs: multiaddresses for external connections from other p2p instances
         :param use_relay: enables circuit relay
         :param use_relay_hop: enables hop for relay
         :param use_relay_discovery: enables passive discovery for relay
         :param use_auto_relay: enables autorelay
         :param relay_hop_limit: sets the hop limit for hop relays
         :param quiet: make the daemon process quiet
-        :param args: positional CLI arguments for the p2p daemon
-        :param kwargs: keyword CLI arguments for the p2p daemon
+        :param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
+        :param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
+          (in particular, wait for ``ping_delay`` seconds before the first attempt)
         :return: a wrapper for the p2p daemon
         """
 
-        assert not (bootstrap_peers and use_ipfs), \
-            'User-defined bootstrap_peers and use_ipfs=True are incompatible, please choose one option'
+        assert not (initial_peers and use_ipfs), \
+            'User-defined initial_peers and use_ipfs=True are incompatible, please choose one option'
 
         self = cls()
         with path(cli, P2PD_FILENAME) as p:
             p2pd_path = p
 
         socket_uid = secrets.token_urlsafe(8)
-        self._daemon_listen_maddr = Multiaddr(f'/unix/tmp/hivemind-p2pd-{socket_uid}.sock')
-        self._client_listen_maddr = Multiaddr(f'/unix/tmp/hivemind-p2pclient-{socket_uid}.sock')
-
-        need_bootstrap = bool(bootstrap_peers) or use_ipfs
-        bootstrap_peers = cls._make_bootstrap_peers(bootstrap_peers)
-        dht = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0})
-        force_reachability = cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {})
-        host_maddrs = {'hostAddrs': ','.join(str(maddr) for maddr in host_maddrs)} if host_maddrs else {}
+        self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f'p2pd-{socket_uid}.sock')
+        self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f'p2pclient-{socket_uid}.sock')
+
+        need_bootstrap = bool(initial_peers) or use_ipfs
+        process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0})
+        process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
+        for param, value in [('bootstrapPeers', initial_peers),
+                             ('hostAddrs', host_maddrs),
+                             ('announceAddrs', announce_maddrs)]:
+            if value:
+                process_kwargs[param] = self._maddrs_to_str(value)
+
         proc_args = self._make_process_args(
-            str(p2pd_path), *args,
+            str(p2pd_path),
             listen=self._daemon_listen_maddr,
             quic=quic, tls=tls, connManager=conn_manager,
             natPortMap=nat_port_map, autonat=auto_nat,
             relay=use_relay, relayHop=use_relay_hop, relayDiscovery=use_relay_discovery,
             autoRelay=use_auto_relay, relayHopLimit=relay_hop_limit,
-            b=need_bootstrap, q=quiet, **{**bootstrap_peers, **dht, **force_reachability, **host_maddrs, **kwargs})
-
-        self._initialize(proc_args)
-        await self._ping_daemon_with_retries(ping_n_retries, ping_retry_delay)
-
-        return self
+            b=need_bootstrap, q=quiet, **process_kwargs)
 
-    def _initialize(self, proc_args: List[str]) -> None:
         self._child = Popen(args=proc_args, encoding="utf8")
         self._alive = True
         self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
 
-    async def _ping_daemon_with_retries(self, ping_n_retries: int, ping_retry_delay: float) -> None:
-        for try_number in range(ping_n_retries):
-            await asyncio.sleep(ping_retry_delay * (2 ** try_number))
+        await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)
+
+        return self
+
+    async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
+        for try_number in range(ping_n_attempts):
+            await asyncio.sleep(ping_delay * (2 ** try_number))
 
             if self._child.poll() is not None:  # Process died
                 break
@@ -146,8 +157,8 @@ class P2P:
                 await self._ping_daemon()
                 break
             except Exception as e:
-                if try_number == ping_n_retries - 1:
-                    logger.error(f'Failed to ping p2pd: {e}')
+                if try_number == ping_n_attempts - 1:
+                    logger.exception('Failed to ping p2pd that has just started')
                     await self.shutdown()
                     raise
 
@@ -170,7 +181,7 @@ class P2P:
 
         socket_uid = secrets.token_urlsafe(8)
         self._daemon_listen_maddr = daemon_listen_maddr
-        self._client_listen_maddr = Multiaddr(f'/unix/tmp/hivemind-p2pclient-{socket_uid}.sock')
+        self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f'p2pclient-{socket_uid}.sock')
 
         self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
 
@@ -178,16 +189,24 @@ class P2P:
         return self
 
     async def _ping_daemon(self) -> None:
-        self.id, maddrs = await self._client.identify()
-        logger.debug(f'Launched p2pd with id = {self.id}, host multiaddrs = {maddrs}')
+        self.id, self._visible_maddrs = await self._client.identify()
+        logger.debug(f'Launched p2pd with id = {self.id}, host multiaddrs = {self._visible_maddrs}')
 
-    async def identify_maddrs(self) -> List[Multiaddr]:
-        _, maddrs = await self._client.identify()
-        if not maddrs:
+    async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
+        """
+        Get multiaddrs of the current peer that should be accessible by other peers.
+
+        :param latest: ask the P2P daemon to refresh the visible multiaddrs
+        """
+
+        if latest:
+            _, self._visible_maddrs = await self._client.identify()
+
+        if not self._visible_maddrs:
             raise ValueError(f"No multiaddrs found for peer {self.id}")
 
         p2p_maddr = Multiaddr(f'/p2p/{self.id.to_base58()}')
-        return [addr.encapsulate(p2p_maddr) for addr in maddrs]
+        return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
 
     async def list_peers(self) -> List[PeerInfo]:
         return list(await self._client.list_peers())
@@ -282,13 +301,14 @@ class P2P:
                 try:
                     request, err = await P2P.receive_protobuf(in_proto_type, reader)
                 except asyncio.IncompleteReadError:
-                    logger.debug('Incomplete read while receiving request from peer')
+                    logger.debug(f'Incomplete read while receiving request from peer in {handle_name}')
                     return
                 except google.protobuf.message.DecodeError as error:
-                    logger.debug(f'Failed to decode request protobuf: {error}')
+                    logger.debug(f'Failed to decode request protobuf '
+                                 f'of type {in_proto_type} in {handle_name}: {error}')
                     return
                 if err is not None:
-                    logger.debug(f'Got an error instead of a request: {err}')
+                    logger.debug(f'Got an error instead of a request in {handle_name}: {err}')
 
                 context = P2PContext(handle_name=handle_name, local_id=self.id,
                                      remote_id=stream_info.peer_id, remote_maddr=stream_info.addr)
@@ -303,12 +323,10 @@ class P2P:
                     error = p2pd_pb2.RPCError(message=str(exc))
                     await P2P.send_protobuf(error, p2pd_pb2.RPCError, writer)
                 finally:
-                    pending_task = pending.pop()
-                    pending_task.cancel()
-                    try:
-                        await pending_task
-                    except asyncio.CancelledError:
-                        pass
+                    if pending:
+                        for task in pending:
+                            task.cancel()
+                        await asyncio.wait(pending)
             finally:
                 writer.close()
 
@@ -382,6 +400,11 @@ class P2P:
             self._child.wait()
             logger.debug(f'Terminated p2pd with id = {self.id}')
 
+            with suppress(FileNotFoundError):
+                os.remove(self._daemon_listen_maddr['unix'])
+        with suppress(FileNotFoundError):
+            os.remove(self._client_listen_maddr['unix'])
+
     @staticmethod
     def _make_process_args(*args, **kwargs) -> List[str]:
         proc_args = []
@@ -401,11 +424,8 @@ class P2P:
         return val
 
     @staticmethod
-    def _make_bootstrap_peers(maddrs: Optional[List[Multiaddr]]) -> Dict[str, str]:
-        if maddrs is None:
-            return {}
-
-        return {'bootstrapPeers': ','.join(str(addr) for addr in maddrs)}
+    def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
+        return ','.join(str(addr) for addr in maddrs)
 
 
 class P2PInterruptedError(Exception):

+ 0 - 5
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -54,7 +54,6 @@ class DaemonConnector:
     async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
         if self.proto_code == protocols.P_UNIX:
             control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX)
-            logger.debug(f"DaemonConnector {self} opens connection to {self.control_maddr}")
             return await asyncio.open_unix_connection(control_path)
         elif self.proto_code == protocols.P_IP4:
             host = self.control_maddr.value_for_protocol(protocols.P_IP4)
@@ -80,7 +79,6 @@ class ControlClient:
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
         stream_info = StreamInfo.from_protobuf(pb_stream_info)
-        logger.debug(f"New incoming stream: {stream_info}")
         try:
             handler = self.handlers[stream_info.proto]
         except KeyError as e:
@@ -105,11 +103,8 @@ class ControlClient:
             )
 
         async with server:
-            logger.debug(f"DaemonConnector {self} starts listening to {self.listen_maddr}")
             yield self
 
-        logger.debug(f"DaemonConnector {self} closed")
-
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)

+ 98 - 0
hivemind/p2p/servicer.py

@@ -0,0 +1,98 @@
+import asyncio
+import importlib
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+from hivemind.p2p.p2p_daemon import P2P, P2PContext
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
+
+
+@dataclass
+class RPCHandler:
+    method_name: str
+    handle_name: str
+    request_type: type
+    response_type: type
+
+
+class StubBase:
+    """
+    Base class for P2P RPC stubs. The interface mimicks gRPC stubs.
+
+    Servicer derives stub classes for particular services (e.g. DHT, averager, etc.) from StubBase,
+    adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
+    """
+
+    def __init__(self, p2p: P2P, peer: PeerID):
+        self._p2p = p2p
+        self._peer = peer
+
+
+class Servicer:
+    """
+    Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
+
+    - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P unary handlers, allowing
+      other peers to call them. It uses type annotations for the ``request`` parameter and the return value
+      to infer protobufs the methods operate with.
+
+    - ``get_stub(self, p2p, peer)`` creates a stub with all rpc_* methods. Calls to the stub methods are translated
+      to calls to the remote peer.
+    """
+
+    def __init__(self):
+        class_name = self.__class__.__name__
+
+        self._rpc_handlers = []
+        for method_name, method in self.__class__.__dict__.items():
+            if method_name.startswith('rpc_') and callable(method):
+                handle_name = f'{class_name}.{method_name}'
+
+                hints = method.__annotations__
+                try:
+                    request_type = self._hint_to_type(hints['request'])
+                    response_type = self._hint_to_type(hints['return'])
+                except (KeyError, ValueError):
+                    raise ValueError(f'{handle_name} is expected to have type annotations like `dht_pb2.FindRequest` '
+                                     f'(a type from the hivemind.proto module) for the `request` parameter '
+                                     f'and the return value')
+
+                self._rpc_handlers.append(RPCHandler(method_name, handle_name, request_type, response_type))
+
+        self._stub_type = type(f'{class_name}Stub', (StubBase,),
+                               {handler.method_name: self._make_rpc_caller(handler)
+                                for handler in self._rpc_handlers})
+
+    @staticmethod
+    def _make_rpc_caller(handler: RPCHandler):
+        # This method will be added to a new Stub type (a subclass of StubBase)
+        async def caller(self: StubBase, request: handler.request_type,
+                         timeout: Optional[float] = None) -> handler.response_type:
+            return await asyncio.wait_for(
+                self._p2p.call_unary_handler(self._peer, handler.handle_name, request, handler.response_type),
+                timeout=timeout)
+
+        caller.__name__ = handler.method_name
+        return caller
+
+    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
+        servicer = self if wrapper is None else wrapper
+        for handler in self._rpc_handlers:
+            await p2p.add_unary_handler(handler.handle_name, getattr(servicer, handler.method_name),
+                                        handler.request_type, handler.response_type)
+
+    def get_stub(self, p2p: P2P, peer: PeerID) -> StubBase:
+        return self._stub_type(p2p, peer)
+
+    @staticmethod
+    def _hint_to_type(hint: Union[type, str]) -> type:
+        if isinstance(hint, type):
+            return hint
+
+        module_name, proto_name = hint.split('.')
+        module = importlib.import_module('hivemind.proto.' + module_name)
+        result = getattr(module, proto_name)
+        if not isinstance(result, type):
+            raise ValueError(f'`hivemind.proto.{hint}` is not a type')
+        return result

+ 2 - 5
hivemind/proto/dht.proto

@@ -19,8 +19,6 @@ message NodeInfo {
   // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
   // if either node_id or port is absent, simply request recipient info (for client-only mode)
   bytes node_id = 1;                   // sender's own node id serialized with DHTID.to_bytes()
-  int32 rpc_port = 2;                  // port to which sender listens for DHT RPCs
-  string endpoint = 3;                 // (optional) node's preferred return address
 }
 
 message PingRequest {
@@ -32,7 +30,6 @@ message PingRequest {
 message PingResponse {
   ResponseAuthInfo auth = 1;
   NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
-  string sender_endpoint = 3;          // echo sender's visible endpoint - used to infer his ip address
   double dht_time = 4;                 // recipient's local DHT time - used to soft-synchronize peers
   bool available = 5;                  // if validate = True, this flag asserts that the sender is available for ping
 }
@@ -68,8 +65,8 @@ message FindResult {
   double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
 
   // two aligned arrays: DHTIDs and Endpoints for nearest peers (sorted by XOR distance)
-  repeated bytes nearest_node_ids = 4;      // DHTIDs serialized with node_id.to_bytes()
-  repeated string nearest_endpoints = 5;    // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+  repeated bytes nearest_node_ids = 4;      // DHTIDs of the nearest peers serialized with node_id.to_bytes()
+  repeated string nearest_endpoints = 5;    // Base58-serialized libp2p PeerIDs of the nearest peers
 }
 
 message FindResponse {

+ 35 - 1
hivemind/utils/networking.py

@@ -1,6 +1,10 @@
 import socket
 from contextlib import closing
-from typing import Optional
+from ipaddress import ip_address
+from typing import Optional, Sequence
+
+from multiaddr import Multiaddr
+
 
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
@@ -36,3 +40,33 @@ def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_
             return sock.getsockname()[1]
     except Exception as e:
         raise e
+
+
+def choose_ip_address(maddrs: Sequence[Multiaddr],
+                      prefer_global: bool = True,
+                      protocol_priority: Sequence[str] = ('ip4', 'ip6')) -> Hostname:
+    """
+    Currently, some components of hivemind are not converted to work over libp2p and use classical networking.
+    To allow other peers reach a server when needed, these components announce a machine's IP address.
+
+    This function automatically selects the best IP address to announce among publicly visible multiaddrs
+    of this machine identified by libp2p (typically, using the ``P2P.get_visible_maddrs()`` method),
+    so a user does not need to define this address manually (unless the user wants to).
+
+    The best IP address is chosen using the following logic:
+      - Prefer IP addresses from global address blocks
+        (in terms of https://docs.python.org/3/library/ipaddress.html#ipaddress.IPv4Address.is_global)
+      - Among the IP addresses of the same globality status, prefer IPv4 addresses over IPv6
+
+    If the default logic does not suit you, it is recommended to set the announced IP address manually.
+    """
+
+    for need_global in [prefer_global, not prefer_global]:
+        for protocol in protocol_priority:
+            for addr in maddrs:
+                if protocol in addr.protocols():
+                    value_for_protocol = addr[protocol]
+                    if ip_address(value_for_protocol).is_global == need_global:
+                        return value_for_protocol
+
+    raise ValueError(f'No IP address found among given multiaddrs: {maddrs}')

+ 22 - 3
tests/conftest.py

@@ -1,9 +1,28 @@
-import pytest
+import gc
+from contextlib import suppress
+
 import psutil
+import pytest
+
+from hivemind.utils import get_logger
+
+
+logger = get_logger(__name__)
 
 
 @pytest.fixture(autouse=True, scope='session')
 def cleanup_children():
     yield
-    for child in psutil.Process().children(recursive=True):
-        child.terminate()
+
+    gc.collect()  # Call .__del__() for removed objects
+
+    children = psutil.Process().children(recursive=True)
+    if children:
+        logger.info(f'Cleaning up {len(children)} leftover child processes')
+        for child in children:
+            with suppress(psutil.NoSuchProcess):
+                child.terminate()
+        psutil.wait_procs(children, timeout=1)
+        for child in children:
+            with suppress(psutil.NoSuchProcess):
+                child.kill()

+ 12 - 12
tests/test_auth.py

@@ -73,12 +73,12 @@ async def test_valid_request_and_response():
     service_authorizer = MockAuthorizer(RSAPrivateKey())
 
     request = dht_pb2.PingRequest()
-    request.peer.endpoint = '127.0.0.1:7777'
+    request.peer.node_id = b'ping'
     await client_authorizer.sign_request(request, service_authorizer.local_public_key)
     assert await service_authorizer.validate_request(request)
 
     response = dht_pb2.PingResponse()
-    response.sender_endpoint = '127.0.0.1:31337'
+    response.peer.node_id = b'pong'
     await service_authorizer.sign_response(response, request)
     assert await client_authorizer.validate_response(response, request)
 
@@ -89,7 +89,7 @@ async def test_invalid_access_token():
     service_authorizer = MockAuthorizer(RSAPrivateKey())
 
     request = dht_pb2.PingRequest()
-    request.peer.endpoint = '127.0.0.1:7777'
+    request.peer.node_id = b'ping'
     await client_authorizer.sign_request(request, service_authorizer.local_public_key)
 
     # Break the access token signature
@@ -98,7 +98,7 @@ async def test_invalid_access_token():
     assert not await service_authorizer.validate_request(request)
 
     response = dht_pb2.PingResponse()
-    response.sender_endpoint = '127.0.0.1:31337'
+    response.peer.node_id = b'pong'
     await service_authorizer.sign_response(response, request)
 
     # Break the access token signature
@@ -113,20 +113,20 @@ async def test_invalid_signatures():
     service_authorizer = MockAuthorizer(RSAPrivateKey())
 
     request = dht_pb2.PingRequest()
-    request.peer.endpoint = '127.0.0.1:7777'
+    request.peer.node_id = b'true-ping'
     await client_authorizer.sign_request(request, service_authorizer.local_public_key)
 
     # A man-in-the-middle attacker changes the request content
-    request.peer.endpoint = '127.0.0.2:7777'
+    request.peer.node_id = b'fake-ping'
 
     assert not await service_authorizer.validate_request(request)
 
     response = dht_pb2.PingResponse()
-    response.sender_endpoint = '127.0.0.1:31337'
+    response.peer.node_id = b'true-pong'
     await service_authorizer.sign_response(response, request)
 
     # A man-in-the-middle attacker changes the response content
-    response.sender_endpoint = '127.0.0.2:31337'
+    response.peer.node_id = b'fake-pong'
 
     assert not await client_authorizer.validate_response(response, request)
 
@@ -135,11 +135,11 @@ async def test_invalid_signatures():
 async def test_auth_rpc_wrapper():
     class Servicer:
         async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
-            assert request.peer.endpoint == '127.0.0.1:1111'
+            assert request.peer.node_id == b'ping'
             assert request.auth.client_access_token.username == 'alice'
 
             response = dht_pb2.PingResponse()
-            response.sender_endpoint = '127.0.0.1:2222'
+            response.peer.node_id = b'pong'
             return response
 
     class Client:
@@ -153,9 +153,9 @@ async def test_auth_rpc_wrapper():
     client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), 'alice'))
 
     request = dht_pb2.PingRequest()
-    request.peer.endpoint = '127.0.0.1:1111'
+    request.peer.node_id = b'ping'
 
     response = await client.rpc_increment(request)
 
-    assert response.sender_endpoint == '127.0.0.1:2222'
+    assert response.peer.node_id == b'pong'
     assert response.auth.service_access_token.username == 'bob'

+ 19 - 18
tests/test_averaging.py

@@ -43,7 +43,7 @@ async def test_key_manager():
 
 
 def _test_allreduce_once(n_clients, n_aux):
-    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+    dht = hivemind.DHT(start=True)
 
     n_peers = 4
     modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (
@@ -100,7 +100,7 @@ def test_allreduce_once_edge_cases(n_clients, n_aux):
 
 @pytest.mark.forked
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
-    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+    dht = hivemind.DHT(start=True)
 
     n_peers = 4
     should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
@@ -137,7 +137,7 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
 @pytest.mark.forked
 def test_allreduce_compression():
     """ this test ensures that compression works correctly when multiple tensors have different compression types """
-    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+    dht = hivemind.DHT(start=True)
 
     tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
     tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
@@ -152,7 +152,8 @@ def test_allreduce_compression():
                                                              start=True)
         averager2 = hivemind.averaging.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
                                                              compression_type=compression_type_pair,
-                                                             target_group_size=2, prefix='mygroup', start=True)
+                                                             target_group_size=2, prefix='mygroup',
+                                                             listen_on='127.0.0.1:*', start=True)
 
         for future in averager1.step(wait=False), averager2.step(wait=False):
             future.result()
@@ -190,10 +191,10 @@ def compute_mean_std(averagers, unbiased=True):
 
 @pytest.mark.forked
 def test_allreduce_grid():
-    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    dht = hivemind.DHT(start=True)
     averagers = [hivemind.averaging.DecentralizedAverager(
         averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
-        prefix='mygroup', initial_group_bits=bin(i // 2)[2:].rjust(2, '0'), start=True)
+        prefix='mygroup', initial_group_bits=bin(i // 2)[2:].rjust(2, '0'), listen_on='127.0.0.1:*', start=True)
         for i in range(8)]
 
     [means0], [stds0] = compute_mean_std(averagers)
@@ -220,7 +221,7 @@ def test_allreduce_grid():
 
 @pytest.mark.forked
 def test_allgather():
-    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+    dht = hivemind.DHT(start=True)
     averagers = [hivemind.averaging.DecentralizedAverager([torch.ones(1)], dht=dht, target_group_size=4,
                                                           averaging_expiration=15, prefix='mygroup',
                                                           initial_group_bits='000', listen_on='127.0.0.1:*', start=True)
@@ -290,11 +291,11 @@ def test_load_balancing():
 
 @pytest.mark.forked
 def test_too_few_peers():
-    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    dht = hivemind.DHT(start=True)
     averagers = [hivemind.averaging.DecentralizedAverager(
         averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
         averaging_expiration=1, request_timeout=0.5,
-        prefix='mygroup', initial_group_bits=bin(i)[2:].rjust(3, '0'), start=True)
+        prefix='mygroup', initial_group_bits=bin(i)[2:].rjust(3, '0'), listen_on='127.0.0.1:*', start=True)
         for i in range(4)]
     step_futures = [averager.step(wait=False) for averager in averagers]
     for future in step_futures:
@@ -307,11 +308,11 @@ def test_too_few_peers():
 
 @pytest.mark.forked
 def test_overcrowded(num_peers=16):
-    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    dht = hivemind.DHT(start=True)
     averagers = [hivemind.averaging.DecentralizedAverager(
         averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
         averaging_expiration=1, request_timeout=0.5,
-        prefix='mygroup', initial_group_bits='', start=True)
+        prefix='mygroup', initial_group_bits='', listen_on='127.0.0.1:*', start=True)
         for _ in range(num_peers)]
     for t in range(5):
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
@@ -339,17 +340,17 @@ def test_load_state_from_peers():
             return super_metadata, super_tensors
 
     dht_root = hivemind.DHT(start=True)
-    initial_peers = [f'{hivemind.LOCALHOST}:{dht_root.port}']
+    initial_peers = dht_root.get_visible_maddrs()
     dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
     averager1 = TestAverager([torch.randn(3), torch.rand(5)],
                              dht=dht1, start=True,
-                             prefix='demo-run', target_group_size=2)
+                             prefix='demo-run', target_group_size=2, listen_on='127.0.0.1:*')
 
     dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
     dht2.get('demo-run.all_averagers')
     averager2 = TestAverager([torch.randn(3), torch.rand(5)],
                              dht=dht2, start=True,
-                             prefix='demo-run', target_group_size=2)
+                             prefix='demo-run', target_group_size=2, listen_on='127.0.0.1:*')
 
     assert num_calls == 0
     got_metadata, got_tensors = averager2.load_state_from_peers()
@@ -377,9 +378,9 @@ def test_load_state_from_peers():
 
 @pytest.mark.forked
 def test_getset_bits():
-    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
-    averager = hivemind.averaging.DecentralizedAverager([torch.randn(3)], dht=dht, start=True,
-                                                        prefix='test_prefix', target_group_size=2)
+    dht = hivemind.DHT(start=True)
+    averager = hivemind.averaging.DecentralizedAverager([torch.randn(3)], dht=dht, start=True, prefix='test_prefix',
+                                                        target_group_size=2, listen_on='127.0.0.1:*')
     averager.set_group_bits('00101011101010')
     assert averager.get_group_bits() == '00101011101010'
 
@@ -388,7 +389,7 @@ def test_getset_bits():
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)
 
-    dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
+    dht = hivemind.DHT(start=True)
     common_kwargs = {'dht': dht, 'start': True, 'listen_on': '127.0.0.1:*',
                      'prefix': 'demo-run', 'target_group_size': 2}
 

+ 22 - 17
tests/test_dht.py

@@ -3,18 +3,17 @@ import random
 import time
 
 import pytest
+from multiaddr import Multiaddr
 
 import hivemind
-from hivemind import LOCALHOST, strip_port
 
 
 
 @pytest.mark.forked
-def test_get_store():
-    peers = []
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
+def test_get_store(n_peers=10):
+    peers = [hivemind.DHT(start=True)]
+    initial_peers = peers[0].get_visible_maddrs()
+    peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
 
     node1, node2 = random.sample(peers, 2)
     assert node1.store('key1', 'value1', expiration_time=hivemind.get_dht_time() + 30)
@@ -87,18 +86,24 @@ def test_run_coroutine():
     future.cancel()
     assert dht.run_coroutine(dummy_dht_coro_stateful) == -99
 
+    dht.shutdown()
+
 
 @pytest.mark.forked
-def test_dht_get_address(addr=LOCALHOST, dummy_endpoint='123.45.67.89:*'):
-    node1 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
-    node2 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node1.port}"])
-    node3 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node2.port}"])
-    assert addr in node3.get_visible_address(num_peers=2)
+@pytest.mark.asyncio
+async def test_dht_get_visible_maddrs():
+    # test 1: IPv4 localhost multiaddr is visible by default
 
-    node4 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
-    with pytest.raises(ValueError):
-        node4.get_visible_address()
-    assert node4.get_visible_address(peers=[f'{addr}:{node1.port}']).endswith(addr)
+    dht = hivemind.DHT(start=True)
+
+    assert any(str(maddr).startswith('/ip4/127.0.0.1') for maddr in dht.get_visible_maddrs())
+    dht.shutdown()
+
+    # test 2: announce_maddrs are the single visible multiaddrs if defined
+
+    dummy_endpoint = Multiaddr('/ip4/123.45.67.89/tcp/31337')
+    p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
+    dht = hivemind.DHT(p2p, start=True)
 
-    node5 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", endpoint=f"{dummy_endpoint}")
-    assert node5.get_visible_address() == strip_port(dummy_endpoint)
+    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f'/p2p/{p2p.id}')]
+    dht.shutdown()

+ 4 - 5
tests/test_dht_crypto.py

@@ -7,7 +7,7 @@ import pytest
 import hivemind
 from hivemind.utils.timed_storage import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
-from hivemind.dht.node import LOCALHOST, DHTNode
+from hivemind.dht.node import DHTNode
 from hivemind.dht.validation import DHTRecord
 from hivemind.utils.crypto import RSAPrivateKey
 
@@ -106,12 +106,11 @@ def test_signing_in_different_process():
 @pytest.mark.asyncio
 async def test_dhtnode_signatures():
     alice = await DHTNode.create(record_validator=RSASignatureValidator())
+    initial_peers = await alice.get_visible_maddrs()
     bob = await DHTNode.create(
-        record_validator=RSASignatureValidator(RSAPrivateKey()),
-        initial_peers=[f"{LOCALHOST}:{alice.port}"])
+        record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers)
     mallory = await DHTNode.create(
-        record_validator=RSASignatureValidator(RSAPrivateKey()),
-        initial_peers=[f"{LOCALHOST}:{alice.port}"])
+        record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers)
 
     key = b'key'
     subkey = b'protected_subkey' + bob.protocol.record_validator.local_public_key

+ 18 - 19
tests/test_dht_experts.py

@@ -14,11 +14,10 @@ from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_p
 
 
 @pytest.mark.forked
-def test_store_get_experts():
+def test_store_get_experts(n_peers=10):
     peers = [hivemind.DHT(start=True)]
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
+    initial_peers = peers[0].get_visible_maddrs()
+    peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
 
     first_peer = random.choice(peers)
     other_peer = random.choice(peers)
@@ -49,12 +48,11 @@ def test_store_get_experts():
 
 
 @pytest.mark.forked
-def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=4,
+def test_beam_search(n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4,
                      grid_dims=(32, 32, 32)):
-    dht = []
-    for i in range(dht_size):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-        dht.append(hivemind.DHT(start=True, initial_peers=neighbors_i, parallel_rpc=parallel_rpc))
+    dht = [hivemind.DHT(start=True)]
+    initial_peers = dht[0].get_visible_maddrs()
+    dht += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
 
     real_experts = sorted({
         'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
@@ -64,8 +62,8 @@ def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peer
         declare_experts(random.choice(dht), real_experts[batch_start: batch_start + batch_size], wait=True,
                         endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
 
-    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(dht, min(initial_peers, len(dht)))]
-    you = hivemind.DHT(start=True, initial_peers=neighbors_i, parallel_rpc=parallel_rpc)
+    neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
+    you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
     beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
 
     for i in range(10):
@@ -143,22 +141,23 @@ def test_uid_patterns():
 
 @pytest.mark.forked
 @pytest.mark.asyncio
-async def test_negative_caching():
-    peers = []
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True))
+async def test_negative_caching(n_peers=10):
+    dht_kwargs = {'cache_locally': False}
+
+    peers = [hivemind.DHT(start=True, **dht_kwargs)]
+    initial_peers = peers[0].get_visible_maddrs()
+    peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
 
     writer_peer = random.choice(peers)
     assert all(declare_experts(writer_peer, ['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
 
-    neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-    neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True)
+    neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
+    neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)
     beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix='ffn.', grid_size=(10, 10, 10), negative_caching=True)
     # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
     assert len(beam_search.get_initial_beam(scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
 
-    node = await DHTNode.create(initial_peers=neighbors_i)
+    node = await DHTNode.create(initial_peers=neighbors)
     fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
     for i in range(6):
         assert fetched[i] is not None, f"node should have cached ffn.{i}."

+ 107 - 127
tests/test_dht_node.py

@@ -2,72 +2,99 @@ import asyncio
 import heapq
 import multiprocessing as mp
 import random
+import signal
 from itertools import product
-from typing import Optional, List, Dict
+from typing import List, Sequence, Tuple
 
 import numpy as np
 import pytest
+from multiaddr import Multiaddr
 
 import hivemind
-from hivemind import get_dht_time, replace_port
-from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
-from hivemind.dht.protocol import DHTProtocol, ValidationError
+from hivemind import get_dht_time
+from hivemind.dht.node import DHTID, DHTNode
+from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.storage import DictionaryDHTValue
+from hivemind.p2p import P2P, PeerID
+from hivemind.utils.logging import get_logger
+from test_utils.dht_swarms import launch_swarm_in_separate_processes, launch_star_shaped_swarm
 
 
-def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
+logger = get_logger(__name__)
+
+
+def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
+    return list({PeerID.from_base58(maddr['p2p']) for maddr in maddrs})
+
+
+def run_protocol_listener(dhtid: DHTID, maddr_conn: mp.connection.Connection,
+                          initial_peers: Sequence[Multiaddr]) -> None:
     loop = asyncio.get_event_loop()
+
+    p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
+    visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
+
     protocol = loop.run_until_complete(DHTProtocol.create(
-        dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
+        p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5))
+
+    logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
 
-    assert protocol.port == port
-    print(f"Started peer id={protocol.node_id} port={port}", flush=True)
+    for endpoint in maddrs_to_peer_ids(initial_peers):
+        loop.run_until_complete(protocol.call_ping(endpoint))
 
-    if ping is not None:
-        loop.run_until_complete(protocol.call_ping(ping))
-    started.set()
-    loop.run_until_complete(protocol.server.wait_for_termination())
-    print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
+    maddr_conn.send((p2p.id, visible_maddrs))
 
+    async def shutdown():
+        await p2p.shutdown()
+        logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
+        loop.stop()
 
-# note: we run grpc-related tests in a separate process to re-initialize all global states from scratch
-# this helps us avoid undesirable side-effects (e.g. segfaults) when running multiple tests in sequence
+    loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
+    loop.run_forever()
+
+
+def launch_protocol_listener(initial_peers: Sequence[Multiaddr] = ()) -> \
+        Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
+    remote_conn, local_conn = mp.Pipe()
+    dht_id = DHTID.generate()
+    process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
+    process.start()
+    peer_id, visible_maddrs = local_conn.recv()
+
+    return dht_id, process, peer_id, visible_maddrs
+
+
+# note: we run network-related tests in a separate process to re-initialize all global states from scratch
+# this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
 
 
 @pytest.mark.forked
 def test_dht_protocol():
-    # create the first peer
-    peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-    peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
-    peer1_proc.start(), peer1_started.wait()
-
-    # create another peer that connects to the first peer
-    peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-    peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
-                            kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
-    peer2_proc.start(), peer2_started.wait()
+    peer1_id, peer1_proc, peer1_endpoint, peer1_maddrs = launch_protocol_listener()
+    peer2_id, peer2_proc, peer2_endpoint, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
 
     loop = asyncio.get_event_loop()
     for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
+        p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
         protocol = loop.run_until_complete(DHTProtocol.create(
-            DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
-        print(f"Self id={protocol.node_id}", flush=True)
+            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
+        logger.info(f"Self id={protocol.node_id}")
 
-        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
+        assert loop.run_until_complete(protocol.call_ping(peer1_endpoint)) == peer1_id
 
         key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
         store_ok = loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+            peer1_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
         )
         assert all(store_ok), "DHT rejected a trivial store"
 
         # peer 1 must know about peer 2
         (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
+            protocol.call_find(peer1_endpoint, [key]))[key]
         recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
         (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-        assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
-            f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
+        assert recv_id == peer2_id and recv_endpoint == peer2_endpoint, \
+            f"expected id={peer2_id}, peer={peer2_endpoint} but got {recv_id}, {recv_endpoint}"
 
         assert recv_value == value and recv_expiration == expiration, \
             f"call_find_value expected {value} (expires by {expiration}) " \
@@ -76,38 +103,35 @@ def test_dht_protocol():
         # peer 2 must know about peer 1, but not have a *random* nonexistent value
         dummy_key = DHTID.generate()
         empty_item, nodes_found_2 = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
+            protocol.call_find(peer2_endpoint, [dummy_key]))[dummy_key]
         assert empty_item is None, "Non-existent keys shouldn't have values"
         (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
-        assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
-            f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
+        assert recv_id == peer1_id and recv_endpoint == peer1_endpoint, \
+            f"expected id={peer1_id}, peer={peer1_endpoint} but got {recv_id}, {recv_endpoint}"
 
         # cause a non-response by querying a nonexistent peer
-        dummy_port = hivemind.find_open_port()
-        assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
+        assert loop.run_until_complete(protocol.call_find(PeerID.from_base58('fakeid'), [key])) is None
 
         # store/get a dictionary with sub-keys
         nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
         value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
         assert loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
+            peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
             expiration_time=[expiration], subkeys=[subkey1])
         )
         assert loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
+            peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
             expiration_time=[expiration + 5], subkeys=[subkey2])
         )
         (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
+            protocol.call_find(peer1_endpoint, [nested_key]))[nested_key]
         assert isinstance(recv_dict, DictionaryDHTValue)
         assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
 
-        assert LOCALHOST in loop.run_until_complete(protocol.get_outgoing_request_endpoint(f'{LOCALHOST}:{peer1_port}'))
-
         if listen:
-            loop.run_until_complete(protocol.shutdown())
+            loop.run_until_complete(p2p.shutdown())
 
     peer1_proc.terminate()
     peer2_proc.terminate()
@@ -116,83 +140,63 @@ def test_dht_protocol():
 @pytest.mark.forked
 def test_empty_table():
     """ Test RPC methods with empty routing table """
-    peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-    peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
-    peer_proc.start(), peer_started.wait()
+    peer_id, peer_proc, peer_endpoint, peer_maddrs = launch_protocol_listener()
 
     loop = asyncio.get_event_loop()
+    p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
     protocol = loop.run_until_complete(DHTProtocol.create(
-        DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
+        p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
 
     key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
 
     empty_item, nodes_found = loop.run_until_complete(
-        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        protocol.call_find(peer_endpoint, [key]))[key]
     assert empty_item is None and len(nodes_found) == 0
     assert all(loop.run_until_complete(protocol.call_store(
-        f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        peer_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
     )), "peer rejected store"
 
     (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        protocol.call_find(peer_endpoint, [key]))[key]
     recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
     assert len(nodes_found) == 0
     assert recv_value == value and recv_expiration == expiration
 
-    assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
-    assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
+    assert loop.run_until_complete(protocol.call_ping(peer_endpoint)) == peer_id
+    assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58('fakeid'))) is None
     peer_proc.terminate()
 
 
-def run_node(node_id, peers, status_pipe: mp.Pipe):
-    if asyncio.get_event_loop().is_running():
-        asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-        asyncio.set_event_loop(asyncio.new_event_loop())
-    loop = asyncio.get_event_loop()
-    node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
-    status_pipe.send(node.port)
-    while True:
-        loop.run_forever()
-
-
 @pytest.mark.forked
 def test_dht_node():
-    # create dht with 50 nodes + your 51-st node
-    dht: Dict[Endpoint, DHTID] = {}
-    processes: List[mp.Process] = []
-
-    for i in range(50):
-        node_id = DHTID.generate()
-        peers = random.sample(dht.keys(), min(len(dht), 5))
-        pipe_recv, pipe_send = mp.Pipe(duplex=False)
-        proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
-        proc.start()
-        port = pipe_recv.recv()
-        processes.append(proc)
-        dht[f"{LOCALHOST}:{port}"] = node_id
+    # step A: create a swarm of 50 dht nodes in separate processes
+    #         (first 5 created sequentially, others created in parallel)
+    processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(n_peers=50, n_sequential_peers=5)
 
+    # step B: run 51-st node in this process
     loop = asyncio.get_event_loop()
-    me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
+    initial_peers = random.choice(swarm_maddrs)
+    me = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, parallel_rpc=10,
                                                 cache_refresh_before_expiry=False))
 
     # test 1: find self
     nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
-    assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
+    assert len(nearest) == 1 and nearest[me.node_id] == me.endpoint
 
     # test 2: find others
-    for i in range(10):
+    for _ in range(10):
         ref_endpoint, query_id = random.choice(list(dht.items()))
         nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
         assert len(nearest) == 1
         found_node_id, found_endpoint = next(iter(nearest.items()))
-        assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
+        assert found_node_id == query_id and found_endpoint == ref_endpoint
 
     # test 3: find neighbors to random nodes
     accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
     jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
     all_node_ids = list(dht.values())
 
-    for i in range(10):
+    for _ in range(10):
         query_id = DHTID.generate()
         k_nearest = random.randint(1, 10)
         exclude_self = random.random() > 0.5
@@ -217,9 +221,9 @@ def test_dht_node():
         jaccard_denominator += k_nearest
 
     accuracy = accuracy_numerator / accuracy_denominator
-    print("Top-1 accuracy:", accuracy)  # should be 98-100%
+    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 98-100%
     jaccard_index = jaccard_numerator / jaccard_denominator
-    print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
+    logger.debug(f"Jaccard index (intersection over union): {jaccard_index}")  # should be 95-100%
     assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
     assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
 
@@ -232,14 +236,16 @@ def test_dht_node():
     # test 5: node without peers
     detached_node = loop.run_until_complete(DHTNode.create())
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
-    assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
+    assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.endpoint
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
     assert len(nearest) == 0
 
-    # test 6 store and get value
+    # test 6: store and get value
     true_time = get_dht_time() + 1200
     assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
-    that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
+
+    initial_peers = random.choice(swarm_maddrs)
+    that_guy = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, parallel_rpc=10,
                                                       cache_refresh_before_expiry=False, cache_locally=False))
 
     for node in [me, that_guy]:
@@ -285,19 +291,15 @@ def test_dht_node():
 
     for proc in processes:
         proc.terminate()
+    # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
+    loop.run_until_complete(asyncio.wait([node.shutdown() for node in [me, detached_node, that_guy]]))
 
 
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_replicas():
-    dht_size = 20
-    initial_peers = 3
     num_replicas = random.randint(1, 20)
-
-    peers = []
-    for i in range(dht_size):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
-        peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
+    peers = await launch_star_shaped_swarm(n_peers=20, num_replicas=num_replicas)
 
     you = random.choice(peers)
     assert await you.store('key1', 'foo', get_dht_time() + 999)
@@ -318,8 +320,8 @@ async def test_dhtnode_replicas():
 @pytest.mark.asyncio
 async def test_dhtnode_caching(T=0.05):
     node2 = await DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
-    node1 = await DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
-                                          cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
+    node1 = await DHTNode.create(initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
+                                 cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
     await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
     await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
     await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
@@ -363,10 +365,7 @@ async def test_dhtnode_caching(T=0.05):
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_reuse_get():
-    peers = []
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(await DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
+    peers = await launch_star_shaped_swarm(n_peers=10, parallel_rpc=256)
 
     await asyncio.gather(
         random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
@@ -396,51 +395,30 @@ async def test_dhtnode_reuse_get():
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_blacklist():
-    node1 = await DHTNode.create(blacklist_time=999)
-    node2 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
-    node3 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
-    node4 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node1, node2, node3, node4 = await launch_star_shaped_swarm(n_peers=4, blacklist_time=999)
 
     assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
     assert len(node2.blacklist.ban_counter) == 0
 
-    await node3.shutdown()
-    await node4.shutdown()
+    await asyncio.gather(node3.shutdown(), node4.shutdown())
 
     assert await node2.store('def', 456, expiration_time=hivemind.get_dht_time() + 99)
 
-    assert len(node2.blacklist.ban_counter) == 2
+    assert set(node2.blacklist.ban_counter.keys()) == {node3.endpoint, node4.endpoint}
 
-    for banned_peer in node2.blacklist.ban_counter:
-        assert any(banned_peer.endswith(str(port)) for port in [node3.port, node4.port])
-
-    node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
-    node3_endpoint = replace_port(node3_endpoint, node3.port)
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
-    assert node3_endpoint in node1.blacklist
+    assert node3.endpoint in node1.blacklist
 
-    node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
-    node2_endpoint = replace_port(node2_endpoint, node2.port)
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
-    assert node2_endpoint not in node1.blacklist
-
+    assert node2.endpoint not in node1.blacklist
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
-    node1 = await DHTNode.create(blacklist_time=999)
-    with pytest.raises(ValidationError):
-        node2 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
-                                              endpoint=fake_endpoint)
+    await asyncio.gather(node1.shutdown(), node2.shutdown())
 
 
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_edge_cases():
-    peers = []
-    for i in range(5):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(await DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
+    peers = await launch_star_shaped_swarm(n_peers=4, parallel_rpc=4)
 
     subkeys = [0, '', False, True, 'abyrvalg', 4555]
     keys = subkeys + [()]
@@ -453,3 +431,5 @@ async def test_dhtnode_edge_cases():
         assert stored is not None
         assert subkey in stored.value
         assert stored.value[subkey].value == value
+
+    await asyncio.wait([node.shutdown() for node in peers])

+ 17 - 11
tests/test_dht_schema.py

@@ -1,12 +1,14 @@
+import asyncio
+from typing import Dict
+
 import pytest
 from pydantic import BaseModel, StrictInt, conint
-from typing import Dict
 
 import hivemind
-from hivemind.utils.timed_storage import get_dht_time
-from hivemind.dht.node import DHTNode, LOCALHOST
+from hivemind.dht.node import DHTNode
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
+from hivemind.utils.timed_storage import get_dht_time
 
 
 class SampleSchema(BaseModel):
@@ -20,9 +22,10 @@ async def dht_nodes_with_schema():
     validator = SchemaValidator(SampleSchema)
 
     alice = await DHTNode.create(record_validator=validator)
-    bob = await DHTNode.create(
-        record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
-    return alice, bob
+    bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
+    yield alice, bob
+
+    await asyncio.gather(alice.shutdown(), bob.shutdown())
 
 
 @pytest.mark.forked
@@ -108,8 +111,7 @@ async def test_keys_outside_schema(dht_nodes_with_schema):
         assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))
 
         alice = await DHTNode.create(record_validator=validator)
-        bob = await DHTNode.create(
-            record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+        bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
 
         store_ok = await bob.store('unknown_key', b'foo_bar', get_dht_time() + 10)
         assert store_ok == allow_extra_keys
@@ -131,8 +133,7 @@ async def test_prefix():
     validator = SchemaValidator(Schema, allow_extra_keys=False, prefix='prefix')
 
     alice = await DHTNode.create(record_validator=validator)
-    bob = await DHTNode.create(
-        record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+    bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
 
     assert await bob.store('prefix_field', 777, get_dht_time() + 10)
     assert not await bob.store('prefix_field', 'string_value', get_dht_time() + 10)
@@ -142,6 +143,8 @@ async def test_prefix():
         assert (await peer.get('prefix_field', latest=True)).value == 777
         assert (await peer.get('field', latest=True)) is None
 
+    await asyncio.gather(alice.shutdown(), bob.shutdown())
+
 
 @pytest.mark.forked
 @pytest.mark.asyncio
@@ -188,7 +191,7 @@ async def test_merging_schema_validators(dht_nodes_with_schema):
 @pytest.mark.forked
 def test_sending_validator_instance_between_processes():
     alice = hivemind.DHT(start=True)
-    bob = hivemind.DHT(start=True, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+    bob = hivemind.DHT(start=True, initial_peers=alice.get_visible_maddrs())
 
     alice.add_validators([SchemaValidator(SampleSchema)])
     bob.add_validators([SchemaValidator(SampleSchema)])
@@ -196,3 +199,6 @@ def test_sending_validator_instance_between_processes():
     assert bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
     assert not bob.store('experiment_name', 777, get_dht_time() + 10)
     assert alice.get('experiment_name', latest=True).value == b'foo_bar'
+
+    alice.shutdown()
+    bob.shutdown()

+ 7 - 7
tests/test_moe.py

@@ -14,8 +14,8 @@ def test_moe():
     all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
                        for _ in range(10)]
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1,
-                           hidden_dim=16) as (server_endpoint, dht_endpoint):
-        dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])
+                           hidden_dim=16) as (server_endpoint, dht_maddrs):
+        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
 
         dmoe = hivemind.RemoteMixtureOfExperts(
             in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix='ffn.')
@@ -30,8 +30,8 @@ def test_no_experts():
     all_expert_uids = [f'expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
                        for _ in range(10)]
     with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='nop_delay', num_handlers=1,
-                           hidden_dim=16) as (server_endpoint, dht_endpoint):
-        dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])
+                           hidden_dim=16) as (server_endpoint, dht_maddrs):
+        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
 
         dmoe = hivemind.RemoteSwitchMixtureOfExperts(
             in_features=16, grid_size=(4, 4, 4), dht=dht, uid_prefix='expert.', forward_timeout=0.1,
@@ -54,7 +54,7 @@ def test_call_many(hidden_dim=16):
     atol = 1e-5
 
     with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
-                           optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
+                           optim_cls=None, no_dht=True) as (server_endpoint, _):
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
         e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]
@@ -96,7 +96,7 @@ def test_call_many(hidden_dim=16):
 @pytest.mark.forked
 def test_remote_module_call(hidden_dim=16):
     with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
-                           optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
+                           optim_cls=None, no_dht=True) as (server_endpoint, _):
         real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
         fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)
 
@@ -151,7 +151,7 @@ def test_determinism(hidden_dim=16):
     mask = torch.randint(0, 1, (32, hidden_dim))
 
     with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1, hidden_dim=hidden_dim,
-                           optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
+                           optim_cls=None, no_dht=True) as (server_endpoint, _):
         expert = hivemind.RemoteExpert(uid=f'expert.0', endpoint=server_endpoint)
 
         out = expert(xx, mask)

+ 22 - 42
tests/test_p2p_daemon.py

@@ -1,6 +1,5 @@
 import asyncio
 import multiprocessing as mp
-import socket
 import subprocess
 from functools import partial
 from typing import List
@@ -14,7 +13,6 @@ from hivemind.p2p import P2P, P2PHandlerError
 from hivemind.proto import dht_pb2, runtime_pb2
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.utils.networking import find_open_port
 
 
 def is_process_running(pid: int) -> bool:
@@ -28,7 +26,7 @@ async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
 async def bootstrap_from(daemons: List[P2P]) -> List[Multiaddr]:
     maddrs = []
     for d in daemons:
-        maddrs += await d.identify_maddrs()
+        maddrs += await d.get_visible_maddrs()
     return maddrs
 
 
@@ -43,39 +41,21 @@ async def test_daemon_killed_on_del():
     assert not is_process_running(child_pid)
 
 
+@pytest.mark.parametrize(
+    'host_maddrs', [
+        [Multiaddr('/ip4/127.0.0.1/tcp/0')],
+        [Multiaddr('/ip4/127.0.0.1/udp/0/quic')],
+        [Multiaddr('/ip4/127.0.0.1/tcp/0'), Multiaddr('/ip4/127.0.0.1/udp/0/quic')],
+    ]
+)
 @pytest.mark.asyncio
-async def test_error_for_wrong_daemon_arguments():
-    with pytest.raises(RuntimeError):
-        await P2P.create(unknown_argument=True)
-
-
-@pytest.mark.asyncio
-async def test_server_client_connection():
-    server = await P2P.create()
-    peers = await server.list_peers()
-    assert len(peers) == 0
-
-    nodes = await bootstrap_from([server])
-    client = await P2P.create(bootstrap_peers=nodes)
-    await client.wait_for_at_least_n_peers(1)
-
-    peers = await client.list_peers()
-    assert len(peers) == 1
-    peers = await server.list_peers()
-    assert len(peers) == 1
-
-
-@pytest.mark.asyncio
-async def test_quic_transport():
-    server_port = find_open_port((socket.AF_INET, socket.SOCK_DGRAM))
-    server = await P2P.create(quic=True, host_maddrs=[Multiaddr(f'/ip4/127.0.0.1/udp/{server_port}/quic')])
+async def test_transports(host_maddrs: List[Multiaddr]):
+    server = await P2P.create(quic=True, host_maddrs=host_maddrs)
     peers = await server.list_peers()
     assert len(peers) == 0
 
     nodes = await bootstrap_from([server])
-    client_port = find_open_port((socket.AF_INET, socket.SOCK_DGRAM))
-    client = await P2P.create(quic=True, host_maddrs=[Multiaddr(f'/ip4/127.0.0.1/udp/{client_port}/quic')],
-                              bootstrap_peers=nodes)
+    client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=nodes)
     await client.wait_for_at_least_n_peers(1)
 
     peers = await client.list_peers()
@@ -163,7 +143,7 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
             handler_cancelled = True
         return dht_pb2.PingResponse(
             peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
-            sender_endpoint=context.handle_name, available=True)
+            available=True)
 
     server_pid = server_primary._child.pid
     await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
@@ -171,7 +151,7 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
     assert is_process_running(server_pid)
 
     nodes = await bootstrap_from([server])
-    client_primary = await P2P.create(bootstrap_peers=nodes)
+    client_primary = await P2P.create(initial_peers=nodes)
     client = await replicate_if_needed(client_primary, replicate)
     client_pid = client_primary._child.pid
     assert is_process_running(client_pid)
@@ -182,7 +162,7 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
         validate=True)
     expected_response = dht_pb2.PingResponse(
         peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
-        sender_endpoint=handle_name, available=True)
+        available=True)
 
     if should_cancel:
         stream_info, reader, writer = await client._client.stream_open(server.id, (handle_name,))
@@ -215,7 +195,7 @@ async def test_call_unary_handler_error(handle_name="handle"):
     assert is_process_running(server_pid)
 
     nodes = await bootstrap_from([server])
-    client = await P2P.create(bootstrap_peers=nodes)
+    client = await P2P.create(initial_peers=nodes)
     client_pid = client._child.pid
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
@@ -249,7 +229,7 @@ async def test_call_peer_single_process(test_input, expected, handle, handler_na
     assert is_process_running(server_pid)
 
     nodes = await bootstrap_from([server])
-    client = await P2P.create(bootstrap_peers=nodes)
+    client = await P2P.create(initial_peers=nodes)
     client_pid = client._child.pid
     assert is_process_running(client_pid)
 
@@ -274,7 +254,7 @@ async def run_server(handler_name, server_side, client_side, response_received):
     assert is_process_running(server_pid)
 
     server_side.send(server.id)
-    server_side.send(await server.identify_maddrs())
+    server_side.send(await server.get_visible_maddrs())
     while response_received.value == 0:
         await asyncio.sleep(0.5)
 
@@ -301,7 +281,7 @@ async def test_call_peer_different_processes():
     peer_id = client_side.recv()
     peer_maddrs = client_side.recv()
 
-    client = await P2P.create(bootstrap_peers=peer_maddrs)
+    client = await P2P.create(initial_peers=peer_maddrs)
     client_pid = client._child.pid
     assert is_process_running(client_pid)
 
@@ -335,7 +315,7 @@ async def test_call_peer_torch_square(test_input, expected, handler_name="handle
     await server.add_stream_handler(handler_name, handle)
 
     nodes = await bootstrap_from([server])
-    client = await P2P.create(bootstrap_peers=nodes)
+    client = await P2P.create(initial_peers=nodes)
 
     await client.wait_for_at_least_n_peers(1)
 
@@ -366,7 +346,7 @@ async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
     await server.add_stream_handler(handler_name, handle)
 
     nodes = await bootstrap_from([server])
-    client = await P2P.create(bootstrap_peers=nodes)
+    client = await P2P.create(initial_peers=nodes)
 
     await client.wait_for_at_least_n_peers(1)
 
@@ -396,7 +376,7 @@ async def test_call_peer_error(replicate, handler_name="handle"):
     await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
 
     nodes = await bootstrap_from([server])
-    client_primary = await P2P.create(bootstrap_peers=nodes)
+    client_primary = await P2P.create(initial_peers=nodes)
     client = await replicate_if_needed(client_primary, replicate)
 
     await client.wait_for_at_least_n_peers(1)
@@ -428,7 +408,7 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
     await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2'))
 
     nodes = await bootstrap_from([server_primary])
-    client = await P2P.create(bootstrap_peers=nodes)
+    client = await P2P.create(initial_peers=nodes)
     await client.wait_for_at_least_n_peers(1)
 
     result = await client.call_peer_handler(server_id, handler_name, b'1')

+ 1 - 1
tests/test_p2p_daemon_bindings.py

@@ -11,7 +11,7 @@ from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, St
 from hivemind.p2p.p2p_daemon_bindings.utils import (ControlFailure, raise_if_failed, read_pbmsg_safe,
                                                     read_unsigned_varint, write_pbmsg, write_unsigned_varint)
 from hivemind.proto import p2pd_pb2 as p2pd_pb
-from test_utils import make_p2pd_pair_ip4, connect_safe
+from test_utils.p2p_daemon import make_p2pd_pair_ip4, connect_safe
 
 
 def test_raise_if_failed_raises():

+ 7 - 7
tests/test_training.py

@@ -20,7 +20,7 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
     SGD = partial(torch.optim.SGD, lr=0.05)
 
     with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1,
-                           no_dht=True) as (server_endpoint, dht_endpoint):
+                           no_dht=True) as (server_endpoint, _):
         expert1 = RemoteExpert('expert.0', server_endpoint)
         expert2 = RemoteExpert('expert.1', server_endpoint)
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
@@ -49,8 +49,8 @@ def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=
 
     all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
     with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \
-            as (server_endpoint, dht_endpoint):
-        dht = DHT(start=True, initial_peers=[dht_endpoint])
+            as (server_endpoint, dht_maddrs):
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
         moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix='expert.', k_best=2)
         model = nn.Sequential(moe, nn.Linear(64, 2))
@@ -92,8 +92,8 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
 
     all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
     with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64,
-                           num_handlers=1) as (server_endpoint, dht_endpoint):
-        dht = DHT(start=True, initial_peers=[dht_endpoint])
+                           num_handlers=1) as (server_endpoint, dht_maddrs):
+        dht = DHT(start=True, initial_peers=dht_maddrs)
 
         model = SwitchNetwork(dht, 64, 2, num_experts)
         opt = SGD(model.parameters(), lr=0.05)
@@ -116,7 +116,7 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
 @pytest.mark.forked
 def test_decentralized_optimizer_step():
     dht_root = DHT(start=True)
-    initial_peers = [f"127.0.0.1:{dht_root.port}"]
+    initial_peers = dht_root.get_visible_maddrs()
 
     param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
     opt1 = DecentralizedSGD([param1], lr=0.1, dht=DHT(initial_peers=initial_peers, start=True),
@@ -142,7 +142,7 @@ def test_decentralized_optimizer_step():
 @pytest.mark.forked
 def test_decentralized_optimizer_averaging():
     dht_root = DHT(start=True)
-    initial_peers = [f"127.0.0.1:{dht_root.port}"]
+    initial_peers = dht_root.get_visible_maddrs()
 
     param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
     opt1 = DecentralizedAdam([param1], lr=0.1, averaging_steps_period=1, dht=DHT(initial_peers=initial_peers, start=True),

+ 0 - 194
tests/test_utils/__init__.py

@@ -1,194 +0,0 @@
-import asyncio
-import functools
-import os
-import subprocess
-import time
-import uuid
-from contextlib import asynccontextmanager
-from typing import NamedTuple
-from pkg_resources import resource_filename
-
-from multiaddr import Multiaddr, protocols
-
-from hivemind import find_open_port
-from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
-
-
-TIMEOUT_DURATION = 30  # seconds
-P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")
-
-
-async def try_until_success(coro_func, timeout=TIMEOUT_DURATION):
-    """
-    Keep running ``coro_func`` until the time is out.
-    All arguments of ``coro_func`` should be filled, i.e. it should be called without arguments.
-    """
-    t_start = time.monotonic()
-    while True:
-        result = await coro_func()
-        if result:
-            break
-        if (time.monotonic() - t_start) >= timeout:
-            # timeout
-            assert False, f"{coro_func} still failed after `{timeout}` seconds"
-        await asyncio.sleep(0.01)
-
-
-class Daemon:
-    control_maddr = None
-    proc_daemon = None
-    log_filename = ""
-    f_log = None
-    closed = None
-
-    def __init__(
-            self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub
-    ):
-        self.control_maddr = control_maddr
-        self.enable_control = enable_control
-        self.enable_connmgr = enable_connmgr
-        self.enable_dht = enable_dht
-        self.enable_pubsub = enable_pubsub
-        self.is_closed = False
-        self._start_logging()
-        self._run()
-
-    def _start_logging(self):
-        name_control_maddr = str(self.control_maddr).replace("/", "_").replace(".", "_")
-        self.log_filename = f"/tmp/log_p2pd{name_control_maddr}.txt"
-        self.f_log = open(self.log_filename, "wb")
-
-    def _run(self):
-        cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
-        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
-        if self.enable_connmgr:
-            cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
-        if self.enable_dht:
-            cmd_list += ["-dht=true"]
-        if self.enable_pubsub:
-            cmd_list += ["-pubsub=true", "-pubsubRouter=gossipsub"]
-        self.proc_daemon = subprocess.Popen(
-            cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0
-        )
-
-    async def wait_until_ready(self):
-        lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
-        lines_head_occurred = {line: False for line in lines_head_pattern}
-
-        with open(self.log_filename, "rb") as f_log_read:
-
-            async def read_from_daemon_and_check():
-                line = f_log_read.readline()
-                for head_pattern in lines_head_occurred:
-                    if line.startswith(head_pattern):
-                        lines_head_occurred[head_pattern] = True
-                return all([value for _, value in lines_head_occurred.items()])
-
-            await try_until_success(read_from_daemon_and_check)
-
-        # sleep for a while in case that the daemon haven't been ready after emitting these lines
-        await asyncio.sleep(0.1)
-
-    def close(self):
-        if self.is_closed:
-            return
-        self.proc_daemon.terminate()
-        self.proc_daemon.wait()
-        self.f_log.close()
-        self.is_closed = True
-
-
-class DaemonTuple(NamedTuple):
-    daemon: Daemon
-    client: Client
-
-
-class ConnectionFailure(Exception):
-    pass
-
-
-@asynccontextmanager
-async def make_p2pd_pair_unix(
-        enable_control, enable_connmgr, enable_dht, enable_pubsub
-):
-    name = str(uuid.uuid4())[:8]
-    control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
-    listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
-    # Remove the existing unix socket files if they are existing
-    try:
-        os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
-    except FileNotFoundError:
-        pass
-    try:
-        os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
-    except FileNotFoundError:
-        pass
-    async with _make_p2pd_pair(
-            control_maddr=control_maddr,
-            listen_maddr=listen_maddr,
-            enable_control=enable_control,
-            enable_connmgr=enable_connmgr,
-            enable_dht=enable_dht,
-            enable_pubsub=enable_pubsub,
-    ) as pair:
-        yield pair
-
-
-@asynccontextmanager
-async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
-    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
-    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
-    async with _make_p2pd_pair(
-            control_maddr=control_maddr,
-            listen_maddr=listen_maddr,
-            enable_control=enable_control,
-            enable_connmgr=enable_connmgr,
-            enable_dht=enable_dht,
-            enable_pubsub=enable_pubsub,
-    ) as pair:
-        yield pair
-
-
-@asynccontextmanager
-async def _make_p2pd_pair(
-        control_maddr,
-        listen_maddr,
-        enable_control,
-        enable_connmgr,
-        enable_dht,
-        enable_pubsub,
-):
-    p2pd = Daemon(
-        control_maddr=control_maddr,
-        enable_control=enable_control,
-        enable_connmgr=enable_connmgr,
-        enable_dht=enable_dht,
-        enable_pubsub=enable_pubsub,
-    )
-    # wait for daemon ready
-    await p2pd.wait_until_ready()
-    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
-    try:
-        async with client.listen():
-            yield DaemonTuple(daemon=p2pd, client=client)
-    finally:
-        if not p2pd.is_closed:
-            p2pd.close()
-
-
-async def _check_connection(p2pd_tuple_0, p2pd_tuple_1):
-    peer_id_0, _ = await p2pd_tuple_0.identify()
-    peer_id_1, _ = await p2pd_tuple_1.identify()
-    peers_0 = [pinfo.peer_id for pinfo in await p2pd_tuple_0.list_peers()]
-    peers_1 = [pinfo.peer_id for pinfo in await p2pd_tuple_1.list_peers()]
-    return (peer_id_0 in peers_1) and (peer_id_1 in peers_0)
-
-
-async def connect_safe(p2pd_tuple_0, p2pd_tuple_1):
-    peer_id_1, maddrs_1 = await p2pd_tuple_1.identify()
-    await p2pd_tuple_0.connect(peer_id_1, maddrs_1)
-    await try_until_success(
-        functools.partial(
-            _check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1
-        )
-    )

+ 86 - 0
tests/test_utils/dht_swarms.py

@@ -0,0 +1,86 @@
+import asyncio
+import multiprocessing as mp
+import random
+import signal
+import threading
+from typing import Dict, List, Tuple
+
+from multiaddr import Multiaddr
+
+from hivemind.dht.node import DHTID, Endpoint, DHTNode
+
+
+def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
+    if asyncio.get_event_loop().is_running():
+        asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+        asyncio.set_event_loop(asyncio.new_event_loop())
+    loop = asyncio.get_event_loop()
+
+    node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10))
+    maddrs = loop.run_until_complete(node.get_visible_maddrs())
+
+    info_queue.put((node.node_id, node.endpoint, maddrs))
+
+    async def shutdown():
+        await node.shutdown()
+        loop.stop()
+
+    loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
+    loop.run_forever()
+
+
+def launch_swarm_in_separate_processes(n_peers: int, n_sequential_peers: int) -> \
+        Tuple[List[mp.Process], Dict[Endpoint, DHTID], List[List[Multiaddr]]]:
+    assert n_sequential_peers < n_peers, \
+        'Parameters imply that first n_sequential_peers of n_peers will be run sequentially'
+
+    processes = []
+    dht = {}
+    swarm_maddrs = []
+
+    info_queue = mp.Queue()
+    info_lock = mp.RLock()
+
+    for _ in range(n_sequential_peers):
+        initial_peers = random.choice(swarm_maddrs) if swarm_maddrs else []
+
+        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
+        proc.start()
+        processes.append(proc)
+
+        node_id, peer_endpoint, peer_maddrs = info_queue.get()
+        dht[peer_endpoint] = node_id
+        swarm_maddrs.append(peer_maddrs)
+
+    def collect_info():
+        while True:
+            node_id, peer_endpoint, peer_maddrs = info_queue.get()
+            with info_lock:
+                dht[peer_endpoint] = node_id
+                swarm_maddrs.append(peer_maddrs)
+
+                if len(dht) == n_peers:
+                    break
+
+    collect_thread = threading.Thread(target=collect_info)
+    collect_thread.start()
+
+    for _ in range(n_peers - n_sequential_peers):
+        with info_lock:
+            initial_peers = random.choice(swarm_maddrs)
+
+        proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
+        proc.start()
+        processes.append(proc)
+
+    collect_thread.join()
+
+    return processes, dht, swarm_maddrs
+
+
+async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
+    nodes = [await DHTNode.create(**kwargs)]
+    initial_peers = await nodes[0].get_visible_maddrs()
+    nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs)
+                                    for _ in range(n_peers - 1)])
+    return nodes

+ 194 - 0
tests/test_utils/p2p_daemon.py

@@ -0,0 +1,194 @@
+import asyncio
+import functools
+import os
+import subprocess
+import time
+import uuid
+from contextlib import asynccontextmanager
+from typing import NamedTuple
+from pkg_resources import resource_filename
+
+from multiaddr import Multiaddr, protocols
+
+from hivemind import find_open_port
+from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
+
+
+TIMEOUT_DURATION = 30  # seconds
+P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")
+
+
+async def try_until_success(coro_func, timeout=TIMEOUT_DURATION):
+    """
+    Keep running ``coro_func`` until the time is out.
+    All arguments of ``coro_func`` should be filled, i.e. it should be called without arguments.
+    """
+    t_start = time.monotonic()
+    while True:
+        result = await coro_func()
+        if result:
+            break
+        if (time.monotonic() - t_start) >= timeout:
+            # timeout
+            assert False, f"{coro_func} still failed after `{timeout}` seconds"
+        await asyncio.sleep(0.01)
+
+
+class Daemon:
+    control_maddr = None
+    proc_daemon = None
+    log_filename = ""
+    f_log = None
+    closed = None
+
+    def __init__(
+            self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub
+    ):
+        self.control_maddr = control_maddr
+        self.enable_control = enable_control
+        self.enable_connmgr = enable_connmgr
+        self.enable_dht = enable_dht
+        self.enable_pubsub = enable_pubsub
+        self.is_closed = False
+        self._start_logging()
+        self._run()
+
+    def _start_logging(self):
+        name_control_maddr = str(self.control_maddr).replace("/", "_").replace(".", "_")
+        self.log_filename = f"/tmp/log_p2pd{name_control_maddr}.txt"
+        self.f_log = open(self.log_filename, "wb")
+
+    def _run(self):
+        cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
+        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
+        if self.enable_connmgr:
+            cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
+        if self.enable_dht:
+            cmd_list += ["-dht=true"]
+        if self.enable_pubsub:
+            cmd_list += ["-pubsub=true", "-pubsubRouter=gossipsub"]
+        self.proc_daemon = subprocess.Popen(
+            cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0
+        )
+
+    async def wait_until_ready(self):
+        lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
+        lines_head_occurred = {line: False for line in lines_head_pattern}
+
+        with open(self.log_filename, "rb") as f_log_read:
+
+            async def read_from_daemon_and_check():
+                line = f_log_read.readline()
+                for head_pattern in lines_head_occurred:
+                    if line.startswith(head_pattern):
+                        lines_head_occurred[head_pattern] = True
+                return all([value for _, value in lines_head_occurred.items()])
+
+            await try_until_success(read_from_daemon_and_check)
+
+        # sleep for a while in case that the daemon haven't been ready after emitting these lines
+        await asyncio.sleep(0.1)
+
+    def close(self):
+        if self.is_closed:
+            return
+        self.proc_daemon.terminate()
+        self.proc_daemon.wait()
+        self.f_log.close()
+        self.is_closed = True
+
+
+class DaemonTuple(NamedTuple):
+    daemon: Daemon
+    client: Client
+
+
+class ConnectionFailure(Exception):
+    pass
+
+
+@asynccontextmanager
+async def make_p2pd_pair_unix(
+        enable_control, enable_connmgr, enable_dht, enable_pubsub
+):
+    name = str(uuid.uuid4())[:8]
+    control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
+    listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
+    # Remove the existing unix socket files if they are existing
+    try:
+        os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    try:
+        os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
+    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def _make_p2pd_pair(
+        control_maddr,
+        listen_maddr,
+        enable_control,
+        enable_connmgr,
+        enable_dht,
+        enable_pubsub,
+):
+    p2pd = Daemon(
+        control_maddr=control_maddr,
+        enable_control=enable_control,
+        enable_connmgr=enable_connmgr,
+        enable_dht=enable_dht,
+        enable_pubsub=enable_pubsub,
+    )
+    # wait for daemon ready
+    await p2pd.wait_until_ready()
+    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    try:
+        async with client.listen():
+            yield DaemonTuple(daemon=p2pd, client=client)
+    finally:
+        if not p2pd.is_closed:
+            p2pd.close()
+
+
+async def _check_connection(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_0, _ = await p2pd_tuple_0.identify()
+    peer_id_1, _ = await p2pd_tuple_1.identify()
+    peers_0 = [pinfo.peer_id for pinfo in await p2pd_tuple_0.list_peers()]
+    peers_1 = [pinfo.peer_id for pinfo in await p2pd_tuple_1.list_peers()]
+    return (peer_id_0 in peers_1) and (peer_id_1 in peers_0)
+
+
+async def connect_safe(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_1, maddrs_1 = await p2pd_tuple_1.identify()
+    await p2pd_tuple_0.connect(peer_id_1, maddrs_1)
+    await try_until_success(
+        functools.partial(
+            _check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1
+        )
+    )