Bladeren bron

Merge branch 'master' into decentralized_lr_scheduler

justheuristic 4 jaren geleden
bovenliggende
commit
d1d1627578
47 gewijzigde bestanden met toevoegingen van 3525 en 492 verwijderingen
  1. 0 70
      .circleci/config.yml
  2. 92 0
      .github/workflows/run-tests.yml
  3. 3 0
      .gitignore
  4. 13 3
      README.md
  5. 23 13
      benchmarks/benchmark_averaging.py
  6. 8 8
      benchmarks/benchmark_dht.py
  7. 5 1
      benchmarks/benchmark_tensor_compression.py
  8. 20 16
      benchmarks/benchmark_throughput.py
  9. 12 0
      codecov.yml
  10. 1 1
      docs/modules/client.rst
  11. 18 12
      examples/albert/README.md
  12. 11 5
      examples/albert/run_first_peer.py
  13. 7 6
      examples/albert/run_trainer.py
  14. 6 4
      examples/albert/tokenize_wikitext103.py
  15. 2 1
      hivemind/__init__.py
  16. 107 55
      hivemind/client/averaging/__init__.py
  17. 166 189
      hivemind/client/averaging/allreduce.py
  18. 1 0
      hivemind/client/averaging/load_balancing.py
  19. 1 1
      hivemind/client/averaging/matchmaking.py
  20. 224 0
      hivemind/client/averaging/partition.py
  21. 37 3
      hivemind/optim/collaborative.py
  22. 1 0
      hivemind/p2p/__init__.py
  23. 377 0
      hivemind/p2p/p2p_daemon.py
  24. 0 0
      hivemind/p2p/p2p_daemon_bindings/__init__.py
  25. 210 0
      hivemind/p2p/p2p_daemon_bindings/control.py
  26. 170 0
      hivemind/p2p/p2p_daemon_bindings/datastructures.py
  27. 85 0
      hivemind/p2p/p2p_daemon_bindings/p2pclient.py
  28. 73 0
      hivemind/p2p/p2p_daemon_bindings/utils.py
  29. 1 1
      hivemind/proto/averaging.proto
  30. 166 0
      hivemind/proto/p2pd.proto
  31. 1 1
      hivemind/server/runtime.py
  32. 49 1
      hivemind/utils/asyncio.py
  33. 14 1
      hivemind/utils/compression.py
  34. 5 1
      hivemind/utils/grpc.py
  35. 1 1
      hivemind/utils/threading.py
  36. 1 0
      requirements-dev.txt
  37. 2 0
      requirements.txt
  38. 80 12
      setup.py
  39. 217 0
      tests/test_allreduce.py
  40. 2 2
      tests/test_auth.py
  41. 76 77
      tests/test_averaging.py
  42. 2 4
      tests/test_dht_schema.py
  43. 1 2
      tests/test_dht_validation.py
  44. 440 0
      tests/test_p2p_daemon.py
  45. 559 0
      tests/test_p2p_daemon_bindings.py
  46. 41 1
      tests/test_util_modules.py
  47. 194 0
      tests/test_utils/__init__.py

+ 0 - 70
.circleci/config.yml

@@ -1,70 +0,0 @@
-version: 2.1
-
-jobs:
-  build-and-test-py37:
-    docker:
-      - image: circleci/python:3.7.10
-    steps:
-      - checkout
-      - restore_cache:
-          keys:
-            - py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-      - run: pip install -r requirements.txt
-      - run: pip install -r requirements-dev.txt
-      - save_cache:
-          key: py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-          paths:
-            - '~/.cache/pip'
-      - run:
-          command: pip install -e .
-          name: setup
-      - run:
-          command: pytest ./tests
-          name: tests
-  build-and-test-py38:
-    docker:
-      - image: circleci/python:3.8.1
-    steps:
-      - checkout
-      - restore_cache:
-          keys:
-            - py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-      - run: pip install -r requirements.txt
-      - run: pip install -r requirements-dev.txt
-      - save_cache:
-          key: py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-          paths:
-            - '~/.cache/pip'
-      - run:
-          command: pip install -e .
-          name: setup
-      - run:
-          command: pytest ./tests
-          name: tests
-  build-and-test-py39:
-    docker:
-      - image: circleci/python:3.9.1
-    steps:
-      - checkout
-      - restore_cache:
-          keys:
-            - py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-      - run: pip install -r requirements.txt
-      - run: pip install -r requirements-dev.txt
-      - save_cache:
-          key: py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-          paths:
-            - '~/.cache/pip'
-      - run:
-          command: pip install -e .
-          name: setup
-      - run:
-          command: pytest ./tests
-          name: tests
-
-workflows:
-  main:
-    jobs:
-      - build-and-test-py37
-      - build-and-test-py38
-      - build-and-test-py39

+ 92 - 0
.github/workflows/run-tests.yml

@@ -0,0 +1,92 @@
+name: Tests
+
+on: [ push ]
+
+
+jobs:
+  run_tests:
+
+    runs-on: ubuntu-latest
+    strategy:
+      matrix:
+        python-version: [ 3.7, 3.8, 3.9 ]
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install .
+      - name: Test
+        run: |
+          cd tests
+          pytest --durations=0 --durations-min=1.0
+
+  build_and_test_p2pd:
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: '3.8'
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-3.8-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install . --global-option=build_py --global-option="--buildgo"
+      - name: Test
+        run: |
+          cd tests
+          pytest -k "p2p" 
+
+  codecov_in_develop_mode:
+
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: '3.8'
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-3.8-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install -e .
+      - name: Test
+        run: |
+          pytest --cov=hivemind tests
+      - name: Upload coverage to Codecov
+        uses: codecov/codecov-action@v1

+ 3 - 0
.gitignore

