Przeglądaj źródła

Merge branch 'master' into decentralized_lr_scheduler

justheuristic 4 lat temu
rodzic
commit
d1d1627578
47 zmienionych plików z 3525 dodań i 492 usunięć
  1. 0 70
      .circleci/config.yml
  2. 92 0
      .github/workflows/run-tests.yml
  3. 3 0
      .gitignore
  4. 13 3
      README.md
  5. 23 13
      benchmarks/benchmark_averaging.py
  6. 8 8
      benchmarks/benchmark_dht.py
  7. 5 1
      benchmarks/benchmark_tensor_compression.py
  8. 20 16
      benchmarks/benchmark_throughput.py
  9. 12 0
      codecov.yml
  10. 1 1
      docs/modules/client.rst
  11. 18 12
      examples/albert/README.md
  12. 11 5
      examples/albert/run_first_peer.py
  13. 7 6
      examples/albert/run_trainer.py
  14. 6 4
      examples/albert/tokenize_wikitext103.py
  15. 2 1
      hivemind/__init__.py
  16. 107 55
      hivemind/client/averaging/__init__.py
  17. 166 189
      hivemind/client/averaging/allreduce.py
  18. 1 0
      hivemind/client/averaging/load_balancing.py
  19. 1 1
      hivemind/client/averaging/matchmaking.py
  20. 224 0
      hivemind/client/averaging/partition.py
  21. 37 3
      hivemind/optim/collaborative.py
  22. 1 0
      hivemind/p2p/__init__.py
  23. 377 0
      hivemind/p2p/p2p_daemon.py
  24. 0 0
      hivemind/p2p/p2p_daemon_bindings/__init__.py
  25. 210 0
      hivemind/p2p/p2p_daemon_bindings/control.py
  26. 170 0
      hivemind/p2p/p2p_daemon_bindings/datastructures.py
  27. 85 0
      hivemind/p2p/p2p_daemon_bindings/p2pclient.py
  28. 73 0
      hivemind/p2p/p2p_daemon_bindings/utils.py
  29. 1 1
      hivemind/proto/averaging.proto
  30. 166 0
      hivemind/proto/p2pd.proto
  31. 1 1
      hivemind/server/runtime.py
  32. 49 1
      hivemind/utils/asyncio.py
  33. 14 1
      hivemind/utils/compression.py
  34. 5 1
      hivemind/utils/grpc.py
  35. 1 1
      hivemind/utils/threading.py
  36. 1 0
      requirements-dev.txt
  37. 2 0
      requirements.txt
  38. 80 12
      setup.py
  39. 217 0
      tests/test_allreduce.py
  40. 2 2
      tests/test_auth.py
  41. 76 77
      tests/test_averaging.py
  42. 2 4
      tests/test_dht_schema.py
  43. 1 2
      tests/test_dht_validation.py
  44. 440 0
      tests/test_p2p_daemon.py
  45. 559 0
      tests/test_p2p_daemon_bindings.py
  46. 41 1
      tests/test_util_modules.py
  47. 194 0
      tests/test_utils/__init__.py

+ 0 - 70
.circleci/config.yml

@@ -1,70 +0,0 @@
-version: 2.1
-
-jobs:
-  build-and-test-py37:
-    docker:
-      - image: circleci/python:3.7.10
-    steps:
-      - checkout
-      - restore_cache:
-          keys:
-            - py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-      - run: pip install -r requirements.txt
-      - run: pip install -r requirements-dev.txt
-      - save_cache:
-          key: py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-          paths:
-            - '~/.cache/pip'
-      - run:
-          command: pip install -e .
-          name: setup
-      - run:
-          command: pytest ./tests
-          name: tests
-  build-and-test-py38:
-    docker:
-      - image: circleci/python:3.8.1
-    steps:
-      - checkout
-      - restore_cache:
-          keys:
-            - py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-      - run: pip install -r requirements.txt
-      - run: pip install -r requirements-dev.txt
-      - save_cache:
-          key: py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-          paths:
-            - '~/.cache/pip'
-      - run:
-          command: pip install -e .
-          name: setup
-      - run:
-          command: pytest ./tests
-          name: tests
-  build-and-test-py39:
-    docker:
-      - image: circleci/python:3.9.1
-    steps:
-      - checkout
-      - restore_cache:
-          keys:
-            - py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-      - run: pip install -r requirements.txt
-      - run: pip install -r requirements-dev.txt
-      - save_cache:
-          key: py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
-          paths:
-            - '~/.cache/pip'
-      - run:
-          command: pip install -e .
-          name: setup
-      - run:
-          command: pytest ./tests
-          name: tests
-
-workflows:
-  main:
-    jobs:
-      - build-and-test-py37
-      - build-and-test-py38
-      - build-and-test-py39

+ 92 - 0
.github/workflows/run-tests.yml

@@ -0,0 +1,92 @@
+name: Tests
+
+on: [ push ]
+
+
+jobs:
+  run_tests:
+
+    runs-on: ubuntu-latest
+    strategy:
+      matrix:
+        python-version: [ 3.7, 3.8, 3.9 ]
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install .
+      - name: Test
+        run: |
+          cd tests
+          pytest --durations=0 --durations-min=1.0
+
+  build_and_test_p2pd:
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: '3.8'
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-3.8-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install . --global-option=build_py --global-option="--buildgo"
+      - name: Test
+        run: |
+          cd tests
+          pytest -k "p2p" 
+
+  codecov_in_develop_mode:
+
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: '3.8'
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-3.8-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install -e .
+      - name: Test
+        run: |
+          pytest --cov=hivemind tests
+      - name: Upload coverage to Codecov
+        uses: codecov/codecov-action@v1

+ 3 - 0
.gitignore

@@ -78,3 +78,6 @@ debian/files
 
 
 # protobuf stuff
 # protobuf stuff
 hivemind/proto/*_pb2*
 hivemind/proto/*_pb2*
+
+# libp2p-daemon binary
+hivemind/hivemind_cli/p2pd

+ 13 - 3
README.md

@@ -1,6 +1,6 @@
 ## Hivemind: decentralized deep learning in PyTorch
 ## Hivemind: decentralized deep learning in PyTorch
 
 
-[![Build status](https://circleci.com/gh/learning-at-home/hivemind.svg?style=shield)](https://circleci.com/gh/learning-at-home/hivemind)
+[![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions)
 [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest)
 [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest)
 [![Gitter](https://badges.gitter.im/learning-at-home/hivemind.svg)](https://gitter.im/learning-at-home/hivemind?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
 [![Gitter](https://badges.gitter.im/learning-at-home/hivemind.svg)](https://gitter.im/learning-at-home/hivemind?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
 
 
@@ -76,8 +76,18 @@ our [guide](https://learning-at-home.readthedocs.io/en/latest/user/contributing.
 
 
 ## Citation
 ## Citation
 
 
-If you found hivemind useful for your experiments, you can cite [the paper](https://arxiv.org/abs/2002.04013) that
-inspired it:
+If you found hivemind or its underlying algorithms useful for your experiments, please cite the following source:
+
+```
+@misc{hivemind,
+  author = {Learning@home team},
+  title = {{H}ivemind: a {L}ibrary for {D}ecentralized {D}eep {L}earning},
+  year = 2020,
+  howpublished = {\url{https://github.com/learning-at-home/hivemind}},
+}
+```
+
+Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired the creation of this library:
 
 
 ```
 ```
 @inproceedings{ryabinin2020crowdsourced,
 @inproceedings{ryabinin2020crowdsourced,

+ 23 - 13
benchmarks/benchmark_averaging.py

@@ -6,10 +6,13 @@ import argparse
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-from hivemind.utils import LOCALHOST, increase_file_limit
+from hivemind.utils import LOCALHOST, increase_file_limit, get_logger
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 
 
 
 
+logger = get_logger(__name__)
+
+
 def sample_tensors(hid_size, num_layers):
 def sample_tensors(hid_size, num_layers):
     tensors = []
     tensors = []
     for i in range(num_layers):
     for i in range(num_layers):
@@ -38,8 +41,11 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
     peer_tensors = [sample_tensors(hid_size, num_layers)
     peer_tensors = [sample_tensors(hid_size, num_layers)
                     for _ in range(num_peers)]
                     for _ in range(num_peers)]
     processes = {dht_root}
     processes = {dht_root}
+    lock_stats = threading.Lock()
+    successful_steps = total_steps = 0
 
 
     def run_averager(index):
     def run_averager(index):
+        nonlocal successful_steps, total_steps, lock_stats
         dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
         dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
                            initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
                            initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
                            start=True)
                            start=True)
@@ -50,11 +56,17 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
             averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
             averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
         processes.update({dht, averager})
         processes.update({dht, averager})
 
 
-        print(end=f'<started {index}>\n', flush=True)
-        for _ in range(num_rounds):
-            success = averager.step(timeout=round_timeout)
-            print(end=('+' if success else '-'), flush=True)
-        print(end=f'<finished {index}>\n', flush=True)
+        logger.info(f'Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}')
+        for step in range(num_rounds):
+            try:
+                success = averager.step(timeout=round_timeout) is not None
+            except:
+                success = False
+            with lock_stats:
+                successful_steps += int(success)
+                total_steps += 1
+            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
+        logger.info(f"Averager {index}: done.")
 
 
     threads = []
     threads = []
     for i in range(num_peers):
     for i in range(num_peers):
@@ -67,10 +79,8 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
     for thread in threads:
     for thread in threads:
         thread.join()
         thread.join()
 
 
-    print(f"\ntest run took {time.time() - t:.3f} seconds")
-
-    for process in processes:
-        process.terminate()
+    logger.info(f"Benchmark finished in {time.time() - t:.3f} seconds.")
+    logger.info(f"Success rate: {successful_steps / total_steps} ({successful_steps} out of {total_steps} attempts)")
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
@@ -80,9 +90,9 @@ if __name__ == "__main__":
     parser.add_argument('--num_rounds', type=int, default=5, required=False)
     parser.add_argument('--num_rounds', type=int, default=5, required=False)
     parser.add_argument('--hid_size', type=int, default=256, required=False)
     parser.add_argument('--hid_size', type=int, default=256, required=False)
     parser.add_argument('--num_layers', type=int, default=3, required=False)
     parser.add_argument('--num_layers', type=int, default=3, required=False)
-    parser.add_argument('--averaging_expiration', type=float, default=15, required=False)
-    parser.add_argument('--round_timeout', type=float, default=30, required=False)
-    parser.add_argument('--request_timeout', type=float, default=3, required=False)
+    parser.add_argument('--averaging_expiration', type=float, default=5, required=False)
+    parser.add_argument('--round_timeout', type=float, default=15, required=False)
+    parser.add_argument('--request_timeout', type=float, default=1, required=False)
     parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
     parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
     parser.add_argument('--increase_file_limit', action="store_true")
     parser.add_argument('--increase_file_limit', action="store_true")
     args = vars(parser.parse_args())
     args = vars(parser.parse_args())

+ 8 - 8
benchmarks/benchmark_dht.py

@@ -20,7 +20,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
                   wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration: float):
                   wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration: float):
     random.seed(random_seed)
     random.seed(random_seed)
 
 
-    print("Creating peers...")
+    logger.info("Creating peers...")
     peers = []
     peers = []
     for _ in trange(num_peers):
     for _ in trange(num_peers):
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
@@ -32,10 +32,10 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
 
     expert_uids = list(set(f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
     expert_uids = list(set(f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
                            for _ in range(num_experts)))
                            for _ in range(num_experts)))
-    print(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
+    logger.info(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
     random.shuffle(expert_uids)
     random.shuffle(expert_uids)
 
 
-    print(f"Storing experts to dht in batches of {expert_batch_size}...")
+    logger.info(f"Storing experts to dht in batches of {expert_batch_size}...")
     successful_stores = total_stores = total_store_time = 0
     successful_stores = total_stores = total_store_time = 0
     benchmark_started = time.perf_counter()
     benchmark_started = time.perf_counter()
     endpoints = []
     endpoints = []
@@ -52,8 +52,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
         successful_stores += sum(successes)
         successful_stores += sum(successes)
         time.sleep(wait_after_request)
         time.sleep(wait_after_request)
 
 
-    print(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
-    print(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
+    logger.info(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
+    logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
     time.sleep(wait_before_read)
     time.sleep(wait_before_read)
 
 
     if time.perf_counter() - benchmark_started > expiration:
     if time.perf_counter() - benchmark_started > expiration:
@@ -74,11 +74,11 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     if time.perf_counter() - benchmark_started > expiration:
     if time.perf_counter() - benchmark_started > expiration:
         logger.warning("keys expired midway during get requests. If that isn't desired, increase expiration_time param")
         logger.warning("keys expired midway during get requests. If that isn't desired, increase expiration_time param")
 
 
-    print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
-    print(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
+    logger.info(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
+    logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
 
 
     alive_peers = [peer.is_alive() for peer in peers]
     alive_peers = [peer.is_alive() for peer in peers]
-    print(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
+    logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 5 - 1
benchmarks/benchmark_tensor_compression.py

@@ -5,6 +5,10 @@ import torch
 
 
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
 
 
 
 
 def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
 def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
@@ -29,4 +33,4 @@ if __name__ == "__main__":
         for i in range(args.num_iters):
         for i in range(args.num_iters):
             tm += benchmark_compression(X, compression_type)
             tm += benchmark_compression(X, compression_type)
         tm /= args.num_iters
         tm /= args.num_iters
-        print(f"Compression type: {name}, time: {tm}")
+        logger.info(f"Compression type: {name}, time: {tm}")

+ 20 - 16
benchmarks/benchmark_throughput.py

@@ -10,19 +10,23 @@ import hivemind
 from hivemind import find_open_port
 from hivemind import find_open_port
 from hivemind.server import layers
 from hivemind.server import layers
 from hivemind.utils.threading import increase_file_limit
 from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
 
 
 
 
 def print_device_info(device=None):
 def print_device_info(device=None):
     """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
     """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
     device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
     device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
-    print('Using device:', device)
+    logger.info(f'Using device: {device}')
 
 
     # Additional Info when using cuda
     # Additional Info when using cuda
     if device.type == 'cuda':
     if device.type == 'cuda':
-        print(torch.cuda.get_device_name(0))
-        print('Memory Usage:')
-        print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
-        print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
+        logger.info(torch.cuda.get_device_name(0))
+        logger.info(f'Memory Usage:')
+        logger.info(f'Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB')
+        logger.info(f'Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB')
 
 
 
 
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
@@ -111,25 +115,25 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
         abs(timestamps[key2] - timestamps[key1]) if (key1 in timestamps and key2 in timestamps) else float('nan')
         abs(timestamps[key2] - timestamps[key1]) if (key1 in timestamps and key2 in timestamps) else float('nan')
     total_examples = batch_size * num_clients * num_batches_per_client
     total_examples = batch_size * num_clients * num_batches_per_client
 
 
-    print('\n' * 3)
-    print("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
-    print(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, max_batch_size={max_batch_size},"
+    logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
+    logger.info(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, max_batch_size={max_batch_size},"
           f" expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
           f" expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
-    print(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
+    logger.info(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
           f"batch_size={batch_size}, backprop={backprop}")
           f"batch_size={batch_size}, backprop={backprop}")
 
 
-    print("Results: ")
-    print(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
+    logger.info("Results: ")
+    logger.info(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
           f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
           f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
           f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
           f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
-    print(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
-    print(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
+    logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
+    logger.info(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
           f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
           f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
-    print(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
+    logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
     if benchmarking_failed.is_set():
     if benchmarking_failed.is_set():
-        print("Note: benchmark code failed, timing/memory results only indicate time till failure!")
+        logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
     print_device_info(device)
     print_device_info(device)
-    print(flush=True)
+    sys.stdout.flush()
+    sys.stderr.flush()
 
 
     assert not benchmarking_failed.is_set()
     assert not benchmarking_failed.is_set()
 
 

+ 12 - 0
codecov.yml

@@ -0,0 +1,12 @@
+comment:
+  layout: "diff, files"
+  behavior: default
+  require_changes: true
+coverage:
+  status:
+    patch:
+      default:
+        informational: true
+    project:
+      default:
+        threshold: 1%

+ 1 - 1
docs/modules/client.rst

@@ -25,4 +25,4 @@
 .. autoclass:: DecentralizedAverager
 .. autoclass:: DecentralizedAverager
    :members:
    :members:
    :member-order: bysource
    :member-order: bysource
-   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part
+   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part, register_allreduce_group

+ 18 - 12
examples/albert/README.md

@@ -12,21 +12,24 @@ This tutorial will walk you through the steps to set up collaborative training w
 ## Running an experiment
 ## Running an experiment
 - Run the first DHT peer to welcome trainers and record training statistics (e.g. loss, performance):
 - Run the first DHT peer to welcome trainers and record training statistics (e.g. loss, performance):
    - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics; If you're unfamiliar with Weights & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
    - In this example, we use [wandb.ai](https://wandb.ai/site) to plot training metrics; If you're unfamiliar with Weights & Biases, here's a [quickstart tutorial](https://docs.wandb.ai/quickstart).
-   - Run `python run_first_peer.py --listen_on '[::]:*' --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
+   - Run `python run_first_peer.py --dht_listen_on '[::]:*' --experiment_prefix NAME_YOUR_EXPERIMENT --wandb_project WANDB_PROJECT_HERE`
    - `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.` due to naming conventions.
    - `NAME_YOUR_EXPERIMENT` must be a unique name of this training run, e.g. `my-first-albert`. It cannot contain `.` due to naming conventions.
    - `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the same project name.
    - `WANDB_PROJECT_HERE` is a name of wandb project used to track training metrics. Multiple experiments can have the same project name.
    - This peer will run a DHT node on a certain IP/port (`Running DHT root at ...`). You will need this address for next steps
    - This peer will run a DHT node on a certain IP/port (`Running DHT root at ...`). You will need this address for next steps
 ```
 ```
-+ python ./run_first_peer.py --listen_on '[::]:31209' --experiment_prefix ysda_albert_v10 --wandb_project Demo-run
-[2021/04/19 02:30:06.051][WARN][root.<module>:36] No address specified. Attempting to infer address from DNS.
-[2021/04/19 02:30:06.088][INFO][root.<module>:44] Running DHT root at 18.217.13.97:31209
-wandb: Currently logged in as: ??? (use `wandb login --relogin` to force relogin)
-wandb: Tracking run with wandb version 0.10.26
-wandb: Syncing run wandering-sky-58
-wandb: ⭐ View project at https://wandb.ai/yhn112/Demo-run
-wandb: 🚀 View run at https://wandb.ai/yhn112/Demo-run/runs/38ygvt3n
-wandb: Run data is saved locally in /home/hivemind/examples/albert/wandb/run-20210419_023006-38ygvt3n
++ python run_first_peer.py --dht_listen_on '[::]:*' --experiment_prefix my-albert-v1 --wandb_project Demo-run
+[2021/06/17 16:26:35.931][WARN][root.<module>:140] No address specified. Attempting to infer address from DNS.
+[2021/06/17 16:26:36.083][INFO][root.<module>:149] Running DHT root at 193.106.95.184:38319
+wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
+wandb: Tracking run with wandb version 0.10.32
+wandb: Syncing run dry-mountain-2
+wandb:  View project at https://wandb.ai/XXX/Demo-run
+wandb:  View run at https://wandb.ai/XXX/Demo-run/runs/YYY
+wandb: Run data is saved locally in /path/to/run/data
 wandb: Run `wandb offline` to turn off syncing.
 wandb: Run `wandb offline` to turn off syncing.
+[2021/04/19 02:26:41.064][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
+[2021/04/19 02:26:44.068][INFO][optim.collaborative.fetch_collaboration_state:323] Found no active peers: None
+...
 [2021/04/19 02:37:37.246][INFO][root.<module>:74] 11.05164
 [2021/04/19 02:37:37.246][INFO][root.<module>:74] 11.05164
 [2021/04/19 02:39:37.441][INFO][root.<module>:74] 11.03771
 [2021/04/19 02:39:37.441][INFO][root.<module>:74] 11.03771
 [2021/04/19 02:40:37.541][INFO][root.<module>:74] 11.02886
 [2021/04/19 02:40:37.541][INFO][root.<module>:74] 11.02886
@@ -37,7 +40,7 @@ wandb: Run `wandb offline` to turn off syncing.
   - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
   - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
   - run:
   - run:
 ```shell
 ```shell
- CUDA_VISIBLE_DEVICES=0 HIVEMIND_THREADS=64 python ./hivemind/examples/albert/run_trainer.py \
+HIVEMIND_THREADS=64 python run_trainer.py \
  --experiment_prefix SAME_AS_IN_RUN_FIRST_PEER --initial_peers ONE_OR_MORE_PEERS --seed 42 \
  --experiment_prefix SAME_AS_IN_RUN_FIRST_PEER --initial_peers ONE_OR_MORE_PEERS --seed 42 \
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
 ```
 ```
@@ -45,11 +48,14 @@ Here, `ONE_OR_MORE_PEERS` stands for either your coordinator endpoint (e.g. `123
 
 
 As the peer begins training, it will periodically report training logs in the following form:
 As the peer begins training, it will periodically report training logs in the following form:
 ```
 ```
-{'loss': 4.3577, 'learning_rate': 0.001318944, 'epoch': 0.0}
 [...][INFO][...] Collaboration accumulated 448 samples from 17 peers; ETA 18.88 seconds (refresh in 15.73s.)
 [...][INFO][...] Collaboration accumulated 448 samples from 17 peers; ETA 18.88 seconds (refresh in 15.73s.)
 [...][INFO][...] Collaboration accumulated 4096 samples from 16 peers; ETA 0.00 seconds (refresh in 0.50s.)
 [...][INFO][...] Collaboration accumulated 4096 samples from 16 peers; ETA 0.00 seconds (refresh in 0.50s.)
 [...][INFO][optim.collaborative.step:195] Averaged tensors successfully with 17 peers
 [...][INFO][optim.collaborative.step:195] Averaged tensors successfully with 17 peers
 [...][INFO][optim.collaborative.step:211] Optimizer step: done!
 [...][INFO][optim.collaborative.step:211] Optimizer step: done!
+06/17/2021 18:58:23 - INFO - __main__ -   Step 0
+06/17/2021 18:58:23 - INFO - __main__ -   Your current contribution: 892 samples
+06/17/2021 18:58:23 - INFO - __main__ -   Local loss: 11.023
+
 ```
 ```
 
 
 __Sanity check:__ a healthy peer will periodically report `Averaged tensors successfully with [N > 1]` peers.
 __Sanity check:__ a healthy peer will periodically report `Averaged tensors successfully with [N > 1]` peers.

+ 11 - 5
examples/albert/run_first_peer.py

@@ -17,7 +17,6 @@ import hivemind
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 import metrics_utils
 import metrics_utils
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
@@ -163,6 +162,10 @@ if __name__ == '__main__':
                        for peer in metrics_dict]
                        for peer in metrics_dict]
             latest_step = max(item.step for item in metrics)
             latest_step = max(item.step for item in metrics)
             if latest_step != current_step:
             if latest_step != current_step:
+                logger.debug(f"Got metrics from {len(metrics)} peers")
+
+                for i, metrics_for_peer in enumerate(metrics):
+                    logger.debug(f"{i} peer {metrics_for_peer}")
                 current_step = latest_step
                 current_step = latest_step
                 alive_peers = 0
                 alive_peers = 0
                 num_batches = 0
                 num_batches = 0
@@ -176,17 +179,20 @@ if __name__ == '__main__':
                     sum_perf += item.samples_per_second
                     sum_perf += item.samples_per_second
                     num_samples += item.samples_accumulated
                     num_samples += item.samples_accumulated
                     sum_mini_steps += item.mini_steps
                     sum_mini_steps += item.mini_steps
+                current_loss = sum_loss / sum_mini_steps
+
                 if coordinator_args.wandb_project is not None:
                 if coordinator_args.wandb_project is not None:
                     wandb.log({
                     wandb.log({
-                        "loss": sum_loss / sum_mini_steps,
+                        "loss": current_loss,
                         "alive peers": alive_peers,
                         "alive peers": alive_peers,
                         "samples": num_samples,
                         "samples": num_samples,
-                        "performance": sum_perf
+                        "performance": sum_perf,
+                        "step": latest_step
                     })
                     })
                 if checkpoint_handler.is_time_to_save_state(current_step):
                 if checkpoint_handler.is_time_to_save_state(current_step):
                     checkpoint_handler.save_state(current_step)
                     checkpoint_handler.save_state(current_step)
                     if checkpoint_handler.is_time_to_upload():
                     if checkpoint_handler.is_time_to_upload():
-                        checkpoint_handler.upload_checkpoint(sum_loss / sum_mini_steps)
-                logger.info(f"Step #{current_step}\tloss = {sum_loss / alive_peers:.5f}")
+                        checkpoint_handler.upload_checkpoint(current_loss)
+                logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
         logger.debug("Peer is still alive...")
         logger.debug("Peer is still alive...")
         time.sleep(coordinator_args.refresh_period)
         time.sleep(coordinator_args.refresh_period)

+ 7 - 6
examples/albert/run_trainer.py

@@ -112,7 +112,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
 
     def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
     def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
                        control: transformers.TrainerControl, **kwargs):
                        control: transformers.TrainerControl, **kwargs):
-        logger.warning('Loading state from peers')
+        logger.info('Loading state from peers')
         self.collaborative_optimizer.load_state_from_peers()
         self.collaborative_optimizer.load_state_from_peers()
 
 
     def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
     def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
@@ -139,14 +139,15 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 logger.info(f"Step {self.collaborative_optimizer.local_step}")
                 logger.info(f"Step {self.collaborative_optimizer.local_step}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 if self.steps:
                 if self.steps:
-                    logger.info(f"Loss of your model: {self.loss/self.steps}")
+                    logger.info(f"Local loss: {self.loss / self.steps}")
 
 
                 self.loss = 0
                 self.loss = 0
                 self.steps = 0
                 self.steps = 0
-                self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
-                               subkey=self.local_public_key, value=statistics.dict(),
-                               expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
-                               return_future=True)
+                if self.collaborative_optimizer.is_synchronized:
+                    self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
+                                   subkey=self.local_public_key, value=statistics.dict(),
+                                   expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                                   return_future=True)
 
 
         self.samples = self.collaborative_optimizer.local_samples_accumulated
         self.samples = self.collaborative_optimizer.local_samples_accumulated
 
 

+ 6 - 4
examples/albert/tokenize_wikitext103.py

@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 #!/usr/bin/env python
 """ This script builds a pre-tokenized compressed representation of wikitext103 using huggingface/datasets """
 """ This script builds a pre-tokenized compressed representation of wikitext103 using huggingface/datasets """
 import random
 import random
-from collections import defaultdict
 from functools import partial
 from functools import partial
 from multiprocessing import cpu_count
 from multiprocessing import cpu_count
 
 
@@ -10,6 +9,9 @@ from datasets import load_dataset
 from transformers import AlbertTokenizerFast
 from transformers import AlbertTokenizerFast
 
 
 
 
+COLUMN_NAMES = ('attention_mask', 'input_ids', 'sentence_order_label', 'special_tokens_mask', 'token_type_ids')
+
+
 def create_instances_from_document(tokenizer, document, max_seq_length):
 def create_instances_from_document(tokenizer, document, max_seq_length):
     """Creates `TrainingInstance`s for a single document."""
     """Creates `TrainingInstance`s for a single document."""
     # We DON'T just concatenate all of the tokens from a document into a long
     # We DON'T just concatenate all of the tokens from a document into a long
@@ -76,14 +78,14 @@ def tokenize_function(tokenizer, examples):
     # Remove empty texts
     # Remove empty texts
     texts = (text for text in examples["text"] if len(text) > 0 and not text.isspace())
     texts = (text for text in examples["text"] if len(text) > 0 and not text.isspace())
 
 
-    new_examples = defaultdict(list)
+    new_examples = {col: [] for col in COLUMN_NAMES}
 
 
     for text in texts:
     for text in texts:
         instances = create_instances_from_document(tokenizer, text, max_seq_length=512)
         instances = create_instances_from_document(tokenizer, text, max_seq_length=512)
         for instance in instances:
         for instance in instances:
             for key, value in instance.items():
             for key, value in instance.items():
                 new_examples[key].append(value)
                 new_examples[key].append(value)
-
+    
     return new_examples
     return new_examples
 
 
 
 
@@ -96,7 +98,7 @@ if __name__ == '__main__':
     tokenized_datasets = wikitext.map(
     tokenized_datasets = wikitext.map(
         partial(tokenize_function, tokenizer),
         partial(tokenize_function, tokenizer),
         batched=True,
         batched=True,
-        num_proc=cpu_count(),
+        num_proc=8,
         remove_columns=["text"],
         remove_columns=["text"],
     )
     )
 
 

+ 2 - 1
hivemind/__init__.py

@@ -1,7 +1,8 @@
 from hivemind.client import *
 from hivemind.client import *
 from hivemind.dht import *
 from hivemind.dht import *
+from hivemind.p2p import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 from hivemind.optim import *
 from hivemind.optim import *
 
 
-__version__ = '0.9.8'
+__version__ = "0.9.10"

+ 107 - 55
hivemind/client/averaging/__init__.py

@@ -20,7 +20,8 @@ import torch
 import numpy as np
 import numpy as np
 
 
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
+from hivemind.client.averaging.partition import DEFAULT_PART_SIZE_BYTES
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.client.averaging.group_info import GroupInfo
@@ -34,9 +35,8 @@ from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescripto
 
 
 # flavour types
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
-DataForGather = Any
+GatheredData = Any
 logger = get_logger(__name__)
 logger = get_logger(__name__)
-DEFAULT_CHUNK_SIZE_BYTES = 2 ** 16
 
 
 
 
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
@@ -61,7 +61,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
-    :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
+    :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param throughput: if specified, this value represents the network bandwidth available to averager.
     :param throughput: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           By default, the averager is assumed to have the average bandwidth of his group.
           If throughput == 0, averager will rely on its groupmates to do all the averaging.
           If throughput == 0, averager will rely on its groupmates to do all the averaging.
@@ -71,6 +71,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
     :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
     :param kwargs: extra parameters forwarded to grpc.aio.server
     :param kwargs: extra parameters forwarded to grpc.aio.server
+    :param auxiliary: if this flag is specified, averager.step will only assist others without sending
+          local tensors for averaging
+    :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
+      with averager.allow_state_sharing = True / False
 
 
     Example:
     Example:
 
 
@@ -90,10 +94,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
-                 allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
+                 averaging_expiration: float = 15, request_timeout: float = 3, averaging_alpha: float = 1.0,
+                 part_size_bytes: int = DEFAULT_PART_SIZE_BYTES, allreduce_timeout: Optional[float] = None,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
                  throughput: Optional[float] = None, min_vector_size: int = 0,
+                 auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
                  listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
         assert '.' not in prefix, "group prefix must be a string without trailing '.'"
@@ -102,10 +107,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if not is_power_of_two(target_group_size):
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
             logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
         assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
+        assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
 
 
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
         self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
         self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
+        if not self.listen:
+            self.mode = AveragingMode.CLIENT
+        elif auxiliary:
+            self.mode = AveragingMode.AUX
+        else:
+            self.mode = AveragingMode.NODE
+
         self.channel_options = channel_options
         self.channel_options = channel_options
         self.daemon = daemon
         self.daemon = daemon
 
 
@@ -122,13 +135,17 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.matchmaking_kwargs = dict(
         self.matchmaking_kwargs = dict(
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
-        self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes,
+        self.allreduce_kwargs = dict(compression_type=compression_type, part_size_bytes=part_size_bytes,
                                      min_vector_size=min_vector_size)
                                      min_vector_size=min_vector_size)
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
         self._port = mp.Value(ctypes.c_uint32, 0)  # assigned when averager starts, accessible via self.port
+
+        self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
+        self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing
+
         self._averager_endpoint: Optional[Endpoint] = None
         self._averager_endpoint: Optional[Endpoint] = None
         if not self.listen:
         if not self.listen:
             self._averager_endpoint = f'client::{uuid.uuid4()}'
             self._averager_endpoint = f'client::{uuid.uuid4()}'
@@ -146,6 +163,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     def port(self) -> Optional[Port]:
     def port(self) -> Optional[Port]:
         return self._port.value if self._port.value != 0 else None
         return self._port.value if self._port.value != 0 else None
 
 
+    @property
+    def allow_state_sharing(self) -> bool:
+        """ if set to True, other peers can download this peer's state """
+        return bool(self._allow_state_sharing.value)
+
+    @allow_state_sharing.setter
+    def allow_state_sharing(self, value: bool):
+        if value is True and not self.listen:
+            logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.")
+        else:
+            self._allow_state_sharing.value = value
+
     @property
     @property
     def endpoint(self) -> Optional[Endpoint]:
     def endpoint(self) -> Optional[Endpoint]:
         if self.listen and self._averager_endpoint is None:
         if self.listen and self._averager_endpoint is None:
@@ -222,8 +251,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if self._parent_pid != os.getpid() or self.is_alive():
         if self._parent_pid != os.getpid() or self.is_alive():
             self.shutdown()
             self.shutdown()
 
 
-    def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None,
-             allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
+    def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = None,
+             timeout: Optional[float] = None, allow_retries: bool = True, wait: bool = True
+             ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
         """
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
 
@@ -236,7 +266,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         """
-        assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
+        if self.mode == AveragingMode.AUX and weight is not None:
+            logger.warning("Averager is running in auxiliary mode, weight is unused.")
+        if weight is None:
+            weight = float(self.mode != AveragingMode.AUX)
+        assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
+
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
         self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
         self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
@@ -245,28 +280,21 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
     async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
                     allow_retries: bool, timeout: Optional[float]):
                     allow_retries: bool, timeout: Optional[float]):
-        loop = asyncio.get_event_loop()
         start_time = get_dht_time()
         start_time = get_dht_time()
-        group_id = None
 
 
         try:
         try:
             while not future.done():
             while not future.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
+                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary]) 
                     group_info = await self._matchmaking.look_for_group(timeout=timeout,
                     group_info = await self._matchmaking.look_for_group(timeout=timeout,
                                                                         data_for_gather=data_for_gather)
                                                                         data_for_gather=data_for_gather)
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
                         raise AllreduceException("Averaging step failed: could not find a group.")
-                    group_id = group_info.group_id
-                    allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
-                    self._running_groups[group_id] = allreduce_runner
-                    self._pending_group_assembled.set()
-                    await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                    await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
 
 
-                    # averaging is finished, exit the loop
-                    future.set_result(allreduce_runner.gathered)
+                    future.set_result(await asyncio.wait_for(
+                        self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout))
+                    # averaging is finished, loop will now exit
 
 
                 except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
                 except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
                         asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
                         asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
@@ -277,10 +305,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     else:
                     else:
                         logger.warning(f"Averager caught {repr(e)}, retrying")
                         logger.warning(f"Averager caught {repr(e)}, retrying")
 
 
-                finally:
-                    _ = self._running_groups.pop(group_id, None)
-                    self._pending_group_assembled.set()
-
         except BaseException as e:
         except BaseException as e:
             if not future.done():
             if not future.done():
                 future.set_exception(e)
                 future.set_exception(e)
@@ -290,35 +314,51 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
                 future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
                                                   " Please report this to hivemind issues."))
                                                   " Please report this to hivemind issues."))
 
 
-    async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
-        """ Use a group description found by Matchmaking to form AllreduceRunner """
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """ Run All-Reduce in a given group and update tensors in place, return gathered metadata """
         try:
         try:
-            weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
+            weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
+            modes = tuple(map(AveragingMode, mode_ids))
 
 
-            # compute optimal part sizes from peer throughputs
-            incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
-            part_sizes = await asyncio.get_event_loop().run_in_executor(
+            # compute optimal part sizes from peer throughputs; TODO: replace with proper load balancing
+            incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0
+                                    for thr, mode in zip(throughputs, modes)]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
                 None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
                 None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
-            async with self.get_tensors_async() as averaged_tensors:
-                return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
-                                       ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
-                                       weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
-        except Exception as e:
-            raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
 
 
-    def update_tensors(self, allreduce_group: AllReduceRunner):
-        """
-        a private (extendable) method that applies changes from a finished allreduce to local tensors
-        """
-        assert allreduce_group.return_deltas and allreduce_group.future.done()
-        averaging_deltas = allreduce_group.future.result()
+            async with self.get_tensors_async() as local_tensors:
+                allreduce = AllReduceRunner(
+                    group_id=group_info.group_id, tensors=local_tensors, endpoint=self.endpoint,
+                    ordered_group_endpoints=group_info.endpoints, peer_fractions=peer_fractions, weights=weights,
+                    gathered=user_gathered, modes=modes, **kwargs)
 
 
-        with torch.no_grad(), self.get_tensors() as local_tensors:
-            assert len(local_tensors) == len(self._averaged_tensors)
-            for tensor, update in zip(local_tensors, averaging_deltas):
-                tensor.add_(update, alpha=self._averaging_alpha)
-        self.last_updated = get_dht_time()
+                with self.register_allreduce_group(group_info.group_id, allreduce):
+
+                    # actually run all-reduce
+                    averaging_outputs = [output async for output in allreduce]
+
+                    if modes[group_info.endpoints.index(self.endpoint)] != AveragingMode.AUX:
+                        assert len(local_tensors) == len(self._averaged_tensors)
+                        for tensor, update in zip(local_tensors, averaging_outputs):
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self.last_updated = get_dht_time()
+
+                return allreduce.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+    @contextlib.contextmanager
+    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
+        """ registers a given all-reduce runner to listen for incoming connections """
+        try:
+            self._running_groups[group_id] = allreduce
+            self._pending_group_assembled.set()
+            yield
+        finally:
+            self._running_groups.pop(group_id, None)
+            self._pending_group_assembled.set()
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -366,10 +406,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     async def _declare_for_download_periodically(self):
     async def _declare_for_download_periodically(self):
         download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
         download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
         while True:
         while True:
-            asyncio.create_task(asyncio.wait_for(self.dht.store(
-                download_key, subkey=self.endpoint, value=self.last_updated,
-                expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
-                timeout=self._matchmaking.averaging_expiration))
+            if self.allow_state_sharing:
+                asyncio.create_task(asyncio.wait_for(self.dht.store(
+                    download_key, subkey=self.endpoint, value=self.last_updated,
+                    expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
+                    timeout=self._matchmaking.averaging_expiration))
             await asyncio.sleep(self._matchmaking.averaging_expiration)
             await asyncio.sleep(self._matchmaking.averaging_expiration)
 
 
     async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
     async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
@@ -381,11 +422,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
          - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
          - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
          - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
          - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
         """
         """
-        chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
+        if not self.allow_state_sharing:
+            return  # deny request and direct peer to the next prospective averager
         metadata, tensors = await self._get_current_state_from_host_process()
         metadata, tensors = await self._get_current_state_from_host_process()
 
 
         for tensor in tensors:
         for tensor in tensors:
-            for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes):
+            for part in split_for_streaming(serialize_torch_tensor(tensor)):
                 if metadata is not None:
                 if metadata is not None:
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     metadata = None
                     metadata = None
@@ -452,6 +494,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                             current_tensor_parts.append(message.tensor_part)
                             current_tensor_parts.append(message.tensor_part)
                         if current_tensor_parts:
                         if current_tensor_parts:
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
+
+                        if not metadata:
+                            logger.debug(f"Peer {peer} did not send its state.")
+                            continue
+
                         logger.info(f"Finished downloading state from {peer}")
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
                         future.set_result((metadata, tensors))
                         self.last_updated = get_dht_time()
                         self.last_updated = get_dht_time()
@@ -512,7 +559,12 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
     :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
     :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
     """
     """
     while True:
     while True:
-        trigger, future = pipe.recv()
+        try:
+            trigger, future = pipe.recv()
+        except BaseException as e:
+            logger.debug(f"Averager background thread finished: {repr(e)}")
+            break
+            
         if trigger == '_SHUTDOWN':
         if trigger == '_SHUTDOWN':
             break
             break
 
 

+ 166 - 189
hivemind/client/averaging/allreduce.py

@@ -1,252 +1,229 @@
 import asyncio
 import asyncio
-from typing import Sequence, Set, Dict, Tuple, Iterable, AsyncIterator, Any
+from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
+from enum import Enum
 
 
 import grpc
 import grpc
 import torch
 import torch
 
 
-from hivemind.utils import Endpoint, get_logger, ChannelCache, anext
-from hivemind.utils import split_for_streaming, combine_from_streaming
+from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
+from hivemind.utils import Endpoint, get_logger, ChannelCache
+from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
+from hivemind.proto import averaging_pb2_grpc, averaging_pb2
 
 
 # flavour types
 # flavour types
 GroupID = bytes
 GroupID = bytes
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class AllReduceProtocol:
+class AveragingMode(Enum):
+    NODE = 0
+    CLIENT = 1
+    AUX = 2
+
+
+class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     """
     """
     An internal class that runs butterfly AllReduce in a predefined group of averagers
     An internal class that runs butterfly AllReduce in a predefined group of averagers
 
 
+    :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
+    :param group_id: unique identifier of this specific all-reduce run
+    :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param endpoint: your endpoint, must be included in ordered_group_endpoints
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
     :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
-    :param part_sizes: for each peer, a number of vector elements that this peer is responsible for averaging
-    :param return_deltas: if True, returns the element-wise differences (averaged_tensors - original_tensors)
-           default (False) - return averaged_tensors by themselves
+    :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
+      (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
+    :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
+    :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
+    :param gathered: additional user-defined data collected from this group
+    :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     """
     """
 
 
-    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], part_sizes: Tuple[int, ...], return_deltas: bool = False):
+    def __init__(
+            self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
+            ordered_group_endpoints: Sequence[Endpoint], peer_fractions: Tuple[float, ...],
+            weights: Optional[Sequence[float]] = None, modes: Optional[Sequence[AveragingMode]] = None,
+            gathered: Optional[Dict[Endpoint, Any]] = None, **kwargs):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
-        self.group_id, self.endpoint = group_id, endpoint
-        self.ordered_group_endpoints, self.part_sizes = ordered_group_endpoints, part_sizes
-        self.client_mode_endpoints = {endpoint for endpoint, part_size in zip(self.ordered_group_endpoints, part_sizes)
-                                      if part_size == 0}
-        self.local_tensor_parts = dict(zip(ordered_group_endpoints, split_into_parts(tensors, part_sizes)))
-        self.tensor_shapes = tuple(tensor.shape for tensor in tensors)
-        self.return_deltas = return_deltas
-
-        self.accumulator = torch.zeros_like(self.local_tensor_parts[self.endpoint])
-        self.denominator = 0.0  # number of peers added to accumulator or sum of their weights
-        self.accumulated_from: Set[Endpoint] = set()  # peers that we have accumulated our part from
-        self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
-        self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
-        self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
-        for endpoint in self.client_mode_endpoints:
-            self.averaged_tensor_parts[endpoint] = torch.tensor([])
+        modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
+        weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
+        assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
+        assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
+        for mode, frac, weight in zip(modes, peer_fractions, weights):
+            assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
+            assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
+
+        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
+
+        self._future = asyncio.Future()
+
+        self.sender_endpoints, self.sender_weights = [], []
+        for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
+            if mode != AveragingMode.AUX:
+                self.sender_endpoints.append(endpoint)
+                self.sender_weights.append(weight)
+
+        endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
+        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
+        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
+        self.tensor_part_reducer = TensorPartReducer(tuple(part.shape for part in self.parts_for_local_averaging),
+                                                     len(self.sender_endpoints), self.sender_weights)
 
 
     def __repr__(self):
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
 
 
-    def __await__(self):
-        return self.future.__await__()
+    def __aiter__(self):
+        return self.run()
 
 
     def __contains__(self, endpoint: Endpoint):
     def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.local_tensor_parts
+        return endpoint in self.ordered_group_endpoints
 
 
     @property
     @property
     def group_size(self):
     def group_size(self):
         return len(self.ordered_group_endpoints)
         return len(self.ordered_group_endpoints)
 
 
-    async def accumulate_part(self, source: Endpoint, remote_part: torch.Tensor, weight: float = 1.0) -> torch.Tensor:
-        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
-        assert not self.averaged_part.done(), f"already finished averaging part: {self.averaged_part}"
-        assert not self.future.done(), f"already finished allreduce: {self.future}"
-        assert source in self.local_tensor_parts, "unexpected source, not a part of current group"
-        assert source not in self.accumulated_from, "duplicate source, already received that part"
-        assert not self.endpoint in self.client_mode_endpoints, f"{self.endpoint} is in client mode"
-        assert isinstance(weight, (int, float)) and weight > 0, "averaging weights must be a non-negative int/float"
-        logger.debug(f"{self} - accumulating tensor part from {source}")
-
-        self.accumulator.add_(remote_part, alpha=weight)
-        self.denominator += weight
-        self.accumulated_from.add(source)
-
-        assert len(self.accumulated_from) <= self.group_size
-        if len(self.accumulated_from) == len(self.local_tensor_parts):
-            average_result = self.accumulator.div_(self.denominator)
-            self.register_averaged_part(self.endpoint, average_result)
-            self.averaged_part.set_result(average_result)
-
-        return await self.averaged_part
-
-    def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
-        assert not self.future.done(), f"already finished allreduce: {self.future}"
-        assert source in self.local_tensor_parts, "the provider of averaged part is not from my group"
-        assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
-        assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
-        assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
-        logger.debug(f"{self} - receiving averaged tensor part from {source}")
-        self.averaged_tensor_parts[source] = averaged_part
-        if len(self.averaged_tensor_parts) == len(self.local_tensor_parts):
-            ordered_averaged_parts = [self.averaged_tensor_parts[endpoint] for endpoint in self.ordered_group_endpoints]
-            outputs = restore_from_parts(ordered_averaged_parts, self.tensor_shapes)
-
-            if self.return_deltas:
-                local_parts = [self.local_tensor_parts[peer] for peer in self.ordered_group_endpoints]
-                with torch.no_grad():
-                    original_tensors = restore_from_parts(local_parts, self.tensor_shapes)
-                    for averaged_tensor, original_tensor in zip(outputs, original_tensors):
-                        averaged_tensor -= original_tensor
-
-            self.future.set_result(outputs)
-
-    def cancel(self) -> bool:
-        if not self.future.done():
-            logger.debug(f"{self} - cancelled")
-            self.future.cancel()
-            if not self.averaged_part.done():
-                self.averaged_part.cancel()
-            return True
-        else:
-            logger.debug(f"{self} - failed to cancel, allreduce is already finished: {self.future}")
-            return False
-
-    def set_exception(self, exception: Exception) -> bool:
-        if not self.future.done():
-            logger.debug(f"{self} - {exception}")
-            self.future.set_exception(exception)
-            if not self.averaged_part.done():
-                self.averaged_part.cancel()
-            return True
-        else:
-            logger.debug(f"{self} - failed to set {exception}, allreduce already finished: {self.future}")
-            return False
-
-
-class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragingServicer):
-    """
-    A class that implements ButterflyAllReduceProtocol on top of a gRPC servicer
-    """
-
-    def __init__(self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-                 ordered_group_endpoints: Sequence[Endpoint], compression_type: runtime_pb2.CompressionType,
-                 chunk_size_bytes: int, part_sizes: Tuple[int, ...], weights: Tuple[float, ...],
-                 gathered: Dict[Endpoint, Any], return_deltas: bool = False):
-        super().__init__(group_id=group_id, tensors=tensors, endpoint=endpoint, part_sizes=part_sizes,
-                         ordered_group_endpoints=ordered_group_endpoints, return_deltas=return_deltas)
-        self.compression_type, self.chunk_size_bytes, self.gathered = compression_type, chunk_size_bytes, gathered
-        self.peer_weights = dict(zip(self.ordered_group_endpoints, weights))
-
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
 
 
-    async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
-        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
-        if peer_endpoint == self.endpoint:
-            return await self.accumulate_part(self.endpoint, local_part, weight=self.peer_weights[self.endpoint])
-        serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False)
-        chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes)
-
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id,
-                                                       endpoint=self.endpoint, tensor_part=next(chunks)))
-        for chunk in chunks:
-            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
-        await stream.done_writing()
-
-        outputs: Sequence[averaging_pb2.AveragingData] = [message async for message in stream]
-        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
-        if code != averaging_pb2.AVERAGED_PART:
-            raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
-                                     f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
-                                     f" allreduce failed")
-
+    async def run(self) -> AsyncIterator[torch.Tensor]:
+        """ Run all-reduce, return differences between averaged and original tensors as they are computed """
+        pending_tasks = set()
         try:
         try:
-            averaged_part = local_part + deserialize_torch_tensor(combine_from_streaming(
-                [message.tensor_part for message in outputs]))
-        except RuntimeError as e:
-            raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}")
+            if len(self.sender_endpoints) == 0:
+                logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
+                self.finalize()
 
 
-        self.register_averaged_part(peer_endpoint, averaged_part)
-        return averaged_part
+            elif self.endpoint in self.sender_endpoints:
+                for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
+                    if parts != 0:
+                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
 
 
-    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
-        await stream.done_writing()
+                async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
+                    yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
+                self.finalize()
+
+            else:  # auxiliary peer
+                await self.tensor_part_reducer.finished.wait()
+                self.finalize()
 
 
-    async def run(self) -> Sequence[torch.Tensor]:
-        """
-        send allreduce requests to all peers and collect results, return the averaged tensor (or deltas)
-        """
-        try:
-            await asyncio.gather(self, *(self._communicate_with_peer(peer, self.local_tensor_parts[peer])
-                                         for i, peer in enumerate(self.ordered_group_endpoints)
-                                         if peer not in self.client_mode_endpoints))
-            return await self
         except BaseException as e:
         except BaseException as e:
+            self.finalize(exception=e)
+            for task in pending_tasks:
+                task.cancel()
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             code = averaging_pb2.CANCELLED if isinstance(e, asyncio.CancelledError) else averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            self.set_exception(e)
-            for peer_endpoint, part_size in zip(self.ordered_group_endpoints, self.part_sizes):
-                if peer_endpoint != self.endpoint and part_size > 0:
+            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
+                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
                     asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
                     asyncio.create_task(self._send_error_to_peer(peer_endpoint, code))
             raise
             raise
 
 
-    async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
-                                        ) -> Iterable[runtime_pb2.Tensor]:
-        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
-        try:
-            tensor_part = deserialize_torch_tensor(combine_from_streaming(stream_messages))
-        except RuntimeError as e:
-            raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")
+    async def _communicate_with_peer(self, peer_endpoint: Endpoint):
+        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
+        peer_index = self.ordered_group_endpoints.index(peer_endpoint)
+        if peer_endpoint == self.endpoint:
+            sender_index = self.sender_endpoints.index(peer_endpoint)
+            for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
+                averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
 
-        averaged_part = await self.accumulate_part(source, tensor_part, weight=self.peer_weights[source])
-        serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False)
-        stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
-        return stream_chunks
+        else:
+            loop = asyncio.get_event_loop()
+            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+            write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))
+
+            try:
+                code = None
+                async for part_index, msg in aenumerate(stream):
+                    if code is None:
+                        code = msg.code
+                    averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
+                    self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
+                await write_task
+
+                if code != averaging_pb2.AVERAGED_PART:
+                    raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
+                                             f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
+                                             f", allreduce failed")
+            finally:
+                if not write_task.done():
+                    write_task.cancel()
+
+    async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
+        parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
+        first_part = await anext(parts_aiter)
+        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
+                                                       group_id=self.group_id, endpoint=self.endpoint,
+                                                       tensor_part=first_part))
+        async for part in parts_aiter:
+            await stream.write(averaging_pb2.AveragingData(tensor_part=part))
+
+        await stream.done_writing()
 
 
     async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
     async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
                                  ) -> AsyncIterator[averaging_pb2.AveragingData]:
                                  ) -> AsyncIterator[averaging_pb2.AveragingData]:
-        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the delta"""
+        """ a peer sends us a part of his tensor; we should average it with other peers and return the difference """
         request: averaging_pb2.AveragingData = await anext(stream)
         request: averaging_pb2.AveragingData = await anext(stream)
-
-        if request.group_id != self.group_id:
-            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        reason_to_reject = self._check_reasons_to_reject(request)
+        if reason_to_reject:
+            yield reason_to_reject
+            return
 
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
             try:
-                tensor_chunks = (request.tensor_part, *[msg.tensor_part async for msg in stream])
-                averaged_chunks = iter(await self.accumulate_part_streaming(request.endpoint, tensor_chunks))
-                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks))
-                for averaged_chunk in averaged_chunks:
-                    yield averaging_pb2.AveragingData(tensor_part=averaged_chunk)
+                sender_index = self.sender_endpoints.index(request.endpoint)
+                async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
+                    yield msg
 
 
             except Exception as e:
             except Exception as e:
-                self.set_exception(e)
+                self.finalize(exception=e)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
             error_code = averaging_pb2.MessageCode.Name(request.code)
             logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
             logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
-            self.set_exception(AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
 
+    def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
+        if request.group_id != self.group_id:
+            return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
+        elif self._future.cancelled():
+            return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        elif self._future.done():
+            return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+
+    async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
+        loop = asyncio.get_event_loop()
+        async for part_index, (tensor_part, part_compression) in aenumerate(
+                amap_in_executor(lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression), stream,
+                                 max_prefetch=self.tensor_part_container.prefetch)):
+            averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
+
+            serialized_delta = await loop.run_in_executor(
+                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression))
+            yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
 
-def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int]) -> Tuple[torch.Tensor, ...]:
-    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
-    flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
-    return torch.split_with_sizes(flat_tensor, part_sizes, dim=0)
-
-
-def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
-    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
-    flat_tensor = torch.cat(tuple(chunks))
-    result_sizes = tuple(map(torch.Size.numel, shapes))
-    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
-    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))
-
+    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
+        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
+        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
+        await stream.done_writing()
 
 
-class AllreduceException(Exception):
-    """ A special exception that is raised when allreduce can't continue normally (e.g. disbanded/bad request/etc) """
+    def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
+        assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
+        if not self._future.done():
+            if cancel:
+                logger.debug(f"{self} - cancelled")
+                self._future.cancel()
+            elif exception:
+                logger.debug(f"{self} - caught {exception}")
+                self._future.set_exception(exception)
+            else:
+                logger.debug(f"{self} - finished")
+                self._future.set_result(None)
+            self.tensor_part_container.finalize()
+            self.tensor_part_reducer.finalize()
+            return True
+        else:
+            logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
+            return False

+ 1 - 0
hivemind/client/averaging/load_balancing.py

@@ -28,6 +28,7 @@ def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_
         assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
         assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
         scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
         scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
 
 
+    #TODO(jheuristic) we no longer need hagenbach-bishoff with new AllReduceRunner
     return tuple(hagenbach_bishoff(vector_size, scores))
     return tuple(hagenbach_bishoff(vector_size, scores))
 
 
 
 

+ 1 - 1
hivemind/client/averaging/matchmaking.py

@@ -391,7 +391,7 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
                 self.update_triggered.set()
 
 
-            if maybe_next_leader is None or entry.expiration_time >= self.declared_expiration_time:
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (self.declared_expiration_time, self.endpoint):
                 await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
                 await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
                                    return_when=asyncio.FIRST_COMPLETED)
                                    return_when=asyncio.FIRST_COMPLETED)
                 self.declared_expiration.clear()
                 self.declared_expiration.clear()

+ 224 - 0
hivemind/client/averaging/partition.py

@@ -0,0 +1,224 @@
+"""
+Auxiliary data structures for AllReduceRunner
+"""
+import asyncio
+from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
+from collections import deque
+
+import torch
+import numpy as np
+
+from hivemind.proto.runtime_pb2 import CompressionType, Tensor
+from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
+from hivemind.utils.asyncio import amap_in_executor
+
+
+T = TypeVar('T')
+DEFAULT_PART_SIZE_BYTES = 2 ** 20
+
+
+class TensorPartContainer:
+    """
+    Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
+    The class is designed to avoid excessive memory allocation and run all heavy computation in background
+    :param tensors: local tensors to be split and aggregated
+    :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
+    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
+    :param prefetch: when compressing, pre-compute this many compressed tensors in background
+    """
+
+    def __init__(self, tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float],
+                 compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
+                 part_size_bytes: int = 2 ** 20, prefetch: int = 1):
+        if not isinstance(compression_type, Sequence):
+            compression_type = [compression_type] * len(tensors)
+        assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
+        self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
+        self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
+        self.total_size = sum(tensor.numel() for tensor in tensors)
+        self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
+        self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
+        self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
+        self._outputs_consumed = False
+        self.finished = asyncio.Event()
+        self.num_parts_by_tensor = []
+
+        # split tensor parts in proportion to target_size_by_peer
+        current_length = 0
+        current_peer_index = 0
+        pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
+        pivots[-1] = self.total_size
+
+        for tensor, tensor_compression in zip(self.local_tensors, compression_type):
+            part_size_values = int(part_size_bytes / get_nbytes_per_value(tensor.dtype, tensor_compression))
+            tensor_parts = tensor.detach().view(-1).split(part_size_values)
+            self.num_parts_by_tensor.append(len(tensor_parts))
+            for part in tensor_parts:
+                if current_length + len(part) > pivots[current_peer_index]:
+                    # switch to next peer; if a part lands between parts of two or
+                    # more peers, assign that part to the peer with highest intersection
+                    prev_peer_index = current_peer_index
+                    peer_intersections = [pivots[current_peer_index] - current_length]
+                    while current_length + len(part) > pivots[current_peer_index]:
+                        current_peer_index += 1
+                        current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
+                        peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
+                    assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
+                    self._input_parts_by_peer[assigned_peer_index].append((part, tensor_compression))
+                else:
+                    self._input_parts_by_peer[current_peer_index].append((part, tensor_compression))
+                current_length += len(part)
+
+        assert current_length == self.total_size
+        self.num_parts_by_peer = tuple(len(parts) for parts in self._input_parts_by_peer)
+
+    @torch.no_grad()
+    def get_raw_input_parts(self, peer_index: int) -> Tuple[torch.Tensor, ...]:
+        """ get non-serialized tensor parts for a peer at a given index """
+        assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
+        self._inputs_consumed_by_peer[peer_index] = True
+        input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
+        self._input_parts_by_peer[peer_index].clear()
+        return input_parts
+
+    @torch.no_grad()
+    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
+        """ iterate serialized tensor parts for a peer at a given index. Run serialization in background. """
+        assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
+        self._inputs_consumed_by_peer[peer_index] = True
+
+        async def _aiterate_parts():
+            for _ in range(self.num_parts_by_peer[peer_index]):
+                yield self._input_parts_by_peer[peer_index].popleft()
+
+        async for serialized_part in amap_in_executor(lambda x_and_compr: serialize_torch_tensor(*x_and_compr),
+                                                      _aiterate_parts(), max_prefetch=self.prefetch):
+            yield serialized_part
+
+    def register_processed_part(self, peer_index: int, part_index: int, part: torch.Tensor):
+        """
+        register next-in-line part of results received from a given peer for use in iterate_output_tensors
+        depending on the algorithm, processed part is an average, difference from average or another aggregation
+        """
+        if part_index != self._outputs_registered_by_peer[peer_index]:
+            raise ValueError(f"Could not register part #{part_index} from peer #{peer_index}, "
+                             f" expected part index: {self._outputs_registered_by_peer[peer_index]}")
+        self._output_parts_by_peer[peer_index].append(part)
+        self._outputs_registered_by_peer[peer_index] += 1
+        self._output_part_available[peer_index].set()
+
+    async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
+        """ iterate over the outputs of averaging (whether they are average, delta or other aggregation result) """
+        assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
+        self._outputs_consumed = True
+        peer_index = num_parts_processed = 0
+        for tensor_index in range(len(self.local_tensors)):
+            tensor_parts = []
+            while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
+                if num_parts_processed >= self.num_parts_by_peer[peer_index]:
+                    num_parts_processed = 0
+                    peer_index += 1
+                    continue
+                if not self._output_parts_by_peer[peer_index]:
+                    self._output_part_available[peer_index].clear()
+                    await self._output_part_available[peer_index].wait()
+                    if self.finished.is_set():
+                        raise AllreduceException("All-reduce was terminated during iteration.")
+
+                tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
+                num_parts_processed += 1
+            tensor = torch.cat(tensor_parts)
+            del tensor_parts
+            yield tensor.reshape(self.local_tensors[tensor_index].shape)
+
+    def __del__(self):
+        self.finalize()
+
+    def finalize(self):
+        """ terminate all iterators, delete intermediate data """
+        if not self.finished.is_set():
+            for peer_index in range(self.group_size):
+                self._inputs_consumed_by_peer[peer_index] = True
+                self._input_parts_by_peer[peer_index].clear()
+                self._output_parts_by_peer[peer_index].clear()
+                self._output_part_available[peer_index].set()
+            self._outputs_consumed = True
+            self.finished.set()
+
+
+class TensorPartReducer:
+    """
+    Auxiliary data structure responsible for running asynchronous all-reduce
+    :param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
+    :param num_senders: total number of peers in a given all-reduce group that will send gradients
+    :param weights: relative importance of each sender, used for weighted average (default = equal weights)
+    :note: even if local peer is not sending data, local parts will be used for shape information
+    """
+
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int,
+                 weights: Optional[Sequence[float]] = None):
+        self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
+        self.weights = tuple(weights or (1 for _ in range(num_senders)))
+        assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
+        assert all(isinstance(weight, (int, float)) for weight in self.weights)
+        self.current_part_index = -1  # index in local_parts of the part that should be loaded next
+        self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
+        self.accumulator = None  # this will contain the sum of current tensor part from group peers
+        self.denominator = 0.0  # total weight accumulated from all peers for current part
+        self.current_part_future = asyncio.Future()
+        self.finished = asyncio.Event()
+        self.reset_accumulators()
+
+    def reset_accumulators(self):
+        """ (re)create averaging buffers for the next part in line, prepopulate with local tensor part """
+        assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
+        if self.current_part_index >= self.num_parts - 1:
+            self.finalize()
+            return
+
+        self.current_part_index += 1
+        self.current_part_accumulated_from = 0
+        self.current_part_future = asyncio.Future()
+        self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
+        self.denominator = 0.0
+
+    async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
+        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
+        assert 0 <= sender_index < self.num_senders, "invalid sender index"
+        assert 0 <= part_index < self.num_parts, "invalid part index"
+
+        while part_index > self.current_part_index:
+            # wait for previous parts to finish processing ...
+            await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
+            if self.finished.is_set():
+                raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
+        assert part_index == self.current_part_index
+
+        current_part_future = self.current_part_future
+
+        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
+        self.denominator += self.weights[sender_index]
+        self.current_part_accumulated_from += 1
+
+        assert self.current_part_accumulated_from <= self.num_senders
+        if self.current_part_accumulated_from == self.num_senders:
+            current_part_future.set_result(self.accumulator.div_(self.denominator))
+            self.reset_accumulators()
+        return await current_part_future
+
+    def finalize(self):
+        if not self.finished.is_set():
+            if hasattr(self, 'current_part_future'):
+                self.current_part_future.cancel()
+                del self.accumulator
+            self.finished.set()
+
+    def __del__(self):
+        self.finalize()
+
+
+class AllreduceException(Exception):
+    """ A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error) """

+ 37 - 3
hivemind/optim/collaborative.py

@@ -191,7 +191,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         with self.lock_local_progress:
         with self.lock_local_progress:
             self.local_samples_accumulated += batch_size
             self.local_samples_accumulated += batch_size
             self.local_steps_accumulated += 1
             self.local_steps_accumulated += 1
-            self.performance_ema.update(num_processed=self.batch_size_per_step)
+            self.performance_ema.update(num_processed=batch_size)
             self.should_report_progress.set()
             self.should_report_progress.set()
 
 
         if not self.collaboration_state.ready_for_step:
         if not self.collaboration_state.ready_for_step:
@@ -232,9 +232,43 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.collaboration_state_updated.set()
             self.collaboration_state_updated.set()
             self.update_scheduler()
             self.update_scheduler()
 
 
-            logger.log(self.status_loglevel, f"Optimizer step: done!")
+        logger.log(self.status_loglevel, f"Optimizer step: done!")
 
 
-            return group_info
+        return group_info
+
+    def step_aux(self, **kwargs):
+        """
+        Find and assist other peers in averaging without sending local gradients.
+
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+
+        if not self.collaboration_state.ready_for_step:
+            return
+
+        logger.log(self.status_loglevel,
+                   f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
+        self.collaboration_state = self.fetch_collaboration_state()
+        self.collaboration_state_updated.set()
+
+        with self.lock_collaboration_state:
+            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
+            current_step, group_info = self.averager.local_step, None
+            try:
+                group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
+                if group_info:
+                    logger.log(self.status_loglevel,
+                               f"Averaged tensors successfully with {len(group_info)} peers")
+            except BaseException as e:
+                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
+
+            self.collaboration_state.register_step(current_step + 1)
+            self.averager.local_step = current_step + 1
+            self.collaboration_state_updated.set()
+
+        logger.log(self.status_loglevel, f"Optimizer step: done!")
+
+        return group_info
 
 
     def _grad_buffers(self) -> Iterator[torch.Tensor]:
     def _grad_buffers(self) -> Iterator[torch.Tensor]:
         """ pytorch-internal gradient buffers """
         """ pytorch-internal gradient buffers """

+ 1 - 0
hivemind/p2p/__init__.py

@@ -0,0 +1 @@
+from hivemind.p2p.p2p_daemon import P2P

+ 377 - 0
hivemind/p2p/p2p_daemon.py

@@ -0,0 +1,377 @@
+import asyncio
+from copy import deepcopy
+from dataclasses import dataclass
+from importlib.resources import path
+from subprocess import Popen
+from typing import List, Optional
+
+import google.protobuf
+from multiaddr import Multiaddr
+
+import hivemind.hivemind_cli as cli
+import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, StreamInfo
+from hivemind.proto import p2pd_pb2
+from hivemind.utils import MSGPackSerializer
+from hivemind.utils.logging import get_logger
+from hivemind.utils.networking import find_open_port
+
+logger = get_logger(__name__)
+
+
+P2PD_FILENAME = 'p2pd'
+NUM_RETRIES = 3
+RETRY_DELAY = 0.4
+
+
+class P2PInterruptedError(Exception):
+    pass
+
+
+@dataclass(frozen=False)
+class P2PContext(object):
+    id: str
+    port: int
+    handle_name: str
+    peer_id: PeerID = None
+    peer_addr: Multiaddr = None
+
+
+class P2P:
+    """
+    Forks a child process and executes p2pd command with given arguments.
+    Can be used for peer to peer communication and procedure calls.
+    Sends SIGKILL to the child in destructor.
+    """
+
+    HEADER_LEN = 8
+    BYTEORDER = 'big'
+    PB_HEADER_LEN = 1
+    RESULT_MESSAGE = b'\x00'
+    ERROR_MESSAGE = b'\x01'
+    DHT_MODE_MAPPING = {
+        'dht': {'dht': 1},
+        'dht_server': {'dhtServer': 1},
+        'dht_client': {'dhtClient': 1},
+    }
+    FORCE_REACHABILITY_MAPPING = {
+        'public': {'forceReachabilityPublic': 1},
+        'private': {'forceReachabilityPrivate': 1},
+    }
+
+    def __init__(self):
+        self._child = None
+        self._alive = False
+        self._listen_task = None
+        self._server_stopped = asyncio.Event()
+
+    @classmethod
+    async def create(cls, *args, quic: bool = True, tls: bool = True, conn_manager: bool = True,
+                     dht_mode: str = 'dht_server', force_reachability: Optional[str] = None,
+                     nat_port_map: bool = True, auto_nat: bool = True, bootstrap: bool = False,
+                     bootstrap_peers: Optional[List[str]] = None, use_global_ipfs: bool = False, host_port: int = None,
+                     daemon_listen_port: int = None, use_relay: bool = True, use_relay_hop: bool = False,
+                     use_relay_discovery: bool = False, use_auto_relay: bool = False, relay_hop_limit: int = 0, **kwargs):
+        """
+        Start a new p2pd process and connect to it.
+        :param args:
+        :param quic: Enables the QUIC transport
+        :param tls: Enables TLS1.3 channel security protocol
+        :param conn_manager: Enables the Connection Manager
+        :param dht_mode: DHT mode (dht_client/dht_server/dht)
+        :param force_reachability: Force reachability mode (public/private)
+        :param nat_port_map: Enables NAT port mapping
+        :param auto_nat: Enables the AutoNAT service
+        :param bootstrap: Connects to bootstrap peers and bootstraps the dht if enabled
+        :param bootstrap_peers: List of bootstrap peers; defaults to the IPFS DHT peers
+        :param use_global_ipfs: Bootstrap to global ipfs (works only if bootstrap=True and bootstrap_peers=None)
+        :param host_port: port for p2p network
+        :param daemon_listen_port: port for connection daemon and client binding
+        :param use_relay: enables circuit relay
+        :param use_relay_hop: enables hop for relay
+        :param use_relay_discovery: enables passive discovery for relay
+        :param use_auto_relay: enables autorelay
+        :param relay_hop_limit: sets the hop limit for hop relays
+        :param kwargs:
+        :return: new wrapper for p2p daemon
+        """
+
+        assert not (bootstrap and bootstrap_peers is None and not use_global_ipfs), \
+            'Trying to create with bootstrap node without bootstrap nodes list. ' \
+            'It is very dangerous, because p2pd connects to global ipfs and it is very unstable. ' \
+            'If you really want this, pass use_global_ipfs=True'
+        assert not (bootstrap_peers is not None and use_global_ipfs), \
+            'Non empty bootstrap_nodes and use_global_ipfs=True are incompatible.' \
+            'Choose one option: your nodes list (preferable) or global ipfs (very unstable)'
+
+        self = cls()
+        with path(cli, P2PD_FILENAME) as p:
+            p2pd_path = p
+        bootstrap_peers = cls._make_bootstrap_peers(bootstrap_peers)
+        dht = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0})
+        force_reachability = cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {})
+        proc_args = self._make_process_args(
+            str(p2pd_path), *args,
+            quic=quic, tls=tls, connManager=conn_manager,
+            natPortMap=nat_port_map, autonat=auto_nat,
+            relay=use_relay, relayHop=use_relay_hop, relayDiscovery=use_relay_discovery,
+            autoRelay=use_auto_relay, relayHopLimit=relay_hop_limit,
+            b=bootstrap, **{**bootstrap_peers, **dht, **force_reachability, **kwargs})
+        self._assign_daemon_ports(host_port, daemon_listen_port)
+
+        for try_count in range(NUM_RETRIES):
+            try:
+                self._initialize(proc_args)
+                await self._wait_for_client(RETRY_DELAY * (2 ** try_count))
+                break
+            except Exception as e:
+                logger.debug(f"Failed to initialize p2p daemon: {e}")
+                self._terminate()
+                if try_count == NUM_RETRIES - 1:
+                    raise
+                self._assign_daemon_ports()
+
+        return self
+
+    @classmethod
+    async def replicate(cls, daemon_listen_port: int, host_port: int):
+        """
+        Connect to existing p2p daemon
+        :param daemon_listen_port: port for connection daemon and client binding
+        :param host_port: port for p2p network
+        :return: new wrapper for existing p2p daemon
+        """
+
+        self = cls()
+        # There is no child under control
+        # Use external already running p2pd
+        self._child = None
+        self._alive = True
+        self._assign_daemon_ports(host_port, daemon_listen_port)
+        self._client_listen_port = find_open_port()
+        self._client = p2pclient.Client(
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
+        await self._wait_for_client()
+        return self
+
+    async def wait_for_at_least_n_peers(self, n_peers, attempts=3, delay=1):
+        for _ in range(attempts):
+            peers = await self._client.list_peers()
+            if len(peers) >= n_peers:
+                return
+            await asyncio.sleep(delay)
+
+        raise RuntimeError('Not enough peers')
+
+    def _initialize(self, proc_args: List[str]) -> None:
+        proc_args = deepcopy(proc_args)
+        proc_args.extend(self._make_process_args(
+            hostAddrs=f'/ip4/0.0.0.0/tcp/{self._host_port},/ip4/0.0.0.0/udp/{self._host_port}/quic',
+            listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'
+        ))
+        self._child = Popen(args=proc_args, encoding="utf8")
+        self._alive = True
+        self._client_listen_port = find_open_port()
+        self._client = p2pclient.Client(
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
+
+    async def _wait_for_client(self, delay=0):
+        await asyncio.sleep(delay)
+        encoded = await self._client.identify()
+        self.id = encoded[0].to_base58()
+
+    def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
+        if host_port is None:
+            host_port = find_open_port()
+        if daemon_listen_port is None:
+            daemon_listen_port = find_open_port()
+            while daemon_listen_port == host_port:
+                daemon_listen_port = find_open_port()
+
+        self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
+
+    @staticmethod
+    async def send_raw_data(byte_str, writer):
+        request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
+        writer.write(request)
+
+    @staticmethod
+    async def send_msgpack(data, writer):
+        raw_data = MSGPackSerializer.dumps(data)
+        await P2P.send_raw_data(raw_data, writer)
+
+    @staticmethod
+    async def send_protobuf(protobuf, out_proto_type, writer):
+        if type(protobuf) != out_proto_type:
+            raise TypeError('Unary handler returned protobuf of wrong type.')
+        if out_proto_type == p2pd_pb2.RPCError:
+            await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
+        else:
+            await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer)
+
+        await P2P.send_raw_data(protobuf.SerializeToString(), writer)
+
+    @staticmethod
+    async def receive_raw_data(reader: asyncio.StreamReader, header_len=HEADER_LEN):
+        header = await reader.readexactly(header_len)
+        content_length = int.from_bytes(header, P2P.BYTEORDER)
+        data = await reader.readexactly(content_length)
+        return data
+
+    @staticmethod
+    async def receive_msgpack(reader):
+        return MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
+
+    @staticmethod
+    async def receive_protobuf(in_proto_type, reader):
+        msg_type = await P2P.receive_raw_data(reader)
+        if msg_type == P2P.RESULT_MESSAGE:
+            protobuf = in_proto_type()
+            protobuf.ParseFromString(await P2P.receive_raw_data(reader))
+            return protobuf, None
+        elif msg_type == P2P.ERROR_MESSAGE:
+            protobuf = p2pd_pb2.RPCError()
+            protobuf.ParseFromString(await P2P.receive_raw_data(reader))
+            return None, protobuf
+        else:
+            raise TypeError('Invalid Protobuf message type')
+
+    @staticmethod
+    def _handle_stream(handle):
+        async def do_handle_stream(stream_info, reader, writer):
+            try:
+                request = await P2P.receive_raw_data(reader)
+            except asyncio.IncompleteReadError:
+                logger.debug("Incomplete read while receiving request from peer")
+                writer.close()
+                return
+            try:
+                result = handle(request)
+                await P2P.send_raw_data(result, writer)
+            finally:
+                writer.close()
+
+        return do_handle_stream
+
+    @staticmethod
+    def _handle_unary_stream(handle, context, in_proto_type, out_proto_type):
+        async def watchdog(reader: asyncio.StreamReader):
+            await reader.read(n=1)
+            raise P2PInterruptedError()
+
+        async def do_handle_unary_stream(
+                stream_info: StreamInfo,
+                reader: asyncio.StreamReader,
+                writer: asyncio.StreamWriter) -> None:
+            try:
+                try:
+                    request = await P2P.receive_protobuf(in_proto_type, reader)
+                except asyncio.IncompleteReadError:
+                    logger.debug("Incomplete read while receiving request from peer")
+                    return
+                except google.protobuf.message.DecodeError as error:
+                    logger.exception(error)
+                    return
+
+                context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr
+                done, pending = await asyncio.wait([watchdog(reader), handle(request, context)],
+                                                   return_when=asyncio.FIRST_COMPLETED)
+                try:
+                    result = done.pop().result()
+                    await P2P.send_protobuf(result, out_proto_type, writer)
+                except P2PInterruptedError:
+                    pass
+                except Exception as exc:
+                    error = p2pd_pb2.RPCError(message=str(exc))
+                    await P2P.send_protobuf(error, p2pd_pb2.RPCError, writer)
+                finally:
+                    pending_task = pending.pop()
+                    pending_task.cancel()
+                    try:
+                        await pending_task
+                    except asyncio.CancelledError:
+                        pass
+            finally:
+                writer.close()
+
+        return do_handle_unary_stream
+
+    def start_listening(self):
+        async def listen():
+            async with self._client.listen():
+                await self._server_stopped.wait()
+
+        self._listen_task = asyncio.create_task(listen())
+
+    async def stop_listening(self):
+        if self._listen_task is not None:
+            self._server_stopped.set()
+            self._listen_task.cancel()
+            try:
+                await self._listen_task
+            except asyncio.CancelledError:
+                self._listen_task = None
+                self._server_stopped.clear()
+
+    async def add_stream_handler(self, name, handle):
+        if self._listen_task is None:
+            self.start_listening()
+        await self._client.stream_handler(name, self._handle_stream(handle))
+
+    async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
+        if self._listen_task is None:
+            self.start_listening()
+        context = P2PContext(id=self.id, port=self._host_port, handle_name=name)
+        await self._client.stream_handler(
+            name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
+
+    async def call_peer_handler(self, peer_id, handler_name, input_data):
+        libp2p_peer_id = PeerID.from_base58(peer_id)
+        stream_info, reader, writer = await self._client.stream_open(libp2p_peer_id, (handler_name,))
+        try:
+            await P2P.send_raw_data(input_data, writer)
+            return await P2P.receive_raw_data(reader)
+        finally:
+            writer.close()
+
+    def __del__(self):
+        self._terminate()
+
+    @property
+    def is_alive(self):
+        return self._alive
+
+    async def shutdown(self):
+        await asyncio.get_event_loop().run_in_executor(None, self._terminate)
+
+    def _terminate(self):
+        self._alive = False
+        if self._child is not None and self._child.poll() is None:
+            self._child.kill()
+            self._child.wait()
+
+    @staticmethod
+    def _make_process_args(*args, **kwargs) -> List[str]:
+        proc_args = []
+        proc_args.extend(
+            str(entry) for entry in args
+        )
+        proc_args.extend(
+            f'-{key}={P2P._convert_process_arg_type(value)}' if value is not None else f'-{key}'
+            for key, value in kwargs.items()
+        )
+        return proc_args
+
+    @staticmethod
+    def _convert_process_arg_type(val):
+        if isinstance(val, bool):
+            return 1 if val else 0
+        return val
+
+    @staticmethod
+    def _make_bootstrap_peers(nodes):
+        if nodes is None:
+            return {}
+        return {'bootstrapPeers': ','.join(nodes)}

+ 0 - 0
hivemind/p2p/p2p_daemon_bindings/__init__.py


+ 210 - 0
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -0,0 +1,210 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import asyncio
+from contextlib import asynccontextmanager
+from typing import (AsyncIterator, Awaitable, Callable, Dict, Iterable,
+                    Sequence, Tuple)
+
+from multiaddr import Multiaddr, protocols
+
+from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo,
+                                                             StreamInfo)
+from hivemind.p2p.p2p_daemon_bindings.utils import (DispatchFailure,
+                                                    raise_if_failed,
+                                                    read_pbmsg_safe,
+                                                    write_pbmsg)
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+from hivemind.utils.logging import get_logger
+
+StreamHandler = Callable[[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]
+
+SUPPORT_CONN_PROTOCOLS = (
+    protocols.P_IP4,
+    # protocols.P_IP6,
+    protocols.P_UNIX,
+)
+SUPPORTED_PROTOS = (
+    protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS
+)
+logger = get_logger(__name__)
+
+
+def parse_conn_protocol(maddr: Multiaddr) -> int:
+    proto_codes = set(proto.code for proto in maddr.protocols())
+    proto_cand = proto_codes.intersection(SUPPORT_CONN_PROTOCOLS)
+    if len(proto_cand) != 1:
+        raise ValueError(
+            f"connection protocol should be only one protocol out of {SUPPORTED_PROTOS}"
+            f", maddr={maddr}"
+        )
+    return tuple(proto_cand)[0]
+
+
+class DaemonConnector:
+    DEFAULT_CONTROL_MADDR = "/unix/tmp/p2pd.sock"
+
+    def __init__(self, control_maddr: Multiaddr = Multiaddr(DEFAULT_CONTROL_MADDR)) -> None:
+        self.control_maddr = control_maddr
+        self.proto_code = parse_conn_protocol(self.control_maddr)
+
+    async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        if self.proto_code == protocols.P_UNIX:
+            control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX)
+            logger.debug(f"DaemonConnector {self} opens connection to {self.control_maddr}")
+            return await asyncio.open_unix_connection(control_path)
+        elif self.proto_code == protocols.P_IP4:
+            host = self.control_maddr.value_for_protocol(protocols.P_IP4)
+            port = int(self.control_maddr.value_for_protocol(protocols.P_TCP))
+            return await asyncio.open_connection(host, port)
+        else:
+            raise ValueError(
+                f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}"
+            )
+
+
+class ControlClient:
+    DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
+
+    def __init__(
+            self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+    ) -> None:
+        self.listen_maddr = listen_maddr
+        self.daemon_connector = daemon_connector
+        self.handlers: Dict[str, StreamHandler] = {}
+
+    async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
+        pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
+        await read_pbmsg_safe(reader, pb_stream_info)
+        stream_info = StreamInfo.from_protobuf(pb_stream_info)
+        logger.debug(f"New incoming stream: {stream_info}")
+        try:
+            handler = self.handlers[stream_info.proto]
+        except KeyError as e:
+            # should never enter here... daemon should reject the stream for us.
+            writer.close()
+            raise DispatchFailure(e)
+        await handler(stream_info, reader, writer)
+
+    @asynccontextmanager
+    async def listen(self) -> AsyncIterator["ControlClient"]:
+        proto_code = parse_conn_protocol(self.listen_maddr)
+        if proto_code == protocols.P_UNIX:
+            listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
+            server = await asyncio.start_unix_server(self._handler, path=listen_path)
+        elif proto_code == protocols.P_IP4:
+            host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
+            port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
+            server = await asyncio.start_server(self._handler, port=port, host=host)
+        else:
+            raise ValueError(
+                f"Protocol not supported: {protocols.protocol_with_code(proto_code)}"
+            )
+
+        async with server:
+            logger.info(f"DaemonConnector {self} starts listening to {self.listen_maddr}")
+            yield self
+
+        logger.info(f"DaemonConnector {self} closed")
+
+    async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
+        reader, writer = await self.daemon_connector.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+
+        raise_if_failed(resp)
+        peer_id_bytes = resp.identify.id
+        maddrs_bytes = resp.identify.addrs
+
+        maddrs = tuple(Multiaddr(maddr_bytes) for maddr_bytes in maddrs_bytes)
+        peer_id = PeerID(peer_id_bytes)
+
+        return peer_id, maddrs
+
+    async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        maddrs_bytes = [i.to_bytes() for i in maddrs]
+        connect_req = p2pd_pb.ConnectRequest(
+            peer=peer_id.to_bytes(), addrs=maddrs_bytes
+        )
+        req = p2pd_pb.Request(type=p2pd_pb.Request.CONNECT, connect=connect_req)
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+    async def list_peers(self) -> Tuple[PeerInfo, ...]:
+        req = p2pd_pb.Request(type=p2pd_pb.Request.LIST_PEERS)
+        reader, writer = await self.daemon_connector.open_connection()
+        await write_pbmsg(writer, req)
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+        peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers)
+        return peers
+
+    async def disconnect(self, peer_id: PeerID) -> None:
+        disconnect_req = p2pd_pb.DisconnectRequest(peer=peer_id.to_bytes())
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req
+        )
+        reader, writer = await self.daemon_connector.open_connection()
+        await write_pbmsg(writer, req)
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+    async def stream_open(
+        self, peer_id: PeerID, protocols: Sequence[str]
+    ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        stream_open_req = p2pd_pb.StreamOpenRequest(
+            peer=peer_id.to_bytes(), proto=list(protocols)
+        )
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req
+        )
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        raise_if_failed(resp)
+
+        pb_stream_info = resp.streamInfo
+        stream_info = StreamInfo.from_protobuf(pb_stream_info)
+
+        return stream_info, reader, writer
+
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        listen_path_maddr_bytes = self.listen_maddr.to_bytes()
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(
+            addr=listen_path_maddr_bytes, proto=[proto]
+        )
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req
+        )
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+        # if success, add the handler to the dict
+        self.handlers[proto] = handler_cb

+ 170 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -0,0 +1,170 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import hashlib
+from typing import Any, Sequence, Union
+
+import base58
+import multihash
+from multiaddr import Multiaddr, protocols
+
+from hivemind.proto import p2pd_pb2
+
+# NOTE: On inlining...
+# See: https://github.com/libp2p/specs/issues/138
+# NOTE: enabling to be interoperable w/ the Go implementation
+ENABLE_INLINING = True
+MAX_INLINE_KEY_LENGTH = 42
+
+IDENTITY_MULTIHASH_CODE = 0x00
+
+if ENABLE_INLINING:
+
+    class IdentityHash:
+        def __init__(self) -> None:
+            self._digest = bytearray()
+
+        def update(self, input: bytes) -> None:
+            self._digest += input
+
+        def digest(self) -> bytes:
+            return self._digest
+
+    multihash.FuncReg.register(
+        IDENTITY_MULTIHASH_CODE, "identity", hash_new=IdentityHash
+    )
+
+
+class PeerID:
+    def __init__(self, peer_id_bytes: bytes) -> None:
+        self._bytes = peer_id_bytes
+        self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
+        self._b58_str = base58.b58encode(self._bytes).decode()
+
+    @property
+    def xor_id(self) -> int:
+        return self._xor_id
+
+    def to_bytes(self) -> bytes:
+        return self._bytes
+
+    def to_base58(self) -> str:
+        return self._b58_str
+
+    def __repr__(self) -> str:
+        return f"<libp2p.peer.id.ID ({self.to_base58()})>"
+
+    def __str__(self):
+        return self.to_base58()
+
+    def pretty(self):
+        return self.to_base58()
+
+    def to_string(self):
+        return self.to_base58()
+
+    def __eq__(self, other: object) -> bool:
+        if isinstance(other, str):
+            return self.to_base58() == other
+        elif isinstance(other, bytes):
+            return self._bytes == other
+        elif isinstance(other, PeerID):
+            return self._bytes == other._bytes
+        else:
+            return False
+
+    def __hash__(self) -> int:
+        return hash(self._bytes)
+
+    @classmethod
+    def from_base58(cls, base58_id: str) -> "PeerID":
+        peer_id_bytes = base58.b58decode(base58_id)
+        return cls(peer_id_bytes)
+
+
+def sha256_digest(data: Union[str, bytes]) -> bytes:
+    if isinstance(data, str):
+        data = data.encode("utf8")
+    return hashlib.sha256(data).digest()
+
+
+class StreamInfo:
+    def __init__(self, peer_id: PeerID, addr: Multiaddr, proto: str) -> None:
+        self.peer_id = peer_id
+        self.addr = addr
+        self.proto = proto
+
+    def __repr__(self) -> str:
+        return (
+            f"<StreamInfo peer_id={self.peer_id} addr={self.addr} proto={self.proto}>"
+        )
+
+    def to_protobuf(self) -> p2pd_pb2.StreamInfo:
+        pb_msg = p2pd_pb2.StreamInfo(
+            peer=self.peer_id.to_bytes(), addr=self.addr.to_bytes(), proto=self.proto
+        )
+        return pb_msg
+
+    @classmethod
+    def from_protobuf(cls, pb_msg: p2pd_pb2.StreamInfo) -> "StreamInfo":
+        stream_info = cls(
+            peer_id=PeerID(pb_msg.peer), addr=Multiaddr(pb_msg.addr), proto=pb_msg.proto
+        )
+        return stream_info
+
+
+class PeerInfo:
+    def __init__(self, peer_id: PeerID, addrs: Sequence[Multiaddr]) -> None:
+        self.peer_id = peer_id
+        self.addrs = list(addrs)
+
+    def __eq__(self, other: Any) -> bool:
+        return (
+            isinstance(other, PeerInfo)
+            and self.peer_id == other.peer_id
+            and self.addrs == other.addrs
+        )
+
+    @classmethod
+    def from_protobuf(cls, peer_info_pb: p2pd_pb2.PeerInfo) -> "PeerInfo":
+        peer_id = PeerID(peer_info_pb.id)
+        addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
+        return PeerInfo(peer_id, addrs)
+
+    def __str__(self):
+        return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
+
+
+class InvalidAddrError(ValueError):
+    pass
+
+
+def info_from_p2p_addr(addr: Multiaddr) -> PeerInfo:
+    if addr is None:
+        raise InvalidAddrError("`addr` should not be `None`")
+
+    parts = addr.split()
+    if not parts:
+        raise InvalidAddrError(
+            f"`parts`={parts} should at least have a protocol `P_P2P`"
+        )
+
+    p2p_part = parts[-1]
+    last_protocol_code = p2p_part.protocols()[0].code
+    if last_protocol_code != protocols.P_P2P:
+        raise InvalidAddrError(
+            f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`"
+        )
+
+    # make sure the /p2p value parses as a peer.ID
+    peer_id_str: str = p2p_part.value_for_protocol(protocols.P_P2P)
+    peer_id = PeerID.from_base58(peer_id_str)
+
+    # we might have received just an / p2p part, which means there's no addr.
+    if len(parts) > 1:
+        addr = Multiaddr.join(*parts[:-1])
+
+    return PeerInfo(peer_id, [addr])

+ 85 - 0
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -0,0 +1,85 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import asyncio
+from contextlib import asynccontextmanager
+from typing import AsyncIterator, Iterable, Sequence, Tuple
+
+from multiaddr import Multiaddr
+
+from hivemind.p2p.p2p_daemon_bindings.control import (ControlClient,
+                                                      DaemonConnector,
+                                                      StreamHandler)
+from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo,
+                                                             StreamInfo)
+
+
+class Client:
+    control: ControlClient
+
+    def __init__(
+        self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None
+    ) -> None:
+        daemon_connector = DaemonConnector(control_maddr=control_maddr)
+        self.control = ControlClient(
+            daemon_connector=daemon_connector, listen_maddr=listen_maddr
+        )
+
+    @asynccontextmanager
+    async def listen(self) -> AsyncIterator["Client"]:
+        """
+        Starts to listen incoming connections for handlers registered via stream_handler.
+        :return:
+        """
+        async with self.control.listen():
+            yield self
+
+    async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
+        """
+        Get current node peer id and list of addresses
+        """
+        return await self.control.identify()
+
+    async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
+        """
+        Connect to p2p node with specified addresses and peer id.
+        :peer_id: node peer id you want connect to
+        :maddrs: node multiaddresses you want connect to. Of course, it must be reachable.
+        """
+        await self.control.connect(peer_id=peer_id, maddrs=maddrs)
+
+    async def list_peers(self) -> Tuple[PeerInfo, ...]:
+        """
+        Get list of peers that node connect to
+        """
+        return await self.control.list_peers()
+
+    async def disconnect(self, peer_id: PeerID) -> None:
+        """
+        Disconnect from node with specified peer id
+        :peer_id: node peer id you want disconnect from
+        """
+        await self.control.disconnect(peer_id=peer_id)
+
+    async def stream_open(
+        self, peer_id: PeerID, protocols: Sequence[str]
+    ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
+        """
+        Open a stream to call other peer (with peer_id) handler for specified protocols
+        :peer_id: other peer id
+        :protocols: list of protocols for other peer handling
+        :return: Returns tuple of stream info (info about connection to second peer) and reader/writer
+        """
+        return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
+
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+        """
+        Register a stream handler
+        :param proto: protocols that handler serves
+        :param handler_cb: handler callback
+        :return:
+        """
+        await self.control.stream_handler(proto=proto, handler_cb=handler_cb)

+ 73 - 0
hivemind/p2p/p2p_daemon_bindings/utils.py

@@ -0,0 +1,73 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import asyncio
+
+from google.protobuf.message import Message as PBMessage
+
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+
+DEFAULT_MAX_BITS: int = 64
+
+
+class ControlFailure(Exception):
+    pass
+
+
+class DispatchFailure(Exception):
+    pass
+
+
+async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_bits: int = DEFAULT_MAX_BITS) -> None:
+    max_int = 1 << max_bits
+    if integer < 0:
+        raise ValueError(f"negative integer: {integer}")
+    if integer >= max_int:
+        raise ValueError(f"integer too large: {integer}")
+    while True:
+        value = integer & 0x7F
+        integer >>= 7
+        if integer != 0:
+            value |= 0x80
+        byte = value.to_bytes(1, "big")
+        stream.write(byte)
+        if integer == 0:
+            break
+
+
+async def read_unsigned_varint(stream: asyncio.StreamReader, max_bits: int = DEFAULT_MAX_BITS) -> int:
+    max_int = 1 << max_bits
+    iteration = 0
+    result = 0
+    has_next = True
+    while has_next:
+        data = await stream.readexactly(1)
+        c = data[0]
+        value = c & 0x7F
+        result |= value << (iteration * 7)
+        has_next = (c & 0x80) != 0
+        iteration += 1
+        if result >= max_int:
+            raise ValueError(f"Varint overflowed: {result}")
+    return result
+
+
+def raise_if_failed(response: p2pd_pb.Response) -> None:
+    if response.type == p2pd_pb.Response.ERROR:
+        raise ControlFailure(f"Connect failed. msg={response.error.msg}")
+
+
+async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None:
+    size = pbmsg.ByteSize()
+    await write_unsigned_varint(stream, size)
+    msg_bytes: bytes = pbmsg.SerializeToString()
+    stream.write(msg_bytes)
+
+
+async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None:
+    len_msg_bytes = await read_unsigned_varint(stream)
+    msg_bytes = await stream.readexactly(len_msg_bytes)
+    pbmsg.ParseFromString(msg_bytes)

+ 1 - 1
hivemind/proto/averaging.proto

@@ -43,7 +43,7 @@ message MessageFromLeader {
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
   string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
   string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
   repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
-  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endoints
+  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endpoints
 }
 }
 
 
 message AveragingData {
 message AveragingData {

+ 166 - 0
hivemind/proto/p2pd.proto

@@ -0,0 +1,166 @@
+//Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+//Licence: MIT
+//Author: Kevin Mai-Husan Chia
+
+syntax = "proto2";
+
+package p2pclient.p2pd.pb;
+
+message Request {
+  enum Type {
+    IDENTIFY       = 0;
+    CONNECT        = 1;
+    STREAM_OPEN    = 2;
+    STREAM_HANDLER = 3;
+    DHT            = 4;
+    LIST_PEERS     = 5;
+    CONNMANAGER    = 6;
+    DISCONNECT     = 7;
+    PUBSUB         = 8;
+  }
+
+  required Type type = 1;
+
+  optional ConnectRequest connect = 2;
+  optional StreamOpenRequest streamOpen = 3;
+  optional StreamHandlerRequest streamHandler = 4;
+  optional DHTRequest dht = 5;
+  optional ConnManagerRequest connManager = 6;
+  optional DisconnectRequest disconnect = 7;
+  optional PSRequest pubsub = 8;
+}
+
+message Response {
+  enum Type {
+    OK    = 0;
+    ERROR = 1;
+  }
+
+  required Type type = 1;
+  optional ErrorResponse error = 2;
+  optional StreamInfo streamInfo = 3;
+  optional IdentifyResponse identify = 4;
+  optional DHTResponse dht = 5;
+  repeated PeerInfo peers = 6;
+  optional PSResponse pubsub = 7;
+}
+
+message IdentifyResponse {
+  required bytes id = 1;
+  repeated bytes addrs = 2;
+}
+
+message ConnectRequest {
+  required bytes peer = 1;
+  repeated bytes addrs = 2;
+  optional int64 timeout = 3;
+}
+
+message StreamOpenRequest {
+  required bytes peer = 1;
+  repeated string proto = 2;
+  optional int64 timeout = 3;
+}
+
+message StreamHandlerRequest {
+  required bytes addr = 1;
+  repeated string proto = 2;
+}
+
+message ErrorResponse {
+  required string msg = 1;
+}
+
+message StreamInfo {
+  required bytes peer = 1;
+  required bytes addr = 2;
+  required string proto = 3;
+}
+
+message DHTRequest {
+  enum Type {
+    FIND_PEER                    = 0;
+    FIND_PEERS_CONNECTED_TO_PEER = 1;
+    FIND_PROVIDERS               = 2;
+    GET_CLOSEST_PEERS            = 3;
+    GET_PUBLIC_KEY               = 4;
+    GET_VALUE                    = 5;
+    SEARCH_VALUE                 = 6;
+    PUT_VALUE                    = 7;
+    PROVIDE                      = 8;
+  }
+
+  required Type type = 1;
+  optional bytes peer = 2;
+  optional bytes cid = 3;
+  optional bytes key = 4;
+  optional bytes value = 5;
+  optional int32 count = 6;
+  optional int64 timeout = 7;
+}
+
+message DHTResponse {
+  enum Type {
+    BEGIN = 0;
+    VALUE = 1;
+    END   = 2;
+  }
+
+  required Type type = 1;
+  optional PeerInfo peer = 2;
+  optional bytes value = 3;
+}
+
+message PeerInfo {
+  required bytes id = 1;
+  repeated bytes addrs = 2;
+}
+
+message ConnManagerRequest {
+  enum Type {
+    TAG_PEER        = 0;
+    UNTAG_PEER      = 1;
+    TRIM            = 2;
+  }
+
+  required Type type = 1;
+
+  optional bytes peer = 2;
+  optional string tag = 3;
+  optional int64 weight = 4;
+}
+
+message DisconnectRequest {
+  required bytes peer = 1;
+}
+
+message PSRequest {
+  enum Type {
+    GET_TOPICS = 0;
+    LIST_PEERS = 1;
+    PUBLISH    = 2;
+    SUBSCRIBE  = 3;
+  }
+
+  required Type type = 1;
+  optional string topic = 2;
+  optional bytes data = 3;
+}
+
+message PSMessage {
+  optional bytes from_id = 1;
+  optional bytes data = 2;
+  optional bytes seqno = 3;
+  repeated string topicIDs = 4;
+  optional bytes signature = 5;
+  optional bytes key = 6;
+}
+
+message PSResponse {
+  repeated string topics = 1;
+  repeated bytes peerIDs = 2;
+}
+
+message RPCError {
+  required string message = 1;
+}

+ 1 - 1
hivemind/server/runtime.py

@@ -118,7 +118,7 @@ class Runtime(threading.Thread):
         with DefaultSelector() as selector:
         with DefaultSelector() as selector:
             for pool in self.pools:
             for pool in self.pools:
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
                 selector.register(pool.batch_receiver, EVENT_READ, pool)
-            # selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
+            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
 
 
             while True:
             while True:
                 # wait until at least one batch_receiver becomes available
                 # wait until at least one batch_receiver becomes available

+ 49 - 1
hivemind/utils/asyncio.py

@@ -1,7 +1,14 @@
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable
+from concurrent.futures import ThreadPoolExecutor
+from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable, Tuple, Optional, Callable
 import asyncio
 import asyncio
+
 import uvloop
 import uvloop
+
+from hivemind.utils.logging import get_logger
+
+
 T = TypeVar('T')
 T = TypeVar('T')
+logger = get_logger(__name__)
 
 
 
 
 def switch_to_uvloop() -> asyncio.AbstractEventLoop:
 def switch_to_uvloop() -> asyncio.AbstractEventLoop:
@@ -27,6 +34,16 @@ async def aiter(*args: T) -> AsyncIterator[T]:
         yield arg
         yield arg
 
 
 
 
+async def azip(*iterables: AsyncIterable[T]) -> AsyncIterator[Tuple[T, ...]]:
+    """ equivalent of zip for asynchronous iterables """
+    iterators = [iterable.__aiter__() for iterable in iterables]
+    while True:
+        try:
+            yield tuple(await asyncio.gather(*(itr.__anext__() for itr in iterators)))
+        except StopAsyncIteration:
+            break
+
+
 async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
 async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
     """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
     """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
     for aiter in async_iters:
     for aiter in async_iters:
@@ -34,6 +51,14 @@ async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
             yield elem
             yield elem
 
 
 
 
+async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]]:
+    """ equivalent to enumerate(iter) for asynchronous iterators. """
+    index = 0
+    async for elem in aiterable:
+        yield index, elem
+        index += 1
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
     try:
         await awaitable
         await awaitable
@@ -42,3 +67,26 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
         return True
         return True
     except BaseException:
     except BaseException:
         return False
         return False
+
+
+async def amap_in_executor(func: Callable[..., T], *iterables: AsyncIterable, max_prefetch: Optional[int] = None,
+                           executor: Optional[ThreadPoolExecutor] = None) -> AsyncIterator[T]:
+    """ iterate from an async iterable in a background thread, yield results to async iterable """
+    loop = asyncio.get_event_loop()
+    queue = asyncio.Queue(max_prefetch)
+
+    async def _put_items():
+        async for args in azip(*iterables):
+            await queue.put(loop.run_in_executor(executor, func, *args))
+        await queue.put(None)
+
+    task = asyncio.create_task(_put_items())
+    try:
+        future = await queue.get()
+        while future is not None:
+            yield await future
+            future = await queue.get()
+        await task
+    finally:
+        if not task.done():
+            task.cancel()

+ 14 - 1
hivemind/utils/compression.py

@@ -8,7 +8,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.threading import run_in_background
 from hivemind.utils.threading import run_in_background
 
 
-FP16_MAX = 65_504
+FP32_EPS = 1e-06
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT16 = 2
 NUM_BYTES_FLOAT16 = 2
 NUM_BITS_QUANTILE_COMPRESSION = 8
 NUM_BITS_QUANTILE_COMPRESSION = 8
@@ -86,6 +86,7 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
         tensor.sub_(means)
         tensor.sub_(means)
 
 
         stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
         stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
+        stds.clamp_min_(FP32_EPS)
         tensor.div_(stds)
         tensor.div_(stds)
         tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
         tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
 
 
@@ -187,3 +188,15 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
 
 
     tensor.requires_grad_(serialized_tensor.requires_grad)
     tensor.requires_grad_(serialized_tensor.requires_grad)
     return tensor
     return tensor
+
+
+def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
+    """ returns the number of bytes per value for a given tensor (excluding metadata) """
+    if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
+        return 1
+    elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):
+        return 2
+    elif compression == CompressionType.NONE:
+        return torch.finfo(dtype).bits // 8
+    else:
+        raise NotImplementedError(f"Unknown compression type: {CompressionType.Name(compression)}")

+ 5 - 1
hivemind/utils/grpc.py

@@ -158,7 +158,11 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
 
 
 
-def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int) -> Iterator[runtime_pb2.Tensor]:
+STREAMING_CHUNK_SIZE_BYTES = 2 ** 16
+
+
+def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
+                        ) -> Iterator[runtime_pb2.Tensor]:
     """ Split serialized_tensor into multiple chunks for gRPC streaming """
     """ Split serialized_tensor into multiple chunks for gRPC streaming """
     buffer = memoryview(serialized_tensor.buffer)
     buffer = memoryview(serialized_tensor.buffer)
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))

+ 1 - 1
hivemind/utils/threading.py

@@ -12,7 +12,7 @@ def run_in_background(func: callable, *args, **kwargs) -> Future:
     """ run func(*args, **kwargs) in background and return Future for its outputs """
     """ run func(*args, **kwargs) in background and return Future for its outputs """
     global EXECUTOR_PID, GLOBAL_EXECUTOR
     global EXECUTOR_PID, GLOBAL_EXECUTOR
     if os.getpid() != EXECUTOR_PID:
     if os.getpid() != EXECUTOR_PID:
-        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=float(os.environ.get("HIVEMIND_THREADS", 'inf')))
+        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("HIVEMIND_THREADS", 128)))
         EXECUTOR_PID = os.getpid()
         EXECUTOR_PID = os.getpid()
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
 
 

+ 1 - 0
requirements-dev.txt

@@ -1,6 +1,7 @@
 pytest
 pytest
 pytest-forked
 pytest-forked
 pytest-asyncio
 pytest-asyncio
+pytest-cov
 codecov
 codecov
 tqdm
 tqdm
 scikit-learn
 scikit-learn

+ 2 - 0
requirements.txt

@@ -10,5 +10,7 @@ grpcio>=1.33.2
 grpcio-tools>=1.33.2
 grpcio-tools>=1.33.2
 protobuf>=3.12.2
 protobuf>=3.12.2
 configargparse>=1.2.3
 configargparse>=1.2.3
+multiaddr>=0.0.9
+pymultihash>=0.8.2
 cryptography>=3.4.6
 cryptography>=3.4.6
 pydantic>=1.8.1
 pydantic>=1.8.1

+ 80 - 12
setup.py

@@ -1,12 +1,32 @@
 import codecs
 import codecs
 import glob
 import glob
+import hashlib
 import os
 import os
 import re
 import re
-
-from pkg_resources import parse_requirements
-from setuptools import setup, find_packages
+import shlex
+import subprocess
+import tarfile
+import tempfile
+import urllib.request
+
+from pkg_resources import parse_requirements, parse_version
+from setuptools import find_packages, setup
+from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 from setuptools.command.develop import develop
-from setuptools.command.install import install
+
+P2PD_VERSION = 'v0.3.1'
+P2PD_CHECKSUM = '15292b880c6b31f5b3c36084b3acc17f'
+LIBP2P_TAR_URL = f'https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz'
+
+here = os.path.abspath(os.path.dirname(__file__))
+
+
+def md5(fname, chunk_size=4096):
+    hash_md5 = hashlib.md5()
+    with open(fname, "rb") as f:
+        for chunk in iter(lambda: f.read(chunk_size), b""):
+            hash_md5.update(chunk)
+    return hash_md5.hexdigest()
 
 
 
 
 def proto_compile(output_path):
 def proto_compile(output_path):
@@ -28,20 +48,68 @@ def proto_compile(output_path):
             file.truncate()
             file.truncate()
 
 
 
 
-class ProtoCompileInstall(install):
+def build_p2p_daemon():
+    result = subprocess.run("go version", capture_output=True, shell=True).stdout.decode('ascii', 'replace')
+    m = re.search(r'^go version go([\d.]+)', result)
+
+    if m is None:
+        raise FileNotFoundError('Could not find golang installation')
+    version = parse_version(m.group(1))
+    if version < parse_version("1.13"):
+        raise EnvironmentError(f'Newer version of go required: must be >= 1.13, found {version}')
+
+    with tempfile.TemporaryDirectory() as tempdir:
+        dest = os.path.join(tempdir, 'libp2p-daemon.tar.gz')
+        urllib.request.urlretrieve(LIBP2P_TAR_URL, dest)
+
+        with tarfile.open(dest, 'r:gz') as tar:
+            tar.extractall(tempdir)
+
+        result = subprocess.run(f'go build -o {shlex.quote(os.path.join(here, "hivemind", "hivemind_cli", "p2pd"))}',
+                                cwd=os.path.join(tempdir, f'go-libp2p-daemon-{P2PD_VERSION[1:]}', 'p2pd'), shell=True)
+
+        if result.returncode:
+            raise RuntimeError('Failed to build or install libp2p-daemon:'
+                               f' exited with status code: {result.returncode}')
+
+
+def download_p2p_daemon():
+    install_path = os.path.join(here, 'hivemind', 'hivemind_cli')
+    binary_path = os.path.join(install_path, 'p2pd')
+    if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
+        print('Downloading Peer to Peer Daemon')
+        url = f'https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd'
+        urllib.request.urlretrieve(url, binary_path)
+        os.chmod(binary_path, 0o777)
+        if md5(binary_path) != P2PD_CHECKSUM:
+            raise RuntimeError(f'Downloaded p2pd binary from {url} does not match with md5 checksum')
+
+
+class BuildPy(build_py):
+    user_options = build_py.user_options + [('buildgo', None, "Builds p2pd from source")]
+
+    def initialize_options(self):
+        super().initialize_options()
+        self.buildgo = False
+
     def run(self):
     def run(self):
-        proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto'))
+        if self.buildgo:
+            build_p2p_daemon()
+        else:
+            download_p2p_daemon()
+
         super().run()
         super().run()
 
 
+        proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto'))
+
 
 
-class ProtoCompileDevelop(develop):
+class Develop(develop):
     def run(self):
     def run(self):
-        proto_compile(os.path.join('hivemind', 'proto'))
+        self.reinitialize_command('build_py', build_lib=here)
+        self.run_command('build_py')
         super().run()
         super().run()
 
 
 
 
-here = os.path.abspath(os.path.dirname(__file__))
-
 with open('requirements.txt') as requirements_file:
 with open('requirements.txt') as requirements_file:
     install_requires = list(map(str, parse_requirements(requirements_file)))
     install_requires = list(map(str, parse_requirements(requirements_file)))
 
 
@@ -63,7 +131,7 @@ extras['all'] = extras['dev'] + extras['docs']
 setup(
 setup(
     name='hivemind',
     name='hivemind',
     version=version_string,
     version=version_string,
-    cmdclass={'install': ProtoCompileInstall, 'develop': ProtoCompileDevelop},
+    cmdclass={'build_py': BuildPy, 'develop': Develop},
     description='Decentralized deep learning in PyTorch',
     description='Decentralized deep learning in PyTorch',
     long_description='Decentralized deep learning in PyTorch. Built to train giant models on '
     long_description='Decentralized deep learning in PyTorch. Built to train giant models on '
                      'thousands of volunteers across the world.',
                      'thousands of volunteers across the world.',
@@ -71,7 +139,7 @@ setup(
     author_email='mryabinin0@gmail.com',
     author_email='mryabinin0@gmail.com',
     url="https://github.com/learning-at-home/hivemind",
     url="https://github.com/learning-at-home/hivemind",
     packages=find_packages(exclude=['tests']),
     packages=find_packages(exclude=['tests']),
-    package_data={'hivemind': ['proto/*']},
+    package_data={'hivemind': ['proto/*', 'hivemind_cli/*']},
     include_package_data=True,
     include_package_data=True,
     license='MIT',
     license='MIT',
     setup_requires=['grpcio-tools'],
     setup_requires=['grpcio-tools'],

+ 217 - 0
tests/test_allreduce.py

@@ -0,0 +1,217 @@
+import asyncio
+import random
+import time
+from typing import Sequence
+
+import pytest
+import torch
+import grpc
+
+from hivemind import aenumerate, Endpoint
+from hivemind.client.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer
+from hivemind.utils import deserialize_torch_tensor, ChannelCache
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.proto import averaging_pb2_grpc
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning():
+    all_tensors = [
+        torch.randn(30_000, 128), torch.rand(128), torch.ones(1, 1, 1, 1, 1, 1, 8),
+        torch.ones(1, 0), torch.zeros(0), torch.zeros([]), torch.randn(65536),
+        torch.rand(512, 2048), torch.randn(1024, 1024).add(-9), torch.zeros(1020), torch.randn(4096)
+    ]
+
+    # note: this test does _not_ use parameterization to reuse sampled tensors
+    for num_tensors in 1, 3, 5:
+        for part_size_bytes in 31337, 2 ** 20, 10 ** 10:
+            for weights in [(1, 1), (0.333, 0.1667, 0.5003), (1.0, 0.0), [0.0, 0.4, 0.6, 0.0]]:
+                tensors = random.choices(all_tensors, k=num_tensors)
+                partition = TensorPartContainer(tensors, weights, part_size_bytes=part_size_bytes)
+
+                async def write_tensors():
+                    for peer_index in range(partition.group_size):
+                        async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+                            output_tensor = torch.sin(deserialize_torch_tensor(part))
+                            partition.register_processed_part(peer_index, part_index, output_tensor)
+
+                task = asyncio.create_task(write_tensors())
+                tensor_index = 0
+                async for output_tensor in partition.iterate_output_tensors():
+                    assert torch.allclose(output_tensor, torch.sin(tensors[tensor_index]))
+                    tensor_index += 1
+                assert tensor_index == len(tensors)
+                await task
+
+
+@pytest.mark.parametrize("tensors", [[torch.zeros(0)], [torch.zeros(0), torch.zeros(0), torch.zeros(1)],
+                                     [torch.zeros(0), torch.zeros(999), torch.zeros(0), torch.zeros(0)]])
+@pytest.mark.parametrize("peer_fractions", [(0.33, 0.44, 0.23), (0.5, 0.5), (0.1, 0.0, 0.9), (1.0,), (0.1,) * 9])
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning_edge_cases(tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float]):
+    partition = TensorPartContainer(tensors, peer_fractions, part_size_bytes=16)
+    for peer_index in range(len(peer_fractions)):
+        async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+            partition.register_processed_part(peer_index, part_index, deserialize_torch_tensor(part))
+
+    tensor_index = 0
+    async for output_tensor in partition.iterate_output_tensors():
+        assert torch.allclose(output_tensor, tensors[tensor_index])
+        tensor_index += 1
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_partitioning_asynchronous():
+    """ ensure that tensor partitioning does not interfere with asynchronous code """
+    tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096),
+               torch.randn(4096, 1024), torch.randn(30_000, 1024)]
+    peer_fractions = [0.4, 0.3, 0.2, 0.1]
+
+    partition = TensorPartContainer(tensors, peer_fractions, compression_type=CompressionType.QUANTILE_8BIT)
+    read_started, read_finished = asyncio.Event(), asyncio.Event()
+
+    async def write_tensors():
+        for peer_index in range(partition.group_size):
+            async for part_index, part in aenumerate(partition.iterate_input_parts_for(peer_index)):
+                partition.register_processed_part(peer_index, part_index, deserialize_torch_tensor(part))
+        assert read_started.is_set(), "partitioner should have started reading before it finished writing"
+
+    async def read_tensors():
+        async for _ in partition.iterate_output_tensors():
+            read_started.set()
+        read_finished.set()
+
+    async def wait_synchronously():
+        time_in_waiting = 0.0
+        while not read_finished.is_set():
+            await asyncio.sleep(0.01)
+            time_in_waiting += 0.01
+        return time_in_waiting
+
+    start_time = time.perf_counter()
+    *_, time_in_waiting = await asyncio.gather(write_tensors(), read_tensors(), wait_synchronously())
+    wall_time = time.perf_counter() - start_time
+    # check that event loop had enough time to respond to incoming requests; this is over 50% most of the time
+    # we set 33% threshold to ensure that the test will pass reliably. If we break prefetch, this drops to <10%
+    assert time_in_waiting > wall_time / 3, f"Event loop could only run {time_in_waiting / wall_time :.5f} of the time"
+
+
+@pytest.mark.parametrize("num_senders", [1, 2, 4, 10])
+@pytest.mark.parametrize("num_parts", [0, 1, 100])
+@pytest.mark.parametrize("synchronize_prob", [1.0, 0.1, 0.0])
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
+    tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
+    reducer = TensorPartReducer(tensor_part_shapes, num_senders)
+
+    local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)]
+                               for j in range(num_senders)]
+
+    async def send_tensors(sender_index: int):
+        local_tensors = local_tensors_by_sender[sender_index]
+        averaged_parts = []
+        pending_tasks = []
+
+        for part_index in range(num_parts):
+            pending_tasks.append(asyncio.create_task(
+                reducer.accumulate_part(sender_index, part_index, local_tensors[part_index])))
+
+            if random.random() < synchronize_prob or part_index == num_parts - 1:
+                averaged_parts.extend(await asyncio.gather(*pending_tasks))
+                pending_tasks = []
+        return averaged_parts
+
+    averaged_tensors_by_peer = await asyncio.gather(*map(send_tensors, range(num_senders)))
+
+    reference = [sum(local_tensors_by_sender[sender_index][part_index]
+                     for sender_index in range(num_senders)) / num_senders
+                 for part_index in range(num_parts)]
+
+    for averaged_tensors in averaged_tensors_by_peer:
+        assert len(averaged_tensors) == len(reference)
+        for averaging_result, reference_tensor in zip(averaged_tensors, reference):
+            assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
+
+
+class AllreduceRunnerForTesting(AllReduceRunner):
+    """ a version of AllReduceRunner that was monkey-patched to accept custom endpoint names """
+    def __init__(self, *args, peer_endpoints, **kwargs):
+        self.__peer_endpoints = peer_endpoints
+        super().__init__(*args, **kwargs)
+
+    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
+        return ChannelCache.get_stub(
+            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+
+
+NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
+
+
+@pytest.mark.parametrize("peer_modes, averaging_weights, peer_fractions", [
+    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1)),
+    ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1)),
+    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0)),
+    ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0)),
+    ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4)),
+    ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1)),
+    ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0)),
+    ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4)),
+])
+@pytest.mark.parametrize("part_size_bytes", [2 ** 20, 256, 19],)
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
+    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
+
+    peers = "alice", "bob", "carol", "colab"
+
+    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
+                       for i, peer in enumerate(peers)}
+
+    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
+
+    servers = []
+    allreduce_protocols = []
+    peer_endpoints = {}
+
+    for peer in peers:
+        server = grpc.aio.server()
+        allreduce_protocol = AllreduceRunnerForTesting(
+            group_id=group_id, endpoint=peer, tensors=[x.clone() for x in tensors_by_peer[peer]],
+            ordered_group_endpoints=peers, peer_fractions=peer_fractions, modes=peer_modes,
+            weights=averaging_weights, peer_endpoints=peer_endpoints, part_size_bytes=part_size_bytes
+        )
+        averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(allreduce_protocol, server)
+        peer_endpoints[peer] = f"127.0.0.1:{server.add_insecure_port('127.0.0.1:*')}"
+        allreduce_protocols.append(allreduce_protocol)
+        servers.append(server)
+        await server.start()
+
+    async def _run_allreduce_inplace(allreduce: AllReduceRunner):
+        async for tensor_index, tensor_delta in aenumerate(allreduce):
+            allreduce.tensor_part_container.local_tensors[tensor_index].add_(tensor_delta)
+
+    await asyncio.gather(*map(_run_allreduce_inplace, allreduce_protocols))
+
+    reference_tensors = [sum(tensors_by_peer[peer][i] * averaging_weights[peer_index]
+                             for peer_index, peer in enumerate(peers)) / sum(averaging_weights)
+                         for i in range(len(tensors_by_peer[peers[0]]))]
+
+    for peer_index, protocol in enumerate(allreduce_protocols):
+        assert protocol._future.done()
+        if protocol.modes[peer_index] != AveragingMode.AUX:
+            targets_for_peer = reference_tensors
+        else:
+            targets_for_peer = tensors_by_peer[peers[peer_index]]
+        output_tensors = protocol.tensor_part_container.local_tensors
+        assert len(output_tensors) == len(targets_for_peer)
+        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
+                   for our, ref in zip(output_tensors, targets_for_peer))
+
+    for server in servers:
+        await server.stop(grace=1)

+ 2 - 2
tests/test_auth.py

@@ -1,12 +1,12 @@
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-from typing import Optional, Tuple
+from typing import Optional
 
 
 import pytest
 import pytest
 
 
 from hivemind.proto import dht_pb2
 from hivemind.proto import dht_pb2
 from hivemind.proto.auth_pb2 import AccessToken
 from hivemind.proto.auth_pb2 import AccessToken
 from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
 from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
-from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 
 

+ 76 - 77
tests/test_averaging.py

@@ -1,4 +1,3 @@
-import asyncio
 import random
 import random
 
 
 import numpy as np
 import numpy as np
@@ -6,10 +5,10 @@ import torch
 import pytest
 import pytest
 import time
 import time
 import hivemind
 import hivemind
-from hivemind.client.averaging.allreduce import AllReduceProtocol, split_into_parts, restore_from_parts
+from hivemind.client.averaging.allreduce import AveragingMode
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.key_manager import GroupKeyManager
 from hivemind.client.averaging.key_manager import GroupKeyManager
-from hivemind.utils import Endpoint
+from hivemind.proto.runtime_pb2 import CompressionType
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -42,26 +41,26 @@ async def test_key_manager():
     assert len(q5) == 0
     assert len(q5) == 0
 
 
 
 
-@pytest.mark.forked
-@pytest.mark.parametrize("n_client_mode_peers", [0, 2])
-def test_allreduce_once(n_client_mode_peers):
+def _test_allreduce_once(n_clients, n_aux):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
 
 
     n_peers = 4
     n_peers = 4
-    should_listen = [False] * n_client_mode_peers + [True] * (n_peers - n_client_mode_peers)
-    random.shuffle(should_listen)
+    modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
+    random.shuffle(modes)
 
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors1 = [torch.randn(123), torch.zeros(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
+    peer_tensors = [tensors1, tensors2, tensors3, tensors4]
 
 
-    reference = [(tensors1[i] + tensors2[i] + tensors3[i] + tensors4[i]) / 4 for i in range(len(tensors1))]
+    reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
+                 if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
 
 
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
     averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
-                                                start=True)
-                 for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
+                                                prefix='mygroup', listen=mode != AveragingMode.CLIENT, listen_on='127.0.0.1:*',
+                                                auxiliary=mode == AveragingMode.AUX, start=True)
+                 for tensors, mode in zip(peer_tensors, modes)]
 
 
     futures = []
     futures = []
     for averager in averagers:
     for averager in averagers:
@@ -72,15 +71,29 @@ def test_allreduce_once(n_client_mode_peers):
             assert averager.endpoint in result
             assert averager.endpoint in result
 
 
     for averager in averagers:
     for averager in averagers:
-        with averager.get_tensors() as averaged_tensors:
-            for ref, our in zip(reference, averaged_tensors):
-                assert torch.allclose(ref, our, atol=1e-6)
+        if averager.mode != AveragingMode.AUX:
+            with averager.get_tensors() as averaged_tensors:
+                for ref, our in zip(reference, averaged_tensors):
+                    assert torch.allclose(ref, our, atol=1e-6)
 
 
     for averager in averagers:
     for averager in averagers:
         averager.shutdown()
         averager.shutdown()
     dht.shutdown()
     dht.shutdown()
 
 
 
 
+@pytest.mark.forked
+@pytest.mark.parametrize("n_clients", [0, 1, 2])
+@pytest.mark.parametrize("n_aux", [0, 1, 2])
+def test_allreduce_once(n_clients, n_aux):
+    _test_allreduce_once(n_clients, n_aux)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize("n_clients, n_aux", [(0, 4), (1, 3), (0, 3)])
+def test_allreduce_once_edge_cases(n_clients, n_aux):
+    _test_allreduce_once(n_clients, n_aux)
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
@@ -117,6 +130,47 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     dht.shutdown()
     dht.shutdown()
 
 
 
 
+@pytest.mark.forked
+def test_allreduce_compression():
+    """ this test ensures that compression works correctly when multiple tensors have different compression types """
+    dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
+
+    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
+    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
+    results = {}
+
+    FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
+
+    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        averager1 = hivemind.DecentralizedAverager([x.clone() for x in tensors1], dht=dht,
+                                                   compression_type=compression_type_pair, listen=False,
+                                                   target_group_size=2, prefix='mygroup', start=True)
+        averager2 = hivemind.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
+                                                   compression_type=compression_type_pair,
+                                                   target_group_size=2, prefix='mygroup', start=True)
+
+        for future in averager1.step(wait=False), averager2.step(wait=False):
+            future.result()
+
+        with averager1.get_tensors() as averaged_tensors:
+            results[compression_type_pair] = averaged_tensors
+
+    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
+    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
+    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
+    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
+
+    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
+    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
+    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
+    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
+
+    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
+    for i in range(2):
+        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
+        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
+
+
 def compute_mean_std(averagers, unbiased=True):
 def compute_mean_std(averagers, unbiased=True):
     results = []
     results = []
     for averager in averagers:
     for averager in averagers:
@@ -188,68 +242,6 @@ def test_allgather():
     dht.shutdown()
     dht.shutdown()
 
 
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_allreduce_protocol():
-    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
-    peers = "alice", "bob", "carol", "colab"
-
-    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
-                       for i, peer in enumerate(peers)}
-
-    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
-    allreduce_protocols = [AllReduceProtocol(
-        group_id=group_id, endpoint=peer, tensors=tensors_by_peer[peer],
-        ordered_group_endpoints=peers, part_sizes=(150, 200, 67, 0))
-        for peer in peers]
-
-    async def _accumulate(sender: Endpoint, recipient: Endpoint):
-        sender_allreduce = allreduce_protocols[peers.index(sender)]
-        recipient_allreduce = allreduce_protocols[peers.index(recipient)]
-        averaged_part = await recipient_allreduce.accumulate_part(
-            source=sender, remote_part=sender_allreduce.local_tensor_parts[recipient])
-        sender_allreduce.register_averaged_part(source=recipient, averaged_part=averaged_part)
-
-    await asyncio.wait({_accumulate(sender, recipient) for sender in peers for recipient in peers
-                        if recipient != "colab"})
-
-    reference_tensors = [
-        sum(tensors_by_peer[peer][i] for peer in peers) / len(peers)
-        for i in range(len(tensors_by_peer[peers[0]]))
-    ]
-
-    for peer, allreduce in zip(peers, allreduce_protocols):
-        assert allreduce.future.done()
-        averaged_tensors = await allreduce
-        assert len(averaged_tensors) == len(reference_tensors)
-        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
-                   for our, ref in zip(averaged_tensors, reference_tensors))
-
-
-@pytest.mark.forked
-def test_partitioning():
-    for _ in range(100):
-        tensors = []
-        for _ in range(random.randint(1, 5)):
-            ndim = random.randint(0, 4)
-            shape = torch.Size([random.randint(0, 16) for _ in range(ndim)])
-            make_tensor = random.choice([torch.rand, torch.randn, torch.zeros, torch.ones])
-            tensors.append(make_tensor(shape))
-
-        total_size = sum(map(torch.Tensor.numel, tensors))
-        if total_size == 0:
-            continue
-        num_chunks = random.randint(1, min(100, sum(x.numel() for x in tensors)))
-        part_sizes = load_balance_peers(total_size, [None] * num_chunks)
-        chunks = split_into_parts(tensors, part_sizes)
-        assert len(chunks) == num_chunks
-        shapes = [tensor.shape for tensor in tensors]
-        restored = restore_from_parts(chunks, shapes)
-        assert len(restored) == len(tensors)
-        assert all(new.shape == old.shape for new, old in zip(restored, tensors))
-        assert all(torch.allclose(new, old) for new, old in zip(restored, tensors))
-
-
 def get_cost(vector_size, partitions, throughputs):
 def get_cost(vector_size, partitions, throughputs):
     return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
     return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
                for i in range(len(partitions)))
                for i in range(len(partitions)))
@@ -370,6 +362,13 @@ def test_load_state_from_peers():
     assert got_metadata == super_metadata
     assert got_metadata == super_metadata
     assert all(map(torch.allclose, got_tensors, super_tensors))
     assert all(map(torch.allclose, got_tensors, super_tensors))
 
 
+    averager1.allow_state_sharing = False
+    assert averager2.load_state_from_peers() is None
+    averager1.allow_state_sharing = True
+    got_metadata, got_tensors = averager2.load_state_from_peers()
+    assert num_calls == 3
+    assert got_metadata == super_metadata
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_getset_bits():
 def test_getset_bits():

+ 2 - 4
tests/test_dht_schema.py

@@ -1,13 +1,11 @@
-import re
-
 import pytest
 import pytest
-from pydantic import BaseModel, StrictFloat, StrictInt, conint
+from pydantic import BaseModel, StrictInt, conint
 from typing import Dict
 from typing import Dict
 
 
 import hivemind
 import hivemind
 from hivemind.dht import get_dht_time
 from hivemind.dht import get_dht_time
 from hivemind.dht.node import DHTNode, LOCALHOST
 from hivemind.dht.node import DHTNode, LOCALHOST
-from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator, conbytes
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 
 
 
 

+ 1 - 2
tests/test_dht_validation.py

@@ -1,5 +1,4 @@
 import dataclasses
 import dataclasses
-from functools import partial
 from typing import Dict
 from typing import Dict
 
 
 import pytest
 import pytest
@@ -10,7 +9,7 @@ from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID
 from hivemind.dht.routing import DHTID
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.dht.validation import DHTRecord, CompositeValidator, RecordValidatorBase
+from hivemind.dht.validation import DHTRecord, CompositeValidator
 
 
 
 
 class SchemaA(BaseModel):
 class SchemaA(BaseModel):

+ 440 - 0
tests/test_p2p_daemon.py

@@ -0,0 +1,440 @@
+import asyncio
+import multiprocessing as mp
+import subprocess
+from functools import partial
+from typing import List
+
+import numpy as np
+import pytest
+import torch
+
+from hivemind.p2p import P2P
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
+from hivemind.proto import dht_pb2, runtime_pb2
+from hivemind.utils import MSGPackSerializer
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+
+
+def is_process_running(pid: int) -> bool:
+    return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
+
+
+async def replicate_if_needed(p2p: P2P, replicate: bool):
+    return await P2P.replicate(p2p._daemon_listen_port, p2p._host_port) if replicate else p2p
+
+
+def bootstrap_addr(host_port, id_):
+    return f'/ip4/127.0.0.1/tcp/{host_port}/p2p/{id_}'
+
+
+def bootstrap_from(daemons: List[P2P]) -> List[str]:
+    return [bootstrap_addr(d._host_port, d.id) for d in daemons]
+
+
+@pytest.mark.asyncio
+async def test_daemon_killed_on_del():
+    p2p_daemon = await P2P.create()
+
+    child_pid = p2p_daemon._child.pid
+    assert is_process_running(child_pid)
+
+    await p2p_daemon.shutdown()
+    assert not is_process_running(child_pid)
+
+
+@pytest.mark.asyncio
+async def test_server_client_connection():
+    server = await P2P.create()
+    peers = await server._client.list_peers()
+    assert len(peers) == 0
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    await client.wait_for_at_least_n_peers(1)
+
+    peers = await client._client.list_peers()
+    assert len(peers) == 1
+    peers = await server._client.list_peers()
+    assert len(peers) == 1
+
+
+@pytest.mark.asyncio
+async def test_daemon_replica_does_not_affect_primary():
+    p2p_daemon = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon._host_port)
+
+    child_pid = p2p_daemon._child.pid
+    assert is_process_running(child_pid)
+
+    await p2p_replica.shutdown()
+    assert is_process_running(child_pid)
+
+    await p2p_daemon.shutdown()
+    assert not is_process_running(child_pid)
+
+
+def handle_square(x):
+    x = MSGPackSerializer.loads(x)
+    return MSGPackSerializer.dumps(x ** 2)
+
+
+def handle_add(args):
+    args = MSGPackSerializer.loads(args)
+    result = args[0]
+    for i in range(1, len(args)):
+        result = result + args[i]
+    return MSGPackSerializer.dumps(result)
+
+
+def handle_square_torch(x):
+    tensor = runtime_pb2.Tensor()
+    tensor.ParseFromString(x)
+    tensor = deserialize_torch_tensor(tensor)
+    result = tensor ** 2
+    return serialize_torch_tensor(result).SerializeToString()
+
+
+def handle_add_torch(args):
+    args = MSGPackSerializer.loads(args)
+    tensor = runtime_pb2.Tensor()
+    tensor.ParseFromString(args[0])
+    result = deserialize_torch_tensor(tensor)
+
+    for i in range(1, len(args)):
+        tensor = runtime_pb2.Tensor()
+        tensor.ParseFromString(args[i])
+        result = result + deserialize_torch_tensor(tensor)
+
+    return serialize_torch_tensor(result).SerializeToString()
+
+
+def handle_add_torch_with_exc(args):
+    try:
+        return handle_add_torch(args)
+    except Exception:
+        return b'something went wrong :('
+
+
+@pytest.mark.parametrize(
+    'should_cancel,replicate', [
+        (True, False),
+        (True, True),
+        (False, False),
+        (False, True),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
+    handler_cancelled = False
+
+    async def ping_handler(request, context):
+        try:
+            await asyncio.sleep(2)
+        except asyncio.CancelledError:
+            nonlocal handler_cancelled
+            handler_cancelled = True
+        return dht_pb2.PingResponse(
+            peer=dht_pb2.NodeInfo(
+                node_id=context.id.encode(), rpc_port=context.port),
+            sender_endpoint=context.handle_name, available=True)
+
+    server_primary = await P2P.create()
+    server = await replicate_if_needed(server_primary, replicate)
+    server_pid = server_primary._child.pid
+    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
+                                   dht_pb2.PingResponse)
+    assert is_process_running(server_pid)
+
+    nodes = bootstrap_from([server])
+    client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client = await replicate_if_needed(client_primary, replicate)
+    client_pid = client_primary._child.pid
+    assert is_process_running(client_pid)
+
+    ping_request = dht_pb2.PingRequest(
+        peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
+        validate=True)
+    expected_response = dht_pb2.PingResponse(
+        peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
+        sender_endpoint=handle_name, available=True)
+
+    await client.wait_for_at_least_n_peers(1)
+    libp2p_server_id = PeerID.from_base58(server.id)
+    stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
+
+    await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
+
+    if should_cancel:
+        writer.close()
+        await asyncio.sleep(1)
+        assert handler_cancelled
+    else:
+        result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
+        assert err is None
+        assert result == expected_response
+        assert not handler_cancelled
+
+    await server.stop_listening()
+    await server_primary.shutdown()
+    assert not is_process_running(server_pid)
+
+    await client_primary.shutdown()
+    assert not is_process_running(client_pid)
+
+
+@pytest.mark.asyncio
+async def test_call_unary_handler_error(handle_name="handle"):
+    async def error_handler(request, context):
+        raise ValueError('boom')
+
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
+    assert is_process_running(server_pid)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+    await client.wait_for_at_least_n_peers(1)
+
+    ping_request = dht_pb2.PingRequest(
+        peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
+        validate=True)
+    libp2p_server_id = PeerID.from_base58(server.id)
+    stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
+
+    await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
+    result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
+    assert result is None
+    assert err.message == 'boom'
+
+    await server.stop_listening()
+    await server.shutdown()
+    await client.shutdown()
+
+
+@pytest.mark.parametrize(
+    "test_input,expected,handle",
+    [
+        pytest.param(10, 100, handle_square, id="square_integer"),
+        pytest.param((1, 2), 3, handle_add, id="add_integers"),
+        pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
+        pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda")
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_stream_handler(handler_name, handle)
+    assert is_process_running(server_pid)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    test_input_msgp = MSGPackSerializer.dumps(test_input)
+    result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
+    result = MSGPackSerializer.loads(result_msgp)
+    assert result == expected
+
+    await server.stop_listening()
+    await server.shutdown()
+    assert not is_process_running(server_pid)
+
+    await client.shutdown()
+    assert not is_process_running(client_pid)
+
+
+async def run_server(handler_name, server_side, client_side, response_received):
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_stream_handler(handler_name, handle_square)
+    assert is_process_running(server_pid)
+
+    server_side.send(server.id)
+    server_side.send(server._host_port)
+    while response_received.value == 0:
+        await asyncio.sleep(0.5)
+
+    await server.stop_listening()
+    await server.shutdown()
+    assert not is_process_running(server_pid)
+
+
+def server_target(handler_name, server_side, client_side, response_received):
+    asyncio.run(run_server(handler_name, server_side, client_side, response_received))
+
+
+@pytest.mark.asyncio
+async def test_call_peer_different_processes():
+    handler_name = "square"
+    test_input = 2
+
+    server_side, client_side = mp.Pipe()
+    response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
+    response_received.value = 0
+
+    proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
+    proc.start()
+
+    peer_id = client_side.recv()
+    peer_port = client_side.recv()
+
+    nodes = [bootstrap_addr(peer_port, peer_id)]
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    test_input_msgp = MSGPackSerializer.dumps(2)
+    result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
+    result = MSGPackSerializer.loads(result_msgp)
+    assert np.allclose(result, test_input ** 2)
+    response_received.value = 1
+
+    await client.shutdown()
+    assert not is_process_running(client_pid)
+
+    proc.join()
+
+
+@pytest.mark.parametrize(
+    "test_input,expected",
+    [
+        pytest.param(torch.tensor([2]), torch.tensor(4)),
+        pytest.param(
+            torch.tensor([[1.0, 2.0], [0.5, 0.1]]),
+            torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
+    handle = handle_square_torch
+    server = await P2P.create()
+    await server.add_stream_handler(handler_name, handle)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    inp = serialize_torch_tensor(test_input).SerializeToString()
+    result_pb = await client.call_peer_handler(server.id, handler_name, inp)
+    result = runtime_pb2.Tensor()
+    result.ParseFromString(result_pb)
+    result = deserialize_torch_tensor(result)
+    assert torch.allclose(result, expected)
+
+    await server.stop_listening()
+    await server.shutdown()
+    await client.shutdown()
+
+
+@pytest.mark.parametrize(
+    "test_input,expected",
+    [
+        pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
+        pytest.param(
+            [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
+            torch.tensor([[1.2, 1.4], [1.6, 1.8]])),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
+    handle = handle_add_torch
+    server = await P2P.create()
+    await server.add_stream_handler(handler_name, handle)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
+    inp_msgp = MSGPackSerializer.dumps(inp)
+    result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
+    result = runtime_pb2.Tensor()
+    result.ParseFromString(result_pb)
+    result = deserialize_torch_tensor(result)
+    assert torch.allclose(result, expected)
+
+    await server.stop_listening()
+    await server.shutdown()
+    await client.shutdown()
+
+
+@pytest.mark.parametrize(
+    "replicate",
+    [
+        pytest.param(False, id="primary"),
+        pytest.param(True, id="replica"),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_error(replicate, handler_name="handle"):
+    server_primary = await P2P.create()
+    server = await replicate_if_needed(server_primary, replicate)
+    await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
+
+    nodes = bootstrap_from([server])
+    client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client = await replicate_if_needed(client_primary, replicate)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
+    inp_msgp = MSGPackSerializer.dumps(inp)
+    result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
+    assert result == b'something went wrong :('
+
+    await server.stop_listening()
+    await server_primary.shutdown()
+    await client_primary.shutdown()
+
+
+@pytest.mark.asyncio
+async def test_handlers_on_different_replicas(handler_name="handle"):
+    def handler(arg, key):
+        return key
+
+    server_primary = await P2P.create(bootstrap=False)
+    server_id = server_primary.id
+    await server_primary.add_stream_handler(handler_name, partial(handler, key=b'primary'))
+
+    server_replica1 = await replicate_if_needed(server_primary, True)
+    await server_replica1.add_stream_handler(handler_name + '1', partial(handler, key=b'replica1'))
+
+    server_replica2 = await replicate_if_needed(server_primary, True)
+    await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2'))
+
+    nodes = bootstrap_from([server_primary])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    await client.wait_for_at_least_n_peers(1)
+
+    result = await client.call_peer_handler(server_id, handler_name, b'1')
+    assert result == b"primary"
+
+    result = await client.call_peer_handler(server_id, handler_name + '1', b'2')
+    assert result == b"replica1"
+
+    result = await client.call_peer_handler(server_id, handler_name + '2', b'3')
+    assert result == b"replica2"
+
+    await server_replica1.stop_listening()
+    await server_replica2.stop_listening()
+
+    # Primary does not handle replicas protocols
+    with pytest.raises(Exception):
+        await client.call_peer_handler(server_id, handler_name + '1', b'')
+    with pytest.raises(Exception):
+        await client.call_peer_handler(server_id, handler_name + '2', b'')
+
+    await server_primary.stop_listening()
+    await server_primary.shutdown()
+    await client.shutdown()

+ 559 - 0
tests/test_p2p_daemon_bindings.py

@@ -0,0 +1,559 @@
+import asyncio
+import io
+from contextlib import AsyncExitStack
+
+import pytest
+from google.protobuf.message import EncodeError
+from multiaddr import Multiaddr, protocols
+
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, parse_conn_protocol
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import (ControlFailure, raise_if_failed, read_pbmsg_safe,
+                                                    read_unsigned_varint, write_pbmsg, write_unsigned_varint)
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+from test_utils import make_p2pd_pair_ip4, connect_safe
+
+
+def test_raise_if_failed_raises():
+    resp = p2pd_pb.Response()
+    resp.type = p2pd_pb.Response.ERROR
+    with pytest.raises(ControlFailure):
+        raise_if_failed(resp)
+
+
+def test_raise_if_failed_not_raises():
+    resp = p2pd_pb.Response()
+    resp.type = p2pd_pb.Response.OK
+    raise_if_failed(resp)
+
+
+PAIRS_INT_SERIALIZED_VALID = (
+    (0, b"\x00"),
+    (1, b"\x01"),
+    (128, b"\x80\x01"),
+    (2 ** 32, b"\x80\x80\x80\x80\x10"),
+    (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
+)
+
+PAIRS_INT_SERIALIZED_OVERFLOW = (
+    (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (
+        2 ** 128,
+        b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04",
+    ),
+)
+
+PEER_ID_STRING = "QmS5QmciTXXnCUCyxud5eWFenUMAmvAWSDa1c7dvdXRMZ7"
+PEER_ID_BYTES = b'\x12 7\x87F.[\xb5\xb1o\xe5*\xc7\xb9\xbb\x11:"Z|j2\x8ad\x1b\xa6\xe5<Ip\xfe\xb4\xf5v'
+PEER_ID = PeerID(PEER_ID_BYTES)
+MADDR = Multiaddr("/unix/123")
+NUM_P2PDS = 4
+PEER_ID_RANDOM = PeerID.from_base58("QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNK1")
+ENABLE_CONTROL = True
+ENABLE_CONNMGR = False
+ENABLE_DHT = False
+ENABLE_PUBSUB = False
+FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_ip4
+
+
+class MockReader(io.BytesIO):
+    async def readexactly(self, n):
+        await asyncio.sleep(0)
+        return self.read(n)
+
+
+class MockWriter(io.BytesIO):
+    pass
+
+
+class MockReaderWriter(MockReader, MockWriter):
+    pass
+
+
+@pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
+@pytest.mark.asyncio
+async def test_write_unsigned_varint(integer, serialized_integer):
+    s = MockWriter()
+    await write_unsigned_varint(s, integer)
+    assert s.getvalue() == serialized_integer
+
+
+@pytest.mark.parametrize("integer", tuple(i[0] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
+@pytest.mark.asyncio
+async def test_write_unsigned_varint_overflow(integer):
+    s = MockWriter()
+    with pytest.raises(ValueError):
+        await write_unsigned_varint(s, integer)
+
+
+@pytest.mark.parametrize("integer", (-1, -(2 ** 32), -(2 ** 64), -(2 ** 128)))
+@pytest.mark.asyncio
+async def test_write_unsigned_varint_negative(integer):
+    s = MockWriter()
+    with pytest.raises(ValueError):
+        await write_unsigned_varint(s, integer)
+
+
+@pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
+@pytest.mark.asyncio
+async def test_read_unsigned_varint(integer, serialized_integer):
+    s = MockReader(serialized_integer)
+    result = await read_unsigned_varint(s)
+    assert result == integer
+
+
+@pytest.mark.parametrize("serialized_integer", tuple(i[1] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
+@pytest.mark.asyncio
+async def test_read_unsigned_varint_overflow(serialized_integer):
+    s = MockReader(serialized_integer)
+    with pytest.raises(ValueError):
+        await read_unsigned_varint(s)
+
+
+@pytest.mark.parametrize("max_bits", (2, 31, 32, 63, 64, 127, 128))
+@pytest.mark.asyncio
+async def test_read_write_unsigned_varint_max_bits_edge(max_bits):
+    """
+    Test edge cases with different `max_bits`
+    """
+    for i in range(-3, 0):
+        integer = i + (2 ** max_bits)
+        s = MockReaderWriter()
+        await write_unsigned_varint(s, integer, max_bits=max_bits)
+        s.seek(0, 0)
+        result = await read_unsigned_varint(s, max_bits=max_bits)
+        assert integer == result
+
+
+def test_peer_id():
+    assert PEER_ID.to_bytes() == PEER_ID_BYTES
+    assert PEER_ID.to_string() == PEER_ID_STRING
+
+    peer_id_2 = PeerID.from_base58(PEER_ID_STRING)
+    assert peer_id_2.to_bytes() == PEER_ID_BYTES
+    assert peer_id_2.to_string() == PEER_ID_STRING
+    assert PEER_ID == peer_id_2
+    peer_id_3 = PeerID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
+    assert PEER_ID != peer_id_3
+
+
+def test_stream_info():
+    proto = "123"
+    si = StreamInfo(PEER_ID, MADDR, proto)
+    assert si.peer_id == PEER_ID
+    assert si.addr == MADDR
+    assert si.proto == proto
+    pb_si = si.to_protobuf()
+    assert pb_si.peer == PEER_ID.to_bytes()
+    assert pb_si.addr == MADDR.to_bytes()
+    assert pb_si.proto == si.proto
+    si_1 = StreamInfo.from_protobuf(pb_si)
+    assert si_1.peer_id == PEER_ID
+    assert si_1.addr == MADDR
+    assert si_1.proto == proto
+
+
+def test_peer_info():
+    pi = PeerInfo(PEER_ID, [MADDR])
+    assert pi.peer_id == PEER_ID
+    assert pi.addrs == [MADDR]
+    pi_pb = p2pd_pb.PeerInfo(id=PEER_ID.to_bytes(), addrs=[MADDR.to_bytes()])
+    pi_1 = PeerInfo.from_protobuf(pi_pb)
+    assert pi.peer_id == pi_1.peer_id
+    assert pi.addrs == pi_1.addrs
+
+
+@pytest.mark.parametrize(
+    "maddr_str, expected_proto",
+    (("/unix/123", protocols.P_UNIX), ("/ip4/127.0.0.1/tcp/7777", protocols.P_IP4)),
+)
+def test_parse_conn_protocol_valid(maddr_str, expected_proto):
+    assert parse_conn_protocol(Multiaddr(maddr_str)) == expected_proto
+
+
+@pytest.mark.parametrize(
+    "maddr_str",
+    (
+        "/p2p/QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP",
+        "/onion/timaq4ygg2iegci7:1234",
+    ),
+)
+def test_parse_conn_protocol_invalid(maddr_str):
+    maddr = Multiaddr(maddr_str)
+    with pytest.raises(ValueError):
+        parse_conn_protocol(maddr)
+
+
+@pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
+def test_client_ctor_control_maddr(control_maddr_str):
+    c = DaemonConnector(Multiaddr(control_maddr_str))
+    assert c.control_maddr == Multiaddr(control_maddr_str)
+
+
+def test_client_ctor_default_control_maddr():
+    c = DaemonConnector()
+    assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
+
+
+@pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
+def test_control_client_ctor_listen_maddr(listen_maddr_str):
+    c = ControlClient(
+        daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str)
+    )
+    assert c.listen_maddr == Multiaddr(listen_maddr_str)
+
+
+def test_control_client_ctor_default_listen_maddr():
+    c = ControlClient(daemon_connector=DaemonConnector())
+    assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
+
+
+@pytest.mark.parametrize(
+    "msg_bytes",
+    (
+        p2pd_pb.Response(
+            type=p2pd_pb.Response.Type.OK,
+            identify=p2pd_pb.IdentifyResponse(
+                id=PeerID.from_base58('QmT7WhTne9zBLfAgAJt9aiZ8jZ5BxJGowRubxsHYmnyzUd').to_bytes(),
+                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51126').to_bytes(),
+                       Multiaddr('/ip4/192.168.10.135/tcp/51126').to_bytes(),
+                       Multiaddr('/ip6/::1/tcp/51127').to_bytes()]
+            )).SerializeToString(),
+        p2pd_pb.Response(
+            type=p2pd_pb.Response.Type.OK,
+            identify=p2pd_pb.IdentifyResponse(
+                id=PeerID.from_base58('QmcQFt2MFfCZ9AxzUCNrk4k7TtMdZZvAAteaA6tHpBKdrk').to_bytes(),
+                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51493').to_bytes(),
+                       Multiaddr('/ip4/192.168.10.135/tcp/51493').to_bytes(),
+                       Multiaddr('/ip6/::1/tcp/51494').to_bytes()]
+            )).SerializeToString(),
+        p2pd_pb.Response(
+            type=p2pd_pb.Response.Type.OK,
+            identify=p2pd_pb.IdentifyResponse(
+                id=PeerID.from_base58('QmbWqVVoz7v9LS9ZUQAhyyfdFJY3iU8ZrUY3XQozoTA5cc').to_bytes(),
+                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51552').to_bytes(),
+                       Multiaddr('/ip4/192.168.10.135/tcp/51552').to_bytes(),
+                       Multiaddr('/ip6/::1/tcp/51553').to_bytes()]
+            )).SerializeToString(),
+    ),
+    # give test cases ids to prevent bytes from ruining the terminal
+    ids=("pb example Response 0", "pb example Response 1", "pb example Response 2"),
+)
+@pytest.mark.asyncio
+async def test_read_pbmsg_safe_valid(msg_bytes):
+    s = MockReaderWriter()
+    await write_unsigned_varint(s, len(msg_bytes))
+    s.write(msg_bytes)
+    # reset the offset back to the beginning
+    s.seek(0, 0)
+    pb_msg = p2pd_pb.Response()
+    await read_pbmsg_safe(s, pb_msg)
+    assert pb_msg.SerializeToString() == msg_bytes
+
+
+@pytest.mark.parametrize(
+    "pb_type, pb_msg",
+    (
+        (
+            p2pd_pb.Response,
+            p2pd_pb.Response(
+                type=p2pd_pb.Response.Type.OK,
+                dht=p2pd_pb.DHTResponse(
+                    type=p2pd_pb.DHTResponse.Type.VALUE,
+                    peer=p2pd_pb.PeerInfo(
+                        id=PeerID.from_base58('QmNaXUy78W9moQ9APCoKaTtPjLcEJPN9hRBCqErY7o2fQs').to_bytes(),
+                        addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56929').to_bytes(),
+                               Multiaddr('/ip4/192.168.10.135/tcp/56929').to_bytes(),
+                               Multiaddr('/ip6/::1/tcp/56930').to_bytes()]
+                    )
+                )
+            ),
+        ),
+        (p2pd_pb.Request, p2pd_pb.Request(type=p2pd_pb.Request.Type.LIST_PEERS)),
+        (
+            p2pd_pb.DHTRequest,
+            p2pd_pb.DHTRequest(type=p2pd_pb.DHTRequest.Type.FIND_PEER,
+                               peer=PeerID.from_base58('QmcgHMuEhqdLHDVeNjiCGU7Ds6E7xK3f4amgiwHNPKKn7R').to_bytes()),
+        ),
+        (
+            p2pd_pb.DHTResponse,
+            p2pd_pb.DHTResponse(
+                type=p2pd_pb.DHTResponse.Type.VALUE,
+                peer=p2pd_pb.PeerInfo(
+                    id=PeerID.from_base58('QmWP32GhEyXVQsLXFvV81eadDC8zQRZxZvJK359rXxLquk').to_bytes(),
+                    addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56897').to_bytes(),
+                           Multiaddr('/ip4/192.168.10.135/tcp/56897').to_bytes(),
+                           Multiaddr('/ip6/::1/tcp/56898').to_bytes()]
+                )
+            ),
+        ),
+        (
+            p2pd_pb.StreamInfo,
+            p2pd_pb.StreamInfo(peer=PeerID.from_base58('QmewLxB46MftfxQiunRgJo2W8nW4Lh5NLEkRohkHhJ4wW6').to_bytes(),
+                               addr=Multiaddr('/ip4/127.0.0.1/tcp/57029').to_bytes(),
+                               proto=b'protocol123'),
+        ),
+    ),
+    ids=(
+        "pb example Response",
+        "pb example Request",
+        "pb example DHTRequest",
+        "pb example DHTResponse",
+        "pb example StreamInfo",
+    ),
+)
+@pytest.mark.asyncio
+async def test_write_pbmsg(pb_type, pb_msg):
+    msg_bytes = bytes(chr(pb_msg.ByteSize()), 'utf-8') + pb_msg.SerializeToString()
+    pb_obj = pb_type()
+
+    s_read = MockReaderWriter(msg_bytes)
+    await read_pbmsg_safe(s_read, pb_obj)
+    s_write = MockReaderWriter()
+    await write_pbmsg(s_write, pb_obj)
+    assert msg_bytes == s_write.getvalue()
+
+
+@pytest.mark.parametrize(
+    "pb_msg",
+    (
+        p2pd_pb.Response(),
+        p2pd_pb.Request(),
+        p2pd_pb.DHTRequest(),
+        p2pd_pb.DHTResponse(),
+        p2pd_pb.StreamInfo(),
+    ),
+)
+@pytest.mark.asyncio
+async def test_write_pbmsg_missing_fields(pb_msg):
+    with pytest.raises(EncodeError):
+        await write_pbmsg(MockReaderWriter(), pb_msg)
+
+
+@pytest.fixture
+async def p2pcs():
+    # TODO: Change back to gather style
+    async with AsyncExitStack() as stack:
+        p2pd_tuples = [
+            await stack.enter_async_context(
+                FUNC_MAKE_P2PD_PAIR(
+                    enable_control=ENABLE_CONTROL,
+                    enable_connmgr=ENABLE_CONNMGR,
+                    enable_dht=ENABLE_DHT,
+                    enable_pubsub=ENABLE_PUBSUB,
+                )
+            )
+            for _ in range(NUM_P2PDS)
+        ]
+        yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
+
+
+@pytest.mark.asyncio
+async def test_client_identify_unix_socket(p2pcs):
+    await p2pcs[0].identify()
+
+
+@pytest.mark.asyncio
+async def test_client_identify(p2pcs):
+    await p2pcs[0].identify()
+
+
+@pytest.mark.asyncio
+async def test_client_connect_success(p2pcs):
+    peer_id_0, maddrs_0 = await p2pcs[0].identify()
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await p2pcs[0].connect(peer_id_1, maddrs_1)
+    # test case: repeated connections
+    await p2pcs[1].connect(peer_id_0, maddrs_0)
+
+
+@pytest.mark.asyncio
+async def test_client_connect_failure(p2pcs):
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await p2pcs[0].identify()
+    # test case: `peer_id` mismatches
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(PEER_ID_RANDOM, maddrs_1)
+    # test case: empty maddrs
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(peer_id_1, [])
+    # test case: wrong maddrs
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(peer_id_1, [Multiaddr("/ip4/127.0.0.1/udp/0")])
+
+
+@pytest.mark.asyncio
+async def test_connect_safe(p2pcs):
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+
+@pytest.mark.asyncio
+async def test_client_list_peers(p2pcs):
+    # test case: no peers
+    assert len(await p2pcs[0].list_peers()) == 0
+    # test case: 1 peer
+    await connect_safe(p2pcs[0], p2pcs[1])
+    assert len(await p2pcs[0].list_peers()) == 1
+    assert len(await p2pcs[1].list_peers()) == 1
+    # test case: one more peer
+    await connect_safe(p2pcs[0], p2pcs[2])
+    assert len(await p2pcs[0].list_peers()) == 2
+    assert len(await p2pcs[1].list_peers()) == 1
+    assert len(await p2pcs[2].list_peers()) == 1
+
+
+@pytest.mark.asyncio
+async def test_client_disconnect(p2pcs):
+    # test case: disconnect a peer without connections
+    await p2pcs[1].disconnect(PEER_ID_RANDOM)
+    # test case: disconnect
+    peer_id_0, _ = await p2pcs[0].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+    assert len(await p2pcs[0].list_peers()) == 1
+    assert len(await p2pcs[1].list_peers()) == 1
+    await p2pcs[1].disconnect(peer_id_0)
+    assert len(await p2pcs[0].list_peers()) == 0
+    assert len(await p2pcs[1].list_peers()) == 0
+    # test case: disconnect twice
+    await p2pcs[1].disconnect(peer_id_0)
+    assert len(await p2pcs[0].list_peers()) == 0
+    assert len(await p2pcs[1].list_peers()) == 0
+
+
+@pytest.mark.asyncio
+async def test_client_stream_open_success(p2pcs):
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    async def handle_proto(stream_info, reader, writer):
+        await reader.readexactly(1)
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+
+    # test case: normal
+    stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
+    assert stream_info.peer_id == peer_id_1
+    assert stream_info.addr in maddrs_1
+    assert stream_info.proto == "123"
+    writer.close()
+
+    # test case: open with multiple protocols
+    stream_info, reader, writer = await p2pcs[0].stream_open(
+        peer_id_1, (proto, "another_protocol")
+    )
+    assert stream_info.peer_id == peer_id_1
+    assert stream_info.addr in maddrs_1
+    assert stream_info.proto == "123"
+    writer.close()
+
+
+@pytest.mark.asyncio
+async def test_client_stream_open_failure(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    # test case: `stream_open` to a peer who didn't register the protocol
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, (proto,))
+
+    # test case: `stream_open` to a peer for a non-registered protocol
+    async def handle_proto(stream_info, reader, writer):
+        pass
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, ("another_protocol",))
+
+
+@pytest.mark.asyncio
+async def test_client_stream_handler_success(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "protocol123"
+    bytes_to_send = b"yoyoyoyoyog"
+    # event for this test function to wait until the handler function receiving the incoming data
+    event_handler_finished = asyncio.Event()
+
+    async def handle_proto(stream_info, reader, writer):
+        nonlocal event_handler_finished
+        bytes_received = await reader.readexactly(len(bytes_to_send))
+        assert bytes_received == bytes_to_send
+        event_handler_finished.set()
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+    assert proto in p2pcs[1].control.handlers
+    assert handle_proto == p2pcs[1].control.handlers[proto]
+
+    # test case: test the stream handler `handle_proto`
+
+    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
+
+    # wait until the handler function starts blocking waiting for the data
+    # because we haven't sent the data, we know the handler function must still blocking waiting.
+    # get the task of the protocol handler
+    writer.write(bytes_to_send)
+
+    # wait for the handler to finish
+    writer.close()
+
+    await event_handler_finished.wait()
+
+    # test case: two streams to different handlers respectively
+    another_proto = "another_protocol123"
+    another_bytes_to_send = b"456"
+    event_another_proto = asyncio.Event()
+
+    async def handle_another_proto(stream_info, reader, writer):
+        event_another_proto.set()
+        bytes_received = await reader.readexactly(len(another_bytes_to_send))
+        assert bytes_received == another_bytes_to_send
+
+    await p2pcs[1].stream_handler(another_proto, handle_another_proto)
+    assert another_proto in p2pcs[1].control.handlers
+    assert handle_another_proto == p2pcs[1].control.handlers[another_proto]
+
+    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (another_proto,))
+    await event_another_proto.wait()
+
+    # we know at this moment the handler must still blocking wait
+
+    writer.write(another_bytes_to_send)
+
+    writer.close()
+
+    # test case: registering twice can override the previous registration
+    event_third = asyncio.Event()
+
+    async def handler_third(stream_info, reader, writer):
+        event_third.set()
+
+    await p2pcs[1].stream_handler(another_proto, handler_third)
+    assert another_proto in p2pcs[1].control.handlers
+    # ensure the handler is override
+    assert handler_third == p2pcs[1].control.handlers[another_proto]
+
+    await p2pcs[0].stream_open(peer_id_1, (another_proto,))
+    # ensure the overriding handler is called when the protocol is opened a stream
+    await event_third.wait()
+
+
+@pytest.mark.asyncio
+async def test_client_stream_handler_failure(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    # test case: registered a wrong protocol name
+    async def handle_proto_correct_params(stream_info, stream):
+        pass
+
+    await p2pcs[1].stream_handler("another_protocol", handle_proto_correct_params)
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, (proto,))

+ 41 - 1
tests/test_util_modules.py

@@ -11,6 +11,7 @@ from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 import hivemind
 import hivemind
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
 from hivemind.utils.mpfuture import FutureStateError
 from hivemind.utils.mpfuture import FutureStateError
 
 
 
 
@@ -138,6 +139,11 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     assert error.square().mean() < beta
     assert error.square().mean() < beta
 
 
+    zeros = torch.zeros(5,5)
+    for compression_type in CompressionType.values():
+        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_channel_cache():
 async def test_channel_cache():
@@ -252,7 +258,7 @@ def test_split_parts():
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
         with pytest.raises(RuntimeError):
         with pytest.raises(RuntimeError):
             deserialize_torch_tensor(combined)
             deserialize_torch_tensor(combined)
-            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceProtocol
+            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceRunner
 
 
 
 
 def test_generic_data_classes():
 def test_generic_data_classes():
@@ -267,3 +273,37 @@ def test_generic_data_classes():
     sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
     sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
     sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
     sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
     assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
     assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
+
+
+@pytest.mark.asyncio
+async def test_asyncio_utils():
+    res = [i async for i, item in aenumerate(aiter('a', 'b', 'c'))]
+    assert res == list(range(len(res)))
+
+    num_steps = 0
+    async for elem in amap_in_executor(lambda x: x ** 2, aiter(*range(100)), max_prefetch=5):
+        assert elem == num_steps ** 2
+        num_steps += 1
+    assert num_steps == 100
+
+    ours = [elem async for elem in amap_in_executor(max, aiter(*range(7)), aiter(*range(-50, 50, 10)), max_prefetch=1)]
+    ref = list(map(max, range(7), range(-50, 50, 10)))
+    assert ours == ref
+
+    ours = [row async for row in azip(aiter('a', 'b', 'c'), aiter(1, 2, 3))]
+    ref = list(zip(['a', 'b', 'c'], [1, 2, 3]))
+    assert ours == ref
+
+    async def _aiterate():
+        yield 'foo'
+        yield 'bar'
+        yield 'baz'
+
+    iterator = _aiterate()
+    assert (await anext(iterator)) == 'foo'
+    tail = [item async for item in iterator]
+    assert tail == ['bar', 'baz']
+    with pytest.raises(StopAsyncIteration):
+        await anext(iterator)
+
+    assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ['foo', 'bar', 'baz'] + list(range(5))

+ 194 - 0
tests/test_utils/__init__.py

@@ -0,0 +1,194 @@
+import asyncio
+import functools
+import os
+import subprocess
+import time
+import uuid
+from contextlib import asynccontextmanager
+from typing import NamedTuple
+from pkg_resources import resource_filename
+
+from multiaddr import Multiaddr, protocols
+
+from hivemind import find_open_port
+from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
+
+
+TIMEOUT_DURATION = 30  # seconds
+P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")
+
+
+async def try_until_success(coro_func, timeout=TIMEOUT_DURATION):
+    """
+    Keep running ``coro_func`` until the time is out.
+    All arguments of ``coro_func`` should be filled, i.e. it should be called without arguments.
+    """
+    t_start = time.monotonic()
+    while True:
+        result = await coro_func()
+        if result:
+            break
+        if (time.monotonic() - t_start) >= timeout:
+            # timeout
+            assert False, f"{coro_func} still failed after `{timeout}` seconds"
+        await asyncio.sleep(0.01)
+
+
+class Daemon:
+    control_maddr = None
+    proc_daemon = None
+    log_filename = ""
+    f_log = None
+    closed = None
+
+    def __init__(
+            self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub
+    ):
+        self.control_maddr = control_maddr
+        self.enable_control = enable_control
+        self.enable_connmgr = enable_connmgr
+        self.enable_dht = enable_dht
+        self.enable_pubsub = enable_pubsub
+        self.is_closed = False
+        self._start_logging()
+        self._run()
+
+    def _start_logging(self):
+        name_control_maddr = str(self.control_maddr).replace("/", "_").replace(".", "_")
+        self.log_filename = f"/tmp/log_p2pd{name_control_maddr}.txt"
+        self.f_log = open(self.log_filename, "wb")
+
+    def _run(self):
+        cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
+        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
+        if self.enable_connmgr:
+            cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
+        if self.enable_dht:
+            cmd_list += ["-dht=true"]
+        if self.enable_pubsub:
+            cmd_list += ["-pubsub=true", "-pubsubRouter=gossipsub"]
+        self.proc_daemon = subprocess.Popen(
+            cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0
+        )
+
+    async def wait_until_ready(self):
+        lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
+        lines_head_occurred = {line: False for line in lines_head_pattern}
+
+        with open(self.log_filename, "rb") as f_log_read:
+
+            async def read_from_daemon_and_check():
+                line = f_log_read.readline()
+                for head_pattern in lines_head_occurred:
+                    if line.startswith(head_pattern):
+                        lines_head_occurred[head_pattern] = True
+                return all([value for _, value in lines_head_occurred.items()])
+
+            await try_until_success(read_from_daemon_and_check)
+
+        # sleep for a while in case that the daemon haven't been ready after emitting these lines
+        await asyncio.sleep(0.1)
+
+    def close(self):
+        if self.is_closed:
+            return
+        self.proc_daemon.terminate()
+        self.proc_daemon.wait()
+        self.f_log.close()
+        self.is_closed = True
+
+
+class DaemonTuple(NamedTuple):
+    daemon: Daemon
+    client: Client
+
+
+class ConnectionFailure(Exception):
+    pass
+
+
+@asynccontextmanager
+async def make_p2pd_pair_unix(
+        enable_control, enable_connmgr, enable_dht, enable_pubsub
+):
+    name = str(uuid.uuid4())[:8]
+    control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
+    listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
+    # Remove the existing unix socket files if they are existing
+    try:
+        os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    try:
+        os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
+    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def _make_p2pd_pair(
+        control_maddr,
+        listen_maddr,
+        enable_control,
+        enable_connmgr,
+        enable_dht,
+        enable_pubsub,
+):
+    p2pd = Daemon(
+        control_maddr=control_maddr,
+        enable_control=enable_control,
+        enable_connmgr=enable_connmgr,
+        enable_dht=enable_dht,
+        enable_pubsub=enable_pubsub,
+    )
+    # wait for daemon ready
+    await p2pd.wait_until_ready()
+    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    try:
+        async with client.listen():
+            yield DaemonTuple(daemon=p2pd, client=client)
+    finally:
+        if not p2pd.is_closed:
+            p2pd.close()
+
+
+async def _check_connection(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_0, _ = await p2pd_tuple_0.identify()
+    peer_id_1, _ = await p2pd_tuple_1.identify()
+    peers_0 = [pinfo.peer_id for pinfo in await p2pd_tuple_0.list_peers()]
+    peers_1 = [pinfo.peer_id for pinfo in await p2pd_tuple_1.list_peers()]
+    return (peer_id_0 in peers_1) and (peer_id_1 in peers_0)
+
+
+async def connect_safe(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_1, maddrs_1 = await p2pd_tuple_1.identify()
+    await p2pd_tuple_0.connect(peer_id_1, maddrs_1)
+    await try_until_success(
+        functools.partial(
+            _check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1
+        )
+    )