@@ -78,3 +78,6 @@ debian/files
 
 # protobuf stuff
 hivemind/proto/*_pb2*
+
+# libp2p-daemon binary
+hivemind/hivemind_cli/p2pd

+ 13 - 3
README.md

@@ -1,6 +1,6 @@
 ## Hivemind: decentralized deep learning in PyTorch
 
-[![Build status](https://circleci.com/gh/learning-at-home/hivemind.svg?style=shield)](https://circleci.com/gh/learning-at-home/hivemind)
+[![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions)
 [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest)
 [![Gitter](https://badges.gitter.im/learning-at-home/hivemind.svg)](https://gitter.im/learning-at-home/hivemind?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
 
@@ -76,8 +76,18 @@ our [guide](https://learning-at-home.readthedocs.io/en/latest/user/contributing.
 
 ## Citation
 
-If you found hivemind useful for your experiments, you can cite [the paper](https://arxiv.org/abs/2002.04013) that
-inspired it:
+If you found hivemind or its underlying algorithms useful for your experiments, please cite the following source:
+
+```
+@misc{hivemind,
+  author = {Learning@home team},
+  title = {{H}ivemind: a {L}ibrary for {D}ecentralized {D}eep {L}earning},
+  year = 2020,
+  howpublished = {\url{https://github.com/learning-at-home/hivemind}},
+}
+```
+
+Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired the creation of this library:
 
 ```
 @inproceedings{ryabinin2020crowdsourced,

+ 23 - 13
benchmarks/benchmark_averaging.py

@@ -6,10 +6,13 @@ import argparse
 import torch
 
 import hivemind
-from hivemind.utils import LOCALHOST, increase_file_limit
+from hivemind.utils import LOCALHOST, increase_file_limit, get_logger
 from hivemind.proto import runtime_pb2
 
 
+logger = get_logger(__name__)
+
+
 def sample_tensors(hid_size, num_layers):
     tensors = []
     for i in range(num_layers):
@@ -38,8 +41,11 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
     peer_tensors = [sample_tensors(hid_size, num_layers)
                     for _ in range(num_peers)]
     processes = {dht_root}
+    lock_stats = threading.Lock()
+    successful_steps = total_steps = 0
 
     def run_averager(index):
+        nonlocal successful_steps, total_steps, lock_stats
         dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
                            initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
                            start=True)
@@ -50,11 +56,17 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
             averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
         processes.update({dht, averager})
 
-        print(end=f'<started {index}>\n', flush=True)
-        for _ in range(num_rounds):
-            success = averager.step(timeout=round_timeout)
-            print(end=('+' if success else '-'), flush=True)
-        print(end=f'<finished {index}>\n', flush=True)
+        logger.info(f'Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}')
+        for step in range(num_rounds):
+            try:
+                success = averager.step(timeout=round_timeout) is not None
+            except:
+                success = False
+            with lock_stats:
+                successful_steps += int(success)
+                total_steps += 1
+            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
+        logger.info(f"Averager {index}: done.")
 
     threads = []
     for i in range(num_peers):
@@ -67,10 +79,8 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
     for thread in threads:
         thread.join()
 
-    print(f"\ntest run took {time.time() - t:.3f} seconds")
-
-    for process in processes:
-        process.terminate()
+    logger.info(f"Benchmark finished in {time.time() - t:.3f} seconds.")
+    logger.info(f"Success rate: {successful_steps / total_steps} ({successful_steps} out of {total_steps} attempts)")
 
 
 if __name__ == "__main__":
@@ -80,9 +90,9 @@ if __name__ == "__main__":
     parser.add_argument('--num_rounds', type=int, default=5, required=False)
     parser.add_argument('--hid_size', type=int, default=256, required=False)
     parser.add_argument('--num_layers', type=int, default=3, required=False)
-    parser.add_argument('--averaging_expiration', type=float, default=15, required=False)
-    parser.add_argument('--round_timeout', type=float, default=30, required=False)
-    parser.add_argument('--request_timeout', type=float, default=3, required=False)
+    parser.add_argument('--averaging_expiration', type=float, default=5, required=False)
+    parser.add_argument('--round_timeout', type=float, default=15, required=False)
+    parser.add_argument('--request_timeout', type=float, default=1, required=False)
     parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
     parser.add_argument('--increase_file_limit', action="store_true")
     args = vars(parser.parse_args())

+ 8 - 8
benchmarks/benchmark_dht.py

@@ -20,7 +20,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
                   wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration: float):
     random.seed(random_seed)
 
-    print("Creating peers...")
+    logger.info("Creating peers...")
     peers = []
     for _ in trange(num_peers):
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
@@ -32,10 +32,10 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
     expert_uids = list(set(f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
                            for _ in range(num_experts)))
-    print(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
+    logger.info(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
     random.shuffle(expert_uids)
 
-    print(f"Storing experts to dht in batches of {expert_batch_size}...")
+    logger.info(f"Storing experts to dht in batches of {expert_batch_size}...")
     successful_stores = total_stores = total_store_time = 0
     benchmark_started = time.perf_counter()
     endpoints = []
@@ -52,8 +52,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
         successful_stores += sum(successes)
         time.sleep(wait_after_request)
 
-    print(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
-    print(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
+    logger.info(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
+    logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
     time.sleep(wait_before_read)
 
     if time.perf_counter() - benchmark_started > expiration:
@@ -74,11 +74,11 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     if time.perf_counter() - benchmark_started > expiration:
         logger.warning("keys expired midway during get requests. If that isn't desired, increase expiration_time param")
 
-    print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
-    print(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
+    logger.info(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
+    logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
 
     alive_peers = [peer.is_alive() for peer in peers]
-    print(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
+    logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
 
 
 if __name__ == "__main__":

+ 5 - 1
benchmarks/benchmark_tensor_compression.py

@@ -5,6 +5,10 @@ import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
 
 
 def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
@@ -29,4 +33,4 @@ if __name__ == "__main__":
         for i in range(args.num_iters):
             tm += benchmark_compression(X, compression_type)
         tm /= args.num_iters
-        print(f"Compression type: {name}, time: {tm}")
+        logger.info(f"Compression type: {name}, time: {tm}")

+ 20 - 16
benchmarks/benchmark_throughput.py

@@ -10,19 +10,23 @@ import hivemind
 from hivemind import find_open_port
 from hivemind.server import layers
 from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
 
 
 def print_device_info(device=None):
     """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
     device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
-    print('Using device:', device)
+    logger.info(f'Using device: {device}')
 
     # Additional Info when using cuda
     if device.type == 'cuda':
-        print(torch.cuda.get_device_name(0))
-        print('Memory Usage:')
-        print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
-        print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
+        logger.info(torch.cuda.get_device_name(0))
+        logger.info(f'Memory Usage:')
+        logger.info(f'Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB')
+        logger.info(f'Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB')
 
 
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
@@ -111,25 +115,25 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
         abs(timestamps[key2] - timestamps[key1]) if (key1 in timestamps and key2 in timestamps) else float('nan')
     total_examples = batch_size * num_clients * num_batches_per_client
 
-    print('\n' * 3)
-    print("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
-    print(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, max_batch_size={max_batch_size},"
+    logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
+    logger.info(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, max_batch_size={max_batch_size},"
           f" expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
-    print(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
+    logger.info(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
           f"batch_size={batch_size}, backprop={backprop}")
 
-    print("Results: ")
-    print(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
+    logger.info("Results: ")
+    logger.info(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
           f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
           f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
-    print(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
-    print(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
+    logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
+    logger.info(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
           f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
-    print(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
+    logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
     if benchmarking_failed.is_set():
-        print("Note: benchmark code failed, timing/memory results only indicate time till failure!")
+        logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
     print_device_info(device)
-    print(flush=True)
+    sys.stdout.flush()
+    sys.stderr.flush()
 
     assert not benchmarking_failed.is_set()
 

+ 12 - 0
codecov.yml

@@ -0,0 +1,12 @@
+comment:
+  layout: "diff, files"
+  behavior: default
+  require_changes: true
+coverage:
+  status:
+    patch:
+      default:
+        informational: true
+    project:
+      default:
+        threshold: 1%

+ 1 - 1
docs/modules/client.rst

@@ -25,4 +25,4 @@
 .. autoclass:: DecentralizedAverager
    :members:
    :member-order: bysource
-   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part
+   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part, register_allreduce_group

+ 18 - 12
examples/albert/README.md

@@ -12,21 +12,24 @@ This tutorial will walk you through the steps to set up collaborative training w
 ## Running an experiment
 - Run the first DHT peer to welcome trainers and record training statistics (e.g. loss, performance):
    - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics; If you're unfamiliar with Weights & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
-   - Run `python run_first_peer.py --listen_on '[::]:*' --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
+   - Run `python run_first_peer.py --dht_listen_on '[::]:*' --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
    - `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.` due to naming conventions.
    - `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the same project name.
    - This peer will run a DHT node on a certain IP/port (`Running DHT root at ...`). You will need this address for next steps
 ```
-+ python ./run_first_peer.py --listen_on '[::]:31209' --experiment_prefix ysda_albert_v10 --wandb_project Demo-run
-[2021/04/19 02:30:06.051][WARN][root.<module>:36] No address specified. Attempting to infer address from DNS.
-[2021/04/19 02:30:06.088][INFO][root.<module>:44] Running DHT root at 18.217.13.97:31209
-wandb: Currently logged in as: ??? (use `wandb login --relogin` to force relogin)
-wandb: Tracking run with wandb version 0.10.26
-wandb: Syncing run wandering-sky-58
-wandb: ⭐ View project at https://wandb.ai/yhn112/Demo-run
-wandb: 🚀 View run at https://wandb.ai/yhn112/Demo-run/runs/38ygvt3n
-wandb: Run data is saved locally in /home/hivemind/examples/albert/wandb/run-20210419_023006-38ygvt3n
++ python run_first_peer.py --dht_listen_on '[::]:*' --experiment_prefix my-albert-v1 --wandb_project Demo-run
+[2021/06/17 16:26:35.931][WARN][root.<module>:140] No address specified. Attempting to infer address from DNS.
+[2021/06/17 16:26:36.083][INFO][root.<module>:149] Running DHT root at 193.106.95.184:38319
+wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
+wandb: Tracking run with wandb version 0.10.32
+wandb: Syncing run dry-mountain-2
+wandb:  View project at https://wandb.ai/XXX/Demo-run
+wandb:  View run at https://wandb.ai/XXX/Demo-run/runs/YYY
+wandb: Run data is saved locally in /path/to/run/data
 wandb: Run `wandb offline` to turn off syncing.
+[2021/04/19 02:26:41.064][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
+[2021/04/19 02:26:44.068][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
+...
 [2021/04/19 02:37:37.246][INFO][root.<module>:74] 11.05164
 [2021/04/19 02:39:37.441][INFO][root.<module>:74] 11.03771
 [2021/04/19 02:40:37.541][INFO][root.<module>:74] 11.02886
@@ -37,7 +40,7 @@ wandb: Run `wandb offline` to turn off syncing.
   - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
   - run:
 ```shell
- CUDA_VISIBLE_DEVICES=0 HIVEMIND_THREADS=64 python ./hivemind/examples/albert/run_trainer.py \
+HIVEMIND_THREADS=64 python run_trainer.py \
  --experiment_prefix SAME_AS_IN_RUN_FIRST_PEER --initial_peers ONE_OR_MORE_PEERS --seed 42 \
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
 ```
@@ -45,11 +48,14 @@ Here, `ONE_OR_MORE_PEERS` stands for either your coordinator endpoint (e.g. `123
 
 As the peer begins training, it will periodically report training logs in the following form:
 ```
-{'loss': 4.3577, 'learning_rate': 0.001318944, 'epoch': 0.0}
 [...][INFO][...] Collaboration accumulated 448 samples from 17 peers; ETA 18.88 seconds (refresh in 15.73s.)
 [...][INFO][...] Collaboration accumulated 4096 samples from 16 peers; ETA 0.00 seconds (refresh in 0.50s.)
 [...][INFO][optim.collaborative.step:195] Averaged tensors successfully with 17 peers
 [...][INFO][optim.collaborative.step:211] Optimizer step: done!
+06/17/2021 18:58:23 - INFO - __main__ -   Step 0
+06/17/2021 18:58:23 - INFO - __main__ -   Your current contribution: 892 samples
+06/17/2021 18:58:23 - INFO - __main__ -   Local loss: 11.023
+
 ```
 
 __Sanity check:__ a healthy peer will periodically report `Averaged tensors successfully with [N > 1]` peers.

+ 11 - 5
examples/albert/run_first_peer.py

@@ -17,7 +17,6 @@ import hivemind
 from hivemind.utils.logging import get_logger
 import metrics_utils
 
-
 logger = get_logger(__name__)
 
 
@@ -163,6 +162,10 @@ if __name__ == '__main__':
                        for peer in metrics_dict]
             latest_step = max(item.step for item in metrics)
             if latest_step != current_step:
+                logger.debug(f"Got metrics from {len(metrics)} peers")
+
+                for i, metrics_for_peer in enumerate(metrics):
+                    logger.debug(f"{i} peer {metrics_for_peer}")
                 current_step = latest_step
                 alive_peers = 0
                 num_batches = 0
@@ -176,17 +179,20 @@ if __name__ == '__main__':
                     sum_perf += item.samples_per_second
                     num_samples += item.samples_accumulated
                     sum_mini_steps += item.mini_steps
+                current_loss = sum_loss / sum_mini_steps
+
                 if coordinator_args.wandb_project is not None:
                     wandb.log({
-                        "loss": sum_loss / sum_mini_steps,
+                        "loss": current_loss,
                         "alive peers": alive_peers,
                         "samples": num_samples,
-                        "performance": sum_perf
+                        "performance": sum_perf,
+                        "step": latest_step
                     })
                 if checkpoint_handler.is_time_to_save_state(current_step):
                     checkpoint_handler.save_state(current_step)
                     if checkpoint_handler.is_time_to_upload():
-                        checkpoint_handler.upload_checkpoint(sum_loss / sum_mini_steps)
-                logger.info(f"Step #{current_step}\tloss = {sum_loss / alive_peers:.5f}")
+                        checkpoint_handler.upload_checkpoint(current_loss)
+                logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
         logger.debug("Peer is still alive...")
         time.sleep(coordinator_args.refresh_period)

+ 7 - 6
examples/albert/run_trainer.py

@@ -112,7 +112,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
     def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
                        control: transformers.TrainerControl, **kwargs):
-        logger.warning('Loading state from peers')
+        logger.info('Loading state from peers')
         self.collaborative_optimizer.load_state_from_peers()
 
     def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
@@ -139,14 +139,15 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 logger.info(f"Step {self.collaborative_optimizer.local_step}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 if self.steps:
-                    logger.info(f"Loss of your model: {self.loss/self.steps}")
+                    logger.info(f"Local loss: {self.loss / self.steps}")
 
                 self.loss = 0
                 self.steps = 0
-                self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
-                               subkey=self.local_public_key, value=statistics.dict(),
-                               expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
-                               return_future=True)
+                if self.collaborative_optimizer.is_synchronized:
+                    self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
+                                   subkey=self.local_public_key, value=statistics.dict(),
+                                   expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                                   return_future=True)
 
         self.samples = self.collaborative_optimizer.local_samples_accumulated
 

+ 6 - 4
examples/albert/tokenize_wikitext103.py

@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 """ This script builds a pre-tokenized compressed representation of wikitext103 using huggingface/datasets """
 import random
-from collections import defaultdict
 from functools import partial
 from multiprocessing import cpu_count
 
@@ -10,6 +9,9 @@ from datasets import load_dataset
 from transformers import AlbertTokenizerFast
 
 
+COLUMN_NAMES = ('attention_mask', 'input_ids', 'sentence_order_label', 'special_tokens_mask', 'token_type_ids')
+
+
 def create_instances_from_document(tokenizer, document, max_seq_length):
     """Creates `TrainingInstance`s for a single document."""
     # We DON'T just concatenate all of the tokens from a document into a long
@@ -76,14 +78,14 @@ def tokenize_function(tokenizer, examples):
     # Remove empty texts
     texts = (text for text in examples["text"] if len(text) > 0 and not text.isspace())
 
-    new_examples = defaultdict(list)
+    new_examples = {col: [] for col in COLUMN_NAMES}
 
     for text in texts:
         instances = create_instances_from_document(tokenizer, text, max_seq_length=512)
         for instance in instances:
             for key, value in instance.items():
                 new_examples[key].append(value)
-
+    
     return new_examples
 
 
@@ -96,7 +98,7 @@ if __name__ == '__main__':
     tokenized_datasets = wikitext.map(
         partial(tokenize_function, tokenizer),
         batched=True,
-        num_proc=cpu_count(),
+        num_proc=8,
         remove_columns=["text"],
     )
 

+ 2 - 1
hivemind/__init__.py

@@ -1,7 +1,8 @@
 from hivemind.client import *
 from hivemind.dht import *
+from hivemind.p2p import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.optim import *
 
-__version__ = '0.9.8'
+__version__ = "0.9.10"

+ 107 - 55
hivemind/client/averaging/__init__.py

@@ -20,7 +20,8 @@ import torch
 import numpy as np
 
 from hivemind.dht import DHT, DHTID
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
+from hivemind.client.averaging.partition import DEFAULT_PART_SIZE_BYTES
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.client.averaging.group_info import GroupInfo
@@ -34,9 +35,8 @@ from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescripto
 
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
-DataForGather = Any
+GatheredData = Any
 logger = get_logger(__name__)
-DEFAULT_CHUNK_SIZE_BYTES = 2 ** 16
 
 
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
@@ -61,7 +61,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
-    :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
+    :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param throughput: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           If throughput == 0, averager will rely on its groupmates to do all the averaging.
@@ -71,6 +71,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
     :param kwargs: extra parameters forwarded to grpc.aio.server
+    :param auxiliary: if this flag is specified, averager.step will only assist others without sending
+          local tensors for averaging
+    :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
+      with averager.allow_state_sharing = True / False
 
     Example:
 
@@ -90,10 +94,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
-                 allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
+                 averaging_expiration: float = 15, request_timeout: float = 3, averaging_alpha: float = 1.0,
+                 part_size_bytes: int = DEFAULT_PART_SIZE_BYTES, allreduce_timeout: Optional[float] = None,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
+                 auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
@@ -102,10 +107,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
+        assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
 
         super().__init__()
         self.dht = dht
         self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
+        if not self.listen:
+            self.mode = AveragingMode.CLIENT
+        elif auxiliary:
+            self.mode = AveragingMode.AUX
+        else:
+            self.mode = AveragingMode.NODE
+
         self.channel_options = channel_options
         self.daemon = daemon
 
@@ -122,13 +135,17 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.matchmaking_kwargs = dict(
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
-        self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes,
+        self.allreduce_kwargs = dict(compression_type=compression_type, part_size_bytes=part_size_bytes,
                                      min_vector_size=min_vector_size)
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
+
+        self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
+        self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing
+
         self._averager_endpoint: Optional[Endpoint] = None
         if not self.listen:
             self._averager_endpoint = f'client::{uuid.uuid4()}'
@@ -146,6 +163,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     def port(self) -> Optional[Port]:
         return self._port.value if self._port.value != 0 else None
 
+    @property
+    def allow_state_sharing(self) -> bool:
+        """ if set to True, other peers can download this peer's state """
+        return bool(self._allow_state_sharing.value)
+
+    @allow_state_sharing.setter
+    def allow_state_sharing(self, value: bool):
+        if value is True and not self.listen:
+            logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.")
+        else:
+            self._allow_state_sharing.value = value
+
     @property
     def endpoint(self) -> Optional[Endpoint]:
         if self.listen and self._averager_endpoint is None:
@@ -222,8 +251,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if self._parent_pid != os.getpid() or self.is_alive():
             self.shutdown()
 
-    def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None,
-             allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
+    def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = None,
+             timeout: Optional[float] = None, allow_retries: bool = True, wait: bool = True
+             ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
@@ -236,7 +266,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
-        assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
+        if self.mode == AveragingMode.AUX and weight is not None:
+            logger.warning("Averager is running in auxiliary mode, weight is unused.")
+        if weight is None:
+            weight = float(self.mode != AveragingMode.AUX)
+        assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
+
         future, _future = MPFuture.make_pair()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
         self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
@@ -245,28 +280,21 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
                     allow_retries: bool, timeout: Optional[float]):
-        loop = asyncio.get_event_loop()
         start_time = get_dht_time()
-        group_id = None
 
         try:
             while not future.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
+                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary]) 
                     group_info = await self._matchmaking.look_for_group(timeout=timeout,
                                                                         data_for_gather=data_for_gather)
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
-                    group_id = group_info.group_id
-                    allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
-                    self._running_groups[group_id] = allreduce_runner
-                    self._pending_group_assembled.set()
-                    await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                    await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
 
-                    # averaging is finished, exit the loop
-                    future.set_result(allreduce_runner.gathered)
+                    future.set_result(await asyncio.wait_for(
+                        self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout))
+                    # averaging is finished, loop will now exit
 
                 except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
                         asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
@@ -277,10 +305,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     else:
                         logger.warning(f"Averager caught {repr(e)}, retrying")
 
-                finally:
-                    _ = self._running_groups.pop(group_id, None)
-                    self._pending_group_assembled.set()
-
         except BaseException as e:
             if not future.done():
                 future.set_exception(e)
@@ -290,35 +314,51 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
                                                   " Please report this to hivemind issues."))
 
-    async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
-        """ Use a group description found by Matchmaking to form AllreduceRunner """
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """ Run All-Reduce in a given group and update tensors in place, return gathered metadata """
         try:
-            weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
+            modes = tuple(map(AveragingMode, mode_ids))
 
-            # compute optimal part sizes from peer throughputs
-            incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
-            part_sizes = await asyncio.get_event_loop().run_in_executor(
+            # compute optimal part sizes from peer throughputs; TODO: replace with proper load balancing
+            incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0
+                                    for thr, mode in zip(throughputs, modes)]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
                 None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
-            async with self.get_tensors_async() as averaged_tensors:
-                return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
-                                       ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
-                                       weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
-        except Exception as e:
-            raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
 
-    def update_tensors(self, allreduce_group: AllReduceRunner):
-        """
-        a private (extendable) method that applies changes from a finished allreduce to local tensors
-        """
-        assert allreduce_group.return_deltas and allreduce_group.future.done()
-        averaging_deltas = allreduce_group.future.result()
+            async with self.get_tensors_async() as local_tensors:
+                allreduce = AllReduceRunner(
+                    group_id=group_info.group_id, tensors=local_tensors, endpoint=self.endpoint,
+                    ordered_group_endpoints=group_info.endpoints, peer_fractions=peer_fractions, weights=weights,
+                    gathered=user_gathered, modes=modes, **kwargs)
 
-        with torch.no_grad(), self.get_tensors() as local_tensors:
-            assert len(local_tensors) == len(self._averaged_tensors)
-            for tensor, update in zip(local_tensors, averaging_deltas):
-                tensor.add_(update, alpha=self._averaging_alpha)
-        self.last_updated = get_dht_time()
+                with self.register_allreduce_group(group_info.group_id, allreduce):
+
+                    # actually run all-reduce
+                    averaging_outputs = [output async for output in allreduce]
+
+                    if modes[group_info.endpoints.index(self.endpoint)] != AveragingMode.AUX:
+                        assert len(local_tensors) == len(self._averaged_tensors)
+                        for tensor, update in zip(local_tensors, averaging_outputs):
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self.last_updated = get_dht_time()
+
+                return allreduce.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+    @contextlib.contextmanager
+    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
+        """ registers a given all-reduce runner to listen for incoming connections """
+        try:
+            self._running_groups[group_id] = allreduce
+            self._pending_group_assembled.set()
+            yield
+        finally:
+            self._running_groups.pop(group_id, None)
+            self._pending_group_assembled.set()
 
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -366,10 +406,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _declare_for_download_periodically(self):
         download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
         while True:
-            asyncio.create_task(asyncio.wait_for(self.dht.store(
-                download_key, subkey=self.endpoint, value=self.last_updated,
-                expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
-                timeout=self._matchmaking.averaging_expiration))
+            if self.allow_state_sharing:
+                asyncio.create_task(asyncio.wait_for(self.dht.store(
+                    download_key, subkey=self.endpoint, value=self.last_updated,
+                    expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
+                    timeout=self._matchmaking.averaging_expiration))
             await asyncio.sleep(self._matchmaking.averaging_expiration)
 
     async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
@@ -381,11 +422,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
          - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
          - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
         """
-        chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
+        if not self.allow_state_sharing:
+            return  # deny request and direct peer to the next prospective averager
         metadata, tensors = await self._get_current_state_from_host_process()
 
         for tensor in tensors:
-            for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes):
+            for part in split_for_streaming(serialize_torch_tensor(tensor)):
                 if metadata is not None:
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     metadata = None
@@ -452,6 +494,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                             current_tensor_parts.append(message.tensor_part)
                         if current_tensor_parts:
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+
+                        if not metadata:
+                            logger.debug(f"Peer {peer} did not send its state.")
+                            continue
+
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
                         self.last_updated = get_dht_time()
@@ -512,7 +559,12 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
     :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
     """
     while True:
-        trigger, future = pipe.recv()
+        try:
+            trigger, future = pipe.recv()
+        except BaseException as e:
+            logger.debug(f"Averager background thread finished: {repr(e)}")
+            break
+            
         if trigger == '_SHUTDOWN':
             break
 

+ 166 - 189
hivemind/client/averaging/allreduce.py

@@ -1,252 +1,229 @@
 import asyncio
-from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
+from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
+from enum import Enum
 
 import grpc
 import torch
 
-from hivemind.utils import Endpoint, get_logger, ChannelCache, anext
-from hivemind.utils import split_for_streaming, combine_from_streaming
+from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
+from hivemind.utils import Endpoint, get_logger, ChannelCache
+from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
+from hivemind.proto import averaging_pb2_grpc, averaging_pb2
 
 # flavour types
 GroupID = bytes
 logger = get_logger(__name__)
 
 
-class AllReduceProtocol:
+class AveragingMode(Enum):
+    NODE = 0
+    CLIENT = 1
+    AUX = 2
+
+
+class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     """
     An internal class that runs butterfly AllReduce in a predefined group of averagers
 
+    :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
+    :param group_id: unique identifier of this specific all-reduce run
+    :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
-    :param part_sizes: for each peer, a number of vector elements that this peer is responsible for averaging
-    :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
-           default (False) - return averaged_tensors by themselves
+    :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
+      (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
+    :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
+    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
+    :param gathered: additional user-defined data collected from this group
+    :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     """
 
-    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
+    def __init__(
+            self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
+            ordered_group_endpoints: Sequence[Endpoint], peer_fractions: Tuple[float, ...],
+            weights: Optional[Sequence[float]] = None, modes: Optional[Sequence[AveragingMode]] = None,
+            gathered: Optional[Dict[Endpoint, Any]] = None, **kwargs):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
-        self.group_id, self.endpoint = group_id, endpoint
-        self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
-        self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes)
-                                      if part_size == 0}
-        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
-        self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
-        self.return_deltas = return_deltas
-
-        self.accumulator = torch.zeros_like(self.local_tensor_parts[self.endpoint])
-        self.denominator = 0.0  # number of peers added to accumulator or sum of their weights
-        self.accumulated_from: Set[Endpoint] = set()  # peers that we have accumulated our part from
-        self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
-        self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
-        self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
-        for endpoint in self.client_mode_endpoints:
-            self.averaged_tensor_parts[endpoint] = torch.tensor([])
+        modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
+        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
+        assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
+        assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
+        for mode, frac, weight in zip(modes, peer_fractions, weights):
+            assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
+            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
+
+        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
+
+        self._future = asyncio.Future()
+
+        self.sender_endpoints, self.sender_weights = [], []
+        for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
+            if mode != AveragingMode.AUX:
+                self.sender_endpoints.append(endpoint)
+                self.sender_weights.append(weight)
+
+        endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
+        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
+        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
+        self.tensor_part_reducer = TensorPartReducer(tuple(part.shape for part in self.parts_for_local_averaging),
+                                                     len(self.sender_endpoints), self.sender_weights)
 
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
 
-    def __await__(self):
-        return self.future.__await__()
+    def __aiter__(self):
+        return self.run()
 
     def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.local_tensor_parts
+        return endpoint in self.ordered_group_endpoints
 
     @property
     def group_size(self):
         return len(self.ordered_group_endpoints)
 
-    async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, weight: float = 1.0) -> torch.Tensor:
-        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
-        assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
-        assert not self.future.done(), f"already finished allreduce: {self.future}"
-        assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
-        assert source not in self.accumulated_from, "duplicate source, already received that part"
-        assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
-        assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
-        logger.debug(f"{self} - accumulating tensor part from {source}")
-
-        self.accumulator.add_(remote_part, alpha=weight)
-        self.denominator += weight
-        self.accumulated_from.add(source)
-
-        assert len(self.accumulated_from) <= self.group_size
-        if len(self.accumulated_from) == len(self.local_tensor_parts):
-            average_result = self.accumulator.div_(self.denominator)
-            self.register_averaged_part(self.endpoint, average_result)
-            self.averaged_part.set_result(average_result)
-
-        return await self.averaged_part
-
-    def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
-        assert not self.future.done(), f"already finished allreduce: {self.future}"
-        assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
-        assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
-        assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
-        assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
-        logger.debug(f"{self} - receiving averaged tensor part from {source}")
-        self.averaged_tensor_parts[source] = averaged_part
-        if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
-            ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
-            outputs = restore_from_parts(ordered_averaged_parts, self.tensor_shapes)
-
-            if self.return_deltas:
-                local_parts = [self.local_tensor_parts[peer] for peer in self.ordered_group_endpoints]
-                with torch.no_grad():
-                    original_tensors = restore_from_parts(local_parts, self.tensor_shapes)
-                    for averaged_tensor, original_tensor in zip(outputs, original_tensors):
-                        averaged_tensor -= original_tensor
-
-            self.future.set_result(outputs)
-
-    def cancel(self) -> bool:
-        if not self.future.done():
-            logger.debug(f"{self} - cancelled")
-            self.future.cancel()
-            if not self.averaged_part.done():
-                self.averaged_part.cancel()
-            return True
-        else:
-            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.future}")
-            return False
-
-    def set_exception(self, exception: Exception) -> bool:
-        if not self.future.done():
-            logger.debug(f"{self} - {exception}")
-            self.future.set_exception(exception)
-            if not self.averaged_part.done():
-                self.averaged_part.cancel()
-            return True
-        else:
-            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.future}")
-            return False
-
-
-class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragingServicer):
-    """
-    A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
-    """
-
-    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
-                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
-                 gathered: Dict[Endpoint, Any], return_deltas: bool = False):
-        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
-                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
-        self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
-        self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
-
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
 
-    async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
-        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
-        if peer_endpoint == self.endpoint:
-            return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
-        serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-        chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
-
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
-                                                       endpoint=self.endpoint, tensor_part=next(chunks)))
-        for chunk in chunks:
-            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
-        await stream.done_writing()
-
-        outputs: Sequence[averaging_pb2.AveragingData] = [message async for message in stream]
-        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
-        if code != averaging_pb2.AVERAGED_PART:
-            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
-                                     f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
-                                     f" allreduce failed")
-
+    async def run(self) -> AsyncIterator[torch.Tensor]:
+        """ Run all-reduce, return differences between averaged and original tensors as they are computed """
+        pending_tasks = set()
         try:
-            averaged_part = local_part + deserialize_torch_tensor(combine_from_streaming(
-                [message.tensor_part for message in outputs]))
-        except RuntimeError as e:
-            raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}")
+            if len(self.sender_endpoints) == 0:
+                logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
+                self.finalize()
 
-        self.register_averaged_part(peer_endpoint, averaged_part)
-        return averaged_part
+            elif self.endpoint in self.sender_endpoints:
+                for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
+                    if parts != 0:
+                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
 
-    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
-        await stream.done_writing()
+                async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
+                    yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
+                self.finalize()
+
+            else:  # auxiliary peer
+                await self.tensor_part_reducer.finished.wait()
+                self.finalize()
 
-    async def run(self) -> Sequence[torch.Tensor]:
-        """
-        send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
-        """
-        try:
-            await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
-                                         for i, peer in enumerate(self.ordered_group_endpoints)
-                                         if peer not in self.client_mode_endpoints))
-            return await self
         except BaseException as e:
+            self.finalize(exception=e)
+            for task in pending_tasks:
+                task.cancel()
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            self.set_exception(e)
-            for peer_endpoint, part_size in zip(self.ordered_group_endpoints, self.part_sizes):
-                if peer_endpoint != self.endpoint and part_size > 0:
+            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
+                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
                     asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
             raise
 
-    async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
-                                        ) -> Iterable[runtime_pb2.Tensor]:
-        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
-        try:
-            tensor_part = deserialize_torch_tensor(combine_from_streaming(stream_messages))
-        except RuntimeError as e:
-            raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")
+    async def _communicate_with_peer(self, peer_endpoint: Endpoint):
+        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
+        peer_index = self.ordered_group_endpoints.index(peer_endpoint)
+        if peer_endpoint == self.endpoint:
+            sender_index = self.sender_endpoints.index(peer_endpoint)
+            for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
+                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
-        averaged_part = await self.accumulate_part(source, tensor_part, weight=self.peer_weights[source])
-        serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False)
-        stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
-        return stream_chunks
+        else:
+            loop = asyncio.get_event_loop()
+            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+            write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))
+
+            try:
+                code = None
+                async for part_index, msg in aenumerate(stream):
+                    if code is None:
+                        code = msg.code
+                    averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
+                    self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
+                await write_task
+
+                if code != averaging_pb2.AVERAGED_PART:
+                    raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
+                                             f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
+                                             f", allreduce failed")
+            finally:
+                if not write_task.done():
+                    write_task.cancel()
+
+    async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
+        parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
+        first_part = await anext(parts_aiter)
+        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
+                                                       group_id=self.group_id, endpoint=self.endpoint,
+                                                       tensor_part=first_part))
+        async for part in parts_aiter:
+            await stream.write(averaging_pb2.AveragingData(tensor_part=part))
+
+        await stream.done_writing()
 
     async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
                                  ) -> AsyncIterator[averaging_pb2.AveragingData]:
-        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the delta"""
+        """ a peer sends us a part of his tensor; we should average it with other peers and return the difference """
         request: averaging_pb2.AveragingData = await anext(stream)
-
-        if request.group_id != self.group_id:
-            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        reason_to_reject = self._check_reasons_to_reject(request)
+        if reason_to_reject:
+            yield reason_to_reject
+            return
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
-                tensor_chunks = (request.tensor_part, *[msg.tensor_part async for msg in stream])
-                averaged_chunks = iter(await self.accumulate_part_streaming(request.endpoint, tensor_chunks))
-                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
-                for averaged_chunk in averaged_chunks:
-                    yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
+                sender_index = self.sender_endpoints.index(request.endpoint)
+                async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
+                    yield msg
 
             except Exception as e:
-                self.set_exception(e)
+                self.finalize(exception=e)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
             logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
-            self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
+    def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
+        if request.group_id != self.group_id:
+            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        elif self._future.cancelled():
+            return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        elif self._future.done():
+            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+
+    async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
+        loop = asyncio.get_event_loop()
+        async for part_index, (tensor_part, part_compression) in aenumerate(
+                amap_in_executor(lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression), stream,
+                                 max_prefetch=self.tensor_part_container.prefetch)):
+            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+
+            serialized_delta = await loop.run_in_executor(
+                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression))
+            yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
-def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int]) -> Tuple[torch.Tensor, ...]:
-    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
-    flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
-    return torch.split_with_sizes(flat_tensor, part_sizes, dim=0)
-
-
-def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
-    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
-    flat_tensor = torch.cat(tuple(chunks))
-    result_sizes = tuple(map(torch.Size.numel, shapes))
-    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
-    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
-
+    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
+        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
+        await stream.done_writing()
 
-class AllreduceException(Exception):
-    """ A special exception that is raised when allreduce can't continue normally (e.g. disbanded/bad request/etc) """
+    def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
+        assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
+        if not self._future.done():
+            if cancel:
+                logger.debug(f"{self} - cancelled")
+                self._future.cancel()
+            elif exception:
+                logger.debug(f"{self} - caught {exception}")
+                self._future.set_exception(exception)
+            else:
+                logger.debug(f"{self} - finished")
+                self._future.set_result(None)
+            self.tensor_part_container.finalize()
+            self.tensor_part_reducer.finalize()
+            return True
+        else:
+            logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
+            return False

+ 1 - 0
hivemind/client/averaging/load_balancing.py

@@ -28,6 +28,7 @@ def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_
         assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
         scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
 
+    #TODO(jheuristic) we no longer need hagenbach-bishoff with new AllReduceRunner
     return tuple(hagenbach_bishoff(vector_size, scores))
 
 

+ 1 - 1
hivemind/client/averaging/matchmaking.py

@@ -391,7 +391,7 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
 
-            if maybe_next_leader is None or entry.expiration_time >= self.declared_expiration_time:
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (self.declared_expiration_time, self.endpoint):
                 await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
                                    return_when=asyncio.FIRST_COMPLETED)
                 self.declared_expiration.clear()

+ 224 - 0
hivemind/client/averaging/partition.py

@@ -0,0 +1,224 @@
+"""
+Auxiliary data structures for AllReduceRunner
+"""
+import asyncio
+from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
+from collections import deque
+
+import torch
+import numpy as np
+
+from hivemind.proto.runtime_pb2 import CompressionType, Tensor
+from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
+from hivemind.utils.asyncio import amap_in_executor
+
+
+T = TypeVar('T')
+DEFAULT_PART_SIZE_BYTES = 2 ** 20
+
+
+class TensorPartContainer:
+    """
+    Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
+    The class is designed to avoid excessive memory allocation and run all heavy computation in background
+    :param tensors: local tensors to be split and aggregated
+    :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
+    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
+    :param prefetch: when compressing, pre-compute this many compressed tensors in background
+    """
+
+    def __init__(self, tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float],
+                 compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
+                 part_size_bytes: int = 2 ** 20, prefetch: int = 1):
+        if not isinstance(compression_type, Sequence):
+            compression_type = [compression_type] * len(tensors)
+        assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
+        self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
+        self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
+        self.total_size = sum(tensor.numel() for tensor in tensors)
+        self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
+        self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
+        self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
+        self._outputs_consumed = False
+        self.finished = asyncio.Event()
+        self.num_parts_by_tensor = []
+
+        # split tensor parts in proportion to target_size_by_peer
+        current_length = 0
+        current_peer_index = 0
+        pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
+        pivots[-1] = self.total_size
+
+        for tensor, tensor_compression in zip(self.local_tensors, compression_type):
+            part_size_values = int(part_size_bytes / get_nbytes_per_value(tensor.dtype, tensor_compression))
+            tensor_parts = tensor.detach().view(-1).split(part_size_values)
+            self.num_parts_by_tensor.append(len(tensor_parts))
+            for part in tensor_parts:
+                if current_length + len(part) > pivots[current_peer_index]:
+                    # switch to next peer; if a part lands between parts of two or
+                    # more peers, assign that part to the peer with highest intersection
+                    prev_peer_index = current_peer_index
+                    peer_intersections = [pivots[current_peer_index] - current_length]
+                    while current_length + len(part) > pivots[current_peer_index]:
+                        current_peer_index += 1
+                        current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
+                        peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
+                    assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
+                    self._input_parts_by_peer[assigned_peer_index].append((part, tensor_compression))
+                else:
+                    self._input_parts_by_peer[current_peer_index].append((part, tensor_compression))
+                current_length += len(part)
+
+        assert current_length == self.total_size
+        self.num_parts_by_peer = tuple(len(parts) for parts in self._input_parts_by_peer)
+
+    @torch.no_grad()
+    def get_raw_input_parts(self, peer_index: int) -> Tuple[torch.Tensor, ...]:
+        """ get non-serialized tensor parts for a peer at a given index """
+        assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
+        self._inputs_consumed_by_peer[peer_index] = True
+        input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
+        self._input_parts_by_peer[peer_index].clear()
+        return input_parts
+
+    @torch.no_grad()
+    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
+        """ iterate serialized tensor parts for a peer at a given index. Run serialization in background. """
+        assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
+        self._inputs_consumed_by_peer[peer_index] = True
+
+        async def _aiterate_parts():
+            for _ in range(self.num_parts_by_peer[peer_index]):
+                yield self._input_parts_by_peer[peer_index].popleft()
+
+        async for serialized_part in amap_in_executor(lambda x_and_compr: serialize_torch_tensor(*x_and_compr),
+                                                      _aiterate_parts(), max_prefetch=self.prefetch):
+            yield serialized_part
+
+    def register_processed_part(self, peer_index: int, part_index: int, part: torch.Tensor):
+        """
+        register next-in-line part of results received from a given peer for use in iterate_output_tensors
+        depending on the algorithm, processed part is an average, difference from average or another aggregation
+        """
+        if part_index != self._outputs_registered_by_peer[peer_index]:
+            raise ValueError(f"Could not register part #{part_index} from peer #{peer_index}, "
+                             f" expected part index: {self._outputs_registered_by_peer[peer_index]}")
+        self._output_parts_by_peer[peer_index].append(part)
+        self._outputs_registered_by_peer[peer_index] += 1
+        self._output_part_available[peer_index].set()
+
+    async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
+        """ iterate over the outputs of averaging (whether they are average, delta or other aggregation result) """
+        assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
+        self._outputs_consumed = True
+        peer_index = num_parts_processed = 0
+        for tensor_index in range(len(self.local_tensors)):
+            tensor_parts = []
+            while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
+                if num_parts_processed >= self.num_parts_by_peer[peer_index]:
+                    num_parts_processed = 0
+                    peer_index += 1
+                    continue
+                if not self._output_parts_by_peer[peer_index]:
+                    self._output_part_available[peer_index].clear()
+                    await self._output_part_available[peer_index].wait()
+                    if self.finished.is_set():
+                        raise AllreduceException("All-reduce was terminated during iteration.")
+
+                tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
+                num_parts_processed += 1
+            tensor = torch.cat(tensor_parts)
+            del tensor_parts
+            yield tensor.reshape(self.local_tensors[tensor_index].shape)
+
+    def __del__(self):
+        self.finalize()
+
+    def finalize(self):
+        """ terminate all iterators, delete intermediate data """
+        if not self.finished.is_set():
+            for peer_index in range(self.group_size):
+                self._inputs_consumed_by_peer[peer_index] = True
+                self._input_parts_by_peer[peer_index].clear()
+                self._output_parts_by_peer[peer_index].clear()
+                self._output_part_available[peer_index].set()
+            self._outputs_consumed = True
+            self.finished.set()
+
+
+class TensorPartReducer:
+    """
+    Auxiliary data structure responsible for running asynchronous all-reduce
+    :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
+    :param num_senders: total number of peers in a given all-reduce group that will send gradients
+    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
+    :note: even if local peer is not sending data, local parts will be used for shape information
+    """
+
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int,
+                 weights: Optional[Sequence[float]] = None):
+        self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
+        self.weights = tuple(weights or (1 for _ in range(num_senders)))
+        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
+        assert all(isinstance(weight, (int, float)) for weight in self.weights)
+        self.current_part_index = -1  # index in local_parts of the part that should be loaded next
+        self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
+        self.accumulator = None  # this will contain the sum of current tensor part from group peers
+        self.denominator = 0.0  # total weight accumulated from all peers for current part
+        self.current_part_future = asyncio.Future()
+        self.finished = asyncio.Event()
+        self.reset_accumulators()
+
+    def reset_accumulators(self):
+        """ (re)create averaging buffers for the next part in line, prepopulate with local tensor part """
+        assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
+        if self.current_part_index >= self.num_parts - 1:
+            self.finalize()
+            return
+
+        self.current_part_index += 1
+        self.current_part_accumulated_from = 0
+        self.current_part_future = asyncio.Future()
+        self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
+        self.denominator = 0.0
+
+    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
+        assert 0 <= sender_index < self.num_senders, "invalid sender index"
+        assert 0 <= part_index < self.num_parts, "invalid part index"
+
+        while part_index > self.current_part_index:
+            # wait for previous parts to finish processing ...
+            await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
+            if self.finished.is_set():
+                raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
+        assert part_index == self.current_part_index
+
+        current_part_future = self.current_part_future
+
+        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
+        self.denominator += self.weights[sender_index]
+        self.current_part_accumulated_from += 1
+
+        assert self.current_part_accumulated_from <= self.num_senders
+        if self.current_part_accumulated_from == self.num_senders:
+            current_part_future.set_result(self.accumulator.div_(self.denominator))
+            self.reset_accumulators()
+        return await current_part_future
+
+    def finalize(self):
+        if not self.finished.is_set():
+            if hasattr(self, 'current_part_future'):
+                self.current_part_future.cancel()
+                del self.accumulator
+            self.finished.set()
+
+    def __del__(self):
+        self.finalize()
+
+
+class AllreduceException(Exception):
+    """ A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error) """

+ 37 - 3
hivemind/optim/collaborative.py

@@ -191,7 +191,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
-            self.performance_ema.update(num_processed=self.batch_size_per_step)
+            self.performance_ema.update(num_processed=batch_size)
             self.should_report_progress.set()
 
         if not self.collaboration_state.ready_for_step:
@@ -232,9 +232,43 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.collaboration_state_updated.set()
             self.update_scheduler()
 
-            logger.log(self.status_loglevel, f"Optimizer step: done!")
+        logger.log(self.status_loglevel, f"Optimizer step: done!")
 
-            return group_info
+        return group_info
+
+    def step_aux(self, **kwargs):
+        """
+        Find and assist other peers in averaging without sending local gradients.
+
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+
+        if not self.collaboration_state.ready_for_step:
+            return
+
+        logger.log(self.status_loglevel,
+                   f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
+        self.collaboration_state = self.fetch_collaboration_state()
+        self.collaboration_state_updated.set()
+
+        with self.lock_collaboration_state:
+            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
+            current_step, group_info = self.averager.local_step, None
+            try:
+                group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
+                if group_info:
+                    logger.log(self.status_loglevel,
+                               f"Averaged tensors successfully with {len(group_info)} peers")
+            except BaseException as e:
+                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
+
+            self.collaboration_state.register_step(current_step + 1)
+            self.averager.local_step = current_step + 1
+            self.collaboration_state_updated.set()
+
+        logger.log(self.status_loglevel, f"Optimizer step: done!")
+
+        return group_info
 
     def _grad_buffers(self) -> Iterator[torch.Tensor]:
         """ pytorch-internal gradient buffers """

+ 1 - 0
hivemind/p2p/__init__.py

@@ -0,0 +1 @@
+from hivemind.p2p.p2p_daemon import P2P

+ 377 - 0
hivemind/p2p/p2p_daemon.py

@@ -0,0 +1,377 @@
+import asyncio
+from copy import deepcopy
+from dataclasses import dataclass
+from importlib.resources import path
+from subprocess import Popen
+from typing import List, Optional
+
+import google.protobuf
+from multiaddr import Multiaddr
+
+import hivemind.hivemind_cli as cli
+import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, StreamInfo
+from hivemind.proto import p2pd_pb2
+from hivemind.utils import MSGPackSerializer
+from hivemind.utils.logging import get_logger
+from hivemind.utils.networking import find_open_port
+
+logger = get_logger(__name__)
+
+
+P2PD_FILENAME = 'p2pd'
+NUM_RETRIES = 3
+RETRY_DELAY = 0.4
+
+
+class P2PInterruptedError(Exception):
+    pass
+
+
+@dataclass(frozen=False)
+class P2PContext(object):
+    id: str
+    port: int
+    handle_name: str
+    peer_id: PeerID = None
+    peer_addr: Multiaddr = None
+
+
+class P2P:
+    """
+    Forks a child process and executes p2pd command with given arguments.
+    Can be used for peer to peer communication and procedure calls.
+    Sends SIGKILL to the child in destructor.
+    """
+
+    HEADER_LEN = 8
+    BYTEORDER = 'big'
+    PB_HEADER_LEN = 1
+    RESULT_MESSAGE = b'\x00'
+    ERROR_MESSAGE = b'\x01'
+    DHT_MODE_MAPPING = {
+        'dht': {'dht': 1},
+        'dht_server': {'dhtServer': 1},
+        'dht_client': {'dhtClient': 1},
+    }
+    FORCE_REACHABILITY_MAPPING = {
+        'public': {'forceReachabilityPublic': 1},
+        'private': {'forceReachabilityPrivate': 1},
+    }
+
+    def __init__(self):
+        self._child = None
+        self._alive = False
+        self._listen_task = None
+        self._server_stopped = asyncio.Event()
+
+    @classmethod
+    async def create(cls, *args, quic: bool = True, tls: bool = True, conn_manager: bool = True,
+                     dht_mode: str = 'dht_server', force_reachability: Optional[str] = None,
+                     nat_port_map: bool = True, auto_nat: bool = True, bootstrap: bool = False,
+                     bootstrap_peers: Optional[List[str]] = None, use_global_ipfs: bool = False, host_port: int = None,
+                     daemon_listen_port: int = None, use_relay: bool = True, use_relay_hop: bool = False,
+                     use_relay_discovery: bool = False, use_auto_relay: bool = False, relay_hop_limit: int = 0, **kwargs):
+        """
+        Start a new p2pd process and connect to it.
+        :param args:
+        :param quic: Enables the QUIC transport
+        :param tls: Enables TLS1.3 channel security protocol
+        :param conn_manager: Enables the Connection Manager
+        :param dht_mode: DHT mode (dht_client/dht_server/dht)
+        :param force_reachability: Force reachability mode (public/private)
+        :param nat_port_map: Enables NAT port mapping
+        :param auto_nat: Enables the AutoNAT service
+        :param bootstrap: Connects to bootstrap peers and bootstraps the dht if enabled
+        :param bootstrap_peers: List of bootstrap peers; defaults to the IPFS DHT peers
+        :param use_global_ipfs: Bootstrap to global ipfs (works only if bootstrap=True and bootstrap_peers=None)
+        :param host_port: port for p2p network
+        :param daemon_listen_port: port for connection daemon and client binding
+        :param use_relay: enables circuit relay
+        :param use_relay_hop: enables hop for relay
+        :param use_relay_discovery: enables passive discovery for relay
+        :param use_auto_relay: enables autorelay
+        :param relay_hop_limit: sets the hop limit for hop relays
+        :param kwargs:
+        :return: new wrapper for p2p daemon
+        """
+
+        assert not (bootstrap and bootstrap_peers is None and not use_global_ipfs), \
+            'Trying to create with bootstrap node without bootstrap nodes list. ' \
+            'It is very dangerous, because p2pd connects to global ipfs and it is very unstable. ' \
+            'If you really want this, pass use_global_ipfs=True'
+        assert not (bootstrap_peers is not None and use_global_ipfs), \
+            'Non empty bootstrap_nodes and use_global_ipfs=True are incompatible.' \
+            'Choose one option: your nodes list (preferable) or global ipfs (very unstable)'
+
+        self = cls()
+        with path(cli, P2PD_FILENAME) as p:
+            p2pd_path = p
+        bootstrap_peers = cls._make_bootstrap_peers(bootstrap_peers)
+        dht = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0})
+        force_reachability = cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {})
+        proc_args = self._make_process_args(
+            str(p2pd_path), *args,
+            quic=quic, tls=tls, connManager=conn_manager,
+            natPortMap=nat_port_map, autonat=auto_nat,
+            relay=use_relay, relayHop=use_relay_hop, relayDiscovery=use_relay_discovery,
+            autoRelay=use_auto_relay, relayHopLimit=relay_hop_limit,
+            b=bootstrap, **{**bootstrap_peers, **dht, **force_reachability, **kwargs})
+        self._assign_daemon_ports(host_port, daemon_listen_port)
+
+        for try_count in range(NUM_RETRIES):
+            try:
+                self._initialize(proc_args)
+                await self._wait_for_client(RETRY_DELAY * (2 ** try_count))
+                break
+            except Exception as e:
+                logger.debug(f"Failed to initialize p2p daemon: {e}")
+                self._terminate()
+                if try_count == NUM_RETRIES - 1:
+                    raise
+                self._assign_daemon_ports()
+
+        return self
+
+    @classmethod
+    async def replicate(cls, daemon_listen_port: int, host_port: int):
+        """
+        Connect to existing p2p daemon
+        :param daemon_listen_port: port for connection daemon and client binding
+        :param host_port: port for p2p network
+        :return: new wrapper for existing p2p daemon
+        """
+
+        self = cls()
+        # There is no child under control
+        # Use external already running p2pd
+        self._child = None
+        self._alive = True
+        self._assign_daemon_ports(host_port, daemon_listen_port)
+        self._client_listen_port = find_open_port()
+        self._client = p2pclient.Client(
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
+        await self._wait_for_client()
+        return self
+
+    async def wait_for_at_least_n_peers(self, n_peers, attempts=3, delay=1):
+        for _ in range(attempts):
+            peers = await self._client.list_peers()
+            if len(peers) >= n_peers:
+                return
+            await asyncio.sleep(delay)
+
+        raise RuntimeError('Not enough peers')
+
+    def _initialize(self, proc_args: List[str]) -> None:
+        proc_args = deepcopy(proc_args)
+        proc_args.extend(self._make_process_args(
+            hostAddrs=f'/ip4/0.0.0.0/tcp/{self._host_port},/ip4/0.0.0.0/udp/{self._host_port}/quic',
+            listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'
+        ))
+        self._child = Popen(args=proc_args, encoding="utf8")
+        self._alive = True
+        self._client_listen_port = find_open_port()
+        self._client = p2pclient.Client(
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
+
+    async def _wait_for_client(self, delay=0):
+        await asyncio.sleep(delay)
+        encoded = await self._client.identify()
+        self.id = encoded[0].to_base58()
+
+    def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
+        if host_port is None:
+            host_port = find_open_port()
+        if daemon_listen_port is None:
+            daemon_listen_port = find_open_port()
+            while daemon_listen_port == host_port:
+                daemon_listen_port = find_open_port()
+
+        self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
+
+    @staticmethod
+    async def send_raw_data(byte_str, writer):
+        request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
+        writer.write(request)
+
+    @staticmethod
+    async def send_msgpack(data, writer):
+        raw_data = MSGPackSerializer.dumps(data)
+        await P2P.send_raw_data(raw_data, writer)
+
+    @staticmethod
+    async def send_protobuf(protobuf, out_proto_type, writer):
+        if type(protobuf) != out_proto_type:
+            raise TypeError('Unary handler returned protobuf of wrong type.')
+        if out_proto_type == p2pd_pb2.RPCError:
+            await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
+        else:
+            await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer)
+
+        await P2P.send_raw_data(protobuf.SerializeToString(), writer)
+
+    @staticmethod
+    async def receive_raw_data(reader: asyncio.StreamReader, header_len=HEADER_LEN):
+        header = await reader.readexactly(header_len)
+        content_length = int.from_bytes(header, P2P.BYTEORDER)
+        data = await reader.readexactly(content_length)
+        return data
+
+    @staticmethod
+    async def receive_msgpack(reader):
+        return MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
+
+    @staticmethod
+    async def receive_protobuf(in_proto_type, reader):
+        msg_type = await P2P.receive_raw_data(reader)
+        if msg_type == P2P.RESULT_MESSAGE:
+            protobuf = in_proto_type()
+            protobuf.ParseFromString(await P2P.receive_raw_data(reader))
+            return protobuf, None
+        elif msg_type == P2P.ERROR_MESSAGE:
+            protobuf = p2pd_pb2.RPCError()
+            protobuf.ParseFromString(await P2P.receive_raw_data(reader))
+            return None, protobuf
+        else:
+            raise TypeError('Invalid Protobuf message type')
+
+    @staticmethod
+    def _handle_stream(handle):
+        async def do_handle_stream(stream_info, reader, writer):
+            try:
+                request = await P2P.receive_raw_data(reader)
+            except asyncio.IncompleteReadError:
+                logger.debug("Incomplete read while receiving request from peer")
+                writer.close()
+                return
+            try:
+                result = handle(request)
+                await P2P.send_raw_data(result, writer)
+            finally:
+                writer.close()
+
+        return do_handle_stream
+
+    @staticmethod
+    def _handle_unary_stream(handle, context, in_proto_type, out_proto_type):
+        async def watchdog(reader: asyncio.StreamReader):
+            await reader.read(n=1)
+            raise P2PInterruptedError()
+
+        async def do_handle_unary_stream(
+                stream_info: StreamInfo,
+                reader: asyncio.StreamReader,
+                writer: asyncio.StreamWriter) -> None:
+            try:
+                try:
+                    request = await P2P.receive_protobuf(in_proto_type, reader)
+                except asyncio.IncompleteReadError:
+                    logger.debug("Incomplete read while receiving request from peer")
+                    return
+                except google.protobuf.message.DecodeError as error:
+                    logger.exception(error)
+                    return
+
+                context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr
+                done, pending = await asyncio.wait([watchdog(reader), handle(request, context)],
+                                                   return_when=asyncio.FIRST_COMPLETED)
+                try:
+                    result = done.pop().result()
+                    await P2P.send_protobuf(result, out_proto_type, writer)
+                except P2PInterruptedError:
+                    pass
+                except Exception as exc:
+                    error = p2pd_pb2.RPCError(message=str(exc))
+                    await P2P.send_protobuf(error, p2pd_pb2.RPCError, writer)
+                finally:
+                    pending_task = pending.pop()
+                    pending_task.cancel()
+                    try:
+                        await pending_task
+                    except asyncio.CancelledError:
+                        pass
+            finally:
+                writer.close()
+
+        return do_handle_unary_stream
+
+    def start_listening(self):
+        async def listen():
+            async with self._client.listen():
+                await self._server_stopped.wait()
+
+        self._listen_task = asyncio.create_task(listen())
+
+    async def stop_listening(self):
+        if self._listen_task is not None:
+            self._server_stopped.set()
+            self._listen_task.cancel()
+            try:
+                await self._listen_task
+            except asyncio.CancelledError:
+                self._listen_task = None
+                self._server_stopped.clear()
+
+    async def add_stream_handler(self, name, handle):
+        if self._listen_task is None:
+            self.start_listening()
+        await self._client.stream_handler(name, self._handle_stream(handle))
+
+    async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
+        if self._listen_task is None:
+            self.start_listening()
+        context = P2PContext(id=self.id, port=self._host_port, handle_name=name)
+        await self._client.stream_handler(
+            name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
+
+    async def call_peer_handler(self, peer_id, handler_name, input_data):
+        libp2p_peer_id = PeerID.from_base58(peer_id)
+        stream_info, reader, writer = await self._client.stream_open(libp2p_peer_id, (handler_name,))
+        try:
+            await P2P.send_raw_data(input_data, writer)
+            return await P2P.receive_raw_data(reader)
+        finally:
+            writer.close()
+
+    def __del__(self):
+        self._terminate()
+
+    @property
+    def is_alive(self):
+        return self._alive
+
+    async def shutdown(self):
+        await asyncio.get_event_loop().run_in_executor(None, self._terminate)
+
+    def _terminate(self):
+        self._alive = False
+        if self._child is not None and self._child.poll() is None:
+            self._child.kill()
+            self._child.wait()
+
+    @staticmethod
+    def _make_process_args(*args, **kwargs) -> List[str]:
+        proc_args = []
+        proc_args.extend(
+            str(entry) for entry in args
+        )
+        proc_args.extend(
+            f'-{key}={P2P._convert_process_arg_type(value)}' if value is not None else f'-{key}'
+            for key, value in kwargs.items()
+        )
+        return proc_args
+
+    @staticmethod
+    def _convert_process_arg_type(val):
+        if isinstance(val, bool):
+            return 1 if val else 0
+        return val
+
+    @staticmethod
+    def _make_bootstrap_peers(nodes):
+        if nodes is None:
+            return {}
+        return {'bootstrapPeers': ','.join(nodes)}

+ 0 - 0
hivemind/p2p/p2p_daemon_bindings/__init__.py


+ 210 - 0
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -0,0 +1,210 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import asyncio
+from contextlib import asynccontextmanager
+from typing import (AsyncIterator, Awaitable, Callable, Dict, Iterable,
+                    Sequence, Tuple)
+
+from multiaddr import Multiaddr, protocols
+
+from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo,
+                                                             StreamInfo)
+from hivemind.p2p.p2p_daemon_bindings.utils import (DispatchFailure,
+                                                    raise_if_failed,
+                                                    read_pbmsg_safe,
+                                                    write_pbmsg)
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+from hivemind.utils.logging import get_logger
+
+StreamHandler = Callable[[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]
+
+SUPPORT_CONN_PROTOCOLS = (
+    protocols.P_IP4,
+    # protocols.P_IP6,
+    protocols.P_UNIX,
+)
+SUPPORTED_PROTOS = (
+    protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS
+)
+logger = get_logger(__name__)
+
+
+def parse_conn_protocol(maddr: Multiaddr) -> int:
+    proto_codes = set(proto.code for proto in maddr.protocols())
+    proto_cand = proto_codes.intersection(SUPPORT_CONN_PROTOCOLS)
+    if len(proto_cand) != 1:
+        raise ValueError(
+            f"connection protocol should be only one protocol out of {SUPPORTED_PROTOS}"
+            f", maddr={maddr}"
+        )
+    return tuple(proto_cand)[0]
+
+
+class DaemonConnector:
+    DEFAULT_CONTROL_MADDR = "/unix/tmp/p2pd.sock"
+
+    def __init__(self, control_maddr: Multiaddr = Multiaddr(DEFAULT_CONTROL_MADDR)) -> None:
+        self.control_maddr = control_maddr
+        self.proto_code = parse_conn_protocol(self.control_maddr)
+
+    async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        if self.proto_code == protocols.P_UNIX:
+            control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX)
+            logger.debug(f"DaemonConnector {self} opens connection to {self.control_maddr}")
+            return await asyncio.open_unix_connection(control_path)
+        elif self.proto_code == protocols.P_IP4:
+            host = self.control_maddr.value_for_protocol(protocols.P_IP4)
+            port = int(self.control_maddr.value_for_protocol(protocols.P_TCP))
+            return await asyncio.open_connection(host, port)
+        else:
+            raise ValueError(
+                f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}"
+            )
+
+
+class ControlClient:
+    DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
+
+    def __init__(
+            self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+    ) -> None:
+        self.listen_maddr = listen_maddr
+        self.daemon_connector = daemon_connector
+        self.handlers: Dict[str, StreamHandler] = {}
+
+    async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
+        pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
+        await read_pbmsg_safe(reader, pb_stream_info)
+        stream_info = StreamInfo.from_protobuf(pb_stream_info)
+        logger.debug(f"New incoming stream: {stream_info}")
+        try:
+            handler = self.handlers[stream_info.proto]
+        except KeyError as e:
+            # should never enter here... daemon should reject the stream for us.
+            writer.close()
+            raise DispatchFailure(e)
+        await handler(stream_info, reader, writer)
+
+    @asynccontextmanager
+    async def listen(self) -> AsyncIterator["ControlClient"]:
+        proto_code = parse_conn_protocol(self.listen_maddr)
+        if proto_code == protocols.P_UNIX:
+            listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
+            server = await asyncio.start_unix_server(self._handler, path=listen_path)
+        elif proto_code == protocols.P_IP4:
+            host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
+            port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
+            server = await asyncio.start_server(self._handler, port=port, host=host)
+        else:
+            raise ValueError(
+                f"Protocol not supported: {protocols.protocol_with_code(proto_code)}"
+            )
+
+        async with server:
+            logger.info(f"DaemonConnector {self} starts listening to {self.listen_maddr}")
+            yield self
+
+        logger.info(f"DaemonConnector {self} closed")
+
+    async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
+        reader, writer = await self.daemon_connector.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+
+        raise_if_failed(resp)
+        peer_id_bytes = resp.identify.id
+        maddrs_bytes = resp.identify.addrs
+
+        maddrs = tuple(Multiaddr(maddr_bytes) for maddr_bytes in maddrs_bytes)
+        peer_id = PeerID(peer_id_bytes)
+
+        return peer_id, maddrs
+
+    async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        maddrs_bytes = [i.to_bytes() for i in maddrs]
+        connect_req = p2pd_pb.ConnectRequest(
+            peer=peer_id.to_bytes(), addrs=maddrs_bytes
+        )
+        req = p2pd_pb.Request(type=p2pd_pb.Request.CONNECT, connect=connect_req)
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+    async def list_peers(self) -> Tuple[PeerInfo, ...]:
+        req = p2pd_pb.Request(type=p2pd_pb.Request.LIST_PEERS)
+        reader, writer = await self.daemon_connector.open_connection()
+        await write_pbmsg(writer, req)
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+        peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers)
+        return peers
+
+    async def disconnect(self, peer_id: PeerID) -> None:
+        disconnect_req = p2pd_pb.DisconnectRequest(peer=peer_id.to_bytes())
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req
+        )
+        reader, writer = await self.daemon_connector.open_connection()
+        await write_pbmsg(writer, req)
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+    async def stream_open(
+        self, peer_id: PeerID, protocols: Sequence[str]
+    ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        stream_open_req = p2pd_pb.StreamOpenRequest(
+            peer=peer_id.to_bytes(), proto=list(protocols)
+        )
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req
+        )
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        raise_if_failed(resp)
+
+        pb_stream_info = resp.streamInfo
+        stream_info = StreamInfo.from_protobuf(pb_stream_info)
+
+        return stream_info, reader, writer
+
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        listen_path_maddr_bytes = self.listen_maddr.to_bytes()
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(
+            addr=listen_path_maddr_bytes, proto=[proto]
+        )
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req
+        )
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+        # if success, add the handler to the dict
+        self.handlers[proto] = handler_cb

+ 170 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -0,0 +1,170 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import hashlib
+from typing import Any, Sequence, Union
+
+import base58
+import multihash
+from multiaddr import Multiaddr, protocols
+
+from hivemind.proto import p2pd_pb2
+
+# NOTE: On inlining...
+# See: https://github.com/libp2p/specs/issues/138
+# NOTE: enabling to be interoperable w/ the Go implementation
+ENABLE_INLINING = True
+MAX_INLINE_KEY_LENGTH = 42
+
+IDENTITY_MULTIHASH_CODE = 0x00
+
+if ENABLE_INLINING:
+
+    class IdentityHash:
+        def __init__(self) -> None:
+            self._digest = bytearray()
+
+        def update(self, input: bytes) -> None:
+            self._digest += input
+
+        def digest(self) -> bytes:
+            return self._digest
+
+    multihash.FuncReg.register(
+        IDENTITY_MULTIHASH_CODE, "identity", hash_new=IdentityHash
+    )
+
+
+class PeerID:
+    def __init__(self, peer_id_bytes: bytes) -> None:
+        self._bytes = peer_id_bytes
+        self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
+        self._b58_str = base58.b58encode(self._bytes).decode()
+
+    @property
+    def xor_id(self) -> int:
+        return self._xor_id
+
+    def to_bytes(self) -> bytes:
+        return self._bytes
+
+    def to_base58(self) -> str:
+        return self._b58_str
+
+    def __repr__(self) -> str:
+        return f"<libp2p.peer.id.ID ({self.to_base58()})>"
+
+    def __str__(self):
+        return self.to_base58()
+
+    def pretty(self):
+        return self.to_base58()
+
+    def to_string(self):
+        return self.to_base58()
+
+    def __eq__(self, other: object) -> bool:
+        if isinstance(other, str):
+            return self.to_base58() == other
+        elif isinstance(other, bytes):
+            return self._bytes == other
+        elif isinstance(other, PeerID):
+            return self._bytes == other._bytes
+        else:
+            return False
+
+    def __hash__(self) -> int:
+        return hash(self._bytes)
+
+    @classmethod
+    def from_base58(cls, base58_id: str) -> "PeerID":
+        peer_id_bytes = base58.b58decode(base58_id)
+        return cls(peer_id_bytes)
+
+
+def sha256_digest(data: Union[str, bytes]) -> bytes:
+    if isinstance(data, str):
+        data = data.encode("utf8")
+    return hashlib.sha256(data).digest()
+
+
+class StreamInfo:
+    def __init__(self, peer_id: PeerID, addr: Multiaddr, proto: str) -> None:
+        self.peer_id = peer_id
+        self.addr = addr
+        self.proto = proto
+
+    def __repr__(self) -> str:
+        return (
+            f"<StreamInfo peer_id={self.peer_id} addr={self.addr} proto={self.proto}>"
+        )
+
+    def to_protobuf(self) -> p2pd_pb2.StreamInfo:
+        pb_msg = p2pd_pb2.StreamInfo(
+            peer=self.peer_id.to_bytes(), addr=self.addr.to_bytes(), proto=self.proto
+        )
+        return pb_msg
+
+    @classmethod
+    def from_protobuf(cls, pb_msg: p2pd_pb2.StreamInfo) -> "StreamInfo":
+        stream_info = cls(
+            peer_id=PeerID(pb_msg.peer), addr=Multiaddr(pb_msg.addr), proto=pb_msg.proto
+        )
+        return stream_info
+
+
+class PeerInfo:
+    def __init__(self, peer_id: PeerID, addrs: Sequence[Multiaddr]) -> None:
+        self.peer_id = peer_id
+        self.addrs = list(addrs)
+
+    def __eq__(self, other: Any) -> bool:
+        return (
+            isinstance(other, PeerInfo)
+            and self.peer_id == other.peer_id
+            and self.addrs == other.addrs
+        )
+
+    @classmethod
+    def from_protobuf(cls, peer_info_pb: p2pd_pb2.PeerInfo) -> "PeerInfo":
+        peer_id = PeerID(peer_info_pb.id)
+        addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
+        return PeerInfo(peer_id, addrs)
+
+    def __str__(self):
+        return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
+
+
+class InvalidAddrError(ValueError):
+    pass
+
+
+def info_from_p2p_addr(addr: Multiaddr) -> PeerInfo:
+    if addr is None:
+        raise InvalidAddrError("`addr` should not be `None`")
+
+    parts = addr.split()
+    if not parts:
+        raise InvalidAddrError(
+            f"`parts`={parts} should at least have a protocol `P_P2P`"
+        )
+
+    p2p_part = parts[-1]
+    last_protocol_code = p2p_part.protocols()[0].code
+    if last_protocol_code != protocols.P_P2P:
+        raise InvalidAddrError(
+            f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`"
+        )
+
+    # make sure the /p2p value parses as a peer.ID
+    peer_id_str: str = p2p_part.value_for_protocol(protocols.P_P2P)
+    peer_id = PeerID.from_base58(peer_id_str)
+
+    # we might have received just an / p2p part, which means there's no addr.
+    if len(parts) > 1:
+        addr = Multiaddr.join(*parts[:-1])
+
+    return PeerInfo(peer_id, [addr])

+ 85 - 0
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -0,0 +1,85 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import asyncio
+from contextlib import asynccontextmanager
+from typing import AsyncIterator, Iterable, Sequence, Tuple
+
+from multiaddr import Multiaddr
+
+from hivemind.p2p.p2p_daemon_bindings.control import (ControlClient,
+                                                      DaemonConnector,
+                                                      StreamHandler)
+from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo,
+                                                             StreamInfo)
+
+
+class Client:
+    control: ControlClient
+
+    def __init__(
+        self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None
+    ) -> None:
+        daemon_connector = DaemonConnector(control_maddr=control_maddr)
+        self.control = ControlClient(
+            daemon_connector=daemon_connector, listen_maddr=listen_maddr
+        )
+
+    @asynccontextmanager
+    async def listen(self) -> AsyncIterator["Client"]:
+        """
+        Starts to listen incoming connections for handlers registered via stream_handler.
+        :return:
+        """
+        async with self.control.listen():
+            yield self
+
+    async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
+        """
+        Get current node peer id and list of addresses
+        """
+        return await self.control.identify()
+
+    async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
+        """
+        Connect to p2p node with specified addresses and peer id.
+        :peer_id: node peer id you want connect to
+        :maddrs: node multiaddresses you want connect to. Of course, it must be reachable.
+        """
+        await self.control.connect(peer_id=peer_id, maddrs=maddrs)
+
+    async def list_peers(self) -> Tuple[PeerInfo, ...]:
+        """
+        Get list of peers that node connect to
+        """
+        return await self.control.list_peers()
+
+    async def disconnect(self, peer_id: PeerID) -> None:
+        """
+        Disconnect from node with specified peer id
+        :peer_id: node peer id you want disconnect from
+        """
+        await self.control.disconnect(peer_id=peer_id)
+
+    async def stream_open(
+        self, peer_id: PeerID, protocols: Sequence[str]
+    ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
+        """
+        Open a stream to call other peer (with peer_id) handler for specified protocols
+        :peer_id: other peer id
+        :protocols: list of protocols for other peer handling
+        :return: Returns tuple of stream info (info about connection to second peer) and reader/writer
+        """
+        return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
+
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+        """
+        Register a stream handler
+        :param proto: protocols that handler serves
+        :param handler_cb: handler callback
+        :return:
+        """
+        await self.control.stream_handler(proto=proto, handler_cb=handler_cb)

+ 73 - 0
hivemind/p2p/p2p_daemon_bindings/utils.py

@@ -0,0 +1,73 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import asyncio
+
+from google.protobuf.message import Message as PBMessage
+
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+
+DEFAULT_MAX_BITS: int = 64
+
+
+class ControlFailure(Exception):
+    pass
+
+
+class DispatchFailure(Exception):
+    pass
+
+
+async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_bits: int = DEFAULT_MAX_BITS) -> None:
+    max_int = 1 << max_bits
+    if integer < 0:
+        raise ValueError(f"negative integer: {integer}")
+    if integer >= max_int:
+        raise ValueError(f"integer too large: {integer}")
+    while True:
+        value = integer & 0x7F
+        integer >>= 7
+        if integer != 0:
+            value |= 0x80
+        byte = value.to_bytes(1, "big")
+        stream.write(byte)
+        if integer == 0:
+            break
+
+
+async def read_unsigned_varint(stream: asyncio.StreamReader, max_bits: int = DEFAULT_MAX_BITS) -> int:
+    max_int = 1 << max_bits
+    iteration = 0
+    result = 0
+    has_next = True
+    while has_next:
+        data = await stream.readexactly(1)
+        c = data[0]
+        value = c & 0x7F
+        result |= value << (iteration * 7)
+        has_next = (c & 0x80) != 0
+        iteration += 1
+        if result >= max_int:
+            raise ValueError(f"Varint overflowed: {result}")
+    return result
+
+
+def raise_if_failed(response: p2pd_pb.Response) -> None:
+    if response.type == p2pd_pb.Response.ERROR:
+        raise ControlFailure(f"Connect failed. msg={response.error.msg}")
+
+
+async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None:
+    size = pbmsg.ByteSize()
+    await write_unsigned_varint(stream, size)
+    msg_bytes: bytes = pbmsg.SerializeToString()
+    stream.write(msg_bytes)
+
+
+async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None:
+    len_msg_bytes = await read_unsigned_varint(stream)
+    msg_bytes = await stream.readexactly(len_msg_bytes)
+    pbmsg.ParseFromString(msg_bytes)

+ 1 - 1
hivemind/proto/averaging.proto

@@ -43,7 +43,7 @@ message MessageFromLeader {
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
-  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endoints
+  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endpoints
 }
 
 message AveragingData {

+ 166 - 0
hivemind/proto/p2pd.proto

@@ -0,0 +1,166 @@
+//Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+//Licence: MIT
+//Author: Kevin Mai-Husan Chia
+
+syntax = "proto2";
+
+package p2pclient.p2pd.pb;
+
+message Request {
+  enum Type {
+    IDENTIFY       = 0;
+    CONNECT        = 1;
+    STREAM_OPEN    = 2;
+    STREAM_HANDLER = 3;
+    DHT            = 4;
+    LIST_PEERS     = 5;
+    CONNMANAGER    = 6;
+    DISCONNECT     = 7;
+    PUBSUB         = 8;
+  }
+
+  required Type type = 1;
+
+  optional ConnectRequest connect = 2;
+  optional StreamOpenRequest streamOpen = 3;
+  optional StreamHandlerRequest streamHandler = 4;
+  optional DHTRequest dht = 5;
+  optional ConnManagerRequest connManager = 6;
+  optional DisconnectRequest disconnect = 7;
+  optional PSRequest pubsub = 8;
+}
+
+message Response {
+  enum Type {
+    OK    = 0;
+    ERROR = 1;
+  }
+
+  required Type type = 1;
+  optional ErrorResponse error = 2;
+  optional StreamInfo streamInfo = 3;
+  optional IdentifyResponse identify = 4;
+  optional DHTResponse dht = 5;
+  repeated PeerInfo peers = 6;
+  optional PSResponse pubsub = 7;
+}
+
+message IdentifyResponse {
+  required bytes id = 1;
+  repeated bytes addrs = 2;
+}
+
+message ConnectRequest {
+  required bytes peer = 1;
+  repeated bytes addrs = 2;
+  optional int64 timeout = 3;
+}
+
+message StreamOpenRequest {
+  required bytes peer = 1;
+  repeated string proto = 2;
+  optional int64 timeout = 3;
+}
+
+message StreamHandlerRequest {
+  required bytes addr = 1;
+  repeated string proto = 2;
+}
+
+message ErrorResponse {
+  required string msg = 1;
+}
+
+message StreamInfo {
+  required bytes peer = 1;
+  required bytes addr = 2;
+  required string proto = 3;
+}
+
+message DHTRequest {
+  enum Type {
+    FIND_PEER                    = 0;
+    FIND_PEERS_CONNECTED_TO_PEER = 1;
+    FIND_PROVIDERS               = 2;
+    GET_CLOSEST_PEERS            = 3;
+    GET_PUBLIC_KEY               = 4;
+    GET_VALUE                    = 5;
+    SEARCH_VALUE                 = 6;
+    PUT_VALUE                    = 7;
+    PROVIDE                      = 8;
+  }
+
+  required Type type = 1;
+  optional bytes peer = 2;
+  optional bytes cid = 3;
+  optional bytes key = 4;
+  optional bytes value = 5;
+  optional int32 count = 6;
+  optional int64 timeout = 7;
+}
+
+message DHTResponse {
+  enum Type {
+    BEGIN = 0;
+    VALUE = 1;
+    END   = 2;
+  }
+
+  required Type type = 1;
+  optional PeerInfo peer = 2;
+  optional bytes value = 3;
+}
+
+message PeerInfo {
+  required bytes id = 1;
+  repeated bytes addrs = 2;
+}
+
+message ConnManagerRequest {
+  enum Type {
+    TAG_PEER        = 0;
+    UNTAG_PEER      = 1;
+    TRIM            = 2;
+  }
+
+  required Type type = 1;
+
+  optional bytes peer = 2;
+  optional string tag = 3;
+  optional int64 weight = 4;
+}
+
+message DisconnectRequest {
+  required bytes peer = 1;
+}
+
+message PSRequest {
+  enum Type {
+    GET_TOPICS = 0;
+    LIST_PEERS = 1;
+    PUBLISH    = 2;
+    SUBSCRIBE  = 3;
+  }
+
+  required Type type = 1;
+  optional string topic = 2;
+  optional bytes data = 3;
+}
+
+message PSMessage {
+  optional bytes from_id = 1;
+  optional bytes data = 2;
+  optional bytes seqno = 3;
+  repeated string topicIDs = 4;
+  optional bytes signature = 5;
+  optional bytes key = 6;
+}
+
+message PSResponse {
+  repeated string topics = 1;
+  repeated bytes peerIDs = 2;
+}
+
+message RPCError {
+  required string message = 1;
+}

+ 1 - 1
hivemind/server/runtime.py

@@ -118,7 +118,7 @@ class Runtime(threading.Thread):
         with DefaultSelector() as selector:
             for pool in self.pools:
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
-            # selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
+            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
 
             while True:
                 # wait until at least one batch_receiver becomes available

+ 49 - 1
hivemind/utils/asyncio.py

@@ -1,7 +1,14 @@
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable
+from concurrent.futures import ThreadPoolExecutor
+from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable, Tuple, Optional, Callable
 import asyncio
+
 import uvloop
+
+from hivemind.utils.logging import get_logger
+
+
 T = TypeVar('T')
+logger = get_logger(__name__)
 
 
 def switch_to_uvloop() -> asyncio.AbstractEventLoop:
@@ -27,6 +34,16 @@ async def aiter(*args: T) -> AsyncIterator[T]:
         yield arg
 
 
+async def azip(*iterables: AsyncIterable[T]) -> AsyncIterator[Tuple[T, ...]]:
+    """ equivalent of zip for asynchronous iterables """
+    iterators = [iterable.__aiter__() for iterable in iterables]
+    while True:
+        try:
+            yield tuple(await asyncio.gather(*(itr.__anext__() for itr in iterators)))
+        except StopAsyncIteration:
+            break
+
+
 async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
     """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
     for aiter in async_iters:
@@ -34,6 +51,14 @@ async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
             yield elem
 
 
+async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]]:
+    """ equivalent to enumerate(iter) for asynchronous iterators. """
+    index = 0
+    async for elem in aiterable:
+        yield index, elem
+        index += 1
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable
@@ -42,3 +67,26 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
         return True
     except BaseException:
         return False
+
+
+async def amap_in_executor(func: Callable[..., T], *iterables: AsyncIterable, max_prefetch: Optional[int] = None,
+                           executor: Optional[ThreadPoolExecutor] = None) -> AsyncIterator[T]:
+    """ iterate from an async iterable in a background thread, yield results to async iterable """
+    loop = asyncio.get_event_loop()
+    queue = asyncio.Queue(max_prefetch)
+
+    async def _put_items():
+        async for args in azip(*iterables):
+            await queue.put(loop.run_in_executor(executor, func, *args))
+        await queue.put(None)
+
+    task = asyncio.create_task(_put_items())
+    try:
+        future = await queue.get()
+        while future is not None:
+            yield await future
+            future = await queue.get()
+        await task
+    finally:
+        if not task.done():
+            task.cancel()

+ 14 - 1
hivemind/utils/compression.py

@@ -8,7 +8,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.threading import run_in_background
 
-FP16_MAX = 65_504
+FP32_EPS = 1e-06
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT16 = 2
 NUM_BITS_QUANTILE_COMPRESSION = 8
@@ -86,6 +86,7 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
         tensor.sub_(means)
 
         stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
+        stds.clamp_min_(FP32_EPS)
         tensor.div_(stds)
         tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
 
@@ -187,3 +188,15 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
 
     tensor.requires_grad_(serialized_tensor.requires_grad)
     return tensor
+
+
+def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
+    """ returns the number of bytes per value for a given tensor (excluding metadata) """
+    if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
+        return 1
+    elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):
+        return 2
+    elif compression == CompressionType.NONE:
+        return torch.finfo(dtype).bits // 8
+    else:
+        raise NotImplementedError(f"Unknown compression type: {CompressionType.Name(compression)}")

+ 5 - 1
hivemind/utils/grpc.py

@@ -158,7 +158,11 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
 
-def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int) -> Iterator[runtime_pb2.Tensor]:
+STREAMING_CHUNK_SIZE_BYTES = 2 ** 16
+
+
+def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
+                        ) -> Iterator[runtime_pb2.Tensor]:
     """ Split serialized_tensor into multiple chunks for gRPC streaming """
     buffer = memoryview(serialized_tensor.buffer)
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))

+ 1 - 1
hivemind/utils/threading.py

@@ -12,7 +12,7 @@ def run_in_background(func: callable, *args, **kwargs) -> Future:
     """ run func(*args, **kwargs) in background and return Future for its outputs """
     global EXECUTOR_PID, GLOBAL_EXECUTOR
     if os.getpid() != EXECUTOR_PID:
-        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=float(os.environ.get("HIVEMIND_THREADS", 'inf')))
+        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("HIVEMIND_THREADS", 128)))
         EXECUTOR_PID = os.getpid()
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
 

+ 1 - 0
requirements-dev.txt

@@ -1,6 +1,7 @@
 pytest
 pytest-forked
 pytest-asyncio
+pytest-cov
 codecov
 tqdm
 scikit-learn

+ 2 - 0
requirements.txt

@@ -10,5 +10,7 @@ grpcio>=1.33.2
 grpcio-tools>=1.33.2
 protobuf>=3.12.2
 configargparse>=1.2.3
+multiaddr>=0.0.9
+pymultihash>=0.8.2
 cryptography>=3.4.6
 pydantic>=1.8.1

+ 80 - 12
setup.py

@@ -1,12 +1,32 @@
 import codecs
 import glob
+import hashlib
 import os
 import re
-
-from pkg_resources import parse_requirements
-from setuptools import setup, find_packages
+import shlex
+import subprocess
+import tarfile
+import tempfile
+import urllib.request
+
+from pkg_resources import parse_requirements, parse_version
+from setuptools import find_packages, setup
+from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
-from setuptools.command.install import install
+
+P2PD_VERSION = 'v0.3.1'
+P2PD_CHECKSUM = '15292b880c6b31f5b3c36084b3acc17f'
+LIBP2P_TAR_URL = f'https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz'
+
+here = os.path.abspath(os.path.dirname(__file__))
+
+
+def md5(fname, chunk_size=4096):
+    hash_md5 = hashlib.md5()
+    with open(fname, "rb") as f:
+        for chunk in iter(lambda: f.read(chunk_size), b""):
+            hash_md5.update(chunk)
+    return hash_md5.hexdigest()
 
 
 def proto_compile(output_path):
@@ -28,20 +48,68 @@ def proto_compile(output_path):
             file.truncate()
 
 
-class ProtoCompileInstall(install):
+def build_p2p_daemon():
+    result = subprocess.run("go version", capture_output=True, shell=True).stdout.decode('ascii', 'replace')
+    m = re.search(r'^go version go([\d.]+)', result)
+
+    if m is None:
+        raise FileNotFoundError('Could not find golang installation')
+    version = parse_version(m.group(1))
+    if version < parse_version("1.13"):
+        raise EnvironmentError(f'Newer version of go required: must be >= 1.13, found {version}')
+
+    with tempfile.TemporaryDirectory() as tempdir:
+        dest = os.path.join(tempdir, 'libp2p-daemon.tar.gz')
+        urllib.request.urlretrieve(LIBP2P_TAR_URL, dest)
+
+        with tarfile.open(dest, 'r:gz') as tar:
+            tar.extractall(tempdir)
+
+        result = subprocess.run(f'go build -o {shlex.quote(os.path.join(here, "hivemind", "hivemind_cli", "p2pd"))}',
+                                cwd=os.path.join(tempdir, f'go-libp2p-daemon-{P2PD_VERSION[1:]}', 'p2pd'), shell=True)
+
+        if result.returncode:
+            raise RuntimeError('Failed to build or install libp2p-daemon:'
+                               f' exited with status code: {result.returncode}')
+
+
+def download_p2p_daemon():
+    install_path = os.path.join(here, 'hivemind', 'hivemind_cli')
+    binary_path = os.path.join(install_path, 'p2pd')
+    if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
+        print('Downloading Peer to Peer Daemon')
+        url = f'https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd'
+        urllib.request.urlretrieve(url, binary_path)
+        os.chmod(binary_path, 0o777)
+        if md5(binary_path) != P2PD_CHECKSUM:
+            raise RuntimeError(f'Downloaded p2pd binary from {url} does not match with md5 checksum')
+
+
+class BuildPy(build_py):
+    user_options = build_py.user_options + [('buildgo', None, "Builds p2pd from source")]
+
+    def initialize_options(self):
+        super().initialize_options()
+        self.buildgo = False
+
     def run(self):
-        proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto'))
+        if self.buildgo:
+            build_p2p_daemon()
+        else:
+            download_p2p_daemon()
+
         super().run()
 
+        proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto'))
+
 
-class ProtoCompileDevelop(develop):
+class Develop(develop):
     def run(self):
-        proto_compile(os.path.join('hivemind', 'proto'))
+        self.reinitialize_command('build_py', build_lib=here)
+        self.run_command('build_py')
         super().run()
 
 
-here = os.path.abspath(os.path.dirname(__file__))
-
 with open('requirements.txt') as requirements_file:
     install_requires = list(map(str, parse_requirements(requirements_file)))
 
@@ -63,7 +131,7 @@ extras['all'] = extras['dev'] + extras['docs']
 setup(
     name='hivemind',
     version=version_string,
-    cmdclass={'install': ProtoCompileInstall, 'develop': ProtoCompileDevelop},
+    cmdclass={'build_py': BuildPy, 'develop': Develop},
     description='Decentralized deep learning in PyTorch',
     long_description='Decentralized deep learning in PyTorch. Built to train giant models on '
                      'thousands of volunteers across the world.',
@@ -71,7 +139,7 @@ setup(
     author_email='mryabinin0@gmail.com',
     url="https://github.com/learning-at-home/hivemind",
     packages=find_packages(exclude=['tests']),
-    package_data={'hivemind': ['proto/*']},
+    package_data={'hivemind': ['proto/*', 'hivemind_cli/*']},
     include_package_data=True,
     license='MIT',
     setup_requires=['grpcio-tools'],

+ 217 - 0
tests/test_allreduce.py

@@ -0,0 +1,217 @@
+import asyncio
+import random
+import time
+from typing import Sequence
+
+import pytest
+import torch
+import grpc
+
+from hivemind import aenumerate, Endpoint
+from hivemind.client.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer
+from hivemind.utils import deserialize_torch_tensor, ChannelCache
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.proto import averaging_pb2_grpc
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning():
+    all_tensors = [
+        torch.randn(30_000, 128), torch.rand(128), torch.ones(1, 1, 1, 1, 1, 1, 8),
+        torch.ones(1, 0), torch.zeros(0), torch.zeros([]), torch.randn(65536),
+        torch.rand(512, 2048), torch.randn(1024, 1024).add(-9), torch.zeros(1020), torch.randn(4096)
+    ]
+
+    # note: this test does _not_ use parameterization to reuse sampled tensors
+    for num_tensors in 1, 3, 5:
+        for part_size_bytes in 31337, 2 ** 20, 10 ** 10:
+            for weights in [(1, 1), (0.333, 0.1667, 0.5003), (1.0, 0.0), [0.0, 0.4, 0.6, 0.0]]:
+                tensors = random.choices(all_tensors, k=num_tensors)
+                partition = TensorPartContainer(tensors, weights, part_size_bytes=part_size_bytes)
+
+                async def write_tensors():
+                    for peer_index in range(partition.group_size):
+                        async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+                            output_tensor = torch.sin(deserialize_torch_tensor(part))
+                            partition.register_processed_part(peer_index, part_index, output_tensor)
+
+                task = asyncio.create_task(write_tensors())
+                tensor_index = 0
+                async for output_tensor in partition.iterate_output_tensors():
+                    assert torch.allclose(output_tensor, torch.sin(tensors[tensor_index]))
+                    tensor_index += 1
+                assert tensor_index == len(tensors)
+                await task
+
+
+@pytest.mark.parametrize("tensors", [[torch.zeros(0)], [torch.zeros(0), torch.zeros(0), torch.zeros(1)],
+                                     [torch.zeros(0), torch.zeros(999), torch.zeros(0), torch.zeros(0)]])
+@pytest.mark.parametrize("peer_fractions", [(0.33, 0.44, 0.23), (0.5, 0.5), (0.1, 0.0, 0.9), (1.0,), (0.1,) * 9])
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning_edge_cases(tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float]):
+    partition = TensorPartContainer(tensors, peer_fractions, part_size_bytes=16)
+    for peer_index in range(len(peer_fractions)):
+        async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+            partition.register_processed_part(peer_index, part_index, deserialize_torch_tensor(part))
+
+    tensor_index = 0
+    async for output_tensor in partition.iterate_output_tensors():
+        assert torch.allclose(output_tensor, tensors[tensor_index])
+        tensor_index += 1
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning_asynchronous():
+    """ ensure that tensor partitioning does not interfere with asynchronous code """
+    tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096),
+               torch.randn(4096, 1024), torch.randn(30_000, 1024)]
+    peer_fractions = [0.4, 0.3, 0.2, 0.1]
+
+    partition = TensorPartContainer(tensors, peer_fractions, compression_type=CompressionType.QUANTILE_8BIT)
+    read_started, read_finished = asyncio.Event(), asyncio.Event()
+
+    async def write_tensors():
+        for peer_index in range(partition.group_size):
+            async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+                partition.register_processed_part(peer_index, part_index, deserialize_torch_tensor(part))
+        assert read_started.is_set(), "partitioner should have started reading before it finished writing"
+
+    async def read_tensors():
+        async for _ in partition.iterate_output_tensors():
+            read_started.set()
+        read_finished.set()
+
+    async def wait_synchronously():
+        time_in_waiting = 0.0
+        while not read_finished.is_set():
+            await asyncio.sleep(0.01)
+            time_in_waiting += 0.01
+        return time_in_waiting
+
+    start_time = time.perf_counter()
+    *_, time_in_waiting = await asyncio.gather(write_tensors(), read_tensors(), wait_synchronously())
+    wall_time = time.perf_counter() - start_time
+    # check that event loop had enough time to respond to incoming requests; this is over 50% most of the time
+    # we set 33% threshold to ensure that the test will pass reliably. If we break prefetch, this drops to <10%
+    assert time_in_waiting > wall_time / 3, f"Event loop could only run {time_in_waiting / wall_time :.5f} of the time"
+
+
+@pytest.mark.parametrize("num_senders", [1, 2, 4, 10])
+@pytest.mark.parametrize("num_parts", [0, 1, 100])
+@pytest.mark.parametrize("synchronize_prob", [1.0, 0.1, 0.0])
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
+    tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
+    reducer = TensorPartReducer(tensor_part_shapes, num_senders)
+
+    local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)]
+                               for j in range(num_senders)]
+
+    async def send_tensors(sender_index: int):
+        local_tensors = local_tensors_by_sender[sender_index]
+        averaged_parts = []
+        pending_tasks = []
+
+        for part_index in range(num_parts):
+            pending_tasks.append(asyncio.create_task(
+                reducer.accumulate_part(sender_index, part_index, local_tensors[part_index])))
+
+            if random.random() < synchronize_prob or part_index == num_parts - 1:
+                averaged_parts.extend(await asyncio.gather(*pending_tasks))
+                pending_tasks = []
+        return averaged_parts
+
+    averaged_tensors_by_peer = await asyncio.gather(*map(send_tensors, range(num_senders)))
+
+    reference = [sum(local_tensors_by_sender[sender_index][part_index]
+                     for sender_index in range(num_senders)) / num_senders
+                 for part_index in range(num_parts)]
+
+    for averaged_tensors in averaged_tensors_by_peer:
+        assert len(averaged_tensors) == len(reference)
+        for averaging_result, reference_tensor in zip(averaged_tensors, reference):
+            assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
+
+
+class AllreduceRunnerForTesting(AllReduceRunner):
+    """ a version of AllReduceRunner that was monkey-patched to accept custom endpoint names """
+    def __init__(self, *args, peer_endpoints, **kwargs):
+        self.__peer_endpoints = peer_endpoints
+        super().__init__(*args, **kwargs)
+
+    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
+        return ChannelCache.get_stub(
+            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+
+
+NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
+
+
+@pytest.mark.parametrize("peer_modes, averaging_weights, peer_fractions", [
+    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1)),
+    ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1)),
+    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0)),
+    ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0)),
+    ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4)),
+    ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1)),
+    ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0)),
+    ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4)),
+])
+@pytest.mark.parametrize("part_size_bytes", [2 ** 20, 256, 19],)
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
+    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
+
+    peers = "alice", "bob", "carol", "colab"
+
+    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
+                       for i, peer in enumerate(peers)}
+
+    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
+
+    servers = []
+    allreduce_protocols = []
+    peer_endpoints = {}
+
+    for peer in peers:
+        server = grpc.aio.server()
+        allreduce_protocol = AllreduceRunnerForTesting(
+            group_id=group_id, endpoint=peer, tensors=[x.clone() for x in tensors_by_peer[peer]],
+            ordered_group_endpoints=peers, peer_fractions=peer_fractions, modes=peer_modes,
+            weights=averaging_weights, peer_endpoints=peer_endpoints, part_size_bytes=part_size_bytes
+        )
+        averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(allreduce_protocol, server)
+        peer_endpoints[peer] = f"127.0.0.1:{server.add_insecure_port('127.0.0.1:*')}"
+        allreduce_protocols.append(allreduce_protocol)
+        servers.append(server)
+        await server.start()
+
+    async def _run_allreduce_inplace(allreduce: AllReduceRunner):
+        async for tensor_index, tensor_delta in aenumerate(allreduce):
+            allreduce.tensor_part_container.local_tensors[tensor_index].add_(tensor_delta)
+
+    await asyncio.gather(*map(_run_allreduce_inplace, allreduce_protocols))
+
+    reference_tensors = [sum(tensors_by_peer[peer][i] * averaging_weights[peer_index]
+                             for peer_index, peer in enumerate(peers)) / sum(averaging_weights)
+                         for i in range(len(tensors_by_peer[peers[0]]))]
+
+    for peer_index, protocol in enumerate(allreduce_protocols):
+        assert protocol._future.done()
+        if protocol.modes[peer_index] != AveragingMode.AUX:
+            targets_for_peer = reference_tensors
+        else:
+            targets_for_peer = tensors_by_peer[peers[peer_index]]
+        output_tensors = protocol.tensor_part_container.local_tensors
+        assert len(output_tensors) == len(targets_for_peer)
+        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
+                   for our, ref in zip(output_tensors, targets_for_peer))
+
+    for server in servers:
+        await server.stop(grace=1)

+ 2 - 2
tests/test_auth.py

@@ -1,12 +1,12 @@
 from datetime import datetime, timedelta
-from typing import Optional, Tuple
+from typing import Optional
 
 import pytest
 
 from hivemind.proto import dht_pb2
 from hivemind.proto.auth_pb2 import AccessToken
 from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
-from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger
 
 

+ 76 - 77
tests/test_averaging.py

@@ -1,4 +1,3 @@
-import asyncio
 import random
 
 import numpy as np
@@ -6,10 +5,10 @@ import torch
 import pytest
 import time
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
+from hivemind.client.averaging.allreduce import AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.key_manager import GroupKeyManager
-from hivemind.utils import Endpoint
+from hivemind.proto.runtime_pb2 import CompressionType
 
 
 @pytest.mark.forked
@@ -42,26 +41,26 @@ async def test_key_manager():
     assert len(q5) == 0
 
 
-@pytest.mark.forked
-@pytest.mark.parametrize("n_client_mode_peers", [0, 2])
-def test_allreduce_once(n_client_mode_peers):
+def _test_allreduce_once(n_clients, n_aux):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
 
     n_peers = 4
-    should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
-    random.shuffle(should_listen)
+    modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
+    random.shuffle(modes)
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
+    peer_tensors = [tensors1, tensors2, tensors3, tensors4]
 
-    reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
+    reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
+                 if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
 
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
-                                                start=True)
-                 for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
+                                                prefix='mygroup', listen=mode != AveragingMode.CLIENT, listen_on='127.0.0.1:*',
+                                                auxiliary=mode == AveragingMode.AUX, start=True)
+                 for tensors, mode in zip(peer_tensors, modes)]
 
     futures = []
     for averager in averagers:
@@ -72,15 +71,29 @@ def test_allreduce_once(n_client_mode_peers):
             assert averager.endpoint in result
 
     for averager in averagers:
-        with averager.get_tensors() as averaged_tensors:
-            for ref, our in zip(reference, averaged_tensors):
-                assert torch.allclose(ref, our, atol=1e-6)
+        if averager.mode != AveragingMode.AUX:
+            with averager.get_tensors() as averaged_tensors:
+                for ref, our in zip(reference, averaged_tensors):
+                    assert torch.allclose(ref, our, atol=1e-6)
 
     for averager in averagers:
         averager.shutdown()
     dht.shutdown()
 
 
+@pytest.mark.forked
+@pytest.mark.parametrize("n_clients", [0, 1, 2])
+@pytest.mark.parametrize("n_aux", [0, 1, 2])
+def test_allreduce_once(n_clients, n_aux):
+    _test_allreduce_once(n_clients, n_aux)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
+def test_allreduce_once_edge_cases(n_clients, n_aux):
+    _test_allreduce_once(n_clients, n_aux)
+
+
 @pytest.mark.forked
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
@@ -117,6 +130,47 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht.shutdown()
 
 
+@pytest.mark.forked
+def test_allreduce_compression():
+    """ this test ensures that compression works correctly when multiple tensors have different compression types """
+    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+
+    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
+    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
+    results = {}
+
+    FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
+
+    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        averager1 = hivemind.DecentralizedAverager([x.clone() for x in tensors1], dht=dht,
+                                                   compression_type=compression_type_pair, listen=False,
+                                                   target_group_size=2, prefix='mygroup', start=True)
+        averager2 = hivemind.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
+                                                   compression_type=compression_type_pair,
+                                                   target_group_size=2, prefix='mygroup', start=True)
+
+        for future in averager1.step(wait=False), averager2.step(wait=False):
+            future.result()
+
+        with averager1.get_tensors() as averaged_tensors:
+            results[compression_type_pair] = averaged_tensors
+
+    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
+    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
+    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
+    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
+
+    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
+    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
+    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
+    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
+
+    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
+    for i in range(2):
+        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
+        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
+
+
 def compute_mean_std(averagers, unbiased=True):
     results = []
     for averager in averagers:
@@ -188,68 +242,6 @@ def test_allgather():
     dht.shutdown()
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_allreduce_protocol():
-    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
-    peers = "alice", "bob", "carol", "colab"
-
-    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
-                       for i, peer in enumerate(peers)}
-
-    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
-    allreduce_protocols = [AllReduceProtocol(
-        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
-        ordered_group_endpoints=peers, part_sizes=(150, 200, 67, 0))
-        for peer in peers]
-
-    async def _accumulate(sender: Endpoint, recipient: Endpoint):
-        sender_allreduce = allreduce_protocols[peers.index(sender)]
-        recipient_allreduce = allreduce_protocols[peers.index(recipient)]
-        averaged_part = await recipient_allreduce.accumulate_part(
-            source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
-        sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
-
-    await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
-                        if recipient != "colab"})
-
-    reference_tensors = [
-        sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
-        for i in range(len(tensors_by_peer[peers[0]]))
-    ]
-
-    for peer, allreduce in zip(peers, allreduce_protocols):
-        assert allreduce.future.done()
-        averaged_tensors = await allreduce
-        assert len(averaged_tensors) == len(reference_tensors)
-        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
-                   for our, ref in zip(averaged_tensors, reference_tensors))
-
-
-@pytest.mark.forked
-def test_partitioning():
-    for _ in range(100):
-        tensors = []
-        for _ in range(random.randint(1, 5)):
-            ndim = random.randint(0, 4)
-            shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
-            make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
-            tensors.append(make_tensor(shape))
-
-        total_size = sum(map(torch.Tensor.numel, tensors))
-        if total_size == 0:
-            continue
-        num_chunks = random.randint(1, min(100, sum(x.numel() for x in tensors)))
-        part_sizes = load_balance_peers(total_size, [None] * num_chunks)
-        chunks = split_into_parts(tensors, part_sizes)
-        assert len(chunks) == num_chunks
-        shapes = [tensor.shape for tensor in tensors]
-        restored = restore_from_parts(chunks, shapes)
-        assert len(restored) == len(tensors)
-        assert all(new.shape == old.shape for new, old in zip(restored, tensors))
-        assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
-
-
 def get_cost(vector_size, partitions, throughputs):
     return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
                for i in range(len(partitions)))
@@ -370,6 +362,13 @@ def test_load_state_from_peers():
     assert got_metadata == super_metadata
     assert all(map(torch.allclose, got_tensors, super_tensors))
 
+    averager1.allow_state_sharing = False
+    assert averager2.load_state_from_peers() is None
+    averager1.allow_state_sharing = True
+    got_metadata, got_tensors = averager2.load_state_from_peers()
+    assert num_calls == 3
+    assert got_metadata == super_metadata
+
 
 @pytest.mark.forked
 def test_getset_bits():

+ 2 - 4
tests/test_dht_schema.py

@@ -1,13 +1,11 @@
-import re
-
 import pytest
-from pydantic import BaseModel, StrictFloat, StrictInt, conint
+from pydantic import BaseModel, StrictInt, conint
 from typing import Dict
 
 import hivemind
 from hivemind.dht import get_dht_time
 from hivemind.dht.node import DHTNode, LOCALHOST
-from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator, conbytes
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 
 

+ 1 - 2
tests/test_dht_validation.py

@@ -1,5 +1,4 @@
 import dataclasses
-from functools import partial
 from typing import Dict
 
 import pytest
@@ -10,7 +9,7 @@ from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.dht.validation import DHTRecord, CompositeValidator, RecordValidatorBase
+from hivemind.dht.validation import DHTRecord, CompositeValidator
 
 
 class SchemaA(BaseModel):

+ 440 - 0
tests/test_p2p_daemon.py

@@ -0,0 +1,440 @@
+import asyncio
+import multiprocessing as mp
+import subprocess
+from functools import partial
+from typing import List
+
+import numpy as np
+import pytest
+import torch
+
+from hivemind.p2p import P2P
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
+from hivemind.proto import dht_pb2, runtime_pb2
+from hivemind.utils import MSGPackSerializer
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+
+
+def is_process_running(pid: int) -> bool:
+    return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
+
+
+async def replicate_if_needed(p2p: P2P, replicate: bool):
+    return await P2P.replicate(p2p._daemon_listen_port, p2p._host_port) if replicate else p2p
+
+
+def bootstrap_addr(host_port, id_):
+    return f'/ip4/127.0.0.1/tcp/{host_port}/p2p/{id_}'
+
+
+def bootstrap_from(daemons: List[P2P]) -> List[str]:
+    return [bootstrap_addr(d._host_port, d.id) for d in daemons]
+
+
+@pytest.mark.asyncio
+async def test_daemon_killed_on_del():
+    p2p_daemon = await P2P.create()
+
+    child_pid = p2p_daemon._child.pid
+    assert is_process_running(child_pid)
+
+    await p2p_daemon.shutdown()
+    assert not is_process_running(child_pid)
+
+
+@pytest.mark.asyncio
+async def test_server_client_connection():
+    server = await P2P.create()
+    peers = await server._client.list_peers()
+    assert len(peers) == 0
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    await client.wait_for_at_least_n_peers(1)
+
+    peers = await client._client.list_peers()
+    assert len(peers) == 1
+    peers = await server._client.list_peers()
+    assert len(peers) == 1
+
+
+@pytest.mark.asyncio
+async def test_daemon_replica_does_not_affect_primary():
+    p2p_daemon = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon._host_port)
+
+    child_pid = p2p_daemon._child.pid
+    assert is_process_running(child_pid)
+
+    await p2p_replica.shutdown()
+    assert is_process_running(child_pid)
+
+    await p2p_daemon.shutdown()
+    assert not is_process_running(child_pid)
+
+
+def handle_square(x):
+    x = MSGPackSerializer.loads(x)
+    return MSGPackSerializer.dumps(x ** 2)
+
+
+def handle_add(args):
+    args = MSGPackSerializer.loads(args)
+    result = args[0]
+    for i in range(1, len(args)):
+        result = result + args[i]
+    return MSGPackSerializer.dumps(result)
+
+
+def handle_square_torch(x):
+    tensor = runtime_pb2.Tensor()
+    tensor.ParseFromString(x)
+    tensor = deserialize_torch_tensor(tensor)
+    result = tensor ** 2
+    return serialize_torch_tensor(result).SerializeToString()
+
+
+def handle_add_torch(args):
+    args = MSGPackSerializer.loads(args)
+    tensor = runtime_pb2.Tensor()
+    tensor.ParseFromString(args[0])
+    result = deserialize_torch_tensor(tensor)
+
+    for i in range(1, len(args)):
+        tensor = runtime_pb2.Tensor()
+        tensor.ParseFromString(args[i])
+        result = result + deserialize_torch_tensor(tensor)
+
+    return serialize_torch_tensor(result).SerializeToString()
+
+
+def handle_add_torch_with_exc(args):
+    try:
+        return handle_add_torch(args)
+    except Exception:
+        return b'something went wrong :('
+
+
+@pytest.mark.parametrize(
+    'should_cancel,replicate', [
+        (True, False),
+        (True, True),
+        (False, False),
+        (False, True),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
+    handler_cancelled = False
+
+    async def ping_handler(request, context):
+        try:
+            await asyncio.sleep(2)
+        except asyncio.CancelledError:
+            nonlocal handler_cancelled
+            handler_cancelled = True
+        return dht_pb2.PingResponse(
+            peer=dht_pb2.NodeInfo(
+                node_id=context.id.encode(), rpc_port=context.port),
+            sender_endpoint=context.handle_name, available=True)
+
+    server_primary = await P2P.create()
+    server = await replicate_if_needed(server_primary, replicate)
+    server_pid = server_primary._child.pid
+    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
+                                   dht_pb2.PingResponse)
+    assert is_process_running(server_pid)
+
+    nodes = bootstrap_from([server])
+    client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client = await replicate_if_needed(client_primary, replicate)
+    client_pid = client_primary._child.pid
+    assert is_process_running(client_pid)
+
+    ping_request = dht_pb2.PingRequest(
+        peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
+        validate=True)
+    expected_response = dht_pb2.PingResponse(
+        peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
+        sender_endpoint=handle_name, available=True)
+
+    await client.wait_for_at_least_n_peers(1)
+    libp2p_server_id = PeerID.from_base58(server.id)
+    stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
+
+    await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
+
+    if should_cancel:
+        writer.close()
+        await asyncio.sleep(1)
+        assert handler_cancelled
+    else:
+        result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
+        assert err is None
+        assert result == expected_response
+        assert not handler_cancelled
+
+    await server.stop_listening()
+    await server_primary.shutdown()
+    assert not is_process_running(server_pid)
+
+    await client_primary.shutdown()
+    assert not is_process_running(client_pid)
+
+
+@pytest.mark.asyncio
+async def test_call_unary_handler_error(handle_name="handle"):
+    async def error_handler(request, context):
+        raise ValueError('boom')
+
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
+    assert is_process_running(server_pid)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+    await client.wait_for_at_least_n_peers(1)
+
+    ping_request = dht_pb2.PingRequest(
+        peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
+        validate=True)
+    libp2p_server_id = PeerID.from_base58(server.id)
+    stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
+
+    await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
+    result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
+    assert result is None
+    assert err.message == 'boom'
+
+    await server.stop_listening()
+    await server.shutdown()
+    await client.shutdown()
+
+
+@pytest.mark.parametrize(
+    "test_input,expected,handle",
+    [
+        pytest.param(10, 100, handle_square, id="square_integer"),
+        pytest.param((1, 2), 3, handle_add, id="add_integers"),
+        pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
+        pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda")
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_stream_handler(handler_name, handle)
+    assert is_process_running(server_pid)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    test_input_msgp = MSGPackSerializer.dumps(test_input)
+    result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
+    result = MSGPackSerializer.loads(result_msgp)
+    assert result == expected
+
+    await server.stop_listening()
+    await server.shutdown()
+    assert not is_process_running(server_pid)
+
+    await client.shutdown()
+    assert not is_process_running(client_pid)
+
+
+async def run_server(handler_name, server_side, client_side, response_received):
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_stream_handler(handler_name, handle_square)
+    assert is_process_running(server_pid)
+
+    server_side.send(server.id)
+    server_side.send(server._host_port)
+    while response_received.value == 0:
+        await asyncio.sleep(0.5)
+
+    await server.stop_listening()
+    await server.shutdown()
+    assert not is_process_running(server_pid)
+
+
+def server_target(handler_name, server_side, client_side, response_received):
+    asyncio.run(run_server(handler_name, server_side, client_side, response_received))
+
+
+@pytest.mark.asyncio
+async def test_call_peer_different_processes():
+    handler_name = "square"
+    test_input = 2
+
+    server_side, client_side = mp.Pipe()
+    response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
+    response_received.value = 0
+
+    proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
+    proc.start()
+
+    peer_id = client_side.recv()
+    peer_port = client_side.recv()
+
+    nodes = [bootstrap_addr(peer_port, peer_id)]
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    test_input_msgp = MSGPackSerializer.dumps(2)
+    result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
+    result = MSGPackSerializer.loads(result_msgp)
+    assert np.allclose(result, test_input ** 2)
+    response_received.value = 1
+
+    await client.shutdown()
+    assert not is_process_running(client_pid)
+
+    proc.join()
+
+
+@pytest.mark.parametrize(
+    "test_input,expected",
+    [
+        pytest.param(torch.tensor([2]), torch.tensor(4)),
+        pytest.param(
+            torch.tensor([[1.0, 2.0], [0.5, 0.1]]),
+            torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
+    handle = handle_square_torch
+    server = await P2P.create()
+    await server.add_stream_handler(handler_name, handle)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    inp = serialize_torch_tensor(test_input).SerializeToString()
+    result_pb = await client.call_peer_handler(server.id, handler_name, inp)
+    result = runtime_pb2.Tensor()
+    result.ParseFromString(result_pb)
+    result = deserialize_torch_tensor(result)
+    assert torch.allclose(result, expected)
+
+    await server.stop_listening()
+    await server.shutdown()
+    await client.shutdown()
+
+
+@pytest.mark.parametrize(
+    "test_input,expected",
+    [
+        pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
+        pytest.param(
+            [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
+            torch.tensor([[1.2, 1.4], [1.6, 1.8]])),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
+    handle = handle_add_torch
+    server = await P2P.create()
+    await server.add_stream_handler(handler_name, handle)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
+    inp_msgp = MSGPackSerializer.dumps(inp)
+    result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
+    result = runtime_pb2.Tensor()
+    result.ParseFromString(result_pb)
+    result = deserialize_torch_tensor(result)
+    assert torch.allclose(result, expected)
+
+    await server.stop_listening()
+    await server.shutdown()
+    await client.shutdown()
+
+
+@pytest.mark.parametrize(
+    "replicate",
+    [
+        pytest.param(False, id="primary"),
+        pytest.param(True, id="replica"),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_error(replicate, handler_name="handle"):
+    server_primary = await P2P.create()
+    server = await replicate_if_needed(server_primary, replicate)
+    await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
+
+    nodes = bootstrap_from([server])
+    client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client = await replicate_if_needed(client_primary, replicate)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
+    inp_msgp = MSGPackSerializer.dumps(inp)
+    result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
+    assert result == b'something went wrong :('
+
+    await server.stop_listening()
+    await server_primary.shutdown()
+    await client_primary.shutdown()
+
+
+@pytest.mark.asyncio
+async def test_handlers_on_different_replicas(handler_name="handle"):
+    def handler(arg, key):
+        return key
+
+    server_primary = await P2P.create(bootstrap=False)
+    server_id = server_primary.id
+    await server_primary.add_stream_handler(handler_name, partial(handler, key=b'primary'))
+
+    server_replica1 = await replicate_if_needed(server_primary, True)
+    await server_replica1.add_stream_handler(handler_name + '1', partial(handler, key=b'replica1'))
+
+    server_replica2 = await replicate_if_needed(server_primary, True)
+    await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2'))
+
+    nodes = bootstrap_from([server_primary])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    await client.wait_for_at_least_n_peers(1)
+
+    result = await client.call_peer_handler(server_id, handler_name, b'1')
+    assert result == b"primary"
+
+    result = await client.call_peer_handler(server_id, handler_name + '1', b'2')
+    assert result == b"replica1"
+
+    result = await client.call_peer_handler(server_id, handler_name + '2', b'3')
+    assert result == b"replica2"
+
+    await server_replica1.stop_listening()
+    await server_replica2.stop_listening()
+
+    # Primary does not handle replicas protocols
+    with pytest.raises(Exception):
+        await client.call_peer_handler(server_id, handler_name + '1', b'')
+    with pytest.raises(Exception):
+        await client.call_peer_handler(server_id, handler_name + '2', b'')
+
+    await server_primary.stop_listening()
+    await server_primary.shutdown()
+    await client.shutdown()

+ 559 - 0
tests/test_p2p_daemon_bindings.py

@@ -0,0 +1,559 @@
+import asyncio
+import io
+from contextlib import AsyncExitStack
+
+import pytest
+from google.protobuf.message import EncodeError
+from multiaddr import Multiaddr, protocols
+
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, parse_conn_protocol
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import (ControlFailure, raise_if_failed, read_pbmsg_safe,
+                                                    read_unsigned_varint, write_pbmsg, write_unsigned_varint)
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+from test_utils import make_p2pd_pair_ip4, connect_safe
+
+
+def test_raise_if_failed_raises():
+    resp = p2pd_pb.Response()
+    resp.type = p2pd_pb.Response.ERROR
+    with pytest.raises(ControlFailure):
+        raise_if_failed(resp)
+
+
+def test_raise_if_failed_not_raises():
+    resp = p2pd_pb.Response()
+    resp.type = p2pd_pb.Response.OK
+    raise_if_failed(resp)
+
+
+PAIRS_INT_SERIALIZED_VALID = (
+    (0, b"\x00"),
+    (1, b"\x01"),
+    (128, b"\x80\x01"),
+    (2 ** 32, b"\x80\x80\x80\x80\x10"),
+    (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
+)
+
+PAIRS_INT_SERIALIZED_OVERFLOW = (
+    (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (
+        2 ** 128,
+        b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04",
+    ),
+)
+
+PEER_ID_STRING = "QmS5QmciTXXnCUCyxud5eWFenUMAmvAWSDa1c7dvdXRMZ7"
+PEER_ID_BYTES = b'\x12 7\x87F.[\xb5\xb1o\xe5*\xc7\xb9\xbb\x11:"Z|j2\x8ad\x1b\xa6\xe5<Ip\xfe\xb4\xf5v'
+PEER_ID = PeerID(PEER_ID_BYTES)
+MADDR = Multiaddr("/unix/123")
+NUM_P2PDS = 4
+PEER_ID_RANDOM = PeerID.from_base58("QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNK1")
+ENABLE_CONTROL = True
+ENABLE_CONNMGR = False
+ENABLE_DHT = False
+ENABLE_PUBSUB = False
+FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_ip4
+
+
+class MockReader(io.BytesIO):
+    async def readexactly(self, n):
+        await asyncio.sleep(0)
+        return self.read(n)
+
+
+class MockWriter(io.BytesIO):
+    pass
+
+
+class MockReaderWriter(MockReader, MockWriter):
+    pass
+
+
+@pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
+@pytest.mark.asyncio
+async def test_write_unsigned_varint(integer, serialized_integer):
+    s = MockWriter()
+    await write_unsigned_varint(s, integer)
+    assert s.getvalue() == serialized_integer
+
+
+@pytest.mark.parametrize("integer", tuple(i[0] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
+@pytest.mark.asyncio
+async def test_write_unsigned_varint_overflow(integer):
+    s = MockWriter()
+    with pytest.raises(ValueError):
+        await write_unsigned_varint(s, integer)
+
+
+@pytest.mark.parametrize("integer", (-1, -(2 ** 32), -(2 ** 64), -(2 ** 128)))
+@pytest.mark.asyncio
+async def test_write_unsigned_varint_negative(integer):
+    s = MockWriter()
+    with pytest.raises(ValueError):
+        await write_unsigned_varint(s, integer)
+
+
+@pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
+@pytest.mark.asyncio
+async def test_read_unsigned_varint(integer, serialized_integer):
+    s = MockReader(serialized_integer)
+    result = await read_unsigned_varint(s)
+    assert result == integer
+
+
+@pytest.mark.parametrize("serialized_integer", tuple(i[1] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
+@pytest.mark.asyncio
+async def test_read_unsigned_varint_overflow(serialized_integer):
+    s = MockReader(serialized_integer)
+    with pytest.raises(ValueError):
+        await read_unsigned_varint(s)
+
+
+@pytest.mark.parametrize("max_bits", (2, 31, 32, 63, 64, 127, 128))
+@pytest.mark.asyncio
+async def test_read_write_unsigned_varint_max_bits_edge(max_bits):
+    """
+    Test edge cases with different `max_bits`
+    """
+    for i in range(-3, 0):
+        integer = i + (2 ** max_bits)
+        s = MockReaderWriter()
+        await write_unsigned_varint(s, integer, max_bits=max_bits)
+        s.seek(0, 0)
+        result = await read_unsigned_varint(s, max_bits=max_bits)
+        assert integer == result
+
+
+def test_peer_id():
+    assert PEER_ID.to_bytes() == PEER_ID_BYTES
+    assert PEER_ID.to_string() == PEER_ID_STRING
+
+    peer_id_2 = PeerID.from_base58(PEER_ID_STRING)
+    assert peer_id_2.to_bytes() == PEER_ID_BYTES
+    assert peer_id_2.to_string() == PEER_ID_STRING
+    assert PEER_ID == peer_id_2
+    peer_id_3 = PeerID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
+    assert PEER_ID != peer_id_3
+
+
+def test_stream_info():
+    proto = "123"
+    si = StreamInfo(PEER_ID, MADDR, proto)
+    assert si.peer_id == PEER_ID
+    assert si.addr == MADDR
+    assert si.proto == proto
+    pb_si = si.to_protobuf()
+    assert pb_si.peer == PEER_ID.to_bytes()
+    assert pb_si.addr == MADDR.to_bytes()
+    assert pb_si.proto == si.proto
+    si_1 = StreamInfo.from_protobuf(pb_si)
+    assert si_1.peer_id == PEER_ID
+    assert si_1.addr == MADDR
+    assert si_1.proto == proto
+
+
+def test_peer_info():
+    pi = PeerInfo(PEER_ID, [MADDR])
+    assert pi.peer_id == PEER_ID
+    assert pi.addrs == [MADDR]
+    pi_pb = p2pd_pb.PeerInfo(id=PEER_ID.to_bytes(), addrs=[MADDR.to_bytes()])
+    pi_1 = PeerInfo.from_protobuf(pi_pb)
+    assert pi.peer_id == pi_1.peer_id
+    assert pi.addrs == pi_1.addrs
+
+
+@pytest.mark.parametrize(
+    "maddr_str, expected_proto",
+    (("/unix/123", protocols.P_UNIX), ("/ip4/127.0.0.1/tcp/7777", protocols.P_IP4)),
+)
+def test_parse_conn_protocol_valid(maddr_str, expected_proto):
+    assert parse_conn_protocol(Multiaddr(maddr_str)) == expected_proto
+
+
+@pytest.mark.parametrize(
+    "maddr_str",
+    (
+        "/p2p/QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP",
+        "/onion/timaq4ygg2iegci7:1234",
+    ),
+)
+def test_parse_conn_protocol_invalid(maddr_str):
+    maddr = Multiaddr(maddr_str)
+    with pytest.raises(ValueError):
+        parse_conn_protocol(maddr)
+
+
+@pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
+def test_client_ctor_control_maddr(control_maddr_str):
+    c = DaemonConnector(Multiaddr(control_maddr_str))
+    assert c.control_maddr == Multiaddr(control_maddr_str)
+
+
+def test_client_ctor_default_control_maddr():
+    c = DaemonConnector()
+    assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
+
+
+@pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
+def test_control_client_ctor_listen_maddr(listen_maddr_str):
+    c = ControlClient(
+        daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str)
+    )
+    assert c.listen_maddr == Multiaddr(listen_maddr_str)
+
+
+def test_control_client_ctor_default_listen_maddr():
+    c = ControlClient(daemon_connector=DaemonConnector())
+    assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
+
+
+@pytest.mark.parametrize(
+    "msg_bytes",
+    (
+        p2pd_pb.Response(
+            type=p2pd_pb.Response.Type.OK,
+            identify=p2pd_pb.IdentifyResponse(
+                id=PeerID.from_base58('QmT7WhTne9zBLfAgAJt9aiZ8jZ5BxJGowRubxsHYmnyzUd').to_bytes(),
+                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51126').to_bytes(),
+                       Multiaddr('/ip4/192.168.10.135/tcp/51126').to_bytes(),
+                       Multiaddr('/ip6/::1/tcp/51127').to_bytes()]
+            )).SerializeToString(),
+        p2pd_pb.Response(
+            type=p2pd_pb.Response.Type.OK,
+            identify=p2pd_pb.IdentifyResponse(
+                id=PeerID.from_base58('QmcQFt2MFfCZ9AxzUCNrk4k7TtMdZZvAAteaA6tHpBKdrk').to_bytes(),
+                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51493').to_bytes(),
+                       Multiaddr('/ip4/192.168.10.135/tcp/51493').to_bytes(),
+                       Multiaddr('/ip6/::1/tcp/51494').to_bytes()]
+            )).SerializeToString(),
+        p2pd_pb.Response(
+            type=p2pd_pb.Response.Type.OK,
+            identify=p2pd_pb.IdentifyResponse(
+                id=PeerID.from_base58('QmbWqVVoz7v9LS9ZUQAhyyfdFJY3iU8ZrUY3XQozoTA5cc').to_bytes(),
+                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51552').to_bytes(),
+                       Multiaddr('/ip4/192.168.10.135/tcp/51552').to_bytes(),
+                       Multiaddr('/ip6/::1/tcp/51553').to_bytes()]
+            )).SerializeToString(),
+    ),
+    # give test cases ids to prevent bytes from ruining the terminal
+    ids=("pb example Response 0", "pb example Response 1", "pb example Response 2"),
+)
+@pytest.mark.asyncio
+async def test_read_pbmsg_safe_valid(msg_bytes):
+    s = MockReaderWriter()
+    await write_unsigned_varint(s, len(msg_bytes))
+    s.write(msg_bytes)
+    # reset the offset back to the beginning
+    s.seek(0, 0)
+    pb_msg = p2pd_pb.Response()
+    await read_pbmsg_safe(s, pb_msg)
+    assert pb_msg.SerializeToString() == msg_bytes
+
+
+@pytest.mark.parametrize(
+    "pb_type, pb_msg",
+    (
+        (
+            p2pd_pb.Response,
+            p2pd_pb.Response(
+                type=p2pd_pb.Response.Type.OK,
+                dht=p2pd_pb.DHTResponse(
+                    type=p2pd_pb.DHTResponse.Type.VALUE,
+                    peer=p2pd_pb.PeerInfo(
+                        id=PeerID.from_base58('QmNaXUy78W9moQ9APCoKaTtPjLcEJPN9hRBCqErY7o2fQs').to_bytes(),
+                        addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56929').to_bytes(),
+                               Multiaddr('/ip4/192.168.10.135/tcp/56929').to_bytes(),
+                               Multiaddr('/ip6/::1/tcp/56930').to_bytes()]
+                    )
+                )
+            ),
+        ),
+        (p2pd_pb.Request, p2pd_pb.Request(type=p2pd_pb.Request.Type.LIST_PEERS)),
+        (
+            p2pd_pb.DHTRequest,
+            p2pd_pb.DHTRequest(type=p2pd_pb.DHTRequest.Type.FIND_PEER,
+                               peer=PeerID.from_base58('QmcgHMuEhqdLHDVeNjiCGU7Ds6E7xK3f4amgiwHNPKKn7R').to_bytes()),
+        ),
+        (
+            p2pd_pb.DHTResponse,
+            p2pd_pb.DHTResponse(
+                type=p2pd_pb.DHTResponse.Type.VALUE,
+                peer=p2pd_pb.PeerInfo(
+                    id=PeerID.from_base58('QmWP32GhEyXVQsLXFvV81eadDC8zQRZxZvJK359rXxLquk').to_bytes(),
+                    addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56897').to_bytes(),
+                           Multiaddr('/ip4/192.168.10.135/tcp/56897').to_bytes(),
+                           Multiaddr('/ip6/::1/tcp/56898').to_bytes()]
+                )
+            ),
+        ),
+        (
+            p2pd_pb.StreamInfo,
+            p2pd_pb.StreamInfo(peer=PeerID.from_base58('QmewLxB46MftfxQiunRgJo2W8nW4Lh5NLEkRohkHhJ4wW6').to_bytes(),
+                               addr=Multiaddr('/ip4/127.0.0.1/tcp/57029').to_bytes(),
+                               proto=b'protocol123'),
+        ),
+    ),
+    ids=(
+        "pb example Response",
+        "pb example Request",
+        "pb example DHTRequest",
+        "pb example DHTResponse",
+        "pb example StreamInfo",
+    ),
+)
+@pytest.mark.asyncio
+async def test_write_pbmsg(pb_type, pb_msg):
+    msg_bytes = bytes(chr(pb_msg.ByteSize()), 'utf-8') + pb_msg.SerializeToString()
+    pb_obj = pb_type()
+
+    s_read = MockReaderWriter(msg_bytes)
+    await read_pbmsg_safe(s_read, pb_obj)
+    s_write = MockReaderWriter()
+    await write_pbmsg(s_write, pb_obj)
+    assert msg_bytes == s_write.getvalue()
+
+
+@pytest.mark.parametrize(
+    "pb_msg",
+    (
+        p2pd_pb.Response(),
+        p2pd_pb.Request(),
+        p2pd_pb.DHTRequest(),
+        p2pd_pb.DHTResponse(),
+        p2pd_pb.StreamInfo(),
+    ),
+)
+@pytest.mark.asyncio
+async def test_write_pbmsg_missing_fields(pb_msg):
+    with pytest.raises(EncodeError):
+        await write_pbmsg(MockReaderWriter(), pb_msg)
+
+
+@pytest.fixture
+async def p2pcs():
+    # TODO: Change back to gather style
+    async with AsyncExitStack() as stack:
+        p2pd_tuples = [
+            await stack.enter_async_context(
+                FUNC_MAKE_P2PD_PAIR(
+                    enable_control=ENABLE_CONTROL,
+                    enable_connmgr=ENABLE_CONNMGR,
+                    enable_dht=ENABLE_DHT,
+                    enable_pubsub=ENABLE_PUBSUB,
+                )
+            )
+            for _ in range(NUM_P2PDS)
+        ]
+        yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
+
+
+@pytest.mark.asyncio
+async def test_client_identify_unix_socket(p2pcs):
+    await p2pcs[0].identify()
+
+
+@pytest.mark.asyncio
+async def test_client_identify(p2pcs):
+    await p2pcs[0].identify()
+
+
+@pytest.mark.asyncio
+async def test_client_connect_success(p2pcs):
+    peer_id_0, maddrs_0 = await p2pcs[0].identify()
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await p2pcs[0].connect(peer_id_1, maddrs_1)
+    # test case: repeated connections
+    await p2pcs[1].connect(peer_id_0, maddrs_0)
+
+
+@pytest.mark.asyncio
+async def test_client_connect_failure(p2pcs):
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await p2pcs[0].identify()
+    # test case: `peer_id` mismatches
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(PEER_ID_RANDOM, maddrs_1)
+    # test case: empty maddrs
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(peer_id_1, [])
+    # test case: wrong maddrs
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(peer_id_1, [Multiaddr("/ip4/127.0.0.1/udp/0")])
+
+
+@pytest.mark.asyncio
+async def test_connect_safe(p2pcs):
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+
+@pytest.mark.asyncio
+async def test_client_list_peers(p2pcs):
+    # test case: no peers
+    assert len(await p2pcs[0].list_peers()) == 0
+    # test case: 1 peer
+    await connect_safe(p2pcs[0], p2pcs[1])
+    assert len(await p2pcs[0].list_peers()) == 1
+    assert len(await p2pcs[1].list_peers()) == 1
+    # test case: one more peer
+    await connect_safe(p2pcs[0], p2pcs[2])
+    assert len(await p2pcs[0].list_peers()) == 2
+    assert len(await p2pcs[1].list_peers()) == 1
+    assert len(await p2pcs[2].list_peers()) == 1
+
+
+@pytest.mark.asyncio
+async def test_client_disconnect(p2pcs):
+    # test case: disconnect a peer without connections
+    await p2pcs[1].disconnect(PEER_ID_RANDOM)
+    # test case: disconnect
+    peer_id_0, _ = await p2pcs[0].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+    assert len(await p2pcs[0].list_peers()) == 1
+    assert len(await p2pcs[1].list_peers()) == 1
+    await p2pcs[1].disconnect(peer_id_0)
+    assert len(await p2pcs[0].list_peers()) == 0
+    assert len(await p2pcs[1].list_peers()) == 0
+    # test case: disconnect twice
+    await p2pcs[1].disconnect(peer_id_0)
+    assert len(await p2pcs[0].list_peers()) == 0
+    assert len(await p2pcs[1].list_peers()) == 0
+
+
+@pytest.mark.asyncio
+async def test_client_stream_open_success(p2pcs):
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    async def handle_proto(stream_info, reader, writer):
+        await reader.readexactly(1)
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+
+    # test case: normal
+    stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
+    assert stream_info.peer_id == peer_id_1
+    assert stream_info.addr in maddrs_1
+    assert stream_info.proto == "123"
+    writer.close()
+
+    # test case: open with multiple protocols
+    stream_info, reader, writer = await p2pcs[0].stream_open(
+        peer_id_1, (proto, "another_protocol")
+    )
+    assert stream_info.peer_id == peer_id_1
+    assert stream_info.addr in maddrs_1
+    assert stream_info.proto == "123"
+    writer.close()
+
+
+@pytest.mark.asyncio
+async def test_client_stream_open_failure(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    # test case: `stream_open` to a peer who didn't register the protocol
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, (proto,))
+
+    # test case: `stream_open` to a peer for a non-registered protocol
+    async def handle_proto(stream_info, reader, writer):
+        pass
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, ("another_protocol",))
+
+
+@pytest.mark.asyncio
+async def test_client_stream_handler_success(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "protocol123"
+    bytes_to_send = b"yoyoyoyoyog"
+    # event for this test function to wait until the handler function receiving the incoming data
+    event_handler_finished = asyncio.Event()
+
+    async def handle_proto(stream_info, reader, writer):
+        nonlocal event_handler_finished
+        bytes_received = await reader.readexactly(len(bytes_to_send))
+        assert bytes_received == bytes_to_send
+        event_handler_finished.set()
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+    assert proto in p2pcs[1].control.handlers
+    assert handle_proto == p2pcs[1].control.handlers[proto]
+
+    # test case: test the stream handler `handle_proto`
+
+    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
+
+    # wait until the handler function starts blocking waiting for the data
+    # because we haven't sent the data, we know the handler function must still blocking waiting.
+    # get the task of the protocol handler
+    writer.write(bytes_to_send)
+
+    # wait for the handler to finish
+    writer.close()
+
+    await event_handler_finished.wait()
+
+    # test case: two streams to different handlers respectively
+    another_proto = "another_protocol123"
+    another_bytes_to_send = b"456"
+    event_another_proto = asyncio.Event()
+
+    async def handle_another_proto(stream_info, reader, writer):
+        event_another_proto.set()
+        bytes_received = await reader.readexactly(len(another_bytes_to_send))
+        assert bytes_received == another_bytes_to_send
+
+    await p2pcs[1].stream_handler(another_proto, handle_another_proto)
+    assert another_proto in p2pcs[1].control.handlers
+    assert handle_another_proto == p2pcs[1].control.handlers[another_proto]
+
+    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (another_proto,))
+    await event_another_proto.wait()
+
+    # we know at this moment the handler must still blocking wait
+
+    writer.write(another_bytes_to_send)
+
+    writer.close()
+
+    # test case: registering twice can override the previous registration
+    event_third = asyncio.Event()
+
+    async def handler_third(stream_info, reader, writer):
+        event_third.set()
+
+    await p2pcs[1].stream_handler(another_proto, handler_third)
+    assert another_proto in p2pcs[1].control.handlers
+    # ensure the handler is override
+    assert handler_third == p2pcs[1].control.handlers[another_proto]
+
+    await p2pcs[0].stream_open(peer_id_1, (another_proto,))
+    # ensure the overriding handler is called when the protocol is opened a stream
+    await event_third.wait()
+
+
+@pytest.mark.asyncio
+async def test_client_stream_handler_failure(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    # test case: registered a wrong protocol name
+    async def handle_proto_correct_params(stream_info, stream):
+        pass
+
+    await p2pcs[1].stream_handler("another_protocol", handle_proto_correct_params)
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, (proto,))

+ 41 - 1
tests/test_util_modules.py

@@ -11,6 +11,7 @@ from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 import hivemind
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
 from hivemind.utils.mpfuture import FutureStateError
 
 
@@ -138,6 +139,11 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     assert error.square().mean() < beta
 
+    zeros = torch.zeros(5,5)
+    for compression_type in CompressionType.values():
+        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
+
+
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_channel_cache():
@@ -252,7 +258,7 @@ def test_split_parts():
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
         with pytest.raises(RuntimeError):
             deserialize_torch_tensor(combined)
-            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceProtocol
+            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceRunner
 
 
 def test_generic_data_classes():
@@ -267,3 +273,37 @@ def test_generic_data_classes():
     sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
     sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
     assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
+
+
+@pytest.mark.asyncio
+async def test_asyncio_utils():
+    res = [i async for i, item in aenumerate(aiter('a', 'b', 'c'))]
+    assert res == list(range(len(res)))
+
+    num_steps = 0
+    async for elem in amap_in_executor(lambda x: x ** 2, aiter(*range(100)), max_prefetch=5):
+        assert elem == num_steps ** 2
+        num_steps += 1
+    assert num_steps == 100
+
+    ours = [elem async for elem in amap_in_executor(max, aiter(*range(7)), aiter(*range(-50, 50, 10)), max_prefetch=1)]
+    ref = list(map(max, range(7), range(-50, 50, 10)))
+    assert ours == ref
+
+    ours = [row async for row in azip(aiter('a', 'b', 'c'), aiter(1, 2, 3))]
+    ref = list(zip(['a', 'b', 'c'], [1, 2, 3]))
+    assert ours == ref
+
+    async def _aiterate():
+        yield 'foo'
+        yield 'bar'
+        yield 'baz'
+
+    iterator = _aiterate()
+    assert (await anext(iterator)) == 'foo'
+    tail = [item async for item in iterator]
+    assert tail == ['bar', 'baz']
+    with pytest.raises(StopAsyncIteration):
+        await anext(iterator)
+
+    assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ['foo', 'bar', 'baz'] + list(range(5))

+ 194 - 0
tests/test_utils/__init__.py

@@ -0,0 +1,194 @@
+import asyncio
+import functools
+import os
+import subprocess
+import time
+import uuid
+from contextlib import asynccontextmanager
+from typing import NamedTuple
+from pkg_resources import resource_filename
+
+from multiaddr import Multiaddr, protocols
+
+from hivemind import find_open_port
+from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
+
+
+TIMEOUT_DURATION = 30  # seconds
+P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")
+
+
+async def try_until_success(coro_func, timeout=TIMEOUT_DURATION):
+    """
+    Keep running ``coro_func`` until the time is out.
+    All arguments of ``coro_func`` should be filled, i.e. it should be called without arguments.
+    """
+    t_start = time.monotonic()
+    while True:
+        result = await coro_func()
+        if result:
+            break
+        if (time.monotonic() - t_start) >= timeout:
+            # timeout
+            assert False, f"{coro_func} still failed after `{timeout}` seconds"
+        await asyncio.sleep(0.01)
+
+
+class Daemon:
+    control_maddr = None
+    proc_daemon = None
+    log_filename = ""
+    f_log = None
+    closed = None
+
+    def __init__(
+            self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub
+    ):
+        self.control_maddr = control_maddr
+        self.enable_control = enable_control
+        self.enable_connmgr = enable_connmgr
+        self.enable_dht = enable_dht
+        self.enable_pubsub = enable_pubsub
+        self.is_closed = False
+        self._start_logging()
+        self._run()
+
+    def _start_logging(self):
+        name_control_maddr = str(self.control_maddr).replace("/", "_").replace(".", "_")
+        self.log_filename = f"/tmp/log_p2pd{name_control_maddr}.txt"
+        self.f_log = open(self.log_filename, "wb")
+
+    def _run(self):
+        cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
+        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
+        if self.enable_connmgr:
+            cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
+        if self.enable_dht:
+            cmd_list += ["-dht=true"]
+        if self.enable_pubsub:
+            cmd_list += ["-pubsub=true", "-pubsubRouter=gossipsub"]
+        self.proc_daemon = subprocess.Popen(
+            cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0
+        )
+
+    async def wait_until_ready(self):
+        lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
+        lines_head_occurred = {line: False for line in lines_head_pattern}
+
+        with open(self.log_filename, "rb") as f_log_read:
+
+            async def read_from_daemon_and_check():
+                line = f_log_read.readline()
+                for head_pattern in lines_head_occurred:
+                    if line.startswith(head_pattern):
+                        lines_head_occurred[head_pattern] = True
+                return all([value for _, value in lines_head_occurred.items()])
+
+            await try_until_success(read_from_daemon_and_check)
+
+        # sleep for a while in case that the daemon haven't been ready after emitting these lines
+        await asyncio.sleep(0.1)
+
+    def close(self):
+        if self.is_closed:
+            return
+        self.proc_daemon.terminate()
+        self.proc_daemon.wait()
+        self.f_log.close()
+        self.is_closed = True
+
+
+class DaemonTuple(NamedTuple):
+    daemon: Daemon
+    client: Client
+
+
+class ConnectionFailure(Exception):
+    pass
+
+
+@asynccontextmanager
+async def make_p2pd_pair_unix(
+        enable_control, enable_connmgr, enable_dht, enable_pubsub
+):
+    name = str(uuid.uuid4())[:8]
+    control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
+    listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
+    # Remove the existing unix socket files if they are existing
+    try:
+        os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    try:
+        os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
+    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def _make_p2pd_pair(
+        control_maddr,
+        listen_maddr,
+        enable_control,
+        enable_connmgr,
+        enable_dht,
+        enable_pubsub,
+):
+    p2pd = Daemon(
+        control_maddr=control_maddr,
+        enable_control=enable_control,
+        enable_connmgr=enable_connmgr,
+        enable_dht=enable_dht,
+        enable_pubsub=enable_pubsub,
+    )
+    # wait for daemon ready
+    await p2pd.wait_until_ready()
+    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    try:
+        async with client.listen():
+            yield DaemonTuple(daemon=p2pd, client=client)
+    finally:
+        if not p2pd.is_closed:
+            p2pd.close()
+
+
+async def _check_connection(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_0, _ = await p2pd_tuple_0.identify()
+    peer_id_1, _ = await p2pd_tuple_1.identify()
+    peers_0 = [pinfo.peer_id for pinfo in await p2pd_tuple_0.list_peers()]
+    peers_1 = [pinfo.peer_id for pinfo in await p2pd_tuple_1.list_peers()]
+    return (peer_id_0 in peers_1) and (peer_id_1 in peers_0)
+
+
+async def connect_safe(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_1, maddrs_1 = await p2pd_tuple_1.identify()
+    await p2pd_tuple_0.connect(peer_id_1, maddrs_1)
+    await try_until_success(
+        functools.partial(
+            _check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1
+        )
+    )