Forráskód Böngészése

Reformat code with Black (#274)

* Reformat code with Black

* Check style with a GitHub action

* Update contributing guide
Max Ryabinin 4 éve
szülő
commit
2f07a556e6
98 módosított fájl, 4089 hozzáadás és 2528 törlés
  1. 13 0
      .github/workflows/check_style.yml
  2. 1 1
      .github/workflows/run-tests.yml
  3. 27 23
      CONTRIBUTING.md
  4. 3 2
      README.md
  5. 38 21
      benchmarks/benchmark_averaging.py
  6. 53 29
      benchmarks/benchmark_dht.py
  7. 3 3
      benchmarks/benchmark_tensor_compression.py
  8. 130 63
      benchmarks/benchmark_throughput.py
  9. 55 54
      docs/conf.py
  10. 39 57
      examples/albert/arguments.py
  11. 80 55
      examples/albert/run_trainer.py
  12. 62 48
      examples/albert/run_training_monitor.py
  13. 12 12
      examples/albert/tokenize_wikitext103.py
  14. 11 10
      examples/albert/utils.py
  15. 16 4
      hivemind/__init__.py
  16. 45 21
      hivemind/averaging/allreduce.py
  17. 193 95
      hivemind/averaging/averager.py
  18. 2 1
      hivemind/averaging/group_info.py
  19. 66 33
      hivemind/averaging/key_manager.py
  20. 2 2
      hivemind/averaging/load_balancing.py
  21. 122 63
      hivemind/averaging/matchmaking.py
  22. 25 18
      hivemind/averaging/partition.py
  23. 42 27
      hivemind/averaging/training.py
  24. 73 34
      hivemind/dht/__init__.py
  25. 9 9
      hivemind/dht/crypto.py
  26. 277 129
      hivemind/dht/node.py
  27. 105 49
      hivemind/dht/protocol.py
  28. 31 22
      hivemind/dht/routing.py
  29. 20 15
      hivemind/dht/schema.py
  30. 9 7
      hivemind/dht/storage.py
  31. 33 19
      hivemind/dht/traverse.py
  32. 2 2
      hivemind/dht/validation.py
  33. 7 7
      hivemind/hivemind_cli/run_server.py
  34. 179 64
      hivemind/moe/client/beam_search.py
  35. 25 16
      hivemind/moe/client/expert.py
  36. 127 53
      hivemind/moe/client/moe.py
  37. 81 37
      hivemind/moe/client/switch_moe.py
  38. 77 31
      hivemind/moe/server/__init__.py
  39. 6 6
      hivemind/moe/server/checkpoints.py
  40. 18 11
      hivemind/moe/server/connection_handler.py
  41. 26 13
      hivemind/moe/server/dht_handler.py
  42. 68 47
      hivemind/moe/server/expert_backend.py
  43. 22 16
      hivemind/moe/server/expert_uid.py
  44. 1 1
      hivemind/moe/server/layers/__init__.py
  45. 8 11
      hivemind/moe/server/layers/common.py
  46. 3 2
      hivemind/moe/server/layers/custom_experts.py
  47. 8 5
      hivemind/moe/server/layers/dropout.py
  48. 25 14
      hivemind/moe/server/runtime.py
  49. 33 19
      hivemind/moe/server/task_pool.py
  50. 11 4
      hivemind/optim/adaptive.py
  51. 6 3
      hivemind/optim/base.py
  52. 128 64
      hivemind/optim/collaborative.py
  53. 2 1
      hivemind/optim/performance_ema.py
  54. 90 24
      hivemind/optim/simple.py
  55. 106 75
      hivemind/p2p/p2p_daemon.py
  56. 14 38
      hivemind/p2p/p2p_daemon_bindings/control.py
  57. 7 23
      hivemind/p2p/p2p_daemon_bindings/datastructures.py
  58. 4 11
      hivemind/p2p/p2p_daemon_bindings/p2pclient.py
  59. 28 18
      hivemind/p2p/servicer.py
  60. 2 2
      hivemind/utils/__init__.py
  61. 14 10
      hivemind/utils/asyncio.py
  62. 24 18
      hivemind/utils/auth.py
  63. 25 20
      hivemind/utils/compression.py
  64. 6 4
      hivemind/utils/crypto.py
  65. 54 30
      hivemind/utils/grpc.py
  66. 2 1
      hivemind/utils/limits.py
  67. 8 5
      hivemind/utils/logging.py
  68. 29 18
      hivemind/utils/mpfuture.py
  69. 10 17
      hivemind/utils/nested.py
  70. 13 13
      hivemind/utils/networking.py
  71. 3 2
      hivemind/utils/serializer.py
  72. 16 10
      hivemind/utils/tensor_descr.py
  73. 13 9
      hivemind/utils/timed_storage.py
  74. 3 0
      pyproject.toml
  75. 1 0
      requirements-dev.txt
  76. 71 62
      setup.py
  77. 2 2
      tests/conftest.py
  78. 70 40
      tests/test_allreduce.py
  79. 30 27
      tests/test_auth.py
  80. 172 76
      tests/test_averaging.py
  81. 26 13
      tests/test_custom_experts.py
  82. 21 22
      tests/test_dht.py
  83. 37 34
      tests/test_dht_crypto.py
  84. 93 48
      tests/test_dht_experts.py
  85. 121 86
      tests/test_dht_node.py
  86. 46 52
      tests/test_dht_schema.py
  87. 31 33
      tests/test_dht_storage.py
  88. 34 33
      tests/test_dht_validation.py
  89. 18 13
      tests/test_expert_backend.py
  90. 115 62
      tests/test_moe.py
  91. 34 42
      tests/test_p2p_daemon.py
  92. 64 40
      tests/test_p2p_daemon_bindings.py
  93. 21 26
      tests/test_routing.py
  94. 61 25
      tests/test_training.py
  95. 54 47
      tests/test_util_modules.py
  96. 7 6
      tests/test_utils/custom_networks.py
  97. 7 6
      tests/test_utils/dht_swarms.py
  98. 22 32
      tests/test_utils/p2p_daemon.py

+ 13 - 0
.github/workflows/check_style.yml

@@ -0,0 +1,13 @@
+name: Check style
+
+on: [ push ]
+
+jobs:
+  black:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: psf/black@stable
+        with:
+          options: "--check"
+          version: "21.6b0"

+ 1 - 1
.github/workflows/run-tests.yml

@@ -56,7 +56,7 @@ jobs:
           pip install -r requirements-dev.txt
       - name: Build hivemind
         run: |
-          pip install . --global-option=build_py --global-option="--buildgo"
+          pip install . --global-option=build_py --global-option="--buildgo" --no-use-pep517
       - name: Test
         run: |
           cd tests

+ 27 - 23
CONTRIBUTING.md

@@ -4,18 +4,18 @@ This document covers the technical details of making your contributions to the c
 contribute, read the [contributing guide](https://learning-at-home.readthedocs.io/en/latest/user/contributing.html) in
 our documentation.
 
-Before you begin, file a new issue on [GitHub](https://github.com/learning-at-home/hivemind/issues) or announce that you
-are going to work on an existing one to avoid duplicate effort. After you finish, submit a pull request and wait for it
-to be reviewed by the library maintainers (and possibly other community members).
+Before you begin, file a new issue on [GitHub](https://github.com/learning-at-home/hivemind/issues) or announce that
+you are going to work on an existing one to avoid duplicate effort. After you finish, submit a pull request and wait
+for it to be reviewed by the library maintainers (and possibly other community members).
 
 ## Environment setup
 
-First, install hivemind in the development mode, preferably with Python 3.8 on Linux.
+First, install hivemind in the development mode, preferably with Python 3.8+ on Linux.
 
 ```
 git clone https://github.com/learning-at-home/hivemind
 cd hivemind
-pip install -e .
+pip install -e .[dev]
 ``` 
 
 ## Pull Request checklist
@@ -34,16 +34,18 @@ with the following rules:
 
 ## Code style
 
+* We use [black](https://github.com/psf/black) for code formatting. Before submitting a PR, make sure to install and
+  run `black .` in the root of the repository.
 * The code must follow [PEP8](https://www.python.org/dev/peps/pep-0008/) unless absolutely necessary. Also, each line
-  cannot be longer than 120 characters.
+  cannot be longer than 119 characters.
 * We highly encourage the use of [typing](https://docs.python.org/3/library/typing.html) where applicable.
 * Use `get_logger` from `hivemind.utils.logging` to log any information instead of `print`ing directly to standard
   output/error streams.
-* Comments should be used sparingly and never describe the obvious; usually it's best to clean up the code logic instead
-  of describing it, as it might lead to redundant (or worse, stale or incorrect).
-* In general, strive for code readability instead of compactness. In particular, prefer to create a new variable instead
-  of a long one-liner and to break up a long method into several meaningful parts. This rule can be overridden in case
-  of major performance considerations, but only if verified by benchmarks.
+* Comments should be used sparingly and never describe the obvious. Usually it's best to clean up the code logic
+  instead of describing it, as it might lead to redundant (or worse, stale or incorrect) messages.
+* In general, strive for code readability instead of compactness. In particular, prefer to create a new variable
+  instead of a long one-liner and to break up a long method into several meaningful parts. This rule can be overridden
+  in case of major performance considerations, but only if verified by benchmarks.
 * Each user-facing function must have a [correct](#building-documentation) docstring that describes the intended usage,
   the input arguments and the return value. Both in comments and docstrings, please try to follow the capitalization
   rules for all terms and objects and to use proper grammar.
@@ -67,8 +69,9 @@ It is not required to use this format while you are still working on your pull r
 message has to adhere to these guidelines, and it will be easier for the maintainers to accept the PR if you have
 already done most of the necessary formatting work.
 
-For further reading on the commit message format, see this [guide](https://chris.beams.io/posts/git-commit/#seven-rules)
-on good Git commit messages, as well as this [repository](https://github.com/RomuloOliveira/commit-messages-guide).
+For further reading on the commit message format, see
+this [guide](https://chris.beams.io/posts/git-commit/#seven-rules) on good Git commit messages, as well as
+this [repository](https://github.com/RomuloOliveira/commit-messages-guide).
 
 ### Pull requests
 
@@ -77,9 +80,9 @@ merge commit title is the name of the pull request along with the PR number refe
 the pull request description (if it adheres to the format) or a cleaned up compilation of PR branch commit messages.
 
 * As such, the name and the description of your PR should follow the same guidelines as commit messages.
-* Try to make your pull requests more narrow in scope and split significant changes to the code base in separate pieces.
-  This will ensure [faster and better](https://essenceofcode.com/2019/10/29/the-art-of-small-pull-requests/) feedback
-  from the reviewers.
+* Try to make your pull requests more narrow in scope and split significant changes to the code base in separate
+  pieces. This will ensure [faster and better](https://essenceofcode.com/2019/10/29/the-art-of-small-pull-requests/)
+  feedback from the reviewers.
 * In particular, try to separate functional and non-functional code changes, as well as independent functional changes
   if they make the pull request too large to review in a short period of time.
 * In general, when naming a pull request instead of a commit, it's best to highlight the major change in its title
@@ -88,9 +91,10 @@ the pull request description (if it adheres to the format) or a cleaned up compi
   compare `Implement decentralized parameter averaging` with `Add hivemind.client.averaging`.
 
 For more on the philosophy of easy-to-review pull requests, read these
-guides: [1](https://mtlynch.io/code-review-love/) [2](https://www.atlassian.com/blog/git/written-unwritten-guide-pull-requests)
-. If the changelist is not very large (more than a hundred lines) already, we encourage making small improvements to the
-codebase in the files already changed by the PR; however, they should not dilute its major purpose.
+guides: [1](https://mtlynch.io/code-review-love/)
+[2](https://www.atlassian.com/blog/git/written-unwritten-guide-pull-requests). If the changelist is not very large
+(more than a hundred lines) already, we encourage making small improvements to the codebase in the files already
+changed by the PR; however, they should not dilute its major purpose.
 
 ## Running tests
 
@@ -103,8 +107,8 @@ To run tests, you need to install hivemind in development mode with additional d
 You can run all tests with `pytest tests/` or choose a specific subset, e.g., `pytest tests/test_dht.py`.
 
 When investigating test behavior, please note that pytest automatically wraps all hivemind tests with fixtures defined
-in a global configuration file [`tests/conftest.py`](./tests/conftest.py), some of which will run automatically. 
-For more informantion, refer to the [pytest documentation on fixtures](https://docs.pytest.org/en/6.2.x/fixture.html).
+in a global configuration file [`tests/conftest.py`](./tests/conftest.py), some of which will run automatically. For
+more informantion, refer to the [pytest documentation on fixtures](https://docs.pytest.org/en/6.2.x/fixture.html).
 
 ## Building documentation
 
@@ -128,8 +132,8 @@ the maintainers to provide the benchmarking results for your branch and a compar
 
 * `benchmarks/benchmark_averaging.py` measures the performance of decentralized parameter averaging across the DHT.
 * `benchmarks/benchmark_dht.py` measures the performance of core DHT operations.
-* `benchmarks/benchmark_throughput.py` measures the performance of a server hosting several expert layers under heavy load
-  from multiple clients.
+* `benchmarks/benchmark_throughput.py` measures the performance of a server hosting several expert layers under heavy
+  load from multiple clients.
 
 Example benchmark runs are available in
 the [benchmarking](https://learning-at-home.readthedocs.io/en/latest/user/benchmarks.html) page of the documentation.

+ 3 - 2
README.md

@@ -3,6 +3,7 @@
 [![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions)
 [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest)
 [![Gitter](https://badges.gitter.im/learning-at-home/hivemind.svg)](https://gitter.im/learning-at-home/hivemind?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
+[![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
 
 Hivemind is a PyTorch library to train large neural networks across the Internet. Its intended usage is training a
 single Transformer model on hundreds of computers from different universities, companies, and volunteers.
@@ -16,8 +17,8 @@ single Transformer model on hundreds of computers from different universities, c
   network.
 * Fault-tolerant backpropagation: forward and backward passes succeed even if some nodes are unresponsive or take too
   long to respond.
-* Decentralized parameter averaging: iteratively aggregate updates from multiple workers without the need to synchronize
-  across the entire network.
+* Decentralized parameter averaging: iteratively aggregate updates from multiple workers without the need to
+  synchronize across the entire network.
 
 To learn more about the ideas behind this library, see https://learning-at-home.github.io or read
 the [NeurIPS 2020 paper](https://arxiv.org/abs/2002.04013).

+ 38 - 21
benchmarks/benchmark_averaging.py

@@ -31,16 +31,23 @@ def sample_tensors(hid_size, num_layers):
     return tuple(tensors)
 
 
-def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
-                        averaging_expiration: float, request_timeout: float, round_timeout: float,
-                        hid_size: int, num_layers: int, spawn_dtime: float):
+def benchmark_averaging(
+    num_peers: int,
+    target_group_size: int,
+    num_rounds: int,
+    averaging_expiration: float,
+    request_timeout: float,
+    round_timeout: float,
+    hid_size: int,
+    num_layers: int,
+    spawn_dtime: float,
+):
     dht_root = hivemind.DHT(start=True)
     initial_peers = dht_root.get_visible_maddrs()
 
     num_groups = 2 ** int(round(math.log2(num_peers / target_group_size)))
     nbits = int(round(math.log2(num_groups)))
-    peer_tensors = [sample_tensors(hid_size, num_layers)
-                    for _ in range(num_peers)]
+    peer_tensors = [sample_tensors(hid_size, num_layers) for _ in range(num_peers)]
     processes = {dht_root}
     lock_stats = threading.Lock()
     successful_steps = total_steps = 0
@@ -48,14 +55,24 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
     def run_averager(index):
         nonlocal successful_steps, total_steps, lock_stats
         dht = hivemind.DHT(initial_peers=initial_peers, start=True)
-        initial_bits = bin(index % num_groups)[2:].rjust(nbits, '0')
+        initial_bits = bin(index % num_groups)[2:].rjust(nbits, "0")
         averager = hivemind.averaging.DecentralizedAverager(
-            peer_tensors[i], dht, prefix='my_tensor', initial_group_bits=initial_bits, listen_on=f"{LOCALHOST}:*",
-            compression_type=runtime_pb2.CompressionType.FLOAT16, target_group_size=target_group_size,
-            averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
+            peer_tensors[i],
+            dht,
+            prefix="my_tensor",
+            initial_group_bits=initial_bits,
+            listen_on=f"{LOCALHOST}:*",
+            compression_type=runtime_pb2.CompressionType.FLOAT16,
+            target_group_size=target_group_size,
+            averaging_expiration=averaging_expiration,
+            request_timeout=request_timeout,
+            start=True,
+        )
         processes.update({dht, averager})
 
-        logger.info(f'Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}')
+        logger.info(
+            f"Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}"
+        )
         for step in range(num_rounds):
             try:
                 success = averager.step(timeout=round_timeout) is not None
@@ -84,19 +101,19 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--num_peers', type=int, default=16, required=False)
-    parser.add_argument('--target_group_size', type=int, default=4, required=False)
-    parser.add_argument('--num_rounds', type=int, default=5, required=False)
-    parser.add_argument('--hid_size', type=int, default=256, required=False)
-    parser.add_argument('--num_layers', type=int, default=3, required=False)
-    parser.add_argument('--averaging_expiration', type=float, default=5, required=False)
-    parser.add_argument('--round_timeout', type=float, default=15, required=False)
-    parser.add_argument('--request_timeout', type=float, default=1, required=False)
-    parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
-    parser.add_argument('--increase_file_limit', action="store_true")
+    parser.add_argument("--num_peers", type=int, default=16, required=False)
+    parser.add_argument("--target_group_size", type=int, default=4, required=False)
+    parser.add_argument("--num_rounds", type=int, default=5, required=False)
+    parser.add_argument("--hid_size", type=int, default=256, required=False)
+    parser.add_argument("--num_layers", type=int, default=3, required=False)
+    parser.add_argument("--averaging_expiration", type=float, default=5, required=False)
+    parser.add_argument("--round_timeout", type=float, default=15, required=False)
+    parser.add_argument("--request_timeout", type=float, default=1, required=False)
+    parser.add_argument("--spawn_dtime", type=float, default=0.1, required=False)
+    parser.add_argument("--increase_file_limit", action="store_true")
     args = vars(parser.parse_args())
 
-    if args.pop('increase_file_limit', False):
+    if args.pop("increase_file_limit", False):
         increase_file_limit()
 
     benchmark_averaging(**args)

+ 53 - 29
benchmarks/benchmark_dht.py

@@ -12,26 +12,42 @@ logger = hivemind.get_logger(__name__)
 
 
 def random_endpoint() -> hivemind.Endpoint:
-    return f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}." \
-           f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
-
-
-def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
-                  wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration: float):
+    return (
+        f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}."
+        f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
+    )
+
+
+def benchmark_dht(
+    num_peers: int,
+    initial_peers: int,
+    num_experts: int,
+    expert_batch_size: int,
+    random_seed: int,
+    wait_after_request: float,
+    wait_before_read: float,
+    wait_timeout: float,
+    expiration: float,
+):
     random.seed(random_seed)
 
     logger.info("Creating peers...")
     peers = []
     for _ in trange(num_peers):
-        neighbors = sum([peer.get_visible_maddrs()
-                         for peer in random.sample(peers, min(initial_peers, len(peers)))], [])
+        neighbors = sum(
+            [peer.get_visible_maddrs() for peer in random.sample(peers, min(initial_peers, len(peers)))], []
+        )
         peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
         peers.append(peer)
 
     store_peer, get_peer = peers[-2:]
 
-    expert_uids = list(set(f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
-                           for _ in range(num_experts)))
+    expert_uids = list(
+        set(
+            f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
+            for _ in range(num_experts)
+        )
+    )
     logger.info(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
     random.shuffle(expert_uids)
 
@@ -43,8 +59,9 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for start in trange(0, num_experts, expert_batch_size):
         store_start = time.perf_counter()
         endpoints.append(random_endpoint())
-        store_ok = declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1],
-                                   expiration=expiration)
+        store_ok = declare_experts(
+            store_peer, expert_uids[start : start + expert_batch_size], endpoints[-1], expiration=expiration
+        )
         successes = store_ok.values()
         total_store_time += time.perf_counter() - store_start
 
@@ -53,7 +70,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
         time.sleep(wait_after_request)
 
     logger.info(
-        f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
+        f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})"
+    )
     logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
     time.sleep(wait_before_read)
 
@@ -64,19 +82,25 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
     for start in trange(0, len(expert_uids), expert_batch_size):
         get_start = time.perf_counter()
-        get_result = get_experts(get_peer, expert_uids[start: start + expert_batch_size])
+        get_result = get_experts(get_peer, expert_uids[start : start + expert_batch_size])
         total_get_time += time.perf_counter() - get_start
 
         for i, expert in enumerate(get_result):
-            if expert is not None and expert.uid == expert_uids[start + i] \
-                    and expert.endpoint == endpoints[start // expert_batch_size]:
+            if (
+                expert is not None
+                and expert.uid == expert_uids[start + i]
+                and expert.endpoint == endpoints[start // expert_batch_size]
+            ):
                 successful_gets += 1
 
     if time.perf_counter() - benchmark_started > expiration:
-        logger.warning("keys expired midway during get requests. If that isn't desired, increase expiration_time param")
+        logger.warning(
+            "keys expired midway during get requests. If that isn't desired, increase expiration_time param"
+        )
 
     logger.info(
-        f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
+        f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})"
+    )
     logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
 
     alive_peers = [peer.is_alive() for peer in peers]
@@ -85,19 +109,19 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--num_peers', type=int, default=32, required=False)
-    parser.add_argument('--initial_peers', type=int, default=1, required=False)
-    parser.add_argument('--num_experts', type=int, default=256, required=False)
-    parser.add_argument('--expert_batch_size', type=int, default=32, required=False)
-    parser.add_argument('--expiration', type=float, default=300, required=False)
-    parser.add_argument('--wait_after_request', type=float, default=0, required=False)
-    parser.add_argument('--wait_before_read', type=float, default=0, required=False)
-    parser.add_argument('--wait_timeout', type=float, default=5, required=False)
-    parser.add_argument('--random_seed', type=int, default=random.randint(1, 1000))
-    parser.add_argument('--increase_file_limit', action="store_true")
+    parser.add_argument("--num_peers", type=int, default=32, required=False)
+    parser.add_argument("--initial_peers", type=int, default=1, required=False)
+    parser.add_argument("--num_experts", type=int, default=256, required=False)
+    parser.add_argument("--expert_batch_size", type=int, default=32, required=False)
+    parser.add_argument("--expiration", type=float, default=300, required=False)
+    parser.add_argument("--wait_after_request", type=float, default=0, required=False)
+    parser.add_argument("--wait_before_read", type=float, default=0, required=False)
+    parser.add_argument("--wait_timeout", type=float, default=5, required=False)
+    parser.add_argument("--random_seed", type=int, default=random.randint(1, 1000))
+    parser.add_argument("--increase_file_limit", action="store_true")
     args = vars(parser.parse_args())
 
-    if args.pop('increase_file_limit', False):
+    if args.pop("increase_file_limit", False):
         increase_file_limit()
 
     benchmark_dht(**args)

+ 3 - 3
benchmarks/benchmark_tensor_compression.py

@@ -19,9 +19,9 @@ def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionTyp
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--size', type=int, default=10000000, required=False)
-    parser.add_argument('--seed', type=int, default=7348, required=False)
-    parser.add_argument('--num_iters', type=int, default=30, required=False)
+    parser.add_argument("--size", type=int, default=10000000, required=False)
+    parser.add_argument("--seed", type=int, default=7348, required=False)
+    parser.add_argument("--num_iters", type=int, default=30, required=False)
 
     args = parser.parse_args()
 

+ 130 - 63
benchmarks/benchmark_throughput.py

@@ -17,21 +17,23 @@ logger = get_logger(__name__)
 
 def print_device_info(device=None):
     """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
-    device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
-    logger.info(f'Using device: {device}')
+    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
+    logger.info(f"Using device: {device}")
 
     # Additional Info when using cuda
-    if device.type == 'cuda':
+    if device.type == "cuda":
         logger.info(torch.cuda.get_device_name(0))
-        logger.info(f'Memory Usage:')
-        logger.info(f'Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB')
-        logger.info(f'Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB')
+        logger.info(f"Memory Usage:")
+        logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
+        logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
 
 
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
     torch.set_num_threads(1)
     can_start.wait()
-    experts = [hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)]
+    experts = [
+        hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
+    ]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -45,11 +47,24 @@ def client_process(can_start, benchmarking_failed, port, num_experts, batch_size
         raise e
 
 
-def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num_batches_per_client=16,
-                         expert_cls='ffn', hid_dim=1024, batch_size=2048, max_batch_size=None, backprop=True,
-                         device=None, port=None):
-    assert not hasattr(torch.cuda, 'is_initialized') or not torch.cuda.is_initialized() \
-           or torch.device(device) == torch.device('cpu')
+def benchmark_throughput(
+    num_experts=16,
+    num_handlers=None,
+    num_clients=128,
+    num_batches_per_client=16,
+    expert_cls="ffn",
+    hid_dim=1024,
+    batch_size=2048,
+    max_batch_size=None,
+    backprop=True,
+    device=None,
+    port=None,
+):
+    assert (
+        not hasattr(torch.cuda, "is_initialized")
+        or not torch.cuda.is_initialized()
+        or torch.device(device) == torch.device("cpu")
+    )
     assert expert_cls in layers.name_to_block
     port = port or find_open_port()
     max_batch_size = max_batch_size or batch_size * 4
@@ -63,40 +78,57 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
         # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
         clients = [
             mp.Process(
-                target=client_process, name=f'client_process-{i}',
-                args=(can_start, benchmarking_failed, port, num_experts, batch_size,
-                      hid_dim, num_batches_per_client, backprop))
-            for i in range(num_clients)]
+                target=client_process,
+                name=f"client_process-{i}",
+                args=(
+                    can_start,
+                    benchmarking_failed,
+                    port,
+                    num_experts,
+                    batch_size,
+                    hid_dim,
+                    num_batches_per_client,
+                    backprop,
+                ),
+            )
+            for i in range(num_clients)
+        ]
 
         for client in clients:
             client.daemon = True
             client.start()
 
-        timestamps['launched_clients'] = timestamps['began_launching_server'] = time.perf_counter()
+        timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
         # start server
-        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         for i in range(num_experts):
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f'expert{i}'] = hivemind.ExpertBackend(name=f'expert{i}',
-                                                           expert=expert,
-                                                           optimizer=torch.optim.Adam(expert.parameters()),
-                                                           args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
-                                                           outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
-                                                           max_batch_size=max_batch_size,
-                                                           )
-        timestamps['created_experts'] = time.perf_counter()
-        server = hivemind.moe.Server(None, experts, listen_on=f"{hivemind.LOCALHOST}:{port}",
-                                     num_connection_handlers=num_handlers, device=device)
+            experts[f"expert{i}"] = hivemind.ExpertBackend(
+                name=f"expert{i}",
+                expert=expert,
+                optimizer=torch.optim.Adam(expert.parameters()),
+                args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
+                outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
+                max_batch_size=max_batch_size,
+            )
+        timestamps["created_experts"] = time.perf_counter()
+        server = hivemind.moe.Server(
+            None,
+            experts,
+            listen_on=f"{hivemind.LOCALHOST}:{port}",
+            num_connection_handlers=num_handlers,
+            device=device,
+        )
         server.start()
         server.ready.wait()
-        timestamps['server_ready'] = time.perf_counter()
+        timestamps["server_ready"] = time.perf_counter()
         can_start.set()
 
         for client in clients:
             client.join()
-        timestamps['clients_finished'] = time.perf_counter()
+        timestamps["clients_finished"] = time.perf_counter()
     except BaseException as e:
         benchmarking_failed.set()
         raise e
@@ -105,28 +137,39 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
             if client.is_alive():
                 client.terminate()
         server.shutdown()
-        timestamps['server_shutdown_finished'] = time.perf_counter()
+        timestamps["server_shutdown_finished"] = time.perf_counter()
         server.join()
 
     sys.stdout.flush()
     sys.stderr.flush()
-    time_between = lambda key1, key2: \
-        abs(timestamps[key2] - timestamps[key1]) if (key1 in timestamps and key2 in timestamps) else float('nan')
+    time_between = (
+        lambda key1, key2: abs(timestamps[key2] - timestamps[key1])
+        if (key1 in timestamps and key2 in timestamps)
+        else float("nan")
+    )
     total_examples = batch_size * num_clients * num_batches_per_client
 
     logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
-    logger.info(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
-                f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
-    logger.info(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
-                f"batch_size={batch_size}, backprop={backprop}")
+    logger.info(
+        f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
+        f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}"
+    )
+    logger.info(
+        f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
+        f"batch_size={batch_size}, backprop={backprop}"
+    )
 
     logger.info("Results: ")
-    logger.info(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
-                f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
-                f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
+    logger.info(
+        f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
+        f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
+        f"{time_between('created_experts', 'server_ready') :.3f} s. networking)"
+    )
     logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
-    logger.info(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
-                f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
+    logger.info(
+        f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
+        f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s."
+    )
     logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
     if benchmarking_failed.is_set():
         logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
@@ -139,31 +182,55 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--preset', type=str, default='default', required=False)
-    parser.add_argument('--num_batches_per_client', type=int, default=16, required=False)
+    parser.add_argument("--preset", type=str, default="default", required=False)
+    parser.add_argument("--num_batches_per_client", type=int, default=16, required=False)
     args = parser.parse_args()
 
-    if args.preset in ('default', 'ffn_forward_backward'):
+    if args.preset in ("default", "ffn_forward_backward"):
         benchmark_throughput()
-    elif args.preset == 'ffn_forward':
+    elif args.preset == "ffn_forward":
         benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
-    elif args.preset == 'ffn_small_batch':
-        benchmark_throughput(backprop=False, num_experts=4, batch_size=32, max_batch_size=8192,
-                             num_batches_per_client=args.num_batches_per_client)
-    elif args.preset == 'ffn_small_batch_512clients':
-        benchmark_throughput(backprop=True, num_experts=1, batch_size=1, max_batch_size=8192,
-                             num_clients=512, num_batches_per_client=args.num_batches_per_client)
-    elif args.preset == 'ffn_small_batch_512clients_32handlers':
-        benchmark_throughput(backprop=True, num_experts=1, batch_size=1, max_batch_size=8192, num_handlers=32,
-                             num_clients=512, num_batches_per_client=args.num_batches_per_client)
-    elif args.preset == 'ffn_massive':
+    elif args.preset == "ffn_small_batch":
+        benchmark_throughput(
+            backprop=False,
+            num_experts=4,
+            batch_size=32,
+            max_batch_size=8192,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "ffn_small_batch_512clients":
+        benchmark_throughput(
+            backprop=True,
+            num_experts=1,
+            batch_size=1,
+            max_batch_size=8192,
+            num_clients=512,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "ffn_small_batch_512clients_32handlers":
+        benchmark_throughput(
+            backprop=True,
+            num_experts=1,
+            batch_size=1,
+            max_batch_size=8192,
+            num_handlers=32,
+            num_clients=512,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "ffn_massive":
         increase_file_limit()
-        benchmark_throughput(backprop=False, num_clients=512, batch_size=512,
-                             max_batch_size=8192, num_batches_per_client=args.num_batches_per_client)
-    elif args.preset == 'minimalistic':
-        benchmark_throughput(num_experts=1, num_clients=1, num_handlers=1,
-                             num_batches_per_client=args.num_batches_per_client)
-    elif args.preset == 'nop':
-        benchmark_throughput(expert_cls='nop', backprop=False, num_batches_per_client=args.num_batches_per_client)
+        benchmark_throughput(
+            backprop=False,
+            num_clients=512,
+            batch_size=512,
+            max_batch_size=8192,
+            num_batches_per_client=args.num_batches_per_client,
+        )
+    elif args.preset == "minimalistic":
+        benchmark_throughput(
+            num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
+        )
+    elif args.preset == "nop":
+        benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
     else:
         raise ValueError(f"No such benchmark preset: {args.preset}")

+ 55 - 54
docs/conf.py

@@ -22,16 +22,16 @@ from recommonmark.parser import CommonMarkParser
 
 
 # -- Project information -----------------------------------------------------
-src_path = '../hivemind'
-project = 'hivemind'
-copyright = '2020, Learning@home & contributors'
-author = 'Learning@home & contributors'
+src_path = "../hivemind"
+project = "hivemind"
+copyright = "2020, Learning@home & contributors"
+author = "Learning@home & contributors"
 
 # The short X.Y version
-version = ''
+version = ""
 # The full version, including alpha/beta/rc tags
-release = 'latest'
-branch = 'master'
+release = "latest"
+branch = "master"
 
 
 # -- General configuration ---------------------------------------------------
@@ -44,31 +44,30 @@ branch = 'master'
 # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
 # ones.
 extensions = [
-    'sphinx.ext.autodoc',
-    'sphinx.ext.autosummary',
-    'sphinx.ext.doctest',
-    'sphinx.ext.mathjax',
-    'sphinx.ext.linkcode',  # link to github, see linkcode_resolve() below
-    'sphinx.ext.napoleon',  # alternative to numpydoc
+    "sphinx.ext.autodoc",
+    "sphinx.ext.autosummary",
+    "sphinx.ext.doctest",
+    "sphinx.ext.mathjax",
+    "sphinx.ext.linkcode",  # link to github, see linkcode_resolve() below
+    "sphinx.ext.napoleon",  # alternative to numpydoc
 ]
 
 # see http://stackoverflow.com/q/12206334/562769
 numpydoc_show_class_members = False
 
-mathjax_path = ('https://cdn.mathjax.org/mathjax/latest/MathJax.js?'
-                'config=TeX-AMS-MML_HTMLorMML')
+mathjax_path = "https://cdn.mathjax.org/mathjax/latest/MathJax.js?" "config=TeX-AMS-MML_HTMLorMML"
 
 
 # Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
 
 # The suffix(es) of source filenames.
 # You can specify multiple suffix as a list of string:
 #
-source_suffix = {'.rst': 'restructuredtext', '.md': 'markdown'}
+source_suffix = {".rst": "restructuredtext", ".md": "markdown"}
 
 # The master toctree document.
-master_doc = 'index'
+master_doc = "index"
 
 # The language for content autogenerated by Sphinx. Refer to documentation
 # for a list of supported languages.
@@ -80,10 +79,10 @@ language = None
 # List of patterns, relative to source directory, that match files and
 # directories to ignore when looking for source files.
 # This pattern also affects html_static_path and html_extra_path.
-exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
 
 # The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
+pygments_style = "sphinx"
 
 
 # -- Options for HTML output -------------------------------------------------
@@ -91,22 +90,20 @@ pygments_style = 'sphinx'
 # The theme to use for HTML and HTML Help pages.  See the documentation for
 # a list of builtin themes.
 #
-html_theme = 'sphinx_rtd_theme'
+html_theme = "sphinx_rtd_theme"
 
 # Theme options are theme-specific and customize the look and feel of a theme
 # further.  For a list of options available for each theme, see the
 # documentation.
 #
-html_theme_options = {
-    "collapse_navigation": False
-}
+html_theme_options = {"collapse_navigation": False}
 
-html_favicon = '_static/favicon.png'
+html_favicon = "_static/favicon.png"
 
 # Add any paths that contain custom static files (such as style sheets) here,
 # relative to this directory. They are copied after the builtin static files,
 # so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
 
 # Custom sidebar templates, must be a dictionary that maps document names
 # to template names.
@@ -122,7 +119,7 @@ html_static_path = ['_static']
 # -- Options for HTMLHelp output ---------------------------------------------
 
 # Output file base name for HTML help builder.
-htmlhelp_basename = 'hiveminddoc'
+htmlhelp_basename = "hiveminddoc"
 
 
 # -- Options for LaTeX output ------------------------------------------------
@@ -131,15 +128,12 @@ latex_elements = {
     # The paper size ('letterpaper' or 'a4paper').
     #
     # 'papersize': 'letterpaper',
-
     # The font size ('10pt', '11pt' or '12pt').
     #
     # 'pointsize': '10pt',
-
     # Additional stuff for the LaTeX preamble.
     #
     # 'preamble': '',
-
     # Latex figure (float) alignment
     #
     # 'figure_align': 'htbp',
@@ -149,8 +143,7 @@ latex_elements = {
 # (source start file, target name, title,
 #  author, documentclass [howto, manual, or own class]).
 latex_documents = [
-    (master_doc, 'hivemind.tex', 'hivemind Documentation',
-     'Learning@home \\& contributors', 'manual'),
+    (master_doc, "hivemind.tex", "hivemind Documentation", "Learning@home \\& contributors", "manual"),
 ]
 
 
@@ -158,10 +151,7 @@ latex_documents = [
 
 # One entry per manual page. List of tuples
 # (source start file, name, description, authors, manual section).
-man_pages = [
-    (master_doc, 'hivemind', 'hivemind Documentation',
-     [author], 1)
-]
+man_pages = [(master_doc, "hivemind", "hivemind Documentation", [author], 1)]
 
 
 # -- Options for Texinfo output ----------------------------------------------
@@ -170,9 +160,15 @@ man_pages = [
 # (source start file, target name, title, author,
 #  dir menu entry, description, category)
 texinfo_documents = [
-    (master_doc, 'hivemind', 'hivemind Documentation',
-     author, 'hivemind', 'One line description of project.',
-     'Miscellaneous'),
+    (
+        master_doc,
+        "hivemind",
+        "hivemind Documentation",
+        author,
+        "hivemind",
+        "One line description of project.",
+        "Miscellaneous",
+    ),
 ]
 
 
@@ -191,7 +187,7 @@ epub_title = project
 # epub_uid = ''
 
 # A list of files that should not be packed into the epub file.
-epub_exclude_files = ['search.html']
+epub_exclude_files = ["search.html"]
 
 
 # -- Extension configuration -------------------------------------------------
@@ -199,7 +195,7 @@ epub_exclude_files = ['search.html']
 # -- Options for intersphinx extension ---------------------------------------
 
 # Example configuration for intersphinx: refer to the Python standard library.
-intersphinx_mapping = {'https://docs.python.org/': None}
+intersphinx_mapping = {"https://docs.python.org/": None}
 
 # -- Options for todo extension ----------------------------------------------
 
@@ -209,14 +205,18 @@ todo_include_todos = True
 
 def setup(app):
     app.add_stylesheet("fix_rtd.css")
-    app.add_config_value('recommonmark_config', {
-        'auto_toc_tree_section': 'Contents',
-        'enable_math': True,
-        'enable_inline_math': True,
-        'enable_eval_rst': True,
-    }, True)
+    app.add_config_value(
+        "recommonmark_config",
+        {
+            "auto_toc_tree_section": "Contents",
+            "enable_math": True,
+            "enable_inline_math": True,
+            "enable_eval_rst": True,
+        },
+        True,
+    )
     app.add_transform(AutoStructify)
-    app.add_source_suffix('.md', 'markdown')
+    app.add_source_suffix(".md", "markdown")
     app.add_source_parser(CommonMarkParser)
 
 
@@ -227,22 +227,23 @@ def linkcode_resolve(domain, info):
     def find_source():
         # try to find the file and line number, based on code from numpy:
         # https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
-        obj = sys.modules[info['module']]
-        for part in info['fullname'].split('.'):
+        obj = sys.modules[info["module"]]
+        for part in info["fullname"].split("."):
             obj = getattr(obj, part)
         import inspect
         import os
+
         fn = inspect.getsourcefile(obj)
         fn = os.path.relpath(fn, start=os.path.dirname(src_path))
         source, lineno = inspect.getsourcelines(obj)
         return fn, lineno, lineno + len(source) - 1
 
-    if domain != 'py' or not info['module']:
+    if domain != "py" or not info["module"]:
         return None
     try:
-        filename = '%s#L%d-L%d' % find_source()
+        filename = "%s#L%d-L%d" % find_source()
     except Exception:
-        filename = info['module'].replace('.', '/') + '.py'
+        filename = info["module"].replace(".", "/") + ".py"
 
-    relative_filename = filename[filename.rindex('hivemind'):]
+    relative_filename = filename[filename.rindex("hivemind") :]
     return "https://github.com/learning-at-home/hivemind/blob/%s/%s" % (branch, relative_filename)

+ 39 - 57
examples/albert/arguments.py

@@ -11,74 +11,65 @@ class BaseTrainingArguments:
     )
     initial_peers: List[str] = field(
         default_factory=list,
-        metadata={"help":
-            "Multiaddrs of the peers that will welcome you into the existing collaboration. "
-            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"}
+        metadata={
+            "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
+            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
+        },
     )
     use_ipfs: bool = field(
         default=False,
-        metadata={"help":
-            "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of the multiaddrs "
-            "for the initial_peers (no need to specify a particular IPv4/IPv6 host and port)"}
+        metadata={
+            "help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of the multiaddrs "
+            "for the initial_peers (no need to specify a particular IPv4/IPv6 host and port)"
+        },
     )
     host_maddrs: List[str] = field(
-        default_factory=lambda: ['/ip4/0.0.0.0/tcp/0', '/ip4/0.0.0.0/udp/0/quic'],
-        metadata={"help":
-            "Multiaddrs to listen for external connections from other p2p instances. "
+        default_factory=lambda: ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
+        metadata={
+            "help": "Multiaddrs to listen for external connections from other p2p instances. "
             "Defaults to all IPv4 interfaces with TCP and QUIC (over UDP) protocols: "
-            "/ip4/0.0.0.0/tcp/0 /ip4/0.0.0.0/udp/0/quic"}
+            "/ip4/0.0.0.0/tcp/0 /ip4/0.0.0.0/udp/0/quic"
+        },
     )
     announce_maddrs: List[str] = field(
         default_factory=list,
-        metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"}
+        metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
     )
 
 
 @dataclass
 class AveragerArguments:
     averaging_expiration: float = field(
-        default=5.0,
-        metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
+        default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
     )
     averaging_timeout: float = field(
-        default=30.0,
-        metadata={"help": "Give up on averaging step after this many seconds"}
+        default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
     )
     listen_on: str = field(
         default="[::]:*",
-        metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"}
+        metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"},
     )
     min_refresh_period: float = field(
-        default=0.5,
-        metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
+        default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
     )
     max_refresh_period: float = field(
-        default=30,
-        metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
+        default=30, metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
     )
     default_refresh_period: float = field(
-        default=3,
-        metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
+        default=3, metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
     )
     expected_drift_peers: float = field(
-        default=3,
-        metadata={"help": "Trainer assumes that this many new peers can join per step"}
+        default=3, metadata={"help": "Trainer assumes that this many new peers can join per step"}
     )
     expected_drift_rate: float = field(
-        default=0.2,
-        metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
+        default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
     )
     performance_ema_alpha: float = field(
-        default=0.1,
-        metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
-    )
-    target_group_size: int = field(
-        default=256,
-        metadata={"help": "Maximum group size for all-reduce"}
+        default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
     )
+    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
     metadata_expiration: float = field(
-        default=30,
-        metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
+        default=30, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
     )
 
 
@@ -86,52 +77,43 @@ class AveragerArguments:
 class CollaborativeOptimizerArguments:
     target_batch_size: int = field(
         default=4096,
-        metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"}
+        metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
     )
     client_mode: bool = field(
         default=False,
-        metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"}
+        metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"},
     )
     batch_size_lead: int = field(
         default=0,
-        metadata={"help": "Optional: begin looking for group in advance, this many samples before target_batch_size"}
+        metadata={"help": "Optional: begin looking for group in advance, this many samples before target_batch_size"},
     )
     bandwidth: float = field(
         default=100.0,
-        metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"}
+        metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
     )
     compression: str = field(
-        default="FLOAT16",
-        metadata={"help": "Use this compression when averaging parameters/gradients"}
+        default="FLOAT16", metadata={"help": "Use this compression when averaging parameters/gradients"}
     )
 
 
 @dataclass
 class CollaborationArguments(AveragerArguments, CollaborativeOptimizerArguments, BaseTrainingArguments):
     statistics_expiration: float = field(
-        default=600,
-        metadata={"help": "Statistics will be removed if not updated in this many seconds"}
+        default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
     )
 
 
 @dataclass
 class DatasetArguments:
     dataset_path: Optional[str] = field(
-        default='data/albert_tokenized_wikitext',
-        metadata={"help": "Path to the tokenized dataset"}
-    )
-    tokenizer_path: Optional[str] = field(
-        default='data/tokenizer',
-        metadata={"help": "Path to the tokenizer"}
+        default="data/albert_tokenized_wikitext", metadata={"help": "Path to the tokenized dataset"}
     )
+    tokenizer_path: Optional[str] = field(default="data/tokenizer", metadata={"help": "Path to the tokenizer"})
     config_path: Optional[str] = field(
-        default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
-        metadata={"help": "Path to the model config"}
-    )
-    cache_dir: Optional[str] = field(
-        default='data',
-        metadata={"help": "Path to the cache"}
+        default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
+        metadata={"help": "Path to the model config"},
     )
+    cache_dir: Optional[str] = field(default="data", metadata={"help": "Path to the cache"})
 
 
 @dataclass
@@ -142,7 +124,7 @@ class AlbertTrainingArguments(TrainingArguments):
     gradient_accumulation_steps: int = 2
     seq_length: int = 512
 
-    max_steps: int = 125_000 # please note: this affects both number of steps and learning rate schedule
+    max_steps: int = 125_000  # please note: this affects both number of steps and learning rate schedule
     learning_rate: float = 0.00176
     warmup_steps: int = 5000
     adam_epsilon: float = 1e-6
@@ -151,11 +133,11 @@ class AlbertTrainingArguments(TrainingArguments):
     clamp_value: float = 10000.0
 
     fp16: bool = True
-    fp16_opt_level: str = 'O2'
+    fp16_opt_level: str = "O2"
     do_train: bool = True
 
     logging_steps: int = 100
     save_total_limit: int = 2
     save_steps: int = 500
 
-    output_dir: str = 'outputs'
+    output_dir: str = "outputs"

+ 80 - 55
examples/albert/run_trainer.py

@@ -10,8 +10,15 @@ import torch
 import transformers
 from datasets import load_from_disk
 from torch.utils.data import DataLoader
-from transformers import (set_seed, HfArgumentParser, TrainingArguments,
-                          DataCollatorForLanguageModeling, AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining)
+from transformers import (
+    set_seed,
+    HfArgumentParser,
+    TrainingArguments,
+    DataCollatorForLanguageModeling,
+    AlbertTokenizerFast,
+    AlbertConfig,
+    AlbertForPreTraining,
+)
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer_utils import is_main_process
 from transformers.trainer import Trainer
@@ -23,7 +30,7 @@ from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingAr
 
 
 logger = logging.getLogger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 
 
 def setup_logging(training_args):
@@ -50,13 +57,13 @@ def get_model(training_args, config, tokenizer):
     # Find latest checkpoint in output_dir
     output_dir = Path(training_args.output_dir)
     logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
-    latest_checkpoint_dir = max(output_dir.glob('checkpoint*'), default=None, key=os.path.getctime)
+    latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
 
     if latest_checkpoint_dir is not None:
-        logger.info(f'Loading model from {latest_checkpoint_dir}')
+        logger.info(f"Loading model from {latest_checkpoint_dir}")
         model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
     else:
-        logger.info(f'Training from scratch')
+        logger.info(f"Training from scratch")
         model = AlbertForPreTraining(config)
         model.resize_token_embeddings(len(tokenizer))
 
@@ -87,17 +94,21 @@ def get_optimizer_and_scheduler(training_args, model):
     )
 
     scheduler = get_linear_schedule_with_warmup(
-        opt,
-        num_warmup_steps=training_args.warmup_steps,
-        num_training_steps=training_args.max_steps
+        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
     )
 
     return opt, scheduler
 
 
 class CollaborativeCallback(transformers.TrainerCallback):
-    def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer,
-                 model: torch.nn.Module, local_public_key: bytes, statistics_expiration: float):
+    def __init__(
+        self,
+        dht: hivemind.DHT,
+        optimizer: hivemind.CollaborativeOptimizer,
+        model: torch.nn.Module,
+        local_public_key: bytes,
+        statistics_expiration: float,
+    ):
         super().__init__()
         self.model = model
         self.dht, self.collaborative_optimizer = dht, optimizer
@@ -110,13 +121,15 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self.loss = 0
         self.total_samples_processed = 0
 
-    def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
-                       control: transformers.TrainerControl, **kwargs):
-        logger.info('Loading state from peers')
+    def on_train_begin(
+        self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
+    ):
+        logger.info("Loading state from peers")
         self.collaborative_optimizer.load_state_from_peers()
 
-    def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
-                    control: transformers.TrainerControl, **kwargs):
+    def on_step_end(
+        self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
+    ):
         control.should_log = True
         if not self.params_are_finite():
             self.load_from_state(self.previous_state)
@@ -124,7 +137,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
         self.previous_state = self.get_current_state()
 
         if state.log_history:
-            self.loss += state.log_history[-1]['loss']
+            self.loss += state.log_history[-1]["loss"]
             self.steps += 1
             if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
                 self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
@@ -135,7 +148,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
                     samples_per_second=samples_per_second,
                     samples_accumulated=self.samples,
                     loss=self.loss,
-                    mini_steps=self.steps)
+                    mini_steps=self.steps,
+                )
                 logger.info(f"Step {self.collaborative_optimizer.local_step}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 if self.steps:
@@ -144,10 +158,13 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 self.loss = 0
                 self.steps = 0
                 if self.collaborative_optimizer.is_synchronized:
-                    self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
-                                   subkey=self.local_public_key, value=statistics.dict(),
-                                   expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
-                                   return_future=True)
+                    self.dht.store(
+                        key=self.collaborative_optimizer.prefix + "_metrics",
+                        subkey=self.local_public_key,
+                        value=statistics.dict(),
+                        expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                        return_future=True,
+                    )
 
         self.samples = self.collaborative_optimizer.local_samples_accumulated
 
@@ -155,15 +172,12 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
     @torch.no_grad()
     def get_current_state(self) -> Dict[str, Any]:
-        return {
-            'model': self.model.state_dict(),
-            'opt': self.collaborative_optimizer.opt.state_dict()
-        }
+        return {"model": self.model.state_dict(), "opt": self.collaborative_optimizer.opt.state_dict()}
 
     @torch.no_grad()
     def load_from_state(self, state):
-        self.model.load_state_dict(state['model'])
-        self.collaborative_optimizer.opt.load_state_dict(state['opt'])
+        self.model.load_state_dict(state["model"])
+        self.collaborative_optimizer.opt.load_state_dict(state["opt"])
 
     @torch.no_grad()
     def params_are_finite(self):
@@ -174,10 +188,10 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
 
 class NoOpScheduler(LRSchedulerBase):
-    """ Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler """
+    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
 
     def get_lr(self):
-        return [group['lr'] for group in self.optimizer.param_groups]
+        return [group["lr"] for group in self.optimizer.param_groups]
 
     def print_lr(self, *args, **kwargs):
         if self.optimizer.scheduler:
@@ -219,46 +233,59 @@ def main():
 
     opt, scheduler = get_optimizer_and_scheduler(training_args, model)
 
-    validators, local_public_key = utils.make_validators(
-        collaboration_args_dict['experiment_prefix'])
-    dht = hivemind.DHT(start=True,
-                       initial_peers=collaboration_args_dict.pop('initial_peers'),
-                       listen=not collaboration_args_dict['client_mode'],
-                       record_validators=validators,
-                       use_ipfs=collaboration_args_dict['use_ipfs'],
-                       host_maddrs=collaboration_args_dict.pop('host_maddrs'),
-                       announce_maddrs=collaboration_args_dict.pop('announce_maddrs'))
-    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args_dict.pop('use_ipfs'))
+    validators, local_public_key = utils.make_validators(collaboration_args_dict["experiment_prefix"])
+    dht = hivemind.DHT(
+        start=True,
+        initial_peers=collaboration_args_dict.pop("initial_peers"),
+        listen=not collaboration_args_dict["client_mode"],
+        record_validators=validators,
+        use_ipfs=collaboration_args_dict["use_ipfs"],
+        host_maddrs=collaboration_args_dict.pop("host_maddrs"),
+        announce_maddrs=collaboration_args_dict.pop("announce_maddrs"),
+    )
+    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args_dict.pop("use_ipfs"))
 
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
     if torch.cuda.device_count() != 0:
         total_batch_size_per_step *= torch.cuda.device_count()
 
-    statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
-    adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
-                                 - collaboration_args_dict.pop('batch_size_lead')
+    statistics_expiration = collaboration_args_dict.pop("statistics_expiration")
+    adjusted_target_batch_size = collaboration_args_dict.pop("target_batch_size") - collaboration_args_dict.pop(
+        "batch_size_lead"
+    )
 
     collaborative_optimizer = hivemind.CollaborativeOptimizer(
-        opt=opt, dht=dht, scheduler=scheduler, prefix=collaboration_args_dict.pop('experiment_prefix'),
-        compression_type=hivemind.utils.CompressionType.Value(collaboration_args_dict.pop('compression')),
-        batch_size_per_step=total_batch_size_per_step, throughput=collaboration_args_dict.pop('bandwidth'),
-        target_batch_size=adjusted_target_batch_size, client_mode=collaboration_args_dict.pop('client_mode'),
-        verbose=True, start=True, **collaboration_args_dict
+        opt=opt,
+        dht=dht,
+        scheduler=scheduler,
+        prefix=collaboration_args_dict.pop("experiment_prefix"),
+        compression_type=hivemind.utils.CompressionType.Value(collaboration_args_dict.pop("compression")),
+        batch_size_per_step=total_batch_size_per_step,
+        throughput=collaboration_args_dict.pop("bandwidth"),
+        target_batch_size=adjusted_target_batch_size,
+        client_mode=collaboration_args_dict.pop("client_mode"),
+        verbose=True,
+        start=True,
+        **collaboration_args_dict,
     )
 
     class TrainerWithIndependentShuffling(Trainer):
         def get_train_dataloader(self) -> DataLoader:
-            """ Shuffle data independently for each peer to avoid duplicating batches [important for quality] """
+            """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
             torch.manual_seed(hash(local_public_key))
             return super().get_train_dataloader()
 
     trainer = TrainerWithIndependentShuffling(
-        model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator,
+        model=model,
+        args=training_args,
+        tokenizer=tokenizer,
+        data_collator=data_collator,
         train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
         eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
         optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
-        callbacks=[CollaborativeCallback(
-            dht, collaborative_optimizer, model, local_public_key, statistics_expiration)]
+        callbacks=[
+            CollaborativeCallback(dht, collaborative_optimizer, model, local_public_key, statistics_expiration)
+        ],
     )
     trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
     trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
@@ -266,9 +293,7 @@ def main():
     # Training
     if training_args.do_train:
         latest_checkpoint_dir = max(
-            Path(training_args.output_dir).glob('checkpoint*'),
-            default=None,
-            key=os.path.getctime
+            Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
         )
 
         trainer.train(model_path=latest_checkpoint_dir)

+ 62 - 48
examples/albert/run_training_monitor.py

@@ -28,48 +28,48 @@ class CoordinatorArguments(BaseTrainingArguments):
     new workers still can join the collaboration via alive initial peers' addresses.
     Specify initial_peers argument for that purpose
     """
+
     use_google_dns: bool = field(
         default=False,
-        metadata={"help":
-            "Use Google DNS to determine the public IP address of this machine (and add it to --announce_maddrs)"}
+        metadata={
+            "help": "Use Google DNS to determine the public IP address of this machine (and add it to --announce_maddrs)"
+        },
     )
     refresh_period: float = field(
-        default=30,
-        metadata={"help": "Coordinator will fetch keys from DHT once in this many seconds"}
-    )
-    wandb_project: Optional[str] = field(
-        default=None,
-        metadata={"help": "Learning curves will be published there"}
+        default=30, metadata={"help": "Coordinator will fetch keys from DHT once in this many seconds"}
     )
+    wandb_project: Optional[str] = field(default=None, metadata={"help": "Learning curves will be published there"})
     save_checkpoint_step_interval: int = field(
-        default=5,
-        metadata={"help": "Coordinator will load and save state from peers once every that many steps"}
+        default=5, metadata={"help": "Coordinator will load and save state from peers once every that many steps"}
     )
     model_config_path: str = field(
-        default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
-        metadata={"help": "Path to the model config"}
+        default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
+        metadata={"help": "Path to the model config"},
     )
     repo_path: Optional[str] = field(
         default=None,
-        metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"}
+        metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"},
     )
     repo_url: Optional[str] = field(
         default=None,
-        metadata={"help": "URL to Hugging Face repository to which the coordinator will upload the model and optimizer states"}
+        metadata={
+            "help": "URL to Hugging Face repository to which the coordinator will upload the model and optimizer states"
+        },
     )
     upload_interval: Optional[float] = field(
-        default=None,
-        metadata={"help": "Coordinator will upload model once in this many seconds"}
-    )
-    store_checkpoins: bool = field(
-        default=False,
-        metadata={"help": "If True, enables CheckpointHandler"}
+        default=None, metadata={"help": "Coordinator will upload model once in this many seconds"}
     )
+    store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
 
 
 class CheckpointHandler:
-    def __init__(self, coordinator_args: CoordinatorArguments, collab_optimizer_args: CollaborativeOptimizerArguments,
-                 averager_args: AveragerArguments, dht: hivemind.DHT):
+    def __init__(
+        self,
+        coordinator_args: CoordinatorArguments,
+        collab_optimizer_args: CollaborativeOptimizerArguments,
+        averager_args: AveragerArguments,
+        dht: hivemind.DHT,
+    ):
         self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
         self.repo_path = coordinator_args.repo_path
         self.repo_url = coordinator_args.repo_url
@@ -93,17 +93,25 @@ class CheckpointHandler:
 
         opt = Lamb(
             optimizer_grouped_parameters,
-            lr=0.00176, weight_decay=0.01, clamp_value=10000.0, debias=True,
+            lr=0.00176,
+            weight_decay=0.01,
+            clamp_value=10000.0,
+            debias=True,
         )
 
         adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
 
         self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
-            opt=opt, dht=dht, prefix=experiment_prefix,
+            opt=opt,
+            dht=dht,
+            prefix=experiment_prefix,
             compression_type=hivemind.utils.CompressionType.Value(collab_optimizer_args.compression),
             throughput=collab_optimizer_args.bandwidth,
-            target_batch_size=adjusted_target_batch_size, client_mode=collab_optimizer_args.client_mode,
-            verbose=True, start=True, **asdict(averager_args)
+            target_batch_size=adjusted_target_batch_size,
+            client_mode=collab_optimizer_args.client_mode,
+            verbose=True,
+            start=True,
+            **asdict(averager_args),
         )
         self.previous_timestamp = time.time()
 
@@ -132,13 +140,16 @@ class CheckpointHandler:
         logger.info("Saving optimizer")
         torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
-        logger.info('Started uploading model to Hub')
-        self.model.push_to_hub(repo_name=self.repo_path, repo_url=self.repo_url,
-                               commit_message=f'Step {current_step}, loss {current_loss:.3f}')
-        logger.info('Finished uploading model to Hub')
+        logger.info("Started uploading model to Hub")
+        self.model.push_to_hub(
+            repo_name=self.repo_path,
+            repo_url=self.repo_url,
+            commit_message=f"Step {current_step}, loss {current_loss:.3f}",
+        )
+        logger.info("Finished uploading model to Hub")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     parser = HfArgumentParser((CoordinatorArguments, CollaborativeOptimizerArguments, AveragerArguments))
     coordinator_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
 
@@ -146,16 +157,18 @@ if __name__ == '__main__':
         address = get_ip(GoogleDnsProvider)
         logger.info(f"Received public IP address of this machine from Google DNS: {address}")
         version = ip_address(address).version
-        coordinator_args.announce_maddrs += [f'/ip{version}/{address}/tcp/0', f'/ip{version}/{address}/udp/0/quic']
+        coordinator_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0", f"/ip{version}/{address}/udp/0/quic"]
 
     experiment_prefix = coordinator_args.experiment_prefix
     validators, local_public_key = utils.make_validators(experiment_prefix)
-    dht = hivemind.DHT(start=True,
-                       initial_peers=coordinator_args.initial_peers,
-                       record_validators=validators,
-                       use_ipfs=coordinator_args.use_ipfs,
-                       host_maddrs=coordinator_args.host_maddrs,
-                       announce_maddrs=coordinator_args.announce_maddrs)
+    dht = hivemind.DHT(
+        start=True,
+        initial_peers=coordinator_args.initial_peers,
+        record_validators=validators,
+        use_ipfs=coordinator_args.use_ipfs,
+        host_maddrs=coordinator_args.host_maddrs,
+        announce_maddrs=coordinator_args.announce_maddrs,
+    )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=coordinator_args.use_ipfs)
 
     if coordinator_args.wandb_project is not None:
@@ -166,11 +179,10 @@ if __name__ == '__main__':
         checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
 
     while True:
-        metrics_dict = dht.get(experiment_prefix + '_metrics', latest=True)
+        metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
         if metrics_dict is not None:
             metrics_dict = metrics_dict.value
-            metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
-                       for peer in metrics_dict]
+            metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
             latest_step = max(item.step for item in metrics)
             if latest_step != current_step:
                 logger.debug(f"Got metrics from {len(metrics)} peers")
@@ -194,13 +206,15 @@ if __name__ == '__main__':
                 logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
 
                 if coordinator_args.wandb_project is not None:
-                    wandb.log({
-                        "loss": current_loss,
-                        "alive peers": alive_peers,
-                        "samples": num_samples,
-                        "performance": sum_perf,
-                        "step": latest_step
-                    })
+                    wandb.log(
+                        {
+                            "loss": current_loss,
+                            "alive peers": alive_peers,
+                            "samples": num_samples,
+                            "performance": sum_perf,
+                            "step": latest_step,
+                        }
+                    )
                 if coordinator_args.store_checkpoins:
                     if checkpoint_handler.is_time_to_save_state(current_step):
                         checkpoint_handler.save_state(current_step)

+ 12 - 12
examples/albert/tokenize_wikitext103.py

@@ -9,7 +9,7 @@ from datasets import load_dataset
 from transformers import AlbertTokenizerFast
 
 
-COLUMN_NAMES = ('attention_mask', 'input_ids', 'sentence_order_label', 'special_tokens_mask', 'token_type_ids')
+COLUMN_NAMES = ("attention_mask", "input_ids", "sentence_order_label", "special_tokens_mask", "token_type_ids")
 
 
 def create_instances_from_document(tokenizer, document, max_seq_length):
@@ -56,15 +56,15 @@ def create_instances_from_document(tokenizer, document, max_seq_length):
                 assert len(tokens_b) >= 1
 
                 instance = tokenizer(
-                    ' '.join(tokens_a),
-                    ' '.join(tokens_b),
-                    truncation='longest_first',
+                    " ".join(tokens_a),
+                    " ".join(tokens_b),
+                    truncation="longest_first",
                     max_length=max_seq_length,
                     # We use this option because DataCollatorForLanguageModeling
                     # is more efficient when it receives the `special_tokens_mask`.
                     return_special_tokens_mask=True,
                 )
-                assert len(instance['input_ids']) <= max_seq_length
+                assert len(instance["input_ids"]) <= max_seq_length
                 instance["sentence_order_label"] = 1 if is_random_next else 0
                 instances.append(instance)
 
@@ -85,15 +85,15 @@ def tokenize_function(tokenizer, examples):
         for instance in instances:
             for key, value in instance.items():
                 new_examples[key].append(value)
-    
+
     return new_examples
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     random.seed(0)
-    nltk.download('punkt')
-    tokenizer = AlbertTokenizerFast.from_pretrained('albert-large-v2')
-    wikitext = load_dataset('wikitext', 'wikitext-103-v1', cache_dir='./data/cache')
+    nltk.download("punkt")
+    tokenizer = AlbertTokenizerFast.from_pretrained("albert-large-v2")
+    wikitext = load_dataset("wikitext", "wikitext-103-v1", cache_dir="./data/cache")
 
     tokenized_datasets = wikitext.map(
         partial(tokenize_function, tokenizer),
@@ -102,5 +102,5 @@ if __name__ == '__main__':
         remove_columns=["text"],
     )
 
-    tokenized_datasets.save_to_disk('./data/albert_tokenized_wikitext')
-    tokenizer.save_pretrained('./data/tokenizer')
+    tokenized_datasets.save_to_disk("./data/albert_tokenized_wikitext")
+    tokenizer.save_pretrained("./data/tokenizer")

+ 11 - 10
examples/albert/utils.py

@@ -26,22 +26,23 @@ class MetricSchema(BaseModel):
 
 def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
     signature_validator = RSASignatureValidator()
-    validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix),
-                  signature_validator]
+    validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix), signature_validator]
     return validators, signature_validator.local_public_key
 
 
 class TextStyle:
-    BOLD = '\033[1m'
-    BLUE = '\033[34m'
-    RESET = '\033[0m'
+    BOLD = "\033[1m"
+    BLUE = "\033[34m"
+    RESET = "\033[0m"
 
 
 def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
     if only_p2p:
-        unique_addrs = {addr['p2p'] for addr in visible_maddrs}
-        initial_peers_str = ' '.join(f'/p2p/{addr}' for addr in unique_addrs)
+        unique_addrs = {addr["p2p"] for addr in visible_maddrs}
+        initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
     else:
-        initial_peers_str = ' '.join(str(addr) for addr in visible_maddrs)
-    logger.info(f"Running a DHT peer. To connect other peers to this one, use "
-                f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}")
+        initial_peers_str = " ".join(str(addr) for addr in visible_maddrs)
+    logger.info(
+        f"Running a DHT peer. To connect other peers to this one, use "
+        f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
+    )

+ 16 - 4
hivemind/__init__.py

@@ -1,9 +1,21 @@
 from hivemind.averaging import DecentralizedAverager, TrainingAverager
 from hivemind.dht import DHT
-from hivemind.moe import ExpertBackend, Server, register_expert_class, RemoteExpert, RemoteMixtureOfExperts, \
-    RemoteSwitchMixtureOfExperts
-from hivemind.optim import CollaborativeAdaptiveOptimizer, DecentralizedOptimizerBase, CollaborativeOptimizer, \
-    DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
+from hivemind.moe import (
+    ExpertBackend,
+    Server,
+    register_expert_class,
+    RemoteExpert,
+    RemoteMixtureOfExperts,
+    RemoteSwitchMixtureOfExperts,
+)
+from hivemind.optim import (
+    CollaborativeAdaptiveOptimizer,
+    DecentralizedOptimizerBase,
+    CollaborativeOptimizer,
+    DecentralizedOptimizer,
+    DecentralizedSGD,
+    DecentralizedAdam,
+)
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 

+ 45 - 21
hivemind/averaging/allreduce.py

@@ -41,10 +41,18 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     """
 
     def __init__(
-            self, *, group_id: GroupID, tensors: Sequence[torch.Tensor], endpoint: Endpoint,
-            ordered_group_endpoints: Sequence[Endpoint], peer_fractions: Tuple[float, ...],
-            weights: Optional[Sequence[float]] = None, modes: Optional[Sequence[AveragingMode]] = None,
-            gathered: Optional[Dict[Endpoint, Any]] = None, **kwargs):
+        self,
+        *,
+        group_id: GroupID,
+        tensors: Sequence[torch.Tensor],
+        endpoint: Endpoint,
+        ordered_group_endpoints: Sequence[Endpoint],
+        peer_fractions: Tuple[float, ...],
+        weights: Optional[Sequence[float]] = None,
+        modes: Optional[Sequence[AveragingMode]] = None,
+        gathered: Optional[Dict[Endpoint, Any]] = None,
+        **kwargs,
+    ):
         assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
@@ -68,8 +76,11 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
         endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
         self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
-        self.tensor_part_reducer = TensorPartReducer(tuple(part.shape for part in self.parts_for_local_averaging),
-                                                     len(self.sender_endpoints), self.sender_weights)
+        self.tensor_part_reducer = TensorPartReducer(
+            tuple(part.shape for part in self.parts_for_local_averaging),
+            len(self.sender_endpoints),
+            self.sender_weights,
+        )
 
     def __repr__(self):
         return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
@@ -88,7 +99,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
         return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
-        """ Run all-reduce, return differences between averaged and original tensors as they are computed """
+        """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
         try:
             if len(self.sender_endpoints) == 0:
@@ -115,7 +126,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             raise
 
     async def _communicate_with_peer(self, peer_endpoint: Endpoint):
-        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
+        """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
         peer_index = self.ordered_group_endpoints.index(peer_endpoint)
         if peer_endpoint == self.endpoint:
             sender_index = self.sender_endpoints.index(peer_endpoint)
@@ -138,9 +149,11 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 await write_task
 
                 if code != averaging_pb2.AVERAGED_PART:
-                    raise AllreduceException(f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
-                                             f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                                             f", allreduce failed")
+                    raise AllreduceException(
+                        f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
+                        f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
+                        f", allreduce failed"
+                    )
             finally:
                 if not write_task.done():
                     write_task.cancel()
@@ -148,17 +161,23 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
         first_part = await anext(parts_aiter)
-        await stream.write(averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
-                                                       group_id=self.group_id, endpoint=self.endpoint,
-                                                       tensor_part=first_part))
+        await stream.write(
+            averaging_pb2.AveragingData(
+                code=averaging_pb2.PART_FOR_AVERAGING,
+                group_id=self.group_id,
+                endpoint=self.endpoint,
+                tensor_part=first_part,
+            )
+        )
         async for part in parts_aiter:
             await stream.write(averaging_pb2.AveragingData(tensor_part=part))
 
         await stream.done_writing()
 
-    async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
-                                 ) -> AsyncIterator[averaging_pb2.AveragingData]:
-        """ a peer sends us a part of his tensor; we should average it with other peers and return the difference """
+    async def rpc_aggregate_part(
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+    ) -> AsyncIterator[averaging_pb2.AveragingData]:
+        """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         request: averaging_pb2.AveragingData = await anext(stream)
         reason_to_reject = self._check_reasons_to_reject(request)
         if reason_to_reject:
@@ -191,12 +210,17 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
         loop = asyncio.get_event_loop()
         async for part_index, (tensor_part, part_compression) in aenumerate(
-                amap_in_executor(lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression), stream,
-                                 max_prefetch=self.tensor_part_container.prefetch)):
+            amap_in_executor(
+                lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.tensor_part.compression),
+                stream,
+                max_prefetch=self.tensor_part_container.prefetch,
+            )
+        ):
             averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
 
             serialized_delta = await loop.run_in_executor(
-                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression))
+                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
+            )
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
     async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
@@ -205,7 +229,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
         await stream.done_writing()
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
-        """ finish or terminate AllReduceRunner, propagate any errors / cancellations to peers. """
+        """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
         pending_tasks = set()
         if cancel or exception:

+ 193 - 95
hivemind/averaging/averager.py

@@ -93,28 +93,47 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     >>> with averager.get_tensors() as tensors_after_averaging:
     >>>     pass # use the averaged tensors
     """
+
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
     _server: grpc.aio.Server
     serializer = MSGPackSerializer
 
-    def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
-                 prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, request_timeout: float = 3, averaging_alpha: float = 1.0,
-                 part_size_bytes: int = DEFAULT_PART_SIZE_BYTES, allreduce_timeout: Optional[float] = None,
-                 compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
-                 throughput: Optional[float] = None, min_vector_size: int = 0,
-                 auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
-                 listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
-                 announced_host: Optional[str] = None,
-                 channel_options: Sequence[Tuple[str, Any]] = (),
-                 shutdown_timeout: float = 5, **kwargs):
-        assert '.' not in prefix, "group prefix must be a string without trailing '.'"
-        assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
-            "throughput must be a non-negative float32"
+    def __init__(
+        self,
+        averaged_tensors: Sequence[torch.Tensor],
+        dht: DHT,
+        *,
+        start: bool,
+        prefix: str,
+        target_group_size: int,
+        min_group_size: int = 2,
+        initial_group_bits: Optional[str] = None,
+        averaging_expiration: float = 15,
+        request_timeout: float = 3,
+        averaging_alpha: float = 1.0,
+        part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
+        allreduce_timeout: Optional[float] = None,
+        compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+        throughput: Optional[float] = None,
+        min_vector_size: int = 0,
+        auxiliary: bool = False,
+        allow_state_sharing: Optional[bool] = None,
+        listen: bool = True,
+        listen_on: Endpoint = "0.0.0.0:*",
+        daemon: bool = True,
+        announced_host: Optional[str] = None,
+        channel_options: Sequence[Tuple[str, Any]] = (),
+        shutdown_timeout: float = 5,
+        **kwargs,
+    ):
+        assert "." not in prefix, "group prefix must be a string without trailing '.'"
+        assert throughput is None or (
+            throughput >= 0 and np.isfinite(np.float32(throughput))
+        ), "throughput must be a non-negative float32"
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
-        assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
+        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
         assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
 
         super().__init__()
@@ -135,7 +154,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         self._averaged_tensors = tuple(averaged_tensors)
         self.lock_averaged_tensors = mp.Lock()
-        self.last_updated: DHTExpiration = -float('inf')
+        self.last_updated: DHTExpiration = -float("inf")
         for tensor in self._averaged_tensors:
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             tensor.share_memory_()
@@ -145,10 +164,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         self.throughput = throughput
 
         self.matchmaking_kwargs = dict(
-            prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
-            min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
-        self.allreduce_kwargs = dict(compression_type=compression_type, part_size_bytes=part_size_bytes,
-                                     min_vector_size=min_vector_size)
+            prefix=prefix,
+            initial_group_bits=initial_group_bits,
+            target_group_size=target_group_size,
+            min_group_size=min_group_size,
+            averaging_expiration=averaging_expiration,
+            request_timeout=request_timeout,
+        )
+        self.allreduce_kwargs = dict(
+            compression_type=compression_type, part_size_bytes=part_size_bytes, min_vector_size=min_vector_size
+        )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
@@ -160,26 +185,29 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         self._averager_endpoint: Optional[Endpoint] = None
         if not self.listen:
-            self._averager_endpoint = f'client::{uuid.uuid4()}'
+            self._averager_endpoint = f"client::{uuid.uuid4()}"
 
         self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
-            daemon=True, target=_background_thread_fetch_current_state,
-            args=[self.serializer, self._outer_pipe, weakref.WeakMethod(self.get_current_state)])
+            daemon=True,
+            target=_background_thread_fetch_current_state,
+            args=[self.serializer, self._outer_pipe, weakref.WeakMethod(self.get_current_state)],
+        )
         background_fetcher.start()
         if start:
             self.run_in_background(await_ready=True)
 
     def _choose_announced_host(self) -> Hostname:
-        announced_host = strip_port(self.listen_on).strip('[]')  # Stripping square brackets for IPv6
-        if ip_address(announced_host) not in [ip_address('0.0.0.0'), ip_address('::')]:
+        announced_host = strip_port(self.listen_on).strip("[]")  # Stripping square brackets for IPv6
+        if ip_address(announced_host) not in [ip_address("0.0.0.0"), ip_address("::")]:
             return announced_host
 
         maddrs = self.dht.get_visible_maddrs()
         announced_host = choose_ip_address(maddrs)
-        logger.info(f'Choosing IP {announced_host} as endpoint for DecentralizedAverager '
-                    f'from visible multiaddrs {maddrs}')
+        logger.info(
+            f"Choosing IP {announced_host} as endpoint for DecentralizedAverager " f"from visible multiaddrs {maddrs}"
+        )
         return announced_host
 
     @property
@@ -188,13 +216,15 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     @property
     def allow_state_sharing(self) -> bool:
-        """ if set to True, other peers can download this peer's state """
+        """if set to True, other peers can download this peer's state"""
         return bool(self._allow_state_sharing.value)
 
     @allow_state_sharing.setter
     def allow_state_sharing(self, value: bool):
         if value is True and not self.listen:
-            logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.")
+            logger.warning(
+                "Cannot allow state sharing: averager in client mode (listen=False) cannot share its state."
+            )
         else:
             self._allow_state_sharing.value = value
 
@@ -220,10 +250,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         thread.join()
 
     def _run_internal(self):
-        """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
+        """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
+
             async def _run():
                 grpc.aio.init_grpc_aio()
 
@@ -237,8 +268,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 else:
                     logger.debug(f"The averager is running in client mode.")
 
-                self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
-                                                client_mode=not self.listen)
+                self._matchmaking = Matchmaking(
+                    self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=not self.listen
+                )
                 if self.listen:
                     asyncio.create_task(self._declare_for_download_periodically())
 
@@ -249,7 +281,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 while True:
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == '_shutdown':
+                    if method == "_shutdown":
                         await task
                         break
 
@@ -265,10 +297,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
 
     def shutdown(self) -> None:
-        """ Shut down the averager process """
+        """Shut down the averager process"""
         if self.is_alive():
-            self._outer_pipe.send(('_shutdown', [None], {}))  # shut down the daemon process
-            self._inner_pipe.send(('_SHUTDOWN', None))  # shut down background thread in master
+            self._outer_pipe.send(("_shutdown", [None], {}))  # shut down the daemon process
+            self._inner_pipe.send(("_SHUTDOWN", None))  # shut down background thread in master
             self.join(self.shutdown_timeout)
             if self.is_alive():
                 logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
@@ -288,9 +320,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         if self._parent_pid == os.getpid() and self.is_alive():
             self.shutdown()
 
-    def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = None,
-             timeout: Optional[float] = None, allow_retries: bool = True, wait: bool = True
-             ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
+    def step(
+        self,
+        gather: Optional[GatheredData] = None,
+        weight: Optional[float] = None,
+        timeout: Optional[float] = None,
+        allow_retries: bool = True,
+        wait: bool = True,
+    ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
@@ -310,13 +347,27 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
 
         future = MPFuture()
-        gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
-        self._outer_pipe.send(('_step', [], dict(future=future, gather_binary=gather_binary, weight=weight,
-                                                 allow_retries=allow_retries, timeout=timeout)))
+        gather_binary = self.serializer.dumps(
+            gather
+        )  # serialize here to avoid loading modules in the averager process
+        self._outer_pipe.send(
+            (
+                "_step",
+                [],
+                dict(
+                    future=future,
+                    gather_binary=gather_binary,
+                    weight=weight,
+                    allow_retries=allow_retries,
+                    timeout=timeout,
+                ),
+            )
+        )
         return future.result() if wait else future
 
-    async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
-                    allow_retries: bool, timeout: Optional[float]):
+    async def _step(
+        self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
+    ):
         start_time = get_dht_time()
 
         try:
@@ -324,17 +375,30 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 try:
                     self._pending_group_assembled.clear()
                     data_for_gather = self.serializer.dumps([weight, self.throughput, self.mode.value, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(timeout=timeout,
-                                                                        data_for_gather=data_for_gather)
+                    group_info = await self._matchmaking.look_for_group(
+                        timeout=timeout, data_for_gather=data_for_gather
+                    )
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
-                    future.set_result(await asyncio.wait_for(
-                        self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout))
+                    future.set_result(
+                        await asyncio.wait_for(
+                            self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout
+                        )
+                    )
                     # averaging is finished, loop will now exit
 
-                except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
-                        asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
+                except (
+                    AllreduceException,
+                    MatchmakingException,
+                    AssertionError,
+                    StopAsyncIteration,
+                    InternalError,
+                    asyncio.CancelledError,
+                    asyncio.InvalidStateError,
+                    grpc.RpcError,
+                    grpc.aio.AioRpcError,
+                ) as e:
                     time_elapsed = get_dht_time() - start_time
                     if not allow_retries or (timeout is not None and timeout < time_elapsed):
                         logger.exception(f"Averager caught {repr(e)}")
@@ -348,27 +412,40 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             raise
         finally:
             if not future.done():
-                future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
-                                                  " Please report this to hivemind issues."))
+                future.set_exception(
+                    RuntimeError(
+                        "Internal sanity check failed: averager.step left future pending."
+                        " Please report this to hivemind issues."
+                    )
+                )
 
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
-        """ Run All-Reduce in a given group and update tensors in place, return gathered metadata """
+        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
             weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
 
             # compute optimal part sizes from peer throughputs; TODO: replace with proper load balancing
-            incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0
-                                    for thr, mode in zip(throughputs, modes)]
+            incoming_throughputs = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)
+            ]
             peer_fractions = await asyncio.get_event_loop().run_in_executor(
-                None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
+                None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size
+            )
 
             async with self.get_tensors_async() as local_tensors:
                 allreduce = AllReduceRunner(
-                    group_id=group_info.group_id, tensors=local_tensors, endpoint=self.endpoint,
-                    ordered_group_endpoints=group_info.endpoints, peer_fractions=peer_fractions, weights=weights,
-                    gathered=user_gathered, modes=modes, **kwargs)
+                    group_id=group_info.group_id,
+                    tensors=local_tensors,
+                    endpoint=self.endpoint,
+                    ordered_group_endpoints=group_info.endpoints,
+                    peer_fractions=peer_fractions,
+                    weights=weights,
+                    gathered=user_gathered,
+                    modes=modes,
+                    **kwargs,
+                )
 
                 with self.register_allreduce_group(group_info.group_id, allreduce):
 
@@ -388,7 +465,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     @contextlib.contextmanager
     def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
-        """ registers a given all-reduce runner to listen for incoming connections """
+        """registers a given all-reduce runner to listen for incoming connections"""
         try:
             self._running_groups[group_id] = allreduce
             self._pending_group_assembled.set()
@@ -410,22 +487,24 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     @contextlib.asynccontextmanager
     async def get_tensors_async(self) -> Sequence[torch.Tensor]:
-        """ Like get_tensors, but uses an asynchronous contextmanager """
+        """Like get_tensors, but uses an asynchronous contextmanager"""
         try:
             await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
             yield self._averaged_tensors
         finally:
             self.lock_averaged_tensors.release()
 
-    async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
-                             ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
-        """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
+    async def rpc_join_group(
+        self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+    ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
+        """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         async for response in self._matchmaking.rpc_join_group(request, context):
             yield response
 
-    async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
-                                 ) -> AsyncIterator[averaging_pb2.AveragingData]:
-        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
+    async def rpc_aggregate_part(
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+    ) -> AsyncIterator[averaging_pb2.AveragingData]:
+        """a groupmate sends us a part of his tensor; we should average it with other peers and return the result"""
         request = await anext(stream)
         if request.group_id not in self._running_groups:
             # this handles a special case when leader accepted us to group AND began allreduce right away,
@@ -441,17 +520,26 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             yield message
 
     async def _declare_for_download_periodically(self):
-        download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
+        download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
         while True:
             if self.allow_state_sharing:
-                asyncio.create_task(asyncio.wait_for(self.dht.store(
-                    download_key, subkey=self.endpoint, value=self.last_updated,
-                    expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
-                    timeout=self._matchmaking.averaging_expiration))
+                asyncio.create_task(
+                    asyncio.wait_for(
+                        self.dht.store(
+                            download_key,
+                            subkey=self.endpoint,
+                            value=self.last_updated,
+                            expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
+                            return_future=True,
+                        ),
+                        timeout=self._matchmaking.averaging_expiration,
+                    )
+                )
             await asyncio.sleep(self._matchmaking.averaging_expiration)
 
-    async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
-                                 ) -> AsyncIterator[averaging_pb2.DownloadData]:
+    async def rpc_download_state(
+        self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
+    ) -> AsyncIterator[averaging_pb2.DownloadData]:
         """
         Get the up-to-date trainer state from a peer.
         The state consists of two parts: (serialized_metadata, tensors)
@@ -481,9 +569,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             return dict(group_key=self.get_group_bits()), tensors
 
     async def _get_current_state_from_host_process(self):
-        """ Executed in the averager process inside rpc_download_state """
+        """Executed in the averager process inside rpc_download_state"""
         future = MPFuture()
-        self._inner_pipe.send(('_TRIGGER_GET_CURRENT_STATE', future))
+        self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
         return await future
 
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
@@ -497,15 +585,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
         future = MPFuture()
-        self._outer_pipe.send(('_load_state_from_peers', [], dict(future=future)))
+        self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
         return future.result() if wait else future
 
     async def _load_state_from_peers(self, future: MPFuture):
         try:
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
-            peer_priority = {peer: float(info.value) for peer, info in peer_priority.items()
-                             if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))}
+            peer_priority = {
+                peer: float(info.value)
+                for peer, info in peer_priority.items()
+                if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
+            }
 
             if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
                 logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
@@ -518,8 +609,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                     logger.info(f"Downloading parameters from peer {peer}")
                     stream = None
                     try:
-                        stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True,
-                                                     options=self.channel_options)
+                        stub = ChannelCache.get_stub(
+                            peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True, options=self.channel_options
+                        )
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         async for message in stream:
@@ -558,7 +650,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :returns: averager's current group key bits (without prefix)
         """
         future = MPFuture()
-        self._outer_pipe.send(('_get_group_bits', [], dict(future=future)))
+        self._outer_pipe.send(("_get_group_bits", [], dict(future=future)))
         return future.result() if wait else future
 
     async def _get_group_bits(self, future: MPFuture):
@@ -570,8 +662,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
         """
         future = MPFuture()
-        assert all(bit in '01' for bit in group_bits)
-        self._outer_pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=future)))
+        assert all(bit in "01" for bit in group_bits)
+        self._outer_pipe.send(("_set_group_bits", [], dict(group_bits=group_bits, future=future)))
         return future.result() if wait else future
 
     async def _set_group_bits(self, group_bits: str, future: MPFuture):
@@ -584,12 +676,13 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
 def is_power_of_two(n):
-    """ Check whether n is a power of 2 """
+    """Check whether n is a power of 2"""
     return (n != 0) and (n & (n - 1) == 0)
 
 
-def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.connection.Connection,
-                                           get_current_state_ref: weakref.WeakMethod):
+def _background_thread_fetch_current_state(
+    serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
+):
     """
     Executed in the host process as a background thread. Fetches the averager state when asked by peers.
     :param serializer: a serializer with which to convert metadata into bytes
@@ -603,10 +696,10 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
             logger.debug(f"Averager background thread finished: {repr(e)}")
             break
 
-        if trigger == '_SHUTDOWN':
+        if trigger == "_SHUTDOWN":
             break
 
-        assert trigger == '_TRIGGER_GET_CURRENT_STATE'
+        assert trigger == "_TRIGGER_GET_CURRENT_STATE"
         try:
             get_current_state = get_current_state_ref()
             if get_current_state is None:
@@ -615,8 +708,9 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
             del get_current_state
 
             state_metadata = serializer.dumps(state_metadata)
-            state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
-                                  for tensor in state_tensors)
+            state_tensors = tuple(
+                tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
+            )
             # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
             future.set_result((state_metadata, state_tensors))
         except BaseException as e:
@@ -626,8 +720,12 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
 
 
 def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
-    """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
-    schema_dicts = [{field_name: str(field_value)
-                     for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
-                    for tensor in tensors]
+    """A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values"""
+    schema_dicts = [
+        {
+            field_name: str(field_value)
+            for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()
+        }
+        for tensor in tensors
+    ]
     return DHTID.generate(source=schema_dicts).to_bytes()

+ 2 - 1
hivemind/averaging/group_info.py

@@ -6,7 +6,8 @@ from hivemind.utils import Endpoint
 
 @dataclass(frozen=True)
 class GroupInfo:
-    """ A group of peers assembled through decentralized matchmaking """
+    """A group of peers assembled through decentralized matchmaking"""
+
     group_id: bytes  # random unique bytestring that describes the current group, generated by group leader
     endpoints: Tuple[Endpoint, ...]  # an ordered sequence of endpoints of each groupmate
     gathered: Tuple[bytes, ...]  # binary metadata gathered from all peers by leader, same order as endpoints

+ 66 - 33
hivemind/averaging/key_manager.py

@@ -10,12 +10,12 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
 
 GroupKey = str
-GROUP_PATTERN = re.compile('^(([^.])+)[.]0b[01]*$')  # e.g. bert_exp4_averaging.0b01001101
+GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
 logger = get_logger(__name__)
 
 
 def is_valid_group(maybe_group: str) -> bool:
-    """ A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
+    """A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
     return bool(GROUP_PATTERN.fullmatch(maybe_group))
 
 
@@ -23,16 +23,26 @@ class GroupKeyManager:
     """
     Utility class that declares and fetches averaging-related keys using a DHT
     """
-    RESERVED_KEY_FOR_NBITS = '::NBITS'
 
-    def __init__(self, dht: DHT, endpoint: Endpoint, prefix: str, initial_group_bits: Optional[str],
-                 target_group_size: int, insufficient_size: Optional[int] = None, excessive_size: Optional[int] = None,
-                 nbits_expiration: float = 60, nbits_rewrite_grace_period: float = 15):
-        assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
+    RESERVED_KEY_FOR_NBITS = "::NBITS"
+
+    def __init__(
+        self,
+        dht: DHT,
+        endpoint: Endpoint,
+        prefix: str,
+        initial_group_bits: Optional[str],
+        target_group_size: int,
+        insufficient_size: Optional[int] = None,
+        excessive_size: Optional[int] = None,
+        nbits_expiration: float = 60,
+        nbits_rewrite_grace_period: float = 15,
+    ):
+        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
         if initial_group_bits is None:
             search_result = dht.get(f"{prefix}.0b", latest=True)
             initial_group_nbits = self.get_suggested_nbits(search_result) or 0
-            initial_group_bits = ''.join(random.choice('01') for _ in range(initial_group_nbits))
+            initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
         self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
         self.target_group_size = target_group_size
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
@@ -44,8 +54,9 @@ class GroupKeyManager:
     def current_key(self) -> GroupKey:
         return f"{self.prefix}.0b{self.group_bits}"
 
-    async def declare_averager(self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float,
-                               looking_for_group: bool = True) -> bool:
+    async def declare_averager(
+        self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float, looking_for_group: bool = True
+    ) -> bool:
         """
         Add (or remove) the averager to a given allreduce bucket
 
@@ -58,9 +69,14 @@ class GroupKeyManager:
         :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
         :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
         """
-        expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float('inf')))
-        return await self.dht.store(key=group_key, subkey=endpoint, value=looking_for_group,
-                                    expiration_time=expiration_time, return_future=True)
+        expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
+        return await self.dht.store(
+            key=group_key,
+            subkey=endpoint,
+            value=looking_for_group,
+            expiration_time=expiration_time,
+            return_future=True,
+        )
 
     async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[Endpoint, DHTExpiration]]:
         """
@@ -76,13 +92,19 @@ class GroupKeyManager:
         if result is None or not isinstance(result.value, dict):
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
-        averagers = [(key, entry.expiration_time) for key, entry in result.value.items()
-                     if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)]
+        averagers = [
+            (key, entry.expiration_time)
+            for key, entry in result.value.items()
+            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)
+        ]
         num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
 
         suggested_nbits = self.get_suggested_nbits(result)
-        if suggested_nbits is not None and suggested_nbits != len(self.group_bits) and \
-                suggested_nbits != self.suggested_nbits:
+        if (
+            suggested_nbits is not None
+            and suggested_nbits != len(self.group_bits)
+            and suggested_nbits != self.suggested_nbits
+        ):
             self.suggested_nbits = suggested_nbits
             logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
         elif num_active_averagers >= self.excessive_size:
@@ -91,46 +113,54 @@ class GroupKeyManager:
         return averagers
 
     async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
-        """ notify other peers that they can run averaging at this depth """
-        return await self.dht.store(key=group_key, subkey=self.RESERVED_KEY_FOR_NBITS, value=nbits,
-                                    expiration_time=expiration_time, return_future=True)
+        """notify other peers that they can run averaging at this depth"""
+        return await self.dht.store(
+            key=group_key,
+            subkey=self.RESERVED_KEY_FOR_NBITS,
+            value=nbits,
+            expiration_time=expiration_time,
+            return_future=True,
+        )
 
     @classmethod
     def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
-        if isinstance(search_result, ValueWithExpiration) and cls.RESERVED_KEY_FOR_NBITS in search_result.value \
-                and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int):
+        if (
+            isinstance(search_result, ValueWithExpiration)
+            and cls.RESERVED_KEY_FOR_NBITS in search_result.value
+            and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int)
+        ):
             return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
         else:
             return None
 
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
-        """ this function is triggered every time an averager finds an allreduce group """
+        """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
         index = group_info.endpoints.index(self.endpoint)
         generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
         nbits = int(np.ceil(np.log2(self.target_group_size)))
-        new_bits = bin(generalized_index)[2:].rjust(nbits, '0')
-        self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits):] if self.group_bits else ''
+        new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
+        self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
 
         if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
             asyncio.create_task(self.notify_stragglers())
         if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
             num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
-            self.group_bits = ''.join((random.choice('01') for _ in range(num_extra_bits))) + self.group_bits
-            self.group_bits = self.group_bits[-self.suggested_nbits:]
+            self.group_bits = "".join((random.choice("01") for _ in range(num_extra_bits))) + self.group_bits
+            self.group_bits = self.group_bits[-self.suggested_nbits :]
         self.suggested_nbits = None
 
     async def update_key_on_not_enough_peers(self):
-        """ this function is triggered whenever averager fails to assemble group within timeout """
+        """this function is triggered whenever averager fails to assemble group within timeout"""
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ''
+        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
         if self.group_bits != prev_nbits:
-            logger.warning(f'{self.endpoint} - switching to {len(self.group_bits)}-bit keys')
+            logger.warning(f"{self.endpoint} - switching to {len(self.group_bits)}-bit keys")
         self.suggested_nbits = None
 
     async def notify_stragglers(self):
-        """ Find averagers that have fewer nbits and redirect them to your current nbits """
+        """Find averagers that have fewer nbits and redirect them to your current nbits"""
         for nbits in reversed(range(1, len(self.group_bits) - 1)):
             preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
             preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
@@ -140,7 +170,10 @@ class GroupKeyManager:
                 break
 
         root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
-        if isinstance(root_data, dict) and root_data.get(
-                self.RESERVED_KEY_FOR_NBITS, (None, -float('inf')))[1] > get_dht_time() + self.nbits_grace_period:
+        if (
+            isinstance(root_data, dict)
+            and root_data.get(self.RESERVED_KEY_FOR_NBITS, (None, -float("inf")))[1]
+            > get_dht_time() + self.nbits_grace_period
+        ):
             return
         await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)

+ 2 - 2
hivemind/averaging/load_balancing.py

@@ -28,7 +28,7 @@ def load_balance_peers(vector_size, throughputs: Sequence[Optional[float]], min_
         assert not all(throughput == 0 for throughput in throughputs), "Must have at least one nonzero throughput"
         scores = np.asarray([1.0 if throughput is None else 0.0 for throughput in throughputs])
 
-    #TODO(jheuristic) we no longer need hagenbach-bishoff with new AllReduceRunner
+    # TODO(jheuristic) we no longer need hagenbach-bishoff with new AllReduceRunner
     return tuple(hagenbach_bishoff(vector_size, scores))
 
 
@@ -71,7 +71,7 @@ def optimize_parts_lp(vector_size: int, throughputs: np.ndarray, min_size: int =
 
     A, b = list(map(np.concatenate, zip(nonnegative_weights, weights_sum_to_one, xi_is_maximum, force_max_weights)))
 
-    solution = scipy.optimize.linprog(c, A_ub=A, b_ub=b, method='interior-point')
+    solution = scipy.optimize.linprog(c, A_ub=A, b_ub=b, method="interior-point")
     if solution.success:
         peer_scores = solution.x[:group_size]
         # if some peers have less than min_size elements, transfer their share to other peers (if any)

+ 122 - 63
hivemind/averaging/matchmaking.py

@@ -35,14 +35,26 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
       Hence, instead of accounting for such deadlocks, we simply break them with request_timeout.
     """
 
-    def __init__(self, endpoint: Endpoint, schema_hash: bytes, dht: DHT, *,
-                 prefix: str, target_group_size: int, min_group_size: int,
-                 request_timeout: float, client_mode: bool, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15):
-        assert '.' not in prefix, "group prefix must be a string without ."
+    def __init__(
+        self,
+        endpoint: Endpoint,
+        schema_hash: bytes,
+        dht: DHT,
+        *,
+        prefix: str,
+        target_group_size: int,
+        min_group_size: int,
+        request_timeout: float,
+        client_mode: bool,
+        initial_group_bits: Optional[str] = None,
+        averaging_expiration: float = 15,
+    ):
+        assert "." not in prefix, "group prefix must be a string without ."
         if request_timeout is None or request_timeout >= averaging_expiration:
-            logger.warning("It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
-                           "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring.")
+            logger.warning(
+                "It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
+                "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
+            )
 
         super().__init__()
         self.endpoint, self.schema_hash = endpoint, schema_hash
@@ -74,8 +86,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             if len(self.current_followers):
                 lfg_status += f" leading {len(self.current_followers)} followers,"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
-        return f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}" \
-               f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
+        return (
+            f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}"
+            f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
+        )
 
     async def look_for_group(self, *, data_for_gather: bytes, timeout: Optional[float] = None) -> Optional[GroupInfo]:
         """
@@ -85,8 +99,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         """
         if self.is_looking_for_group:
-            logger.info("Another look_for_group is already in progress. The current run will be scheduled after"
-                        " the existing group is either assembled or disbanded.")
+            logger.info(
+                "Another look_for_group is already in progress. The current run will be scheduled after"
+                " the existing group is either assembled or disbanded."
+            )
         async with self.lock_looking_for_group:
             self.data_for_gather = data_for_gather
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
@@ -117,7 +133,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 self.data_for_gather = None
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
-        """ Request leaders from queue until we find the first runner. This coroutine is meant to run in background. """
+        """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
         async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
             while True:
                 try:
@@ -157,9 +173,15 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         try:
             async with self.lock_request_join_group:
                 leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
-                call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
-                    endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
-                    client_mode=self.client_mode, gather=self.data_for_gather))
+                call = leader_stub.rpc_join_group(
+                    averaging_pb2.JoinRequest(
+                        endpoint=self.endpoint,
+                        schema_hash=self.schema_hash,
+                        expiration=expiration_time,
+                        client_mode=self.client_mode,
+                        gather=self.data_for_gather,
+                    )
+                )
                 message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
@@ -209,9 +231,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             if call is not None:
                 await call.code()
 
-    async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
-                             ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
-        """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
+    async def rpc_join_group(
+        self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+    ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
+        """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         try:
             async with self.lock_request_join_group:
                 reason_to_reject = self._check_reasons_to_reject(request)
@@ -228,8 +251,11 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
             # wait for the group to be assembled or disbanded
             timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
-            await asyncio.wait({self.assembled_group, self.was_accepted_to_group.wait()},
-                               return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
+            await asyncio.wait(
+                {self.assembled_group, self.was_accepted_to_group.wait()},
+                return_when=asyncio.FIRST_COMPLETED,
+                timeout=timeout,
+            )
             if not self.assembled_group.done() and not self.was_accepted_to_group.is_set():
                 async with self.lock_request_join_group:
                     if self.assembled_group.done():
@@ -240,21 +266,29 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                     else:
                         await self.leader_disband_group()
 
-            if self.was_accepted_to_group.is_set() or not self.assembled_group.done() \
-                    or self.assembled_group.cancelled() or request.endpoint not in self.assembled_group.result():
+            if (
+                self.was_accepted_to_group.is_set()
+                or not self.assembled_group.done()
+                or self.assembled_group.cancelled()
+                or request.endpoint not in self.assembled_group.result()
+            ):
                 if self.current_leader is not None:
                     # outcome 3: found by a leader with higher priority, send our followers to him
-                    yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
-                                                          suggested_leader=self.current_leader)
+                    yield averaging_pb2.MessageFromLeader(
+                        code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader
+                    )
                     return
                 else:
                     yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
                     return
 
             group_info = self.assembled_group.result()
-            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.BEGIN_ALLREDUCE, group_id=group_info.group_id,
-                                                  ordered_group_endpoints=group_info.endpoints,
-                                                  gathered=group_info.gathered)
+            yield averaging_pb2.MessageFromLeader(
+                code=averaging_pb2.BEGIN_ALLREDUCE,
+                group_id=group_info.group_id,
+                ordered_group_endpoints=group_info.endpoints,
+                gathered=group_info.gathered,
+            )
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
             return  # note: this is a compatibility layer for python3.7
         except Exception as e:
@@ -265,25 +299,35 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             self.current_followers.pop(request.endpoint, None)
             self.follower_was_discarded.set()
 
-    def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]:
-        """ :returns: if accepted, return None, otherwise return a reason for rejection """
+    def _check_reasons_to_reject(
+        self, request: averaging_pb2.JoinRequest
+    ) -> Optional[averaging_pb2.MessageFromLeader]:
+        """:returns: if accepted, return None, otherwise return a reason for rejection"""
         if not self.is_looking_for_group or self.assembled_group.done():
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
 
-        if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
-                or not isinstance(request.expiration, DHTExpiration) or not isfinite(request.expiration) \
-                or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0 or self.client_mode:
+        if (
+            request.ListFields() == 3
+            and not isinstance(request.schema_hash, bytes)
+            or len(request.schema_hash) == 0
+            or not isinstance(request.expiration, DHTExpiration)
+            or not isfinite(request.expiration)
+            or not isinstance(request.endpoint, Endpoint)
+            or len(request.endpoint) == 0
+            or self.client_mode
+        ):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
 
         elif request.schema_hash != self.schema_hash:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
         elif self.potential_leaders.declared_group_key is None:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
-        elif self.potential_leaders.declared_expiration_time > (request.expiration or float('inf')):
+        elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
         elif self.current_leader is not None:
-            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
-                                                   )  # note: this suggested leader is currently ignored
+            return averaging_pb2.MessageFromLeader(
+                code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
+            )  # note: this suggested leader is currently ignored
         elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
         elif len(self.current_followers) + 1 >= self.target_group_size:
@@ -292,7 +336,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             return None
 
     async def leader_assemble_group(self) -> GroupInfo:
-        """ Form up all current followers into a group and gather metadata """
+        """Form up all current followers into a group and gather metadata"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
         assert not self.assembled_group.done()
         group_id = DHTID.generate().to_bytes()  # note: both groupd_id and the order of endpoints must be random
@@ -300,8 +344,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         ordered_group_endpoints.append(self.endpoint)
         random.shuffle(ordered_group_endpoints)
 
-        gathered = tuple(self.data_for_gather if endpoint == self.endpoint else self.current_followers[endpoint].gather
-                         for endpoint in ordered_group_endpoints)
+        gathered = tuple(
+            self.data_for_gather if endpoint == self.endpoint else self.current_followers[endpoint].gather
+            for endpoint in ordered_group_endpoints
+        )
 
         logger.debug(f"{self.endpoint} - assembled group of {len(ordered_group_endpoints)} peers.")
         group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), gathered)
@@ -310,7 +356,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         return group_info
 
     async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
-        """ Form a group from using peers and metadata provided by our leader """
+        """Form a group from using peers and metadata provided by our leader"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert not self.assembled_group.done()
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
@@ -326,13 +372,13 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
         return group_info
 
     async def leader_disband_group(self):
-        """ Kick out all followers immediately, optionally direct them to our new leader (if we found one) """
+        """Kick out all followers immediately, optionally direct them to our new leader (if we found one)"""
         assert self.lock_request_join_group.locked() and not self.client_mode
         self.current_followers.clear()  # this will cause rpc_join_group to kick all followers out
 
 
 class PotentialLeaders:
-    """ An utility class that searches for averagers that could become our leaders """
+    """An utility class that searches for averagers that could become our leaders"""
 
     def __init__(self, endpoint: Endpoint, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
         self.endpoint, self.averaging_expiration = endpoint, averaging_expiration
@@ -341,16 +387,16 @@ class PotentialLeaders:
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
         self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
         self.past_attempts: Set[Tuple[Endpoint, DHTExpiration]] = set()
-        self.declared_expiration_time = float('inf')
+        self.declared_expiration_time = float("inf")
         self.declared_group_key: Optional[GroupKey] = None
-        self.max_assured_time = float('-inf')
-        self.search_end_time = float('inf')
+        self.max_assured_time = float("-inf")
+        self.search_end_time = float("inf")
 
     @contextlib.asynccontextmanager
     async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
         async with self.lock_search:
             self.running.set()
-            self.search_end_time = get_dht_time() + timeout if timeout is not None else float('inf')
+            self.search_end_time = get_dht_time() + timeout if timeout is not None else float("inf")
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             if declare:
                 declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
@@ -363,11 +409,17 @@ class PotentialLeaders:
                 if declare and not declare_averager_task.done():
                     declare_averager_task.cancel()
 
-                for field in (self.past_attempts, self.leader_queue, self.running,
-                              self.update_finished, self.update_triggered, self.declared_expiration):
+                for field in (
+                    self.past_attempts,
+                    self.leader_queue,
+                    self.running,
+                    self.update_finished,
+                    self.update_triggered,
+                    self.declared_expiration,
+                ):
                     field.clear()
-                self.max_assured_time = float('-inf')
-                self.search_end_time = float('inf')
+                self.max_assured_time = float("-inf")
+                self.search_end_time = float("inf")
 
     @contextlib.asynccontextmanager
     async def pause_search(self):
@@ -382,7 +434,7 @@ class PotentialLeaders:
                 self.running.clear()
 
     async def pop_next_leader(self) -> Endpoint:
-        """ Remove and return the next most suitable leader or throw an exception if reached timeout """
+        """Remove and return the next most suitable leader or throw an exception if reached timeout"""
         assert self.running.is_set(), "Not running search at the moment"
         while True:
             maybe_next_leader, entry = self.leader_queue.top()
@@ -391,9 +443,12 @@ class PotentialLeaders:
                 self.update_triggered.set()
 
             if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (
-                    self.declared_expiration_time, self.endpoint):
-                await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
-                                   return_when=asyncio.FIRST_COMPLETED)
+                self.declared_expiration_time,
+                self.endpoint,
+            ):
+                await asyncio.wait(
+                    {self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
+                )
                 self.declared_expiration.clear()
                 if self.update_finished.is_set():
                     self.update_finished.clear()
@@ -407,7 +462,7 @@ class PotentialLeaders:
 
     @property
     def request_expiration_time(self) -> float:
-        """ this averager's current expiration time - used to send join requests to leaders """
+        """this averager's current expiration time - used to send join requests to leaders"""
         if isfinite(self.declared_expiration_time):
             return self.declared_expiration_time
         else:
@@ -418,8 +473,9 @@ class PotentialLeaders:
             DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
             while get_dht_time() < self.search_end_time:
                 new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
-                self.max_assured_time = max(self.max_assured_time,
-                                            get_dht_time() + self.averaging_expiration - DISCREPANCY)
+                self.max_assured_time = max(
+                    self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+                )
 
                 self.leader_queue.clear()
                 for peer, peer_expiration_time in new_peers:
@@ -431,8 +487,10 @@ class PotentialLeaders:
                 self.update_finished.set()
 
                 await asyncio.wait(
-                    {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED,
-                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None)
+                    {self.running.wait(), self.update_triggered.wait()},
+                    return_when=asyncio.ALL_COMPLETED,
+                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
+                )
                 self.update_triggered.clear()
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
             return  # note: this is a compatibility layer for python3.7
@@ -461,11 +519,12 @@ class PotentialLeaders:
             finally:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
-                    self.declared_group_key, self.declared_expiration_time = None, float('inf')
-                    self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float('-inf')
-                    await key_manager.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time,
-                                                       looking_for_group=False)
+                    self.declared_group_key, self.declared_expiration_time = None, float("inf")
+                    self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float("-inf")
+                    await key_manager.declare_averager(
+                        prev_declared_key, self.endpoint, prev_expiration_time, looking_for_group=False
+                    )
 
 
 class MatchmakingException(Exception):
-    """ An internal exception that marks undesired edge cases during averaging """
+    """An internal exception that marks undesired edge cases during averaging"""

+ 25 - 18
hivemind/averaging/partition.py

@@ -13,7 +13,7 @@ from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_va
 from hivemind.utils.asyncio import amap_in_executor
 
 
-T = TypeVar('T')
+T = TypeVar("T")
 DEFAULT_PART_SIZE_BYTES = 2 ** 20
 
 
@@ -28,9 +28,14 @@ class TensorPartContainer:
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     """
 
-    def __init__(self, tensors: Sequence[torch.Tensor], peer_fractions: Sequence[float],
-                 compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
-                 part_size_bytes: int = 2 ** 20, prefetch: int = 1):
+    def __init__(
+        self,
+        tensors: Sequence[torch.Tensor],
+        peer_fractions: Sequence[float],
+        compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
+        part_size_bytes: int = 2 ** 20,
+        prefetch: int = 1,
+    ):
         if not isinstance(compression_type, Sequence):
             compression_type = [compression_type] * len(tensors)
         assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
@@ -77,7 +82,7 @@ class TensorPartContainer:
 
     @torch.no_grad()
     def get_raw_input_parts(self, peer_index: int) -> Tuple[torch.Tensor, ...]:
-        """ get non-serialized tensor parts for a peer at a given index """
+        """get non-serialized tensor parts for a peer at a given index"""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
         input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
@@ -86,7 +91,7 @@ class TensorPartContainer:
 
     @torch.no_grad()
     async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
-        """ iterate serialized tensor parts for a peer at a given index. Run serialization in background. """
+        """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
 
@@ -94,8 +99,9 @@ class TensorPartContainer:
             for _ in range(self.num_parts_by_peer[peer_index]):
                 yield self._input_parts_by_peer[peer_index].popleft()
 
-        async for serialized_part in amap_in_executor(lambda x_and_compr: serialize_torch_tensor(*x_and_compr),
-                                                      _aiterate_parts(), max_prefetch=self.prefetch):
+        async for serialized_part in amap_in_executor(
+            lambda x_and_compr: serialize_torch_tensor(*x_and_compr), _aiterate_parts(), max_prefetch=self.prefetch
+        ):
             yield serialized_part
 
     def register_processed_part(self, peer_index: int, part_index: int, part: torch.Tensor):
@@ -104,14 +110,16 @@ class TensorPartContainer:
         depending on the algorithm, processed part is an average, difference from average or another aggregation
         """
         if part_index != self._outputs_registered_by_peer[peer_index]:
-            raise ValueError(f"Could not register part #{part_index} from peer #{peer_index}, "
-                             f" expected part index: {self._outputs_registered_by_peer[peer_index]}")
+            raise ValueError(
+                f"Could not register part #{part_index} from peer #{peer_index}, "
+                f" expected part index: {self._outputs_registered_by_peer[peer_index]}"
+            )
         self._output_parts_by_peer[peer_index].append(part)
         self._outputs_registered_by_peer[peer_index] += 1
         self._output_part_available[peer_index].set()
 
     async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
-        """ iterate over the outputs of averaging (whether they are average, delta or other aggregation result) """
+        """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
         assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
         self._outputs_consumed = True
         peer_index = num_parts_processed = 0
@@ -138,7 +146,7 @@ class TensorPartContainer:
         self.finalize()
 
     def finalize(self):
-        """ terminate all iterators, delete intermediate data """
+        """terminate all iterators, delete intermediate data"""
         if not self.finished.is_set():
             for peer_index in range(self.group_size):
                 self._inputs_consumed_by_peer[peer_index] = True
@@ -158,8 +166,7 @@ class TensorPartReducer:
     :note: even if local peer is not sending data, local parts will be used for shape information
     """
 
-    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int,
-                 weights: Optional[Sequence[float]] = None):
+    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
         self.weights = tuple(weights or (1 for _ in range(num_senders)))
         assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
@@ -173,7 +180,7 @@ class TensorPartReducer:
         self.reset_accumulators()
 
     def reset_accumulators(self):
-        """ (re)create averaging buffers for the next part in line, prepopulate with local tensor part """
+        """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
         assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
         if self.current_part_index >= self.num_parts - 1:
             self.finalize()
@@ -186,7 +193,7 @@ class TensorPartReducer:
         self.denominator = 0.0
 
     async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
-        """ Add vector part to accumulator, wait for all other vectors to be added, then return the average part """
+        """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
 
@@ -211,7 +218,7 @@ class TensorPartReducer:
 
     def finalize(self):
         if not self.finished.is_set():
-            if hasattr(self, 'current_part_future'):
+            if hasattr(self, "current_part_future"):
                 self.current_part_future.cancel()
                 del self.accumulator
             self.finished.set()
@@ -221,4 +228,4 @@ class TensorPartReducer:
 
 
 class AllreduceException(Exception):
-    """ A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error) """
+    """A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error)"""

+ 42 - 27
hivemind/averaging/training.py

@@ -33,9 +33,17 @@ class TrainingAverager(DecentralizedAverager):
     :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
     """
 
-    def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
-                 average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
-                 initialize_optimizer: bool = True, **kwargs):
+    def __init__(
+        self,
+        opt: torch.optim.Optimizer,
+        *,
+        average_parameters: bool,
+        average_gradients: bool,
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        initialize_optimizer: bool = True,
+        **kwargs
+    ):
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
@@ -71,7 +79,8 @@ class TrainingAverager(DecentralizedAverager):
                 if use_old_local_tensors:
                     old_local_tensors = tuple(x.cpu().float().clone() for x in local_tensors)
                 assert len(local_tensors) == len(
-                    averaged_tensors), "The number of optimized parameters should not change."
+                    averaged_tensors
+                ), "The number of optimized parameters should not change."
                 for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
                     averaged_tensor[...] = local_tensor.cpu().float()
 
@@ -86,15 +95,20 @@ class TrainingAverager(DecentralizedAverager):
                     if use_old_local_tensors:
                         # since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents
                         # losing local updates that might have occurred during averaging
-                        for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
-                                                                                   old_local_tensors):
-                            local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype,
-                                                                    device=local_tensor.device) - \
-                                                 old_local_tensor.to(dtype=local_tensor.dtype,
-                                                                     device=local_tensor.device)
+                        for averaged_tensor, local_tensor, old_local_tensor in zip(
+                            averaged_tensors, local_tensors, old_local_tensors
+                        ):
+                            averaged_tensor = averaged_tensor.to(
+                                dtype=local_tensor.dtype, device=local_tensor.device, non_blocking=True
+                            )
+                            old_local_tensor = old_local_tensor.to(
+                                dtype=local_tensor.dtype, device=local_tensor.device, non_blocking=True
+                            )
+
+                            local_tensor.add_(averaged_tensor - old_local_tensor)
                     else:
                         for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
-                            local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
+                            local_tensor.copy_(averaged_tensor, non_blocking=True)
 
             self.local_step += 1
             return gathered
@@ -108,17 +122,17 @@ class TrainingAverager(DecentralizedAverager):
         """
         if self.average_parameters:
             for param_group in self.opt.param_groups:
-                yield from param_group['params']
+                yield from param_group["params"]
         if self.average_gradients:
             for param_group in self.opt.param_groups:
-                for param in param_group['params']:
+                for param in param_group["params"]:
                     if param.grad is not None:
                         yield param.grad
                     elif replace_none:
                         yield torch.zeros_like(param)
         for stats in self.opt_statistics:
             for param_group in self.opt.param_groups:
-                for param in param_group['params']:
+                for param in param_group["params"]:
                     yield self.opt.state[param][stats]
         yield from iter(self.extra_tensors)
 
@@ -128,8 +142,9 @@ class TrainingAverager(DecentralizedAverager):
         :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
         """
         with torch.no_grad():
-            optimized_parameters = tuple(param.detach().cpu() for param_group in self.opt.param_groups
-                                         for param in param_group['params'])
+            optimized_parameters = tuple(
+                param.detach().cpu() for param_group in self.opt.param_groups for param in param_group["params"]
+            )
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
 
@@ -141,7 +156,7 @@ class TrainingAverager(DecentralizedAverager):
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         :returns: whether or the averager succeeded in loading parameters
         """
-        parameters_and_extras = [param for param_group in self.opt.param_groups for param in param_group['params']]
+        parameters_and_extras = [param for param_group in self.opt.param_groups for param in param_group["params"]]
         parameters_and_extras.extend(self.extra_tensors)
         num_local_tensors = len(parameters_and_extras)
 
@@ -155,39 +170,39 @@ class TrainingAverager(DecentralizedAverager):
         with torch.no_grad():
             for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
                 local_param[...] = loaded_param
-            load_optimizer_state(self.opt, metadata['optimizer_metadata'], loaded_opt_tensors)
+            load_optimizer_state(self.opt, metadata["optimizer_metadata"], loaded_opt_tensors)
 
-        self.local_step = max(self.local_step, metadata['step'])
+        self.local_step = max(self.local_step, metadata["step"])
 
 
 def initialize_optimizer_state(opt: torch.optim.Optimizer):
     for param_group in opt.param_groups:
-        for param in param_group['params']:
+        for param in param_group["params"]:
             if param.grad is None:
                 (0 * param.sum()).backward()
     opt.step()
 
 
 def dump_optimizer_state(opt: torch.optim.Optimizer):
-    """ Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers """
+    """Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers"""
     with torch.no_grad():
         flat_metadata, flat_tensors = [], []
         for elem in nested_flatten(opt.state_dict()):
             if isinstance(elem, torch.Tensor):
-                flat_metadata.append(dict(type='tensor', index=len(flat_tensors)))
+                flat_metadata.append(dict(type="tensor", index=len(flat_tensors)))
                 flat_tensors.append(elem.cpu())
             else:
-                flat_metadata.append(dict(type='value', value=elem))
+                flat_metadata.append(dict(type="value", value=elem))
         return flat_metadata, flat_tensors
 
 
 def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
     flat_optimizer_state = []
     for elem in flat_metadata:
-        if elem.get('type') == 'tensor' and isinstance(elem.get('index'), int):
-            flat_optimizer_state.append(flat_tensors[elem['index']])
-        elif elem.get('type') == 'value' and 'value' in elem:
-            flat_optimizer_state.append(elem['value'])
+        if elem.get("type") == "tensor" and isinstance(elem.get("index"), int):
+            flat_optimizer_state.append(flat_tensors[elem["index"]])
+        elif elem.get("type") == "value" and "value" in elem:
+            flat_optimizer_state.append(elem["value"])
     with torch.no_grad():
         try:
             return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))

+ 73 - 34
hivemind/dht/__init__.py

@@ -31,7 +31,7 @@ from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_c
 
 logger = get_logger(__name__)
 
-ReturnType = TypeVar('ReturnType')
+ReturnType = TypeVar("ReturnType")
 
 
 class DHT(mp.Process):
@@ -55,19 +55,32 @@ class DHT(mp.Process):
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
     :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
     """
+
     _node: DHTNode
 
-    def __init__(self, p2p: Optional[P2P] = None,
-                 initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-                 *, start: bool, daemon: bool = True, max_workers: Optional[int] = None,
-                 record_validators: Iterable[RecordValidatorBase] = (),
-                 shutdown_timeout: float = 3, **kwargs):
+    def __init__(
+        self,
+        p2p: Optional[P2P] = None,
+        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
+        *,
+        start: bool,
+        daemon: bool = True,
+        max_workers: Optional[int] = None,
+        record_validators: Iterable[RecordValidatorBase] = (),
+        shutdown_timeout: float = 3,
+        **kwargs,
+    ):
         super().__init__()
 
         self.p2p = p2p
-        if not (initial_peers is None or (isinstance(initial_peers, Sequence) and
-                                          all(isinstance(item, (Multiaddr, str)) for item in initial_peers))):
-            raise TypeError('initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]')
+        if not (
+            initial_peers is None
+            or (
+                isinstance(initial_peers, Sequence)
+                and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
+            )
+        ):
+            raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
         self.initial_peers = initial_peers
         self.kwargs = kwargs
         self.max_workers = max_workers
@@ -81,21 +94,25 @@ class DHT(mp.Process):
             self.run_in_background(await_ready=True)
 
     def run(self) -> None:
-        """ Serve DHT forever. This function will not return until DHT node is shut down """
+        """Serve DHT forever. This function will not return until DHT node is shut down"""
         loop = switch_to_uvloop()
 
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
+
             async def _run():
                 self._node = await DHTNode.create(
-                    p2p=self.p2p, initial_peers=self.initial_peers,
-                    num_workers=self.max_workers or 1, record_validator=self._record_validator,
-                    **self.kwargs)
+                    p2p=self.p2p,
+                    initial_peers=self.initial_peers,
+                    num_workers=self.max_workers or 1,
+                    record_validator=self._record_validator,
+                    **self.kwargs,
+                )
                 self.ready.set()
 
                 while True:
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == '_shutdown':
+                    if method == "_shutdown":
                         await task
                         break
 
@@ -112,9 +129,9 @@ class DHT(mp.Process):
             raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")
 
     def shutdown(self) -> None:
-        """ Shut down a running dht process """
+        """Shut down a running dht process"""
         if self.is_alive():
-            self._outer_pipe.send(('_shutdown', [], {}))
+            self._outer_pipe.send(("_shutdown", [], {}))
             self.join(self.shutdown_timeout)
             if self.is_alive():
                 logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
@@ -123,8 +140,9 @@ class DHT(mp.Process):
     async def _shutdown(self):
         await self._node.shutdown()
 
-    def get(self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
-            ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
+    def get(
+        self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
+    ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
         """
         Search for a key across DHT and return either first or latest entry (if found).
         :param key: same key as in node.store(...)
@@ -134,7 +152,7 @@ class DHT(mp.Process):
         :returns: (value, expiration time); if value was not found, returns None
         """
         future = MPFuture()
-        self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=future, **kwargs)))
+        self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
         return future if return_future else future.result()
 
     async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
@@ -147,8 +165,15 @@ class DHT(mp.Process):
                 future.set_exception(e)
             raise
 
-    def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
-              subkey: Optional[Subkey] = None, return_future: bool = False, **kwargs) -> Union[bool, MPFuture]:
+    def store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey] = None,
+        return_future: bool = False,
+        **kwargs,
+    ) -> Union[bool, MPFuture]:
         """
         Find num_replicas best nodes to store (key, value) and store it there until expiration time.
 
@@ -160,12 +185,24 @@ class DHT(mp.Process):
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         future = MPFuture()
-        self._outer_pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
-                                                  future=future, **kwargs)))
+        self._outer_pipe.send(
+            (
+                "_store",
+                [],
+                dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
+            )
+        )
         return future if return_future else future.result()
 
-    async def _store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
-                     subkey: Optional[Subkey], future: MPFuture, **kwargs):
+    async def _store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey],
+        future: MPFuture,
+        **kwargs,
+    ):
         try:
             result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
             if not future.done():
@@ -175,8 +212,9 @@ class DHT(mp.Process):
                 future.set_exception(e)
             raise
 
-    def run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
-                      return_future: bool = False) -> Union[ReturnType, MPFuture[ReturnType]]:
+    def run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
+    ) -> Union[ReturnType, MPFuture[ReturnType]]:
         """
         Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
          for running custom functions DHT for special cases (e.g. declare experts, beam search)
@@ -191,11 +229,12 @@ class DHT(mp.Process):
         :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
         """
         future = MPFuture()
-        self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=future)))
+        self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
         return future if return_future else future.result()
 
-    async def _run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
-                             future: MPFuture[ReturnType]):
+    async def _run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
+    ):
         main_task = asyncio.create_task(coro(self, self._node))
         cancel_task = asyncio.create_task(await_cancelled(future))
         try:
@@ -205,7 +244,7 @@ class DHT(mp.Process):
             else:
                 future.set_result(await main_task)
         except BaseException as e:
-            logger.exception(f'Caught an exception when running a coroutine: {e}')
+            logger.exception(f"Caught an exception when running a coroutine: {e}")
             if not future.done():
                 future.set_exception(e)
 
@@ -213,12 +252,12 @@ class DHT(mp.Process):
         if not self.ready.is_set():
             raise RuntimeError(
                 "Can't append new validators before the DHT process has started. "
-                "Consider adding them to the initial list via DHT.__init__(record_validators=...)")
+                "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
+            )
 
         self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
 
-    async def _add_validators(
-            self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
+    async def _add_validators(self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
         node.protocol.record_validator.extend(record_validators)
 
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:

+ 9 - 9
hivemind/dht/crypto.py

@@ -19,22 +19,22 @@ class RSASignatureValidator(RecordValidatorBase):
     the corresponding private key (so only the owner can change them).
     """
 
-    PUBLIC_KEY_FORMAT = b'[owner:_key_]'
-    SIGNATURE_FORMAT = b'[signature:_value_]'
+    PUBLIC_KEY_FORMAT = b"[owner:_key_]"
+    SIGNATURE_FORMAT = b"[signature:_value_]"
 
-    PUBLIC_KEY_REGEX = re.escape(PUBLIC_KEY_FORMAT).replace(b'_key_', rb'(.+?)')
+    PUBLIC_KEY_REGEX = re.escape(PUBLIC_KEY_FORMAT).replace(b"_key_", rb"(.+?)")
     _PUBLIC_KEY_RE = re.compile(PUBLIC_KEY_REGEX)
-    _SIGNATURE_RE = re.compile(re.escape(SIGNATURE_FORMAT).replace(b'_value_', rb'(.+?)'))
+    _SIGNATURE_RE = re.compile(re.escape(SIGNATURE_FORMAT).replace(b"_value_", rb"(.+?)"))
 
     _cached_private_key = None
 
-    def __init__(self, private_key: Optional[RSAPrivateKey]=None):
+    def __init__(self, private_key: Optional[RSAPrivateKey] = None):
         if private_key is None:
             private_key = RSAPrivateKey.process_wide()
         self._private_key = private_key
 
         serialized_public_key = private_key.get_public_key().to_bytes()
-        self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b'_key_', serialized_public_key)
+        self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b"_key_", serialized_public_key)
 
     @property
     def local_public_key(self) -> bytes:
@@ -60,7 +60,7 @@ class RSASignatureValidator(RecordValidatorBase):
 
         stripped_record = dataclasses.replace(record, value=self.strip_value(record))
         if not public_key.verify(self._serialize_record(stripped_record), signature):
-            logger.debug(f'Signature is invalid in {record}')
+            logger.debug(f"Signature is invalid in {record}")
             return False
         return True
 
@@ -69,10 +69,10 @@ class RSASignatureValidator(RecordValidatorBase):
             return record.value
 
         signature = self._private_key.sign(self._serialize_record(record))
-        return record.value + self.SIGNATURE_FORMAT.replace(b'_value_', signature)
+        return record.value + self.SIGNATURE_FORMAT.replace(b"_value_", signature)
 
     def strip_value(self, record: DHTRecord) -> bytes:
-        return self._SIGNATURE_RE.sub(b'', record.value)
+        return self._SIGNATURE_RE.sub(b"", record.value)
 
     def _serialize_record(self, record: DHTRecord) -> bytes:
         return MSGPackSerializer.dumps(dataclasses.astuple(record))

+ 277 - 129
hivemind/dht/node.py

@@ -6,8 +6,21 @@ import random
 from collections import defaultdict, Counter
 from dataclasses import dataclass, field
 from functools import partial
-from typing import (Any, Awaitable, Callable, Collection, DefaultDict, Dict, List, Optional, Sequence, Set, Tuple,
-                    Type, Union)
+from typing import (
+    Any,
+    Awaitable,
+    Callable,
+    Collection,
+    DefaultDict,
+    Dict,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    Union,
+)
 
 from multiaddr import Multiaddr
 from sortedcontainers import SortedSet
@@ -69,6 +82,7 @@ class DHTNode:
       to reuse the result of this GET request for other requests with the same key. Useful for batch-parallel requests.
 
     """
+
     # fmt:off
     node_id: DHTID; is_alive: bool; peer_id: PeerID; num_replicas: int; num_workers: int; protocol: DHTProtocol
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
@@ -79,19 +93,34 @@ class DHTNode:
 
     @classmethod
     async def create(
-            cls,
-            p2p: Optional[P2P] = None,
-            node_id: Optional[DHTID] = None,
-            initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-            bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
-            wait_timeout: float = 3, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
-            cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
-            cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
-            blacklist_time: float = 5.0, backoff_rate: float = 2.0,
-            listen: bool = True,
-            record_validator: Optional[RecordValidatorBase] = None,
-            authorizer: Optional[AuthorizerBase] = None,
-            validate: bool = True, strict: bool = True, **kwargs) -> DHTNode:
+        cls,
+        p2p: Optional[P2P] = None,
+        node_id: Optional[DHTID] = None,
+        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
+        bucket_size: int = 20,
+        num_replicas: int = 5,
+        depth_modulo: int = 5,
+        parallel_rpc: int = None,
+        wait_timeout: float = 3,
+        refresh_timeout: Optional[float] = None,
+        bootstrap_timeout: Optional[float] = None,
+        cache_locally: bool = True,
+        cache_nearest: int = 1,
+        cache_size=None,
+        cache_refresh_before_expiry: float = 5,
+        cache_on_store: bool = True,
+        reuse_get_requests: bool = True,
+        num_workers: int = 1,
+        chunk_size: int = 16,
+        blacklist_time: float = 5.0,
+        backoff_rate: float = 2.0,
+        listen: bool = True,
+        record_validator: Optional[RecordValidatorBase] = None,
+        authorizer: Optional[AuthorizerBase] = None,
+        validate: bool = True,
+        strict: bool = True,
+        **kwargs,
+    ) -> DHTNode:
         """
         :param p2p: instance of hivemind.p2p.P2P that will be used for communication.
           If None, DHTNode will create and manage its own P2P instance with given initial_peers and
@@ -139,7 +168,7 @@ class DHTNode:
         self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
 
         self.reuse_get_requests = reuse_get_requests
-        self.pending_get_requests = defaultdict(partial(SortedSet, key=lambda _res: - _res.sufficient_expiration_time))
+        self.pending_get_requests = defaultdict(partial(SortedSet, key=lambda _res: -_res.sufficient_expiration_time))
 
         # caching policy
         self.refresh_timeout = refresh_timeout
@@ -151,38 +180,52 @@ class DHTNode:
         self.cache_refresh_task = None
 
         if p2p is None:
-            if not kwargs.get('use_ipfs'):
-                kwargs['initial_peers'] = initial_peers
+            if not kwargs.get("use_ipfs"):
+                kwargs["initial_peers"] = initial_peers
             p2p = await P2P.create(**kwargs)
             self._should_shutdown_p2p = True
         else:
             if kwargs:
                 raise ValueError(
-                    f'**kwargs in DHTNode.create() should be empty if hivemind.p2p.P2P instance is provided'
-                    f'in the constructor. Got kwargs = {kwargs} instead. '
-                    f'You may have a typo in a DHTNode.create() parameter name')
+                    f"**kwargs in DHTNode.create() should be empty if hivemind.p2p.P2P instance is provided"
+                    f"in the constructor. Got kwargs = {kwargs} instead. "
+                    f"You may have a typo in a DHTNode.create() parameter name"
+                )
             self._should_shutdown_p2p = False
         self.p2p = p2p
 
         self.protocol = await DHTProtocol.create(
-            p2p, self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
-            parallel_rpc, cache_size, listen, record_validator, authorizer)
+            p2p,
+            self.node_id,
+            bucket_size,
+            depth_modulo,
+            num_replicas,
+            wait_timeout,
+            parallel_rpc,
+            cache_size,
+            listen,
+            record_validator,
+            authorizer,
+        )
         self.peer_id = p2p.id
 
         if initial_peers:
-            initial_peers = {PeerID.from_base58(Multiaddr(item)['p2p']) for item in initial_peers}
+            initial_peers = {PeerID.from_base58(Multiaddr(item)["p2p"]) for item in initial_peers}
 
             # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             start_time = get_dht_time()
-            ping_tasks = set(asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
-                             for peer in initial_peers)
+            ping_tasks = set(
+                asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
+                for peer in initial_peers
+            )
             finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
 
             # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
             if unfinished_pings:
                 finished_in_time, stragglers = await asyncio.wait(
-                    unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time)
+                    unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time
+                )
                 for straggler in stragglers:
                     straggler.cancel()
                 finished_pings |= finished_in_time
@@ -197,29 +240,39 @@ class DHTNode:
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
             # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
-            await asyncio.wait([asyncio.create_task(self.find_nearest_nodes([self.node_id])),
-                                asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
-                               return_when=asyncio.FIRST_COMPLETED)
+            await asyncio.wait(
+                [
+                    asyncio.create_task(self.find_nearest_nodes([self.node_id])),
+                    asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time),
+                ],
+                return_when=asyncio.FIRST_COMPLETED,
+            )
 
         if self.refresh_timeout is not None:
             asyncio.create_task(self._refresh_routing_table(period=self.refresh_timeout))
         return self
 
     def __init__(self, *, _initialized_with_create=False):
-        """ Internal init method. Please use DHTNode.create coroutine to spawn new node instances """
+        """Internal init method. Please use DHTNode.create coroutine to spawn new node instances"""
         assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
         super().__init__()
 
     async def shutdown(self):
-        """ Process existing requests, close all connections and stop the server """
+        """Process existing requests, close all connections and stop the server"""
         self.is_alive = False
         if self._should_shutdown_p2p:
             await self.p2p.shutdown()
 
     async def find_nearest_nodes(
-            self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
-            num_workers: Optional[int] = None, node_to_peer_id: Optional[Dict[DHTID, PeerID]] = None,
-            exclude_self: bool = False, **kwargs) -> Dict[DHTID, Dict[DHTID, PeerID]]:
+        self,
+        queries: Collection[DHTID],
+        k_nearest: Optional[int] = None,
+        beam_size: Optional[int] = None,
+        num_workers: Optional[int] = None,
+        node_to_peer_id: Optional[Dict[DHTID, PeerID]] = None,
+        exclude_self: bool = False,
+        **kwargs,
+    ) -> Dict[DHTID, Dict[DHTID, PeerID]]:
         """
         :param queries: find k nearest nodes for each of these DHTIDs
         :param k_nearest: return this many nearest nodes for every query (if there are enough nodes)
@@ -254,9 +307,15 @@ class DHTNode:
             return output
 
         nearest_nodes_per_query, visited_nodes = await traverse_dht(
-            queries, initial_nodes=list(node_to_peer_id), beam_size=beam_size, num_workers=num_workers,
-            queries_per_call=int(len(queries) ** 0.5), get_neighbors=get_neighbors,
-            visited_nodes={query: {self.node_id} for query in queries}, **kwargs)
+            queries,
+            initial_nodes=list(node_to_peer_id),
+            beam_size=beam_size,
+            num_workers=num_workers,
+            queries_per_call=int(len(queries) ** 0.5),
+            get_neighbors=get_neighbors,
+            visited_nodes={query: {self.node_id} for query in queries},
+            **kwargs,
+        )
 
         nearest_nodes_with_peer_ids = {}
         for query, nearest_nodes in nearest_nodes_per_query.items():
@@ -266,8 +325,9 @@ class DHTNode:
             nearest_nodes_with_peer_ids[query] = {node: node_to_peer_id[node] for node in nearest_nodes[:k_nearest]}
         return nearest_nodes_with_peer_ids
 
-    async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
-                    subkey: Optional[Subkey] = None, **kwargs) -> bool:
+    async def store(
+        self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, subkey: Optional[Subkey] = None, **kwargs
+    ) -> bool:
         """
         Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
         :note: store is a simplified interface to store_many, all kwargs are be forwarded there
@@ -276,10 +336,16 @@ class DHTNode:
         store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)
         return store_ok[(key, subkey) if subkey is not None else key]
 
-    async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
-                         expiration_time: Union[DHTExpiration, List[DHTExpiration]],
-                         subkeys: Optional[Union[Subkey, List[Optional[Subkey]]]] = None,
-                         exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
+    async def store_many(
+        self,
+        keys: List[DHTKey],
+        values: List[DHTValue],
+        expiration_time: Union[DHTExpiration, List[DHTExpiration]],
+        subkeys: Optional[Union[Subkey, List[Optional[Subkey]]]] = None,
+        exclude_self: bool = False,
+        await_all_replicas=True,
+        **kwargs,
+    ) -> Dict[DHTKey, bool]:
         """
         Traverse DHT to find up :num_replicas: to best nodes to store multiple (key, value, expiration_time) pairs.
 
@@ -299,8 +365,9 @@ class DHTNode:
         if subkeys is None:
             subkeys = [None] * len(keys)
 
-        assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
-            "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
+        assert (
+            len(keys) == len(subkeys) == len(values) == len(expiration_time)
+        ), "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
 
         key_id_to_data: DefaultDict[DHTID, List[Tuple[DHTKey, Subkey, DHTValue, DHTExpiration]]] = defaultdict(list)
         for key, subkey, value, expiration in zip(keys, subkeys, values, expiration_time):
@@ -313,11 +380,14 @@ class DHTNode:
         # pre-populate node_to_peer_id
         node_to_peer_id: Dict[DHTID, PeerID] = dict()
         for key_id in unfinished_key_ids:
-            node_to_peer_id.update(self.protocol.routing_table.get_nearest_neighbors(
-                key_id, self.protocol.bucket_size, exclude=self.node_id))
+            node_to_peer_id.update(
+                self.protocol.routing_table.get_nearest_neighbors(
+                    key_id, self.protocol.bucket_size, exclude=self.node_id
+                )
+            )
 
         async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set[DHTID]) -> None:
-            """ This will be called once per key when find_nearest_nodes is done for a particular node """
+            """This will be called once per key when find_nearest_nodes is done for a particular node"""
             # note: we use callbacks instead of returned values to call store immediately without waiting for stragglers
             assert key_id in unfinished_key_ids, "Internal error: traverse_dht finished the same query twice"
             assert self.node_id not in nearest_nodes
@@ -326,15 +396,15 @@ class DHTNode:
             # ensure k nodes stored the value, optionally include self.node_id as a candidate
             num_successful_stores = 0
             pending_store_tasks = set()
-            store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
-                                      key=key_id.xor_distance, reverse=True)  # ordered so that .pop() returns nearest
+            store_candidates = sorted(
+                nearest_nodes + ([] if exclude_self else [self.node_id]), key=key_id.xor_distance, reverse=True
+            )  # ordered so that .pop() returns nearest
             [original_key, *_], current_subkeys, current_values, current_expirations = zip(*key_id_to_data[key_id])
 
             key_bytes = key_id.to_bytes()
             binary_values = []
             stored_records = []
-            for subkey, value, expiration_time in zip(
-                    current_subkeys, current_values, current_expirations):
+            for subkey, value, expiration_time in zip(current_subkeys, current_values, current_expirations):
                 subkey_bytes = self.protocol.serializer.dumps(subkey)
                 value_bytes = self.protocol.serializer.dumps(value)
                 record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
@@ -351,23 +421,34 @@ class DHTNode:
                     if node_id == self.node_id:
                         num_successful_stores += 1
                         for subkey, record in zip(current_subkeys, stored_records):
-                            if (self.protocol.record_validator is None or
-                                    self.protocol.record_validator.validate(record)):
+                            if self.protocol.record_validator is None or self.protocol.record_validator.validate(
+                                record
+                            ):
                                 store_ok[original_key, subkey] = self.protocol.storage.store(
-                                    key_id, record.value, record.expiration_time, subkey=subkey)
+                                    key_id, record.value, record.expiration_time, subkey=subkey
+                                )
                             else:
                                 store_ok[original_key, subkey] = False
                             if not await_all_replicas:
                                 store_finished_events[original_key, subkey].set()
                     else:
-                        pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
-                            node_to_peer_id[node_id], keys=[key_id] * len(current_values), values=binary_values,
-                            expiration_time=current_expirations, subkeys=current_subkeys)))
+                        pending_store_tasks.add(
+                            asyncio.create_task(
+                                self.protocol.call_store(
+                                    node_to_peer_id[node_id],
+                                    keys=[key_id] * len(current_values),
+                                    values=binary_values,
+                                    expiration_time=current_expirations,
+                                    subkeys=current_subkeys,
+                                )
+                            )
+                        )
 
                 # await nearest task. If it fails, dispatch more on the next iteration
                 if pending_store_tasks:
                     finished_store_tasks, pending_store_tasks = await asyncio.wait(
-                        pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
+                        pending_store_tasks, return_when=asyncio.FIRST_COMPLETED
+                    )
                     for task in finished_store_tasks:
                         if task.result() is not None:
                             num_successful_stores += 1
@@ -377,27 +458,47 @@ class DHTNode:
                                     store_finished_events[original_key, subkey].set()
 
             if self.cache_on_store:
-                self._update_cache_on_store(key_id, current_subkeys, binary_values, current_expirations,
-                                            store_ok=[store_ok[original_key, subkey] for subkey in current_subkeys])
+                self._update_cache_on_store(
+                    key_id,
+                    current_subkeys,
+                    binary_values,
+                    current_expirations,
+                    store_ok=[store_ok[original_key, subkey] for subkey in current_subkeys],
+                )
 
             for subkey, value_bytes, expiration in zip(current_subkeys, binary_values, current_expirations):
                 store_finished_events[original_key, subkey].set()
 
-        store_task = asyncio.create_task(self.find_nearest_nodes(
-            queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_peer_id=node_to_peer_id,
-            found_callback=on_found, exclude_self=exclude_self, **kwargs))
+        store_task = asyncio.create_task(
+            self.find_nearest_nodes(
+                queries=set(unfinished_key_ids),
+                k_nearest=self.num_replicas,
+                node_to_peer_id=node_to_peer_id,
+                found_callback=on_found,
+                exclude_self=exclude_self,
+                **kwargs,
+            )
+        )
         try:
             await asyncio.gather(store_task, *(evt.wait() for evt in store_finished_events.values()))
             assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
-            return {(key, subkey) if subkey is not None else key: status or False
-                    for (key, subkey), status in store_ok.items()}
+            return {
+                (key, subkey) if subkey is not None else key: status or False
+                for (key, subkey), status in store_ok.items()
+            }
         except asyncio.CancelledError as e:
             store_task.cancel()
             raise e
 
-    def _update_cache_on_store(self, key_id: DHTID, subkeys: List[Subkey], binary_values: List[bytes],
-                               expirations: List[DHTExpiration], store_ok: List[bool]):
-        """ Update local cache after finishing a store for one key (with perhaps several subkeys) """
+    def _update_cache_on_store(
+        self,
+        key_id: DHTID,
+        subkeys: List[Subkey],
+        binary_values: List[bytes],
+        expirations: List[DHTExpiration],
+        store_ok: List[bool],
+    ):
+        """Update local cache after finishing a store for one key (with perhaps several subkeys)"""
         store_succeeded = any(store_ok)
         is_dictionary = any(subkey is not None for subkey in subkeys)
         if store_succeeded and not is_dictionary:  # stored a new regular value, cache it!
@@ -406,12 +507,14 @@ class DHTNode:
         elif not store_succeeded and not is_dictionary:  # store rejected, check if local cache is also obsolete
             rejected_expiration, rejected_value = max(zip(expirations, binary_values))
             cached_value = self.protocol.cache.get(key_id)
-            if (cached_value is not None and
-                    cached_value.expiration_time <= rejected_expiration):  # cache would be rejected
+            if (
+                cached_value is not None and cached_value.expiration_time <= rejected_expiration
+            ):  # cache would be rejected
                 self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
         elif is_dictionary and key_id in self.protocol.cache:  # there can be other keys and we should update
             for subkey, stored_value_bytes, expiration_time, accepted in zip(
-                    subkeys, binary_values, expirations, store_ok):
+                subkeys, binary_values, expirations, store_ok
+            ):
                 if accepted:
                     self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
             self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
@@ -425,13 +528,15 @@ class DHTNode:
         :returns: (value, expiration time); if value was not found, returns None
         """
         if latest:
-            kwargs["sufficient_expiration_time"] = float('inf')
+            kwargs["sufficient_expiration_time"] = float("inf")
         result = await self.get_many([key], **kwargs)
         return result[key]
 
-    async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
-                       **kwargs) -> Dict[DHTKey, Union[Optional[ValueWithExpiration[DHTValue]],
-                                                       Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
+    async def get_many(
+        self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None, **kwargs
+    ) -> Dict[
+        DHTKey, Union[Optional[ValueWithExpiration[DHTValue]], Awaitable[Optional[ValueWithExpiration[DHTValue]]]]
+    ]:
         """
         Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
 
@@ -450,10 +555,16 @@ class DHTNode:
         return {id_to_original_key[key]: result_or_future for key, result_or_future in results_by_id.items()}
 
     async def get_many_by_id(
-            self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
-            num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
-            _is_refresh=False) -> Dict[DHTID, Union[Optional[ValueWithExpiration[DHTValue]],
-                                                    Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
+        self,
+        key_ids: Collection[DHTID],
+        sufficient_expiration_time: Optional[DHTExpiration] = None,
+        num_workers: Optional[int] = None,
+        beam_size: Optional[int] = None,
+        return_futures: bool = False,
+        _is_refresh=False,
+    ) -> Dict[
+        DHTID, Union[Optional[ValueWithExpiration[DHTValue]], Awaitable[Optional[ValueWithExpiration[DHTValue]]]]
+    ]:
         """
         Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
 
@@ -473,10 +584,15 @@ class DHTNode:
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
         num_workers = num_workers if num_workers is not None else self.num_workers
-        search_results: Dict[DHTID, _SearchState] = {key_id: _SearchState(
-            key_id, sufficient_expiration_time,
-            serializer=self.protocol.serializer,
-            record_validator=self.protocol.record_validator) for key_id in key_ids}
+        search_results: Dict[DHTID, _SearchState] = {
+            key_id: _SearchState(
+                key_id,
+                sufficient_expiration_time,
+                serializer=self.protocol.serializer,
+                record_validator=self.protocol.record_validator,
+            )
+            for key_id in key_ids
+        }
 
         if not _is_refresh:  # if we're already refreshing cache, there's no need to trigger subsequent refreshes
             for key_id in key_ids:
@@ -498,8 +614,11 @@ class DHTNode:
         unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
         node_to_peer_id: Dict[DHTID, PeerID] = dict()  # global routing table for all keys
         for key_id in unfinished_key_ids:
-            node_to_peer_id.update(self.protocol.routing_table.get_nearest_neighbors(
-                key_id, self.protocol.bucket_size, exclude=self.node_id))
+            node_to_peer_id.update(
+                self.protocol.routing_table.get_nearest_neighbors(
+                    key_id, self.protocol.bucket_size, exclude=self.node_id
+                )
+            )
 
         # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
@@ -521,11 +640,19 @@ class DHTNode:
             search_results[key_id].finish_search()  # finish search whether or we found something
             self._cache_new_result(search_results[key_id], nearest_nodes, node_to_peer_id, _is_refresh=_is_refresh)
 
-        asyncio.create_task(traverse_dht(
-            queries=list(unfinished_key_ids), initial_nodes=list(node_to_peer_id), beam_size=beam_size,
-            num_workers=num_workers, queries_per_call=min(int(len(unfinished_key_ids) ** 0.5), self.chunk_size),
-            get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
-            found_callback=found_callback, await_all_tasks=False))
+        asyncio.create_task(
+            traverse_dht(
+                queries=list(unfinished_key_ids),
+                initial_nodes=list(node_to_peer_id),
+                beam_size=beam_size,
+                num_workers=num_workers,
+                queries_per_call=min(int(len(unfinished_key_ids) ** 0.5), self.chunk_size),
+                get_neighbors=get_neighbors,
+                visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
+                found_callback=found_callback,
+                await_all_tasks=False,
+            )
+        )
 
         if return_futures:
             return {key_id: search_result.future for key_id, search_result in search_results.items()}
@@ -552,14 +679,16 @@ class DHTNode:
             pending_requests.discard(finished)
 
     async def _call_find_with_blacklist(self, peer_id: PeerID, keys: Collection[DHTID]):
-        """ same as call_find, but skip if :peer_id: is blacklisted; also exclude blacklisted neighbors from result """
+        """same as call_find, but skip if :peer_id: is blacklisted; also exclude blacklisted neighbors from result"""
         if peer_id in self.blacklist:
             return None
         response = await self.protocol.call_find(peer_id, keys)
         if response:
             self.blacklist.register_success(peer_id)
-            return {key: (maybe_value, self._filter_blacklisted(nearest_peers))
-                    for key, (maybe_value, nearest_peers) in response.items()}
+            return {
+                key: (maybe_value, self._filter_blacklisted(nearest_peers))
+                for key, (maybe_value, nearest_peers) in response.items()
+            }
         else:
             self.blacklist.register_failure(peer_id)
             return None
@@ -568,13 +697,13 @@ class DHTNode:
         return {peer: peer_id for peer, peer_id in peer_ids.items() if peer_id not in self.blacklist}
 
     def _trigger_cache_refresh(self, search: _SearchState):
-        """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
+        """Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused)"""
         if search.found_something and search.source_node_id == self.node_id:
             if self.cache_refresh_before_expiry and search.key_id in self.protocol.cache:
                 self._schedule_for_refresh(search.key_id, search.expiration_time - self.cache_refresh_before_expiry)
 
     def _schedule_for_refresh(self, key_id: DHTID, refresh_time: DHTExpiration):
-        """ Add key to a refresh queue, refresh at :refresh_time: or later """
+        """Add key to a refresh queue, refresh at :refresh_time: or later"""
         if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
             self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
             logger.debug("Spawned cache refresh task.")
@@ -584,7 +713,7 @@ class DHTNode:
         self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
 
     async def _refresh_stale_cache_entries(self):
-        """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
+        """periodically refresh keys near-expired keys that were accessed at least once during previous lifetime"""
         while self.is_alive:
             while len(self.cache_refresh_queue) == 0:
                 await self.cache_refresh_evt.wait()
@@ -619,12 +748,17 @@ class DHTNode:
                 sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
                 await self.get_many_by_id(keys_to_refresh, sufficient_expiration_time, _is_refresh=True)
 
-    def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID],
-                          node_to_peer_id: Dict[DHTID, PeerID], _is_refresh: bool = False):
-        """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
+    def _cache_new_result(
+        self,
+        search: _SearchState,
+        nearest_nodes: List[DHTID],
+        node_to_peer_id: Dict[DHTID, PeerID],
+        _is_refresh: bool = False,
+    ):
+        """after key_id is found, update cache according to caching policy. used internally in get and get_many"""
         if search.found_something:
-            _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float('inf'))
-            _, cache_expiration_time = self.protocol.cache.get(search.key_id) or (None, -float('inf'))
+            _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float("inf"))
+            _, cache_expiration_time = self.protocol.cache.get(search.key_id) or (None, -float("inf"))
 
             if search.expiration_time > max(storage_expiration_time, cache_expiration_time):
                 if self.cache_locally or _is_refresh:
@@ -634,20 +768,27 @@ class DHTNode:
                     for node_id in nearest_nodes:
                         if node_id == search.source_node_id:
                             continue
-                        asyncio.create_task(self.protocol.call_store(
-                            node_to_peer_id[node_id], [search.key_id], [search.binary_value], [search.expiration_time],
-                            in_cache=True))
+                        asyncio.create_task(
+                            self.protocol.call_store(
+                                node_to_peer_id[node_id],
+                                [search.key_id],
+                                [search.binary_value],
+                                [search.expiration_time],
+                                in_cache=True,
+                            )
+                        )
                         num_cached_nodes += 1
                         if num_cached_nodes >= self.cache_nearest:
                             break
 
     async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
-        """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
+        """Tries to find new nodes for buckets that were unused for more than self.staleness_timeout"""
         while self.is_alive and period is not None:  # if None run once, otherwise run forever
             refresh_time = get_dht_time()
             staleness_threshold = refresh_time - period
-            stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
-                             if bucket.last_updated < staleness_threshold]
+            stale_buckets = [
+                bucket for bucket in self.protocol.routing_table.buckets if bucket.last_updated < staleness_threshold
+            ]
             for bucket in stale_buckets:
                 refresh_id = DHTID(random.randint(bucket.lower, bucket.upper - 1))
                 await self.find_nearest_nodes(refresh_id)
@@ -660,7 +801,8 @@ class DHTNode:
 
 @dataclass(init=True, repr=True, frozen=False, order=False)
 class _SearchState:
-    """ A helper class that stores current-best GET results with metadata """
+    """A helper class that stores current-best GET results with metadata"""
+
     key_id: DHTID
     sufficient_expiration_time: DHTExpiration
     binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
@@ -670,25 +812,28 @@ class _SearchState:
     serializer: Type[SerializerBase] = MSGPackSerializer
     record_validator: Optional[RecordValidatorBase] = None
 
-    def add_candidate(self, candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
-                      source_node_id: Optional[DHTID]):
+    def add_candidate(
+        self,
+        candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
+        source_node_id: Optional[DHTID],
+    ):
         if self.finished or candidate is None:
             return
         elif isinstance(candidate.value, DictionaryDHTValue) and isinstance(self.binary_value, DictionaryDHTValue):
             self.binary_value.maxsize = max(self.binary_value.maxsize, candidate.value.maxsize)
             for subkey, subentry in candidate.value.items():
                 self.binary_value.store(subkey, subentry.value, subentry.expiration_time)
-        elif candidate.expiration_time > (self.expiration_time or float('-inf')):
+        elif candidate.expiration_time > (self.expiration_time or float("-inf")):
             self.binary_value = candidate.value
 
-        if candidate.expiration_time > (self.expiration_time or float('-inf')):
+        if candidate.expiration_time > (self.expiration_time or float("-inf")):
             self.expiration_time = candidate.expiration_time
             self.source_node_id = source_node_id
             if self.expiration_time >= self.sufficient_expiration_time:
                 self.finish_search()
 
     def add_done_callback(self, callback: Callable[[_SearchState], Any]):
-        """ Add callback that will be called when _SearchState is done (found OR cancelled by user) """
+        """Add callback that will be called when _SearchState is done (found OR cancelled by user)"""
         self.future.add_done_callback(lambda _future: callback(self))
 
     def finish_search(self):
@@ -699,30 +844,30 @@ class _SearchState:
         elif isinstance(self.binary_value, BinaryDHTValue):
             value_bytes = self.binary_value
             if self.record_validator is not None:
-                record = DHTRecord(self.key_id.to_bytes(), DHTProtocol.IS_REGULAR_VALUE,
-                                   value_bytes, self.expiration_time)
+                record = DHTRecord(
+                    self.key_id.to_bytes(), DHTProtocol.IS_REGULAR_VALUE, value_bytes, self.expiration_time
+                )
                 value_bytes = self.record_validator.strip_value(record)
 
-            self.future.set_result(
-                ValueWithExpiration(self.serializer.loads(value_bytes), self.expiration_time))
+            self.future.set_result(ValueWithExpiration(self.serializer.loads(value_bytes), self.expiration_time))
         elif isinstance(self.binary_value, DictionaryDHTValue):
             dict_with_subkeys = {}
             for subkey, (value_bytes, item_expiration_time) in self.binary_value.items():
                 if self.record_validator is not None:
                     subkey_bytes = self.serializer.dumps(subkey)
-                    record = DHTRecord(self.key_id.to_bytes(), subkey_bytes,
-                                       value_bytes, item_expiration_time)
+                    record = DHTRecord(self.key_id.to_bytes(), subkey_bytes, value_bytes, item_expiration_time)
                     value_bytes = self.record_validator.strip_value(record)
 
                 dict_with_subkeys[subkey] = ValueWithExpiration(
-                    self.serializer.loads(value_bytes), item_expiration_time)
+                    self.serializer.loads(value_bytes), item_expiration_time
+                )
             self.future.set_result(ValueWithExpiration(dict_with_subkeys, self.expiration_time))
         else:
             logger.error(f"Invalid value type: {type(self.binary_value)}")
 
     @property
     def found_something(self) -> bool:
-        """ Whether or not we have found at least some value, regardless of its expiration time """
+        """Whether or not we have found at least some value, regardless of its expiration time"""
         return self.expiration_time is not None
 
     @property
@@ -730,7 +875,7 @@ class _SearchState:
         return self.future.done()
 
     def __lt__(self, other: _SearchState):
-        """ _SearchState instances will be sorted by their target expiration time """
+        """_SearchState instances will be sorted by their target expiration time"""
         return self.sufficient_expiration_time < other.sufficient_expiration_time
 
     def __hash__(self):
@@ -750,22 +895,24 @@ class Blacklist:
         self.ban_counter = Counter()
 
     def register_failure(self, peer: PeerID):
-        """ peer failed to respond, add him to blacklist or increase his downtime """
+        """peer failed to respond, add him to blacklist or increase his downtime"""
         if peer not in self.banned_peers and self.base_time > 0:
             ban_duration = self.base_time * self.backoff ** self.ban_counter[peer]
             self.banned_peers.store(peer, self.ban_counter[peer], expiration_time=get_dht_time() + ban_duration)
             self.ban_counter[peer] += 1
 
     def register_success(self, peer):
-        """ peer responded successfully, remove him from blacklist and reset his ban time """
+        """peer responded successfully, remove him from blacklist and reset his ban time"""
         del self.banned_peers[peer], self.ban_counter[peer]
 
     def __contains__(self, peer: PeerID) -> bool:
         return peer in self.banned_peers
 
     def __repr__(self):
-        return f"{self.__class__.__name__}(base_time={self.base_time}, backoff={self.backoff}, " \
-               f"banned_peers={len(self.banned_peers)})"
+        return (
+            f"{self.__class__.__name__}(base_time={self.base_time}, backoff={self.backoff}, "
+            f"banned_peers={len(self.banned_peers)})"
+        )
 
     def clear(self):
         self.banned_peers.clear()
@@ -773,5 +920,6 @@ class Blacklist:
 
 
 class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
-    """ a queue of keys scheduled for refresh in future, used in DHTNode """
+    """a queue of keys scheduled for refresh in future, used in DHTNode"""
+
     frozen = True

+ 105 - 49
hivemind/dht/protocol.py

@@ -11,8 +11,12 @@ from hivemind.p2p import P2P, P2PContext, PeerID, Servicer
 from hivemind.proto import dht_pb2
 from hivemind.utils import get_logger, MSGPackSerializer
 from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
-from hivemind.utils.timed_storage import DHTExpiration, get_dht_time, MAX_DHT_TIME_DISCREPANCY_SECONDS, \
-    ValueWithExpiration
+from hivemind.utils.timed_storage import (
+    DHTExpiration,
+    get_dht_time,
+    MAX_DHT_TIME_DISCREPANCY_SECONDS,
+    ValueWithExpiration,
+)
 
 logger = get_logger(__name__)
 
@@ -26,14 +30,23 @@ class DHTProtocol(Servicer):
     # fmt:on
 
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
-    RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b''
+    RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b""
 
     @classmethod
     async def create(
-            cls, p2p: P2P, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
-            parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
-            listen=True, record_validator: Optional[RecordValidatorBase] = None,
-            authorizer: Optional[AuthorizerBase] = None) -> DHTProtocol:
+        cls,
+        p2p: P2P,
+        node_id: DHTID,
+        bucket_size: int,
+        depth_modulo: int,
+        num_replicas: int,
+        wait_timeout: float,
+        parallel_rpc: Optional[int] = None,
+        cache_size: Optional[int] = None,
+        listen=True,
+        record_validator: Optional[RecordValidatorBase] = None,
+        authorizer: Optional[AuthorizerBase] = None,
+    ) -> DHTProtocol:
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
         As a side-effect, DHTProtocol also maintains a routing table as described in
@@ -52,7 +65,7 @@ class DHTProtocol(Servicer):
         self.wait_timeout = wait_timeout
         self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
-        self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
+        self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float("inf"))
         self.listen = listen
         self.record_validator = record_validator
         self.authorizer = authorizer
@@ -67,12 +80,12 @@ class DHTProtocol(Servicer):
         return self
 
     def __init__(self, *, _initialized_with_create=False):
-        """ Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances """
+        """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
         assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
         super().__init__()
 
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:
-        """ get a stub that sends requests to a given peer """
+        """get a stub that sends requests to a given peer"""
         stub = super().get_stub(self.p2p, peer)
         return AuthRPCWrapper(stub, AuthRole.CLIENT, self.authorizer, service_public_key=None)
 
@@ -100,14 +113,19 @@ class DHTProtocol(Servicer):
         if responded and validate:
             try:
                 if self.listen and not response.available:
-                    raise ValidationError(f"Peer {peer} can't access this node. "
-                                          f"Probably, libp2p has failed to bypass the firewall")
+                    raise ValidationError(
+                        f"Peer {peer} can't access this node. " f"Probably, libp2p has failed to bypass the firewall"
+                    )
 
                 if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
-                    if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \
-                            response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS:
-                        raise ValidationError(f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds "
-                                              f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})")
+                    if (
+                        response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS
+                        or response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS
+                    ):
+                        raise ValidationError(
+                            f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds "
+                            f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})"
+                        )
             except ValidationError as e:
                 if strict:
                     raise
@@ -119,10 +137,9 @@ class DHTProtocol(Servicer):
         return peer_id
 
     async def rpc_ping(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
-        """ Some node wants us to add it to our routing table. """
+        """Some node wants us to add it to our routing table."""
 
-        response = dht_pb2.PingResponse(peer=self.node_info,
-                                        dht_time=get_dht_time(), available=False)
+        response = dht_pb2.PingResponse(peer=self.node_info, dht_time=get_dht_time(), available=False)
 
         if request.peer and request.peer.node_id:
             sender_id = DHTID.from_bytes(request.peer.node_id)
@@ -131,16 +148,23 @@ class DHTProtocol(Servicer):
             if request.validate:
                 response.available = await self.call_ping(sender_peer_id, validate=False) == sender_id
 
-            asyncio.create_task(self.update_routing_table(sender_id, sender_peer_id,
-                                                          responded=response.available or not request.validate))
+            asyncio.create_task(
+                self.update_routing_table(
+                    sender_id, sender_peer_id, responded=response.available or not request.validate
+                )
+            )
 
         return response
 
-    async def call_store(self, peer: PeerID, keys: Sequence[DHTID],
-                         values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
-                         expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
-                         subkeys: Optional[Union[Subkey, Sequence[Optional[Subkey]]]] = None,
-                         in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Optional[List[bool]]:
+    async def call_store(
+        self,
+        peer: PeerID,
+        keys: Sequence[DHTID],
+        values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
+        expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
+        subkeys: Optional[Union[Subkey, Sequence[Optional[Subkey]]]] = None,
+        in_cache: Optional[Union[bool, Sequence[bool]]] = None,
+    ) -> Optional[List[bool]]:
         """
         Ask a recipient to store several (key, value : expiration_time) items or update their older value
 
@@ -166,19 +190,29 @@ class DHTProtocol(Servicer):
 
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
-        keys, subkeys, values, expiration_time, in_cache = map(list, [keys, subkeys, values, expiration_time, in_cache])
+        keys, subkeys, values, expiration_time, in_cache = map(
+            list, [keys, subkeys, values, expiration_time, in_cache]
+        )
         for i in range(len(keys)):
             if subkeys[i] is None:  # add default sub-key if not specified
                 subkeys[i] = self.IS_DICTIONARY if isinstance(values[i], DictionaryDHTValue) else self.IS_REGULAR_VALUE
             else:
                 subkeys[i] = self.serializer.dumps(subkeys[i])
             if isinstance(values[i], DictionaryDHTValue):
-                assert subkeys[i] == self.IS_DICTIONARY, "Please don't specify subkey when storing an entire dictionary"
+                assert (
+                    subkeys[i] == self.IS_DICTIONARY
+                ), "Please don't specify subkey when storing an entire dictionary"
                 values[i] = self.serializer.dumps(values[i])
 
         assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
-        store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), subkeys=subkeys, values=values,
-                                             expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
+        store_request = dht_pb2.StoreRequest(
+            keys=list(map(DHTID.to_bytes, keys)),
+            subkeys=subkeys,
+            values=values,
+            expiration_time=expiration_time,
+            in_cache=in_cache,
+            peer=self.node_info,
+        )
         try:
             async with self.rpc_semaphore:
                 response = await self.get_stub(peer).rpc_store(store_request, timeout=self.wait_timeout)
@@ -192,13 +226,14 @@ class DHTProtocol(Servicer):
             return None
 
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: P2PContext) -> dht_pb2.StoreResponse:
-        """ Some node wants us to store this (key, value) pair """
+        """Some node wants us to store this (key, value) pair"""
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         for key, tag, value_bytes, expiration_time, in_cache in zip(
-                request.keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
+            request.keys, request.subkeys, request.values, request.expiration_time, request.in_cache
+        ):
             key_id = DHTID.from_bytes(key)
             storage = self.cache if in_cache else self.storage
 
@@ -209,8 +244,12 @@ class DHTProtocol(Servicer):
                     response.store_ok.append(False)
                     continue
 
-                response.store_ok.append(all(storage.store_subkey(key_id, subkey, item.value, item.expiration_time)
-                                             for subkey, item in value_dictionary.items()))
+                response.store_ok.append(
+                    all(
+                        storage.store_subkey(key_id, subkey, item.value, item.expiration_time)
+                        for subkey, item in value_dictionary.items()
+                    )
+                )
                 continue
 
             if not self._validate_record(key, tag, value_bytes, expiration_time):
@@ -224,8 +263,13 @@ class DHTProtocol(Servicer):
                 response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
         return response
 
-    async def call_find(self, peer: PeerID, keys: Collection[DHTID]) -> Optional[Dict[
-        DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, PeerID]]]]:
+    async def call_find(
+        self, peer: PeerID, keys: Collection[DHTID]
+    ) -> Optional[
+        Dict[
+            DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, PeerID]]
+        ]
+    ]:
         """
         Request keys from a peer. For each key, look for its (value, expiration time) locally and
          k additional peers that are most likely to have this key (ranked by XOR distance)
@@ -249,14 +293,19 @@ class DHTProtocol(Servicer):
             output = {}  # unpack data depending on its type
             for key, result in zip(keys, response.results):
                 key_bytes = DHTID.to_bytes(key)
-                nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids),
-                                   map(PeerID.from_base58, result.nearest_peer_ids)))
+                nearest = dict(
+                    zip(
+                        map(DHTID.from_bytes, result.nearest_node_ids),
+                        map(PeerID.from_base58, result.nearest_peer_ids),
+                    )
+                )
 
                 if result.type == dht_pb2.NOT_FOUND:
                     output[key] = None, nearest
                 elif result.type == dht_pb2.FOUND_REGULAR:
                     if not self._validate_record(
-                            key_bytes, self.IS_REGULAR_VALUE, result.value, result.expiration_time):
+                        key_bytes, self.IS_REGULAR_VALUE, result.value, result.expiration_time
+                    ):
                         output[key] = None, nearest
                         continue
 
@@ -288,21 +337,27 @@ class DHTProtocol(Servicer):
         for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
             maybe_item = self.storage.get(key_id)
             cached_item = self.cache.get(key_id)
-            if cached_item is not None and (maybe_item is None
-                                            or cached_item.expiration_time > maybe_item.expiration_time):
+            if cached_item is not None and (
+                maybe_item is None or cached_item.expiration_time > maybe_item.expiration_time
+            ):
                 maybe_item = cached_item
 
             if maybe_item is None:  # value not found
                 item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND)
             elif isinstance(maybe_item.value, DictionaryDHTValue):
-                item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_item.value),
-                                          expiration_time=maybe_item.expiration_time)
+                item = dht_pb2.FindResult(
+                    type=dht_pb2.FOUND_DICTIONARY,
+                    value=self.serializer.dumps(maybe_item.value),
+                    expiration_time=maybe_item.expiration_time,
+                )
             else:  # found regular value
-                item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_item.value,
-                                          expiration_time=maybe_item.expiration_time)
+                item = dht_pb2.FindResult(
+                    type=dht_pb2.FOUND_REGULAR, value=maybe_item.value, expiration_time=maybe_item.expiration_time
+                )
 
             for node_id, peer_id in self.routing_table.get_nearest_neighbors(
-                    key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
+                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)
+            ):
                 item.nearest_node_ids.append(node_id.to_bytes())
                 item.nearest_peer_ids.append(peer_id.to_base58())
             response.results.append(item)
@@ -344,8 +399,9 @@ class DHTProtocol(Servicer):
             if node_id is not None and node_id in self.routing_table:
                 del self.routing_table[node_id]
 
-    def _validate_record(self, key_bytes: bytes, subkey_bytes: bytes, value_bytes: bytes,
-                         expiration_time: float) -> bool:
+    def _validate_record(
+        self, key_bytes: bytes, subkey_bytes: bytes, value_bytes: bytes, expiration_time: float
+    ) -> bool:
         if self.record_validator is None:
             return True
 
@@ -366,4 +422,4 @@ class DHTProtocol(Servicer):
 
 
 class ValidationError(Exception):
-    """ This exception is thrown if DHT node didn't pass validation by other nodes. """
+    """This exception is thrown if DHT node didn't pass validation by other nodes."""

+ 31 - 22
hivemind/dht/routing.py

@@ -11,7 +11,8 @@ from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
 from hivemind.p2p import PeerID
 from hivemind.utils import MSGPackSerializer, get_dht_time
 
-DHTKey, Subkey, DHTValue, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, bytes, bytes
+DHTKey = Subkey = DHTValue = Any
+BinaryDHTID = BinaryDHTValue = bytes
 
 
 class RoutingTable:
@@ -32,7 +33,7 @@ class RoutingTable:
         self.uid_to_peer_id: Dict[DHTID, PeerID] = dict()  # all nodes currently in buckets, including replacements
 
     def get_bucket_index(self, node_id: DHTID) -> int:
-        """ Get the index of the bucket that the given node would fall into. """
+        """Get the index of the bucket that the given node would fall into."""
         lower_index, upper_index = 0, len(self.buckets)
         while upper_index - lower_index > 1:
             pivot_index = (lower_index + upper_index + 1) // 2
@@ -72,13 +73,13 @@ class RoutingTable:
             return bucket.request_ping_node()
 
     def split_bucket(self, index: int) -> None:
-        """ Split bucket range in two equal parts and reassign nodes to the appropriate half """
+        """Split bucket range in two equal parts and reassign nodes to the appropriate half"""
         first, second = self.buckets[index].split()
         self.buckets[index] = first
         self.buckets.insert(index + 1, second)
 
     def get(self, *, node_id: Optional[DHTID] = None, peer_id: Optional[PeerID] = None, default=None):
-        """ Find peer_id for a given DHTID or vice versa """
+        """Find peer_id for a given DHTID or vice versa"""
         assert (node_id is None) != (peer_id is None), "Please specify either node_id or peer_id, but not both"
         if node_id is not None:
             return self.uid_to_peer_id.get(node_id, default)
@@ -86,11 +87,13 @@ class RoutingTable:
             return self.peer_id_to_uid.get(peer_id, default)
 
     def __getitem__(self, item: Union[DHTID, PeerID]) -> Union[PeerID, DHTID]:
-        """ Find peer_id for a given DHTID or vice versa """
+        """Find peer_id for a given DHTID or vice versa"""
         return self.uid_to_peer_id[item] if isinstance(item, DHTID) else self.peer_id_to_uid[item]
 
     def __setitem__(self, node_id: DHTID, peer_id: PeerID) -> NotImplementedError:
-        raise NotImplementedError("RoutingTable doesn't support direct item assignment. Use table.try_add_node instead")
+        raise NotImplementedError(
+            "RoutingTable doesn't support direct item assignment. Use table.try_add_node instead"
+        )
 
     def __contains__(self, item: Union[DHTID, PeerID]) -> bool:
         return (item in self.uid_to_peer_id) if isinstance(item, DHTID) else (item in self.peer_id_to_uid)
@@ -102,7 +105,8 @@ class RoutingTable:
             del self.peer_id_to_uid[node_peer_id]
 
     def get_nearest_neighbors(
-            self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, PeerID]]:
+        self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None
+    ) -> List[Tuple[DHTID, PeerID]]:
         """
         Find k nearest neighbors from routing table according to XOR distance, does NOT include self.node_id
 
@@ -134,7 +138,8 @@ class RoutingTable:
                 while right_index < len(self.buckets) and self.buckets[right_index].upper <= current_upper:
                     for node_id, peer_id in self.buckets[right_index].nodes_to_peer_id.items():
                         heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, peer_id))
-                    right_index += 1  # note: we may need to add more than one bucket if they are on a lower depth level
+                    right_index += 1
+                    # note: we may need to add more than one bucket if they are on a lower depth level
                 assert self.buckets[right_index - 1].upper == current_upper
 
             else:  # split_direction == 1, leaf was split on the right, merge its left peer(s)
@@ -151,8 +156,10 @@ class RoutingTable:
 
     def __repr__(self):
         bucket_info = "\n".join(repr(bucket) for bucket in self.buckets)
-        return f"{self.__class__.__name__}(node_id={self.node_id}, bucket_size={self.bucket_size}," \
-               f" modulo={self.depth_modulo},\nbuckets=[\n{bucket_info})"
+        return (
+            f"{self.__class__.__name__}(node_id={self.node_id}, bucket_size={self.bucket_size},"
+            f" modulo={self.depth_modulo},\nbuckets=[\n{bucket_info})"
+        )
 
 
 class KBucket:
@@ -170,7 +177,7 @@ class KBucket:
         self.last_updated = get_dht_time()
 
     def has_in_range(self, node_id: DHTID):
-        """ Check if node_id is between this bucket's lower and upper bounds """
+        """Check if node_id is between this bucket's lower and upper bounds"""
         return self.lower <= node_id < self.upper
 
     def add_or_update_node(self, node_id: DHTID, peer_id: PeerID) -> bool:
@@ -198,7 +205,7 @@ class KBucket:
         return True
 
     def request_ping_node(self) -> Optional[Tuple[DHTID, PeerID]]:
-        """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
+        """:returns: least-recently updated node that isn't already being pinged right now -- if such node exists"""
         for uid, peer_id in self.nodes_to_peer_id.items():
             if uid not in self.nodes_requested_for_ping:
                 self.nodes_requested_for_ping.add(uid)
@@ -222,7 +229,7 @@ class KBucket:
                 self.nodes_to_peer_id[newnode_id] = newnode
 
     def split(self) -> Tuple[KBucket, KBucket]:
-        """ Split bucket over midpoint, rounded down, assign nodes to according to their id """
+        """Split bucket over midpoint, rounded down, assign nodes to according to their id"""
         midpoint = (self.lower + self.upper) // 2
         assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
         left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
@@ -233,9 +240,11 @@ class KBucket:
         return left, right
 
     def __repr__(self):
-        return f"{self.__class__.__name__}({len(self.nodes_to_peer_id)} nodes" \
-               f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}" \
-               f" lower={hex(self.lower)}, upper={hex(self.upper)})"
+        return (
+            f"{self.__class__.__name__}({len(self.nodes_to_peer_id)} nodes"
+            f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}"
+            f" lower={hex(self.lower)}, upper={hex(self.upper)})"
+        )
 
 
 class DHTID(int):
@@ -255,7 +264,7 @@ class DHTID(int):
         :param source: if provided, converts this value to bytes and uses it as input for hashing function;
             by default, generates a random dhtid from :nbits: random bits
         """
-        source = random.getrandbits(nbits).to_bytes(nbits, byteorder='big') if source is None else source
+        source = random.getrandbits(nbits).to_bytes(nbits, byteorder="big") if source is None else source
         source = MSGPackSerializer.dumps(source) if not isinstance(source, bytes) else source
         raw_uid = cls.HASH_FUNC(source).digest()
         return cls(int(raw_uid.hex(), 16))
@@ -272,16 +281,16 @@ class DHTID(int):
 
     @classmethod
     def longest_common_prefix_length(cls, *ids: DHTID) -> int:
-        ids_bits = [bin(uid)[2:].rjust(8 * cls.HASH_NBYTES, '0') for uid in ids]
+        ids_bits = [bin(uid)[2:].rjust(8 * cls.HASH_NBYTES, "0") for uid in ids]
         return len(os.path.commonprefix(ids_bits))
 
-    def to_bytes(self, length=HASH_NBYTES, byteorder='big', *, signed=False) -> bytes:
-        """ A standard way to serialize DHTID into bytes """
+    def to_bytes(self, length=HASH_NBYTES, byteorder="big", *, signed=False) -> bytes:
+        """A standard way to serialize DHTID into bytes"""
         return super().to_bytes(length, byteorder, signed=signed)
 
     @classmethod
-    def from_bytes(cls, raw: bytes, byteorder='big', *, signed=False) -> DHTID:
-        """ reverse of to_bytes """
+    def from_bytes(cls, raw: bytes, byteorder="big", *, signed=False) -> DHTID:
+        """reverse of to_bytes"""
         return DHTID(super().from_bytes(raw, byteorder=byteorder, signed=signed))
 
     def __repr__(self):

+ 20 - 15
hivemind/dht/schema.py

@@ -18,8 +18,7 @@ class SchemaValidator(RecordValidatorBase):
     This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
     """
 
-    def __init__(self, schema: pydantic.BaseModel, *,
-                 allow_extra_keys: bool=True, prefix: Optional[str]=None):
+    def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool = True, prefix: Optional[str] = None):
         """
         :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
 
@@ -43,7 +42,7 @@ class SchemaValidator(RecordValidatorBase):
 
         self._key_id_to_field_name = {}
         for field in schema.__fields__.values():
-            raw_key = f'{prefix}_{field.name}' if prefix is not None else field.name
+            raw_key = f"{prefix}_{field.name}" if prefix is not None else field.name
             self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name
         self._allow_extra_keys = allow_extra_keys
 
@@ -79,8 +78,10 @@ class SchemaValidator(RecordValidatorBase):
 
         if record.key not in self._key_id_to_field_name:
             if not self._allow_extra_keys:
-                logger.debug(f"Record {record} has a key ID that is not defined in any of the "
-                             f"schemas (therefore, the raw key is unknown)")
+                logger.debug(
+                    f"Record {record} has a key ID that is not defined in any of the "
+                    f"schemas (therefore, the raw key is unknown)"
+                )
             return self._allow_extra_keys
 
         try:
@@ -102,9 +103,12 @@ class SchemaValidator(RecordValidatorBase):
 
             parsed_value = parsed_record.dict(by_alias=True)[field_name]
             if parsed_value != record[field_name]:
-                validation_errors.append(ValueError(
-                    f"The record {record} needed type conversions to match "
-                    f"the schema: {parsed_value}. Type conversions are not allowed"))
+                validation_errors.append(
+                    ValueError(
+                        f"The record {record} needed type conversions to match "
+                        f"the schema: {parsed_value}. Type conversions are not allowed"
+                    )
+                )
             else:
                 return True
 
@@ -120,17 +124,18 @@ class SchemaValidator(RecordValidatorBase):
         else:
             if isinstance(deserialized_value, dict):
                 raise ValueError(
-                    f'Record {record} contains an improperly serialized dictionary (you must use '
-                    f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
+                    f"Record {record} contains an improperly serialized dictionary (you must use "
+                    f"a DictionaryDHTValue of serialized values instead of a `dict` subclass)"
+                )
             return {field_name: deserialized_value}
 
     @staticmethod
     def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
         inner_errors = exc.errors()
         return (
-            len(inner_errors) == 1 and
-            inner_errors[0]['type'] == 'value_error.extra' and
-            len(inner_errors[0]['loc']) == 1  # Require the extra field to be on the top level
+            len(inner_errors) == 1
+            and inner_errors[0]["type"] == "value_error.extra"
+            and len(inner_errors[0]["loc"]) == 1  # Require the extra field to be on the top level
         )
 
     def merge_with(self, other: RecordValidatorBase) -> bool:
@@ -150,7 +155,7 @@ class SchemaValidator(RecordValidatorBase):
             self._patch_schema(schema)
 
 
-def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
+def conbytes(*, regex: bytes = None, **kwargs) -> Type[pydantic.BaseModel]:
     """
     Extend pydantic.conbytes() to support ``regex`` constraints (like pydantic.constr() does).
     """
@@ -172,4 +177,4 @@ def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
     return ConstrainedBytesWithRegex
 
 
-BytesWithPublicKey = conbytes(regex=b'.*' + RSASignatureValidator.PUBLIC_KEY_REGEX + b'.*')
+BytesWithPublicKey = conbytes(regex=b".*" + RSASignatureValidator.PUBLIC_KEY_REGEX + b".*")

+ 9 - 7
hivemind/dht/storage.py

@@ -9,15 +9,16 @@ from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage, DHTEx
 
 @MSGPackSerializer.ext_serializable(0x50)
 class DictionaryDHTValue(TimedStorage[Subkey, BinaryDHTValue]):
-    """ a dictionary-like DHT value type that maps sub-keys to values with individual expirations """
-    latest_expiration_time = float('-inf')
+    """a dictionary-like DHT value type that maps sub-keys to values with individual expirations"""
+
+    latest_expiration_time = float("-inf")
 
     def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool:
         self.latest_expiration_time = max(self.latest_expiration_time, expiration_time)
         return super().store(key, value, expiration_time)
 
     def packb(self) -> bytes:
-        """ custom behavior for MSGPackSerializer.dumps """
+        """custom behavior for MSGPackSerializer.dumps"""
         packed_items = [[key, value, expiration_time] for key, (value, expiration_time) in self.items()]
         return MSGPackSerializer.dumps([self.maxsize, self.latest_expiration_time, packed_items])
 
@@ -32,10 +33,11 @@ class DictionaryDHTValue(TimedStorage[Subkey, BinaryDHTValue]):
 
 
 class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTValue]]):
-    """ A dictionary-like storage that can store binary values and/or nested dictionaries until expiration """
+    """A dictionary-like storage that can store binary values and/or nested dictionaries until expiration"""
 
-    def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration,
-              subkey: Optional[Subkey] = None) -> bool:
+    def store(
+        self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration, subkey: Optional[Subkey] = None
+    ) -> bool:
         """
         Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
         If subkey is not None, adds a subkey-value pair to a dictionary associated with :key: (see store_subkey below)
@@ -54,7 +56,7 @@ class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTVal
          3) if self[key] is a normal value with smaller expiration time, overwrite it with a dictionary and add sub-key
         :returns: True if new entry was stored, False it was rejected (current value is newer)
         """
-        previous_value, previous_expiration_time = self.get(key) or (b'', -float('inf'))
+        previous_value, previous_expiration_time = self.get(key) or (b"", -float("inf"))
         if isinstance(previous_value, BinaryDHTValue) and expiration_time > previous_expiration_time:
             new_storage = DictionaryDHTValue()
             new_storage.store(subkey, value, expiration_time)

+ 33 - 19
hivemind/dht/traverse.py

@@ -9,9 +9,13 @@ from hivemind.dht.routing import DHTID
 ROOT = 0  # alias for heap root
 
 
-async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], beam_size: int,
-                              get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
-                              visited_nodes: Collection[DHTID] = ()) -> Tuple[Tuple[DHTID], Set[DHTID]]:
+async def simple_traverse_dht(
+    query_id: DHTID,
+    initial_nodes: Collection[DHTID],
+    beam_size: int,
+    get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
+    visited_nodes: Collection[DHTID] = (),
+) -> Tuple[Tuple[DHTID], Set[DHTID]]:
     """
     Traverse the DHT graph using get_neighbors function, find :beam_size: nearest nodes according to DHTID.xor_distance.
 
@@ -37,7 +41,9 @@ async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID],
     heapq.heapify(unvisited_nodes)  # nearest-first heap of candidates, unlimited size
 
     nearest_nodes = [(-distance, node_id) for distance, node_id in heapq.nsmallest(beam_size, unvisited_nodes)]
-    heapq.heapify(nearest_nodes)  # farthest-first heap of size beam_size, used for early-stopping and to select results
+    heapq.heapify(
+        nearest_nodes
+    )  # farthest-first heap of size beam_size, used for early-stopping and to select results
     while len(nearest_nodes) > beam_size:
         heapq.heappop(nearest_nodes)
 
@@ -63,10 +69,15 @@ async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID],
 
 
 async def traverse_dht(
-        queries: Collection[DHTID], initial_nodes: List[DHTID], beam_size: int, num_workers: int, queries_per_call: int,
-        get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[Tuple[DHTID], bool]]]],
-        found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
-        await_all_tasks: bool = True, visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
+    queries: Collection[DHTID],
+    initial_nodes: List[DHTID],
+    beam_size: int,
+    num_workers: int,
+    queries_per_call: int,
+    get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[Tuple[DHTID], bool]]]],
+    found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
+    await_all_tasks: bool = True,
+    visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
 ) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
     """
     Search the DHT for nearest neighbors to :queries: (based on DHTID.xor_distance). Use get_neighbors to request peers.
@@ -133,22 +144,22 @@ async def traverse_dht(
         visited_nodes[query] = set(visited_nodes.get(query, ()))
 
     def heuristic_priority(heap_query: DHTID):
-        """ Workers prioritize expanding nodes that lead to under-explored queries (by other workers) """
+        """Workers prioritize expanding nodes that lead to under-explored queries (by other workers)"""
         if has_candidates(heap_query):
             # prefer candidates in heaps with least number of concurrent workers, break ties by distance to query
             return active_workers[heap_query], candidate_nodes[heap_query][ROOT][0]
-        return float('inf'), float('inf')  # try not to explore vertices with no candidates
+        return float("inf"), float("inf")  # try not to explore vertices with no candidates
 
     def has_candidates(query: DHTID):
-        """ Whether this query's heap contains at least one candidate node that can be explored """
+        """Whether this query's heap contains at least one candidate node that can be explored"""
         return candidate_nodes[query] and candidate_nodes[query][ROOT][0] <= upper_bound(query)
 
     def upper_bound(query: DHTID):
-        """ Any node that is farther from query than upper_bound(query) will not be added to heaps """
-        return -nearest_nodes[query][ROOT][0] if len(nearest_nodes[query]) >= beam_size else float('inf')
+        """Any node that is farther from query than upper_bound(query) will not be added to heaps"""
+        return -nearest_nodes[query][ROOT][0] if len(nearest_nodes[query]) >= beam_size else float("inf")
 
     def finish_search(query):
-        """ Remove query from a list of targets """
+        """Remove query from a list of targets"""
         unfinished_queries.remove(query)
         if len(unfinished_queries) == 0:
             search_finished_event.set()
@@ -181,10 +192,14 @@ async def traverse_dht(
                 continue
 
             # find additional queries to pack in the same request
-            possible_additional_queries = [query for query in unfinished_queries
-                                           if query != chosen_query and chosen_peer not in visited_nodes[query]]
+            possible_additional_queries = [
+                query
+                for query in unfinished_queries
+                if query != chosen_query and chosen_peer not in visited_nodes[query]
+            ]
             queries_to_call = [chosen_query] + heapq.nsmallest(
-                queries_per_call - 1, possible_additional_queries, key=chosen_peer.xor_distance)
+                queries_per_call - 1, possible_additional_queries, key=chosen_peer.xor_distance
+            )
 
             # update priorities for subsequent workers
             active_workers.update(queries_to_call)
@@ -230,8 +245,7 @@ async def traverse_dht(
             await asyncio.gather(*pending_tasks)
 
         nearest_neighbors_per_query = {
-            query: [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
-            for query in queries
+            query: [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])] for query in queries
         }
         return nearest_neighbors_per_query, visited_nodes
 

+ 2 - 2
hivemind/dht/validation.py

@@ -66,7 +66,7 @@ class RecordValidatorBase(ABC):
 
         return 0
 
-    def merge_with(self, other: 'RecordValidatorBase') -> bool:
+    def merge_with(self, other: "RecordValidatorBase") -> bool:
         """
         By default, all validators are applied sequentially (i.e. we require all validate() calls
         to return True for a record to be validated successfully).
@@ -90,7 +90,7 @@ class RecordValidatorBase(ABC):
 
 
 class CompositeValidator(RecordValidatorBase):
-    def __init__(self, validators: Iterable[RecordValidatorBase]=()):
+    def __init__(self, validators: Iterable[RecordValidatorBase] = ()):
         self._validators = []
         self.extend(validators)
 

+ 7 - 7
hivemind/hivemind_cli/run_server.py

@@ -62,18 +62,18 @@ def main():
 
     # fmt:on
     args = vars(parser.parse_args())
-    args.pop('config', None)
-    optimizer = args.pop('optimizer')
-    if optimizer == 'adam':
+    args.pop("config", None)
+    optimizer = args.pop("optimizer")
+    if optimizer == "adam":
         optim_cls = torch.optim.Adam
-    elif optimizer == 'sgd':
+    elif optimizer == "sgd":
         optim_cls = partial(torch.optim.SGD, lr=0.01)
-    elif optimizer == 'none':
+    elif optimizer == "none":
         optim_cls = None
     else:
         raise ValueError("optim_cls must be adam, sgd or none")
 
-    if args.pop('increase_file_limit'):
+    if args.pop("increase_file_limit"):
         increase_file_limit()
 
     compression_type = args.pop("compression")
@@ -89,5 +89,5 @@ def main():
         server.shutdown()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

+ 179 - 64
hivemind/moe/client/beam_search.py

@@ -6,8 +6,17 @@ from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Ite
 
 from hivemind.dht import DHT, DHTNode, DHTExpiration
 from hivemind.moe.client.expert import RemoteExpert
-from hivemind.moe.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, UidEndpoint, Score, Coordinate,
-                                            PREFIX_PATTERN, UID_DELIMITER, is_valid_prefix)
+from hivemind.moe.server.expert_uid import (
+    ExpertUID,
+    ExpertPrefix,
+    FLAT_EXPERT,
+    UidEndpoint,
+    Score,
+    Coordinate,
+    PREFIX_PATTERN,
+    UID_DELIMITER,
+    is_valid_prefix,
+)
 from hivemind.utils import get_logger, get_dht_time, MPFuture
 
 logger = get_logger(__name__)
@@ -63,8 +72,16 @@ class MoEBeamSearcher:
          Though, this is a pathological case (e.g. only 90 experts in an oversized 100x100 grid) that should be avoided.
     """
 
-    def __init__(self, dht: DHT, uid_prefix: ExpertPrefix, grid_size: Sequence[int], num_workers: Optional[int] = None,
-                 negative_caching: bool = True, cache_expiration: DHTExpiration = 300, **kwargs):
+    def __init__(
+        self,
+        dht: DHT,
+        uid_prefix: ExpertPrefix,
+        grid_size: Sequence[int],
+        num_workers: Optional[int] = None,
+        negative_caching: bool = True,
+        cache_expiration: DHTExpiration = 300,
+        **kwargs,
+    ):
         if not uid_prefix.endswith(UID_DELIMITER):
             uid_prefix += UID_DELIMITER
             logger.info(f"Prefix must end with '{UID_DELIMITER}'. Changing to {uid_prefix}{UID_DELIMITER}")
@@ -75,27 +92,44 @@ class MoEBeamSearcher:
         self.negative_caching, self.cache_expiration = negative_caching, cache_expiration
         self.num_workers, self.dht_kwargs = num_workers, kwargs
 
-    def get_initial_beam(self, scores: Sequence[float], beam_size: int, return_future: bool = False
-                         ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+    def get_initial_beam(
+        self, scores: Sequence[float], beam_size: int, return_future: bool = False
+    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
         """
         :param scores: prefer suffix coordinates that have highest scores
         :param beam_size: select this many active suffixes with highest scores
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
         """
-        return self.dht.run_coroutine(partial(self._get_initial_beam, prefix=self.uid_prefix, beam_size=beam_size,
-                                              scores=tuple(scores), negative_caching=self.negative_caching,
-                                              cache_expiration=self.cache_expiration, num_workers=self.num_workers),
-                                      return_future)
+        return self.dht.run_coroutine(
+            partial(
+                self._get_initial_beam,
+                prefix=self.uid_prefix,
+                beam_size=beam_size,
+                scores=tuple(scores),
+                negative_caching=self.negative_caching,
+                cache_expiration=self.cache_expiration,
+                num_workers=self.num_workers,
+            ),
+            return_future,
+        )
 
     @staticmethod
     async def _get_initial_beam(
-            dht: DHT, node: DHTNode, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...],
-            negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None,
+        dht: DHT,
+        node: DHTNode,
+        prefix: ExpertPrefix,
+        beam_size: int,
+        scores: Tuple[float, ...],
+        negative_caching: bool,
+        cache_expiration: DHTExpiration,
+        num_workers: Optional[int] = None,
     ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
         num_workers = num_workers or dht.max_workers or beam_size
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
-        unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
+        unattempted_indices: List[Coordinate] = sorted(
+            range(len(scores)), key=scores.__getitem__
+        )  # from worst to best
         pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()
 
         while len(beam) < beam_size and (unattempted_indices or pending_tasks):
@@ -110,15 +144,25 @@ class MoEBeamSearcher:
             try:
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
-                    successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items()
-                                  if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list)
-                                  and len(match.value) == 2}
+                    successors = {
+                        coord: UidEndpoint(*match.value)
+                        for coord, match in maybe_prefix_data.value.items()
+                        if isinstance(coord, Coordinate)
+                        and isinstance(getattr(match, "value", None), list)
+                        and len(match.value) == 2
+                    }
                     if successors:
                         beam.append((scores[pending_best_index], pending_best_prefix, successors))
                 elif maybe_prefix_data is None and negative_caching:
                     logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
-                    asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + cache_expiration))
+                    asyncio.create_task(
+                        node.store(
+                            pending_best_prefix,
+                            subkey=-1,
+                            value=None,
+                            expiration_time=get_dht_time() + cache_expiration,
+                        )
+                    )
 
             except asyncio.CancelledError:
                 for _, pending_task in pending_tasks:
@@ -126,8 +170,9 @@ class MoEBeamSearcher:
                 raise
         return beam
 
-    def get_active_successors(self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
-                              return_future: bool = False) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+    def get_active_successors(
+        self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
+    ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
         """
         :param prefixes: a list of prefix for which to find active successor uids
         :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
@@ -138,35 +183,54 @@ class MoEBeamSearcher:
         assert not isinstance(prefixes, str), "Please send a list / tuple of expert prefixes."
         for prefix in prefixes:
             assert is_valid_prefix(prefix), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
-        return self.dht.run_coroutine(partial(
-            self._get_active_successors, prefixes=list(prefixes), grid_size=grid_size,
-            negative_caching=self.negative_caching, cache_expiration=self.cache_expiration,
-            num_workers=self.num_workers), return_future=return_future)
+        return self.dht.run_coroutine(
+            partial(
+                self._get_active_successors,
+                prefixes=list(prefixes),
+                grid_size=grid_size,
+                negative_caching=self.negative_caching,
+                cache_expiration=self.cache_expiration,
+                num_workers=self.num_workers,
+            ),
+            return_future=return_future,
+        )
 
     @staticmethod
     async def _get_active_successors(
-            dht: DHT, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int],
-            negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None
+        dht: DHT,
+        node: DHTNode,
+        prefixes: List[ExpertPrefix],
+        grid_size: Optional[int],
+        negative_caching: bool,
+        cache_expiration: DHTExpiration,
+        num_workers: Optional[int] = None,
     ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
-        grid_size = grid_size or float('inf')
+        grid_size = grid_size or float("inf")
         num_workers = num_workers or min(len(prefixes), dht.max_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
         for prefix, found in dht_responses.items():
             if found and isinstance(found.value, dict):
-                successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items()
-                                      if isinstance(coord, Coordinate) and 0 <= coord < grid_size
-                                      and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
+                successors[prefix] = {
+                    coord: UidEndpoint(*match.value)
+                    for coord, match in found.value.items()
+                    if isinstance(coord, Coordinate)
+                    and 0 <= coord < grid_size
+                    and isinstance(getattr(match, "value", None), list)
+                    and len(match.value) == 2
+                }
             else:
                 successors[prefix] = {}
                 if found is None and negative_caching:
                     logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
-                    asyncio.create_task(node.store(prefix, subkey=-1, value=None,
-                                                   expiration_time=get_dht_time() + cache_expiration))
+                    asyncio.create_task(
+                        node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
+                    )
         return successors
 
-    def find_best_experts(self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
-                          ) -> Union[List[RemoteExpert], MPFuture[RemoteExpert]]:
+    def find_best_experts(
+        self, grid_scores: Sequence[Sequence[float]], beam_size: int, return_future: bool = False
+    ) -> Union[List[RemoteExpert], MPFuture[RemoteExpert]]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -181,21 +245,37 @@ class MoEBeamSearcher:
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         assert len(grid_scores) == len(self.grid_size) and beam_size > 0
-        return self.dht.run_coroutine(partial(
-            self._find_best_experts, prefix=self.uid_prefix, beam_size=beam_size, grid_scores=list(grid_scores),
-            negative_caching=self.negative_caching, cache_expiration=self.cache_expiration,
-            num_workers=self.num_workers), return_future)
+        return self.dht.run_coroutine(
+            partial(
+                self._find_best_experts,
+                prefix=self.uid_prefix,
+                beam_size=beam_size,
+                grid_scores=list(grid_scores),
+                negative_caching=self.negative_caching,
+                cache_expiration=self.cache_expiration,
+                num_workers=self.num_workers,
+            ),
+            return_future,
+        )
 
     @classmethod
     async def _find_best_experts(
-            cls, dht: DHT, node: DHTNode, prefix: str, grid_scores: List[Tuple[float]], beam_size: int,
-            negative_caching: bool, cache_expiration: DHTExpiration, num_workers: Optional[int] = None
+        cls,
+        dht: DHT,
+        node: DHTNode,
+        prefix: str,
+        grid_scores: List[Tuple[float]],
+        beam_size: int,
+        negative_caching: bool,
+        cache_expiration: DHTExpiration,
+        num_workers: Optional[int] = None,
     ) -> List[RemoteExpert]:
         num_workers = num_workers or min(beam_size, dht.max_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
-            dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers))
+            dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
+        )
 
         best_experts_heap: List[Tuple[Score, UidEndpoint]] = []  # max-heap of expert uids/endpoints ordered by scores
         unique_experts: Set[ExpertUID] = set()
@@ -209,16 +289,27 @@ class MoEBeamSearcher:
 
             # form new beam using successors from the current beam
             dim_scores = grid_scores[dim_index]
-            best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(beam_size, (
-                (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
-                for prefix_score, prefix, suffixes in beam for next_coord in suffixes.keys()
-                if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)))
+            best_active_pairs: List[Tuple[Score, ExpertPrefix]] = heapq.nlargest(
+                beam_size,
+                (
+                    (prefix_score + dim_scores[next_coord], f"{prefix}{next_coord}{UID_DELIMITER}")
+                    for prefix_score, prefix, suffixes in beam
+                    for next_coord in suffixes.keys()
+                    if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
+                ),
+            )
             _, best_uid_prefixes = zip(*best_active_pairs)
 
             # search DHT for next step suffixes
             successors = await cls._get_active_successors(
-                dht, node, best_uid_prefixes, grid_size=None, negative_caching=negative_caching,
-                cache_expiration=cache_expiration, num_workers=num_workers)
+                dht,
+                node,
+                best_uid_prefixes,
+                grid_size=None,
+                negative_caching=negative_caching,
+                cache_expiration=cache_expiration,
+                num_workers=num_workers,
+            )
             beam = [(score, prefix, successors[prefix]) for score, prefix in best_active_pairs if successors[prefix]]
             if not beam:
                 logger.warning(f"Beam search had to terminate prematurely because of empty beam (dim 0)")
@@ -235,26 +326,32 @@ class MoEBeamSearcher:
         return best_experts
 
     @staticmethod
-    def _iterate_matching_experts(beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]],
-                                  grid_scores: Sequence[Sequence[float]]) -> Iterator[Tuple[Score, UidEndpoint]]:
-        """ iterate over all exemplar experts attached to current beam """
+    def _iterate_matching_experts(
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]], grid_scores: Sequence[Sequence[float]]
+    ) -> Iterator[Tuple[Score, UidEndpoint]]:
+        """iterate over all exemplar experts attached to current beam"""
         for score, prefix, suffixes in beam:
             for next_coord, match in suffixes.items():
                 if len(grid_scores) == 1 and next_coord == FLAT_EXPERT:
                     yield score, match
                 elif isinstance(match.uid, ExpertUID) and match.uid.count(UID_DELIMITER) == len(grid_scores):
                     expert_coords = match.uid.split(UID_DELIMITER)[1:]
-                    if all(coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
-                           for i, coord in enumerate(expert_coords)):
-                        expert_score = sum(scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords)))
+                    if all(
+                        coord.isdigit() and 0 <= int(coord) < len(grid_scores[i])
+                        for i, coord in enumerate(expert_coords)
+                    ):
+                        expert_score = sum(
+                            scores[coord] for scores, coord in zip(grid_scores, map(int, expert_coords))
+                        )
                         yield expert_score, match
                     else:
                         logger.warning(f"Found incompatible expert coordinates: {expert_coords}")
                 else:
                     logger.warning(f"Found incompatible expert UID: {match.uid}")
 
-    def batch_find_best_experts(self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int,
-                                return_future: bool = False) -> Union[List[List[RemoteExpert]], MPFuture]:
+    def batch_find_best_experts(
+        self, batch_grid_scores: Sequence[Sequence[Sequence[float]]], beam_size: int, return_future: bool = False
+    ) -> Union[List[List[RemoteExpert]], MPFuture]:
         """
         Find and return :beam_size: active experts with highest scores, use both local cache and DHT
 
@@ -267,16 +364,34 @@ class MoEBeamSearcher:
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
-        return self.dht.run_coroutine(partial(
-            self._batch_find_best_experts, prefix=self.uid_prefix, batch_grid_scores=batch_grid_scores,
-            beam_size=beam_size, negative_caching=self.negative_caching, num_workers=self.num_workers), return_future)
+        return self.dht.run_coroutine(
+            partial(
+                self._batch_find_best_experts,
+                prefix=self.uid_prefix,
+                batch_grid_scores=batch_grid_scores,
+                beam_size=beam_size,
+                negative_caching=self.negative_caching,
+                num_workers=self.num_workers,
+            ),
+            return_future,
+        )
 
     @classmethod
     async def _batch_find_best_experts(
-            cls, dht: DHT, node: DHTNode, prefix: str, batch_grid_scores: Sequence[Sequence[Tuple[float]]],
-            beam_size: int, negative_caching: bool, num_workers: Optional[int]) -> Sequence[Sequence[RemoteExpert]]:
-        batch_grid_scores = [[tuple(grid_score[i]) for grid_score in batch_grid_scores]
-                             for i in range(len(batch_grid_scores[0]))]
-        coros = [cls._find_best_experts(dht, node, prefix, grid_scores, beam_size, negative_caching, num_workers)
-                 for grid_scores in batch_grid_scores]
+        cls,
+        dht: DHT,
+        node: DHTNode,
+        prefix: str,
+        batch_grid_scores: Sequence[Sequence[Tuple[float]]],
+        beam_size: int,
+        negative_caching: bool,
+        num_workers: Optional[int],
+    ) -> Sequence[Sequence[RemoteExpert]]:
+        batch_grid_scores = [
+            [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
+        ]
+        coros = [
+            cls._find_best_experts(dht, node, prefix, grid_scores, beam_size, negative_caching, num_workers)
+            for grid_scores in batch_grid_scores
+        ]
         return await asyncio.gather(*coros)

+ 25 - 16
hivemind/moe/client/expert.py

@@ -14,8 +14,8 @@ DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autogra
 
 
 def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
-    """ Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache """
-    channel_options = (('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)) + extra_options
+    """Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache"""
+    channel_options = (("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)) + extra_options
     return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
 
 
@@ -41,20 +41,20 @@ class RemoteExpert(nn.Module):
         return _get_expert_stub(self.endpoint)
 
     def forward(self, *args, **kwargs):
-        """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """
-        assert len(kwargs) == len(self.info['keyword_names']), f"Keyword args should be {self.info['keyword_names']}"
-        kwargs = {key: kwargs[key] for key in self.info['keyword_names']}
+        """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
+        assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
+        kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
 
         # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
 
         forward_inputs = (args, kwargs)
 
-        if not nested_compare(forward_inputs, self.info['forward_schema']):
+        if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
-        return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
+        return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
 
     @property
     def info(self):
@@ -68,22 +68,29 @@ class RemoteExpert(nn.Module):
 
 
 class _RemoteModuleCall(torch.autograd.Function):
-    """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
+    """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
 
     @staticmethod
-    def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub,
-                info: Dict[str, Any], *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+    def forward(
+        ctx,
+        dummy: torch.Tensor,
+        uid: str,
+        stub: runtime_grpc.ConnectionHandlerStub,
+        info: Dict[str, Any],
+        *inputs: torch.Tensor,
+    ) -> Tuple[torch.Tensor, ...]:
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         # detach to avoid pickling the computation graph
         inputs = tuple(tensor.cpu().detach() for tensor in inputs)
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
 
-        serialized_tensors = [serialize_torch_tensor(inp, proto.compression)
-                              for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))]
+        serialized_tensors = [
+            serialize_torch_tensor(inp, proto.compression)
+            for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
+        ]
 
-        outputs = stub.forward(
-            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        outputs = stub.forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
 
@@ -95,8 +102,10 @@ class _RemoteModuleCall(torch.autograd.Function):
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-        serialized_tensors = [serialize_torch_tensor(tensor, proto.compression)
-                              for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
+        serialized_tensors = [
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        ]
 
         grad_inputs = ctx.stub.backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
 

+ 127 - 53
hivemind/moe/client/moe.py

@@ -43,10 +43,23 @@ class RemoteMixtureOfExperts(nn.Module):
     :param allow_zero_outputs: whether to return zeros if no experts respond on forward pass
     """
 
-    def __init__(self, *, in_features, grid_size: Tuple[int, ...], dht: hivemind.DHT, uid_prefix: str, k_best: int,
-                 k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None,
-                 backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False,
-                 allow_zero_outputs: bool = False, **dht_kwargs):
+    def __init__(
+        self,
+        *,
+        in_features,
+        grid_size: Tuple[int, ...],
+        dht: hivemind.DHT,
+        uid_prefix: str,
+        k_best: int,
+        k_min: int = 1,
+        forward_timeout: Optional[float] = None,
+        timeout_after_k_min: Optional[float] = None,
+        backward_k_min: int = 1,
+        backward_timeout: Optional[float] = None,
+        detect_anomalies: bool = False,
+        allow_zero_outputs: bool = False,
+        **dht_kwargs,
+    ):
         super().__init__()
         self.dht = dht
         self.beam_search = MoEBeamSearcher(dht, uid_prefix, grid_size, **dht_kwargs)
@@ -80,35 +93,49 @@ class RemoteMixtureOfExperts(nn.Module):
         grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
 
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
-            [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best)
+            [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best
+        )
 
         if self._expert_info is None:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
             except StopIteration:
-                raise RuntimeError("No responding experts found during beam search. Check that UID prefixes and "
-                                   "the grid size are consistent with running Server instances.")
+                raise RuntimeError(
+                    "No responding experts found during beam search. Check that UID prefixes and "
+                    "the grid size are consistent with running Server instances."
+                )
             except grpc.RpcError as e:
                 logger.warning(f"Failed to get RemoteMixtureOfExperts.output_shape: {e}")
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
-            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
-            self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info,
-            *nested_flatten(((input, *args), kwargs)))
+            DUMMY,
+            chosen_experts,
+            self.k_min,
+            self.backward_k_min,
+            self.timeout_after_k_min,
+            self.forward_timeout,
+            self.backward_timeout,
+            self.detect_anomalies,
+            self.allow_zero_outputs,
+            self.info,
+            *nested_flatten(((input, *args), kwargs)),
+        )
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
         expert_logits = self.compute_expert_scores(grid_scores, chosen_experts)
-        masked_logits = torch.full((1,), float('-inf'), device=expert_logits.device, dtype=expert_logits.dtype)
+        masked_logits = torch.full((1,), float("-inf"), device=expert_logits.device, dtype=expert_logits.dtype)
         expert_logits = torch.where(expert_mask, expert_logits, masked_logits)
         expert_weights = torch.softmax(expert_logits, dim=1)
         averaged_outputs_flat = [
             (expert_weights[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
-            for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
+            for tensor in expert_outputs
+        ]  # ^-- multiply by softmax weights along first 2 axes
 
-        return nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
+        return nested_pack(averaged_outputs_flat, self.info["outputs_schema"])
 
     def compute_expert_scores(
-            self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
+        self, grid_scores: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]
+    ) -> torch.Tensor:
         """
         Compute scores for each expert by adding up grid scores, autograd-friendly
         :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension
@@ -131,16 +158,17 @@ class RemoteMixtureOfExperts(nn.Module):
 
         grid_indices = torch.zeros([len(flat_experts), len(grid_scores)], dtype=torch.int64)
         for i, expert in enumerate(flat_experts):
-            expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
+            expert_indices = expert.uid[len(self.beam_search.uid_prefix) :]
             expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
         scores_per_dim = [
             dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
-            for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
+            for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)
+        ]
         flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)
 
-        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float("inf"), device=device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores
 
@@ -167,10 +195,21 @@ class _RemoteCallMany(torch.autograd.Function):
     """
 
     @classmethod
-    def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
-                timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
-                detect_anomalies: bool, allow_zero_outputs: bool, info: Dict[str, Any],
-                *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
+    def forward(
+        cls,
+        ctx,
+        dummy,
+        experts_per_sample: List[List[RemoteExpert]],
+        k_min: int,
+        backward_k_min: int,
+        timeout_after_k_min: float,
+        forward_timeout: Optional[float],
+        backward_timeout: Optional[float],
+        detect_anomalies: bool,
+        allow_zero_outputs: bool,
+        info: Dict[str, Any],
+        *flat_inputs: torch.Tensor,
+    ) -> Tuple[torch.Tensor]:
         assert not torch.is_grad_enabled()
         num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
 
@@ -187,29 +226,35 @@ class _RemoteCallMany(torch.autograd.Function):
         pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
-                input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(
-                    flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
+                input_tensors = [
+                    serialize_torch_tensor(tensor, proto.compression)
+                    for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
+                ]
                 stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
                 new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                 pending_tasks[new_task] = (i, j)
 
         responded_inds, alive_flat_outputs = cls._collect_responses(
-            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
+            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies
+        )
         if len(responded_inds) < k_min:
             raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout.")
 
-        if not isinstance(info['outputs_schema'], tuple):
-            outputs_schema = (info['outputs_schema'],)
+        if not isinstance(info["outputs_schema"], tuple):
+            outputs_schema = (info["outputs_schema"],)
         else:
-            outputs_schema = info['outputs_schema']
+            outputs_schema = info["outputs_schema"]
         outputs = nested_map(
             lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
-            outputs_schema)
+            outputs_schema,
+        )
 
         # assemble responses
         if len(responded_inds) > 0 or allow_zero_outputs:
-            batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, device=flat_inputs[0].device, dtype=torch.long),
-                                          list(zip(*responded_inds)) or ([], []))
+            batch_inds, expert_inds = map(
+                lambda x: torch.as_tensor(x, device=flat_inputs[0].device, dtype=torch.long),
+                list(zip(*responded_inds)) or ([], []),
+            )
 
             alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
             # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
@@ -218,15 +263,21 @@ class _RemoteCallMany(torch.autograd.Function):
                 output[batch_inds, expert_inds] = response_stacked.to(output.device)
 
         else:
-            raise RuntimeError('Forward pass: 0 experts responded within timeout and allow_zero_outputs is False')
+            raise RuntimeError("Forward pass: 0 experts responded within timeout and allow_zero_outputs is False")
 
         mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
         mask[batch_inds, expert_inds] = True
 
         # save individual outputs for backward pass
         ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
-        ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
-                                  detect_anomalies)
+        ctx._saved_non_tensors = (
+            info,
+            backward_k_min,
+            backward_timeout,
+            timeout_after_k_min,
+            experts_per_sample,
+            detect_anomalies,
+        )
 
         return (mask,) + outputs
 
@@ -234,8 +285,14 @@ class _RemoteCallMany(torch.autograd.Function):
     @once_differentiable
     def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
         assert not torch.is_grad_enabled()
-        (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample,
-         detect_anomalies) = ctx._saved_non_tensors
+        (
+            info,
+            backward_k_min,
+            backward_timeout,
+            timeout_after_k_min,
+            expert_per_sample,
+            detect_anomalies,
+        ) = ctx._saved_non_tensors
         alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
 
         dummy_grad_mask, *flat_grad_outputs = raw_grads
@@ -249,53 +306,68 @@ class _RemoteCallMany(torch.autograd.Function):
         num_samples, max_experts = dummy_grad_mask.shape
 
         inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
-        grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu))
+        grad_outputs_per_expert = zip(
+            *(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)
+        )
         backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
 
         # dispatch tasks to all remote experts, collect responses
         pending_tasks = {}
-        for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
-                                                    inputs_per_expert, grad_outputs_per_expert):
+        for i, j, inputs_ij, grad_outputs_ij in zip(
+            alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
+        ):
             expert = expert_per_sample[i.item()][j.item()]
             stub = _get_expert_stub(expert.endpoint)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
-            tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
-                                  for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
+            tensors_serialized = [
+                serialize_torch_tensor(tensor, proto.compression)
+                for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+            ]
             new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
             pending_tasks[new_task] = (i, j)
 
         survivor_inds, survivor_grad_inputs = cls._collect_responses(
-            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
+            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies
+        )
         if len(survivor_inds) < backward_k_min:
             raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout.")
 
         # assemble responses
-        batch_inds, expert_inds = map(lambda x: torch.as_tensor(x, dtype=torch.long),
-                                      list(zip(*survivor_inds)) or ([], []))
+        batch_inds, expert_inds = map(
+            lambda x: torch.as_tensor(x, dtype=torch.long), list(zip(*survivor_inds)) or ([], [])
+        )
 
         survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
         # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
 
         grad_inputs = nested_map(
             lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
-            list(nested_flatten(info['forward_schema'])))
+            list(nested_flatten(info["forward_schema"])),
+        )
 
         for grad_input, survivor_grad_stacked in zip(grad_inputs, survivor_grad_inputs_stacked):
             grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
                 (num_samples, max_experts, *grad_input.shape[1:]),
-                device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
+                device=survivor_grad_stacked.device,
+                dtype=survivor_grad_stacked.dtype,
+            )
             grad_input_per_expert[batch_inds, expert_inds] = survivor_grad_stacked
             grad_input.copy_(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
 
         return (DUMMY, None, None, None, None, None, None, None, None, None, *grad_inputs)
 
     @staticmethod
-    def _collect_responses(task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int,
-                           timeout_total: Optional[float], timeout_after_k_min: Optional[float], detect_anomalies: bool
-                           ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
-        """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
-        timeout_total = float('inf') if timeout_total is None else timeout_total
-        timeout_after_k_min = float('inf') if timeout_after_k_min is None else timeout_after_k_min
+    def _collect_responses(
+        task_to_indices: Dict[grpc.Future, Tuple[int, int]],
+        num_samples: int,
+        k_min: int,
+        timeout_total: Optional[float],
+        timeout_after_k_min: Optional[float],
+        detect_anomalies: bool,
+    ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
+        """await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers"""
+        timeout_total = float("inf") if timeout_total is None else timeout_total
+        timeout_after_k_min = float("inf") if timeout_after_k_min is None else timeout_after_k_min
         num_successful_tasks = [0 for _ in range(num_samples)]
         pending_samples = num_samples  # samples for which we have less than k_min results
         finished_indices, finished_outputs = [], []
@@ -309,7 +381,7 @@ class _RemoteCallMany(torch.autograd.Function):
                 task.add_done_callback(finished_tasks.put)
 
             for _ in range(len(task_to_indices)):
-                timeout = max(0.0, t_finish - time.perf_counter()) if t_finish != float('inf') else None
+                timeout = max(0.0, t_finish - time.perf_counter()) if t_finish != float("inf") else None
                 task = finished_tasks.get(timeout=timeout)
                 pending_tasks.discard(task)
 
@@ -323,7 +395,9 @@ class _RemoteCallMany(torch.autograd.Function):
                     num_successful_tasks[sample_index] += 1
                     if num_successful_tasks[sample_index] == k_min:
                         pending_samples -= 1
-                        if pending_samples <= 0:  # all tasks finished, await stragglers for at most timeout_after_k_min
+                        if (
+                            pending_samples <= 0
+                        ):  # all tasks finished, await stragglers for at most timeout_after_k_min
                             t_finish = min(t_finish, time.perf_counter() + timeout_after_k_min)
 
         except Empty:

+ 81 - 37
hivemind/moe/client/switch_moe.py

@@ -38,16 +38,32 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
     :param allow_zero_outputs: whether to return just the input if no experts respond on forward pass
     """
 
-    def __init__(self, *, grid_size: Tuple[int, ...], utilization_alpha: float = 0.9, grid_dropout: float = 1.0,
-                 jitter_eps: float = 1e-2, k_best=1, k_min=0, backward_k_min=0, allow_zero_outputs=True, **kwargs):
-        super().__init__(grid_size=grid_size, k_best=k_best, k_min=k_min, backward_k_min=backward_k_min,
-                         allow_zero_outputs=allow_zero_outputs, **kwargs)
+    def __init__(
+        self,
+        *,
+        grid_size: Tuple[int, ...],
+        utilization_alpha: float = 0.9,
+        grid_dropout: float = 1.0,
+        jitter_eps: float = 1e-2,
+        k_best=1,
+        k_min=0,
+        backward_k_min=0,
+        allow_zero_outputs=True,
+        **kwargs,
+    ):
+        super().__init__(
+            grid_size=grid_size,
+            k_best=k_best,
+            k_min=k_min,
+            backward_k_min=backward_k_min,
+            allow_zero_outputs=allow_zero_outputs,
+            **kwargs,
+        )
 
         initial_utilization = torch.cat(
-            [torch.tensor([1 / dim_size for _ in range(dim_size)], dtype=torch.float)
-             for dim_size in grid_size],
+            [torch.tensor([1 / dim_size for _ in range(dim_size)], dtype=torch.float) for dim_size in grid_size],
         )
-        self.register_buffer('grid_utilization', initial_utilization)
+        self.register_buffer("grid_utilization", initial_utilization)
         self.utilization_alpha = utilization_alpha
         self.grid_dropout = grid_dropout
         self.jitter_eps = jitter_eps
@@ -66,36 +82,56 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
         grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
 
         grid_dropout_masks = (
-            (torch.rand(size=(dim_size,), dtype=input_for_gating.dtype, device=input_for_gating.device)
-             < self.grid_dropout) for dim_size in self.beam_search.grid_size
+            (
+                torch.rand(size=(dim_size,), dtype=input_for_gating.dtype, device=input_for_gating.device)
+                < self.grid_dropout
+            )
+            for dim_size in self.beam_search.grid_size
         )
-        grid_scores_dropout = [torch.where(dropout_mask, grid_score,
-                                           torch.full((1,), float('-inf'), device=grid_score.device,
-                                                      dtype=grid_score.dtype))
-                               for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks)]
+        grid_scores_dropout = [
+            torch.where(
+                dropout_mask,
+                grid_score,
+                torch.full((1,), float("-inf"), device=grid_score.device, dtype=grid_score.dtype),
+            )
+            for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks)
+        ]
 
         grid_softmax = [torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout]
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
-            [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best)
+            [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best
+        )
 
         if self._expert_info is None:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
             except StopIteration:
-                raise RuntimeError("No responding experts found during beam search. Check that UID prefixes and "
-                                   "the grid size are consistent with running Server instances.")
+                raise RuntimeError(
+                    "No responding experts found during beam search. Check that UID prefixes and "
+                    "the grid size are consistent with running Server instances."
+                )
             except grpc.RpcError as e:
                 logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
 
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
-            DUMMY, chosen_experts, self.k_min, self.backward_k_min, self.timeout_after_k_min, self.forward_timeout,
-            self.backward_timeout, self.detect_anomalies, self.allow_zero_outputs, self.info,
-            *nested_flatten(((input, *args), kwargs)))
+            DUMMY,
+            chosen_experts,
+            self.k_min,
+            self.backward_k_min,
+            self.timeout_after_k_min,
+            self.forward_timeout,
+            self.backward_timeout,
+            self.detect_anomalies,
+            self.allow_zero_outputs,
+            self.info,
+            *nested_flatten(((input, *args), kwargs)),
+        )
         # ^-- multiple tensors of shape [batch_size, max_experts, ...output_shape]
 
         batch_utilization = self._compute_batch_utilization(chosen_experts, expert_mask)
-        self.grid_utilization = \
+        self.grid_utilization = (
             self.utilization_alpha * self.grid_utilization + (1 - self.utilization_alpha) * batch_utilization
+        )
 
         # compute expert probabilities as product across grid dimensions
         expert_probs = self.compute_expert_scores(grid_softmax, chosen_experts)
@@ -105,15 +141,21 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
         # multiply outputs by expert probabilities
         averaged_outputs_flat = [
             (expert_probs[..., None] * tensor.flatten(start_dim=2)).view(tensor.shape).sum(dim=1)
-            for tensor in expert_outputs]  # ^-- multiply by softmax weights along first 2 axes
+            for tensor in expert_outputs
+        ]  # ^-- multiply by softmax weights along first 2 axes
 
-        packed_outputs = nested_pack(averaged_outputs_flat, self.info['outputs_schema'])
+        packed_outputs = nested_pack(averaged_outputs_flat, self.info["outputs_schema"])
 
         # Load balancing loss: multiply fractions of probability mass and fractions of routed examples
         # for each grid dimension, sum across all indices for a dimension. Optimizing this leads to uniform allocation
-        balancing_loss = torch.stack([torch.mean(dim_softmax.mean(0) * dim_utilization) * (dim_size ** 2)
-                                      for dim_softmax, dim_utilization, dim_size in
-                                      zip(grid_softmax, self.grid_utilization, self.beam_search.grid_size)]).sum()
+        balancing_loss = torch.stack(
+            [
+                torch.mean(dim_softmax.mean(0) * dim_utilization) * (dim_size ** 2)
+                for dim_softmax, dim_utilization, dim_size in zip(
+                    grid_softmax, self.grid_utilization, self.beam_search.grid_size
+                )
+            ]
+        ).sum()
 
         # residual connection
         if isinstance(packed_outputs, torch.Tensor):
@@ -125,26 +167,27 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
 
     @torch.no_grad()
     def _compute_batch_utilization(self, batch_experts, expert_mask):
-        batch_utilization = [torch.zeros((dim_size,), dtype=self.grid_utilization.dtype,
-                                         device=self.grid_utilization.device)
-                             for dim_size in self.beam_search.grid_size]
+        batch_utilization = [
+            torch.zeros((dim_size,), dtype=self.grid_utilization.dtype, device=self.grid_utilization.device)
+            for dim_size in self.beam_search.grid_size
+        ]
 
         # out of chosen_experts, select those for which expert_mask is True
         for (sample_idx, expert_idx) in expert_mask.nonzero().cpu().numpy():
             expert = batch_experts[sample_idx][expert_idx]
-            expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
+            expert_indices = expert.uid[len(self.beam_search.uid_prefix) :]
             expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
 
             for dim_index, dim_utilization in zip(expert_indices, batch_utilization):
                 dim_utilization[dim_index] += 1
 
-        return torch.cat([
-            torch.nn.functional.normalize(dim_utilization, p=1, dim=0)
-            for dim_utilization in batch_utilization
-        ])
+        return torch.cat(
+            [torch.nn.functional.normalize(dim_utilization, p=1, dim=0) for dim_utilization in batch_utilization]
+        )
 
     def compute_expert_scores(
-            self, grid_probs: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]) -> torch.Tensor:
+        self, grid_probs: List[torch.Tensor], batch_experts: List[List[RemoteExpert]]
+    ) -> torch.Tensor:
         """
         Compute scores for each expert by multiplying grid probabilities, autograd-friendly
         :param grid_probs: list of torch tensors, i-th tensor contains scores for i-th grid dimension
@@ -167,15 +210,16 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
 
         grid_indices = torch.zeros([len(flat_experts), len(grid_probs)], dtype=torch.int64)
         for i, expert in enumerate(flat_experts):
-            expert_indices = expert.uid[len(self.beam_search.uid_prefix):]
+            expert_indices = expert.uid[len(self.beam_search.uid_prefix) :]
             expert_indices = list(map(int, expert_indices.split(UID_DELIMITER)))
             grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)
 
         scores_per_dim = [
             dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
-            for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
+            for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)
+        ]
         flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)
 
-        scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
+        scores = torch.full((batch_size, max_num_experts), fill_value=-float("inf"), device=device)
         scores[flat_batch_indices, flat_local_indices] = flat_scores  # backprop-able w.r.t. flat_scores
         return scores

+ 77 - 31
hivemind/moe/server/__init__.py

@@ -50,8 +50,16 @@ class Server(threading.Thread):
     """
 
     def __init__(
-            self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], listen_on: Endpoint = "0.0.0.0:*",
-            num_connection_handlers: int = 1, update_period: int = 30, start=False, checkpoint_dir=None, **kwargs):
+        self,
+        dht: Optional[DHT],
+        expert_backends: Dict[str, ExpertBackend],
+        listen_on: Endpoint = "0.0.0.0:*",
+        num_connection_handlers: int = 1,
+        update_period: int = 30,
+        start=False,
+        checkpoint_dir=None,
+        **kwargs,
+    ):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         if get_port(listen_on) is None:
@@ -66,19 +74,44 @@ class Server(threading.Thread):
         self.runtime = Runtime(self.experts, **kwargs)
 
         if self.dht and self.experts:
-            self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
-                                                       update_period=self.update_period, daemon=True)
+            self.dht_handler_thread = DHTHandlerThread(
+                experts=self.experts,
+                dht=self.dht,
+                endpoint=self.listen_on,
+                update_period=self.update_period,
+                daemon=True,
+            )
 
         if start:
             self.run_in_background(await_ready=True)
 
     @classmethod
-    def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
-               expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
-               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1,
-               max_batch_size=4096, device=None, no_dht=False, initial_peers=(),
-               checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE,
-               stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server:
+    def create(
+        cls,
+        listen_on="0.0.0.0:*",
+        num_experts: int = None,
+        expert_uids: str = None,
+        expert_pattern: str = None,
+        expert_cls="ffn",
+        hidden_dim=1024,
+        optim_cls=torch.optim.Adam,
+        scheduler: str = "none",
+        num_warmup_steps=None,
+        num_total_steps=None,
+        clip_grad_norm=None,
+        num_handlers=None,
+        min_batch_size=1,
+        max_batch_size=4096,
+        device=None,
+        no_dht=False,
+        initial_peers=(),
+        checkpoint_dir: Optional[Path] = None,
+        compression=CompressionType.NONE,
+        stats_report_interval: Optional[int] = None,
+        custom_module_path=None,
+        *,
+        start: bool,
+    ) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -121,22 +154,24 @@ class Server(threading.Thread):
             dht = hivemind.DHT(initial_peers=initial_peers, start=True)
             logger.info(f"Running DHT node on {dht.get_visible_maddrs()}, initial peers = {initial_peers}")
 
-        assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
-                (num_experts is not None and expert_uids is None)), \
-            "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
+        assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
+            num_experts is not None and expert_uids is None
+        ), "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"
 
         if expert_uids is None:
             if checkpoint_dir is not None:
                 assert is_directory(checkpoint_dir)
-                expert_uids = [child.name for child in checkpoint_dir.iterdir() if
-                               (child / 'checkpoint_last.pt').exists()]
+                expert_uids = [
+                    child.name for child in checkpoint_dir.iterdir() if (child / "checkpoint_last.pt").exists()
+                ]
                 total_experts_in_checkpoint = len(expert_uids)
                 logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")
 
                 if total_experts_in_checkpoint > num_experts:
                     raise ValueError(
                         f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
-                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints.")
+                        f"which is smaller. Either increase num_experts or remove unneeded checkpoints."
+                    )
             else:
                 expert_uids = []
 
@@ -148,7 +183,7 @@ class Server(threading.Thread):
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
         optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
-        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
+        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
         sample_input = name_to_input[expert_cls](3, hidden_dim)
         if isinstance(sample_input, tuple):
@@ -162,21 +197,32 @@ class Server(threading.Thread):
         experts = {}
         for expert_uid in expert_uids:
             expert = name_to_block[expert_cls](hidden_dim)
-            experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
-                                                         args_schema=args_schema,
-                                                         optimizer=optim_cls(expert.parameters()),
-                                                         scheduler=scheduler,
-                                                         num_warmup_steps=num_warmup_steps,
-                                                         num_total_steps=num_total_steps,
-                                                         clip_grad_norm=clip_grad_norm,
-                                                         min_batch_size=min_batch_size,
-                                                         max_batch_size=max_batch_size)
+            experts[expert_uid] = hivemind.ExpertBackend(
+                name=expert_uid,
+                expert=expert,
+                args_schema=args_schema,
+                optimizer=optim_cls(expert.parameters()),
+                scheduler=scheduler,
+                num_warmup_steps=num_warmup_steps,
+                num_total_steps=num_total_steps,
+                clip_grad_norm=clip_grad_norm,
+                min_batch_size=min_batch_size,
+                max_batch_size=max_batch_size,
+            )
 
         if checkpoint_dir is not None:
             load_experts(experts, checkpoint_dir)
 
-        return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
-                   checkpoint_dir=checkpoint_dir, stats_report_interval=stats_report_interval, start=start)
+        return cls(
+            dht,
+            experts,
+            listen_on=listen_on,
+            num_connection_handlers=num_handlers,
+            device=device,
+            checkpoint_dir=checkpoint_dir,
+            stats_report_interval=stats_report_interval,
+            start=start,
+        )
 
     def run(self):
         """
@@ -263,7 +309,7 @@ class Server(threading.Thread):
 
 @contextmanager
 def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, List[Multiaddr]]:
-    """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
+    """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     try:
@@ -273,7 +319,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.End
         start_ok, data = pipe.recv()
         if start_ok:
             yield data
-            pipe.send('SHUTDOWN')  # on exit from context, send shutdown signal
+            pipe.send("SHUTDOWN")  # on exit from context, send shutdown signal
         else:
             raise RuntimeError(f"Server failed to start: {data}")
     finally:
@@ -289,7 +335,7 @@ def _server_runner(pipe, *args, **kwargs):
         server = Server.create(*args, start=True, **kwargs)
     except Exception as e:
         logger.exception(f"Encountered an exception when starting a server: {e}")
-        pipe.send((False, f'{type(e).__name__} {e}'))
+        pipe.send((False, f"{type(e).__name__} {e}"))
         return
 
     try:

+ 6 - 6
hivemind/moe/server/checkpoints.py

@@ -51,16 +51,16 @@ class CheckpointSaver(threading.Thread):
 
 
 def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
-    logger.debug(f'Storing experts at {checkpoint_dir.absolute()}')
+    logger.debug(f"Storing experts at {checkpoint_dir.absolute()}")
     assert is_directory(checkpoint_dir)
-    timestamp = datetime.now().isoformat(sep='_')
+    timestamp = datetime.now().isoformat(sep="_")
     with TemporaryDirectory() as tmpdirname:
         for expert_name, expert_backend in experts.items():
             expert_dir = Path(tmpdirname) / expert_name
             expert_dir.mkdir()
-            checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
+            checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt"
             torch.save(expert_backend.get_full_state(), checkpoint_name)
-            os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
+            os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt")
         copy_tree(tmpdirname, str(checkpoint_dir))
 
 
@@ -68,8 +68,8 @@ def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
     assert is_directory(checkpoint_dir)
     for expert_name, expert in experts.items():
         checkpoints_folder = checkpoint_dir / expert_name
-        latest_checkpoint = checkpoints_folder / 'checkpoint_last.pt'
+        latest_checkpoint = checkpoints_folder / "checkpoint_last.pt"
         if latest_checkpoint.exists():
             expert.load_full_state(torch.load(latest_checkpoint))
         else:
-            logger.warning(f'Failed to load checkpoint for expert {expert_name}')
+            logger.warning(f"Failed to load checkpoint for expert {expert_name}")

+ 18 - 11
hivemind/moe/server/connection_handler.py

@@ -36,12 +36,15 @@ class ConnectionHandler(mp.context.ForkProcess):
 
         async def _run():
             grpc.aio.init_grpc_aio()
-            logger.debug(f'Starting, pid {os.getpid()}')
-            server = grpc.aio.server(options=GRPC_KEEPALIVE_OPTIONS + (
-                ('grpc.so_reuseport', 1),
-                ('grpc.max_send_message_length', -1),
-                ('grpc.max_receive_message_length', -1)
-            ))
+            logger.debug(f"Starting, pid {os.getpid()}")
+            server = grpc.aio.server(
+                options=GRPC_KEEPALIVE_OPTIONS
+                + (
+                    ("grpc.so_reuseport", 1),
+                    ("grpc.max_send_message_length", -1),
+                    ("grpc.max_receive_message_length", -1),
+                )
+            )
             runtime_grpc.add_ConnectionHandlerServicer_to_server(self, server)
 
             found_port = server.add_insecure_port(self.listen_on)
@@ -55,7 +58,7 @@ class ConnectionHandler(mp.context.ForkProcess):
         try:
             loop.run_until_complete(_run())
         except KeyboardInterrupt:
-            logger.debug('Caught KeyboardInterrupt, shutting down')
+            logger.debug("Caught KeyboardInterrupt, shutting down")
 
     async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
         return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
@@ -63,14 +66,18 @@ class ConnectionHandler(mp.context.ForkProcess):
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
-        serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
-                               in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))]
+        serialized_response = [
+            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
+            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
+        ]
 
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
     async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
-        serialized_response = [serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto
-                               in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))]
+        serialized_response = [
+            serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
+            for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
+        ]
         return runtime_pb2.ExpertResponse(tensors=serialized_response)

+ 26 - 13
hivemind/moe/server/dht_handler.py

@@ -4,8 +4,16 @@ from typing import Sequence, Dict, List, Tuple, Optional
 
 from hivemind.dht import DHT, DHTNode, DHTExpiration, DHTValue
 from hivemind.moe.client.expert import RemoteExpert
-from hivemind.moe.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, Coordinate,
-                                            UID_DELIMITER, UID_PATTERN, is_valid_uid, split_uid)
+from hivemind.moe.server.expert_uid import (
+    ExpertUID,
+    ExpertPrefix,
+    FLAT_EXPERT,
+    Coordinate,
+    UID_DELIMITER,
+    UID_PATTERN,
+    is_valid_uid,
+    split_uid,
+)
 from hivemind.utils import Endpoint, get_dht_time, get_port
 
 
@@ -25,8 +33,9 @@ class DHTHandlerThread(threading.Thread):
             declare_experts(self.dht, self.experts.keys(), self.endpoint)
 
 
-def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration = 300,
-                    wait: bool = True) -> Dict[ExpertUID, bool]:
+def declare_experts(
+    dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration = 300, wait: bool = True
+) -> Dict[ExpertUID, bool]:
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
 
@@ -39,18 +48,20 @@ def declare_experts(dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, exp
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
-    return dht.run_coroutine(partial(_declare_experts, uids=list(uids), endpoint=endpoint, expiration=expiration),
-                             return_future=not wait)
+    return dht.run_coroutine(
+        partial(_declare_experts, uids=list(uids), endpoint=endpoint, expiration=expiration), return_future=not wait
+    )
 
 
-async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint,
-                           expiration: DHTExpiration) -> Dict[ExpertUID, bool]:
+async def _declare_experts(
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
+) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
         data_to_store[uid, None] = endpoint
-        prefix = uid if uid.count(UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}'
+        prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         for i in range(prefix.count(UID_DELIMITER) - 1):
             prefix, last_coord = split_uid(prefix)
             data_to_store[prefix, last_coord] = [uid, endpoint]
@@ -60,8 +71,9 @@ async def _declare_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], endpo
     return store_ok
 
 
-def get_experts(dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None,
-                return_future: bool = False) -> List[Optional[RemoteExpert]]:
+def get_experts(
+    dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
+) -> List[Optional[RemoteExpert]]:
     """
     :param uids: find experts with these ids from across the DHT
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
@@ -72,8 +84,9 @@ def get_experts(dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTEx
     return dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future)
 
 
-async def _get_experts(dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-                       ) -> List[Optional[RemoteExpert]]:
+async def _get_experts(
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
+) -> List[Optional[RemoteExpert]]:
     if expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.max_workers is None else min(len(uids), dht.max_workers)

+ 68 - 47
hivemind/moe/server/expert_backend.py

@@ -4,7 +4,7 @@ import torch
 from torch import nn
 
 from hivemind.moe.server.task_pool import TaskPool
-from hivemind.utils import BatchTensorDescriptor, DUMMY_BATCH_SIZE
+from hivemind.utils.tensor_descr import BatchTensorDescriptor, DUMMY_BATCH_SIZE
 from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_flatten, nested_pack, nested_compare, nested_map
 
@@ -40,13 +40,21 @@ class ExpertBackend:
     :param kwargs: extra parameters to be forwarded into TaskPool.__init__
     """
 
-    def __init__(self, name: str, expert: nn.Module, optimizer: torch.optim.Optimizer, *,
-                 scheduler: Callable = None,
-                 args_schema: Tuple[BatchTensorDescriptor, ...] = None,
-                 kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
-                 outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
-                 num_warmup_steps: int = None, num_total_steps: int = None, clip_grad_norm: float = None,
-                 **kwargs):
+    def __init__(
+        self,
+        name: str,
+        expert: nn.Module,
+        optimizer: torch.optim.Optimizer,
+        *,
+        scheduler: Callable = None,
+        args_schema: Tuple[BatchTensorDescriptor, ...] = None,
+        kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
+        outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
+        num_warmup_steps: int = None,
+        num_total_steps: int = None,
+        clip_grad_norm: float = None,
+        **kwargs,
+    ):
         super().__init__()
         self.expert, self.optimizer, self.name = expert, optimizer, name
 
@@ -59,8 +67,10 @@ class ExpertBackend:
 
         self.args_schema = args_schema = tuple(args_schema or ())
         self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
-        assert args_schema or kwargs_schema, "expert must receive at least one positional or keyword input." \
-                                             " Did you forget to provide args_schema/kwargs_schema?"
+        assert args_schema or kwargs_schema, (
+            "expert must receive at least one positional or keyword input."
+            " Did you forget to provide args_schema/kwargs_schema?"
+        )
 
         if outputs_schema is None:
             # run expert once to get outputs schema
@@ -74,8 +84,8 @@ class ExpertBackend:
 
         self.backward_schema = (self.forward_schema, self.outputs_schema)  # inputs to backward
         self.grad_inputs_schema = self.forward_schema  # outputs from backward
-        self.forward_pool = TaskPool(self.forward, name=f'{self.name}_forward', **kwargs)
-        self.backward_pool = TaskPool(self.backward, name=f'{self.name}_backward', **kwargs)
+        self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
+        self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
 
         self.update_count = 0
         self.examples_processed = 0
@@ -125,11 +135,16 @@ class ExpertBackend:
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
         with torch.enable_grad():
-            args = [tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
-                    else tensor.detach() for tensor in args]
-            kwargs = {input_key: (tensor.detach().requires_grad_(True)
-                                  if tensor.is_floating_point() else tensor.detach())
-                      for input_key, tensor in kwargs.items()}
+            args = [
+                tensor.detach().requires_grad_(True)
+                if tensor.dtype in (torch.half, torch.float, torch.double)
+                else tensor.detach()
+                for tensor in args
+            ]
+            kwargs = {
+                input_key: (tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach())
+                for input_key, tensor in kwargs.items()
+            }
 
             batch_size = args[0].size(0)
 
@@ -138,15 +153,21 @@ class ExpertBackend:
 
             outputs_flat = tuple(nested_flatten(outputs))
 
-            grad_outputs_flat = tuple(map(
-                lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True),
-                nested_flatten(grad_outputs), outputs_flat))
-            torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat,
-                                    create_graph=False, retain_graph=False)
+            grad_outputs_flat = tuple(
+                map(
+                    lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True),
+                    nested_flatten(grad_outputs),
+                    outputs_flat,
+                )
+            )
+            torch.autograd.backward(
+                outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
+            )
             self.apply_gradients(batch_size)
 
-        return tuple(x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x)
-                     for x in nested_flatten((args, kwargs)))
+        return tuple(
+            x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs))
+        )
 
     def apply_gradients(self, batch_size) -> None:
         """
@@ -168,47 +189,47 @@ class ExpertBackend:
         """
         Return current expert training statistics (number of updates, number of processed examples after last optimizer step)
         """
-        return {
-            'updates': self.update_count,
-            'examples_processed': self.examples_processed
-        }
+        return {"updates": self.update_count, "examples_processed": self.examples_processed}
 
     def get_full_state(self) -> Dict:
         """
         Return the current state of the expert (including batch processing statistics)
         """
         full_state = {
-            'stats': self.get_stats(),
-            'model': self.expert.state_dict(),
-            'optimizer': self.optimizer.state_dict(),
-            'scheduler': {} if self.scheduler is None else self.scheduler.state_dict()
+            "stats": self.get_stats(),
+            "model": self.expert.state_dict(),
+            "optimizer": self.optimizer.state_dict(),
+            "scheduler": {} if self.scheduler is None else self.scheduler.state_dict(),
         }
         return full_state
 
     def load_full_state(self, state_dict: Dict):
-        if 'stats' in state_dict:
-            self.update_count = state_dict['stats']['updates']
-            self.examples_processed = state_dict['stats']['examples_processed']
+        if "stats" in state_dict:
+            self.update_count = state_dict["stats"]["updates"]
+            self.examples_processed = state_dict["stats"]["examples_processed"]
         else:
-            logger.warning(f'Batch processing stats missing for expert {self.name}')
+            logger.warning(f"Batch processing stats missing for expert {self.name}")
 
-        self.expert.load_state_dict(state_dict['model'])
+        self.expert.load_state_dict(state_dict["model"])
 
-        if 'optimizer' in state_dict:
-            self.optimizer.load_state_dict(state_dict['optimizer'])
+        if "optimizer" in state_dict:
+            self.optimizer.load_state_dict(state_dict["optimizer"])
         else:
-            logger.warning(f'Optimizer state missing for expert {self.name}')
+            logger.warning(f"Optimizer state missing for expert {self.name}")
 
-        if self.scheduler is not None and 'scheduler' in state_dict:
-            self.scheduler.load_state_dict(state_dict['scheduler'])
+        if self.scheduler is not None and "scheduler" in state_dict:
+            self.scheduler.load_state_dict(state_dict["scheduler"])
         else:
-            logger.warning(f'Learning rate scheduler state missing for expert {self.name}')
+            logger.warning(f"Learning rate scheduler state missing for expert {self.name}")
 
     def get_info(self) -> Dict[str, Any]:
-        """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """
-        return dict(forward_schema=self.forward_schema, outputs_schema=self.outputs_schema,
-                    keyword_names=tuple(self.kwargs_schema.keys()))
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(
+            forward_schema=self.forward_schema,
+            outputs_schema=self.outputs_schema,
+            keyword_names=tuple(self.kwargs_schema.keys()),
+        )
 
     def get_pools(self) -> Sequence[TaskPool]:
-        """ return all pools that should be processed by ``Runtime`` """
+        """return all pools that should be processed by ``Runtime``"""
         return self.forward_pool, self.backward_pool

+ 22 - 16
hivemind/moe/server/expert_uid.py

@@ -9,33 +9,34 @@ from hivemind.utils import Endpoint, get_logger
 logger = get_logger(__name__)
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
-UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
+UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
+UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
-UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
-PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
+UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims
+PREFIX_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$")  # e.g. expert. or ffn.45. (ends with ".")
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 
 
 def is_valid_uid(maybe_uid: str) -> bool:
-    """ An uid must contain a string expert type, followed by one or more .-separated numeric indices """
+    """An uid must contain a string expert type, followed by one or more .-separated numeric indices"""
     return bool(UID_PATTERN.fullmatch(maybe_uid))
 
 
 def is_valid_prefix(maybe_prefix: str) -> bool:
-    """ An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period """
+    """An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period"""
     return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
 
 
 def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
-    """ Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate """
+    """Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate"""
     uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
     pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
     return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
 
 
-def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
-                               attempts_per_expert=10) -> List[str]:
+def generate_uids_from_pattern(
+    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
+) -> List[str]:
     """
     Sample experts from a given pattern, remove duplicates.
     :param num_experts: sample this many unique expert uids
@@ -56,10 +57,10 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
         uid = []
         for block in expert_pattern.split(UID_DELIMITER):
             try:
-                if '[' not in block and ']' not in block:
+                if "[" not in block and "]" not in block:
                     uid.append(block)
-                elif block.startswith('[') and block.endswith(']') and ':' in block:
-                    slice_start, slice_end = map(int, block[1:-1].split(':'))
+                elif block.startswith("[") and block.endswith("]") and ":" in block:
+                    slice_start, slice_end = map(int, block[1:-1].split(":"))
                     uid.append(str(random.randint(slice_start, slice_end - 1)))
                 else:
                     raise ValueError("Block must be either fixed or a range [from:to]")
@@ -82,13 +83,18 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
 
         # 2. look into DHT (if given) and remove duplicates
         if dht:
-            existing_expert_uids = {found_expert.uid for found_expert in hivemind.moe.server.get_experts(dht, new_uids)
-                                    if found_expert is not None}
+            existing_expert_uids = {
+                found_expert.uid
+                for found_expert in hivemind.moe.server.get_experts(dht, new_uids)
+                if found_expert is not None
+            }
             new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
 
         found_uids += new_uids
 
     if len(found_uids) != num_experts:
-        logger.warning(f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
-                       f"{attempts_per_expert * num_experts} attempts")
+        logger.warning(
+            f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
+            f"{attempts_per_expert * num_experts} attempts"
+        )
     return found_uids

+ 1 - 1
hivemind/moe/server/layers/__init__.py

@@ -6,4 +6,4 @@ import hivemind.moe.server.layers.dropout
 from hivemind.moe.server.layers.custom_experts import add_custom_models_from_file, register_expert_class
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
-schedule_name_to_scheduler = {'linear': get_linear_schedule_with_warmup, 'none': None}
+schedule_name_to_scheduler = {"linear": get_linear_schedule_with_warmup, "none": None}

+ 8 - 11
hivemind/moe/server/layers/common.py

@@ -15,9 +15,8 @@ def gelu_fast(x):
 ffn_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
 
 
-@register_expert_class('ffn', ffn_sample_input)
+@register_expert_class("ffn", ffn_sample_input)
 class FeedforwardBlock(nn.Module):
-
     def __init__(self, hid_dim):
         super().__init__()
         self.ffn = nn.Linear(hid_dim, 4 * hid_dim)
@@ -67,14 +66,14 @@ class TransformerEncoderLayer(nn.Module):
         return src
 
 
-transformer_sample_input = lambda batch_size, hid_dim: \
-    (torch.empty((batch_size, 128, hid_dim)), \
-     torch.empty((batch_size, 128), dtype=torch.bool))
+transformer_sample_input = lambda batch_size, hid_dim: (
+    torch.empty((batch_size, 128, hid_dim)),
+    torch.empty((batch_size, 128), dtype=torch.bool),
+)
 
 
-@register_expert_class('transformer', transformer_sample_input)
+@register_expert_class("transformer", transformer_sample_input)
 class TunedTransformer(TransformerEncoderLayer):
-
     def __init__(self, hid_dim):
         super().__init__(hid_dim, dim_feedforward=4 * hid_dim, nhead=16)
 
@@ -82,9 +81,8 @@ class TunedTransformer(TransformerEncoderLayer):
 nop_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim))
 
 
-@register_expert_class('nop', nop_sample_input)
+@register_expert_class("nop", nop_sample_input)
 class NopExpert(nn.Sequential):
-
     def __init__(self, hid_dim):
         super().__init__()
         self.w = nn.Parameter(torch.zeros(0), requires_grad=True)
@@ -93,9 +91,8 @@ class NopExpert(nn.Sequential):
         return x.clone()
 
 
-@register_expert_class('nop_delay', nop_sample_input)
+@register_expert_class("nop_delay", nop_sample_input)
 class DelayedNopExpert(nn.Sequential):
-
     def __init__(self, hid_dim, delay=0.5):
         super().__init__()
         self.w = nn.Parameter(torch.zeros(0), requires_grad=True)

+ 3 - 2
hivemind/moe/server/layers/custom_experts.py

@@ -9,8 +9,7 @@ from hivemind.moe.server.layers import name_to_block, name_to_input
 
 
 def add_custom_models_from_file(path: str):
-    spec = importlib.util.spec_from_file_location(
-        "custom_module", os.path.abspath(path))
+    spec = importlib.util.spec_from_file_location("custom_module", os.path.abspath(path))
     foo = importlib.util.module_from_spec(spec)
     spec.loader.exec_module(foo)
 
@@ -24,6 +23,7 @@ def register_expert_class(name: str, sample_input: Callable[[int, int], torch.te
         sample of an input in the module
     :unchanged module
     """
+
     def _register_expert_class(custom_class: Type[nn.Module]):
         if name in name_to_block or name in name_to_input:
             raise RuntimeError("The class might already exist or be added twice")
@@ -31,4 +31,5 @@ def register_expert_class(name: str, sample_input: Callable[[int, int], torch.te
         name_to_input[name] = sample_input
 
         return custom_class
+
     return _register_expert_class

+ 8 - 5
hivemind/moe/server/layers/dropout.py

@@ -5,7 +5,6 @@ from hivemind.moe.server.layers.custom_experts import register_expert_class
 
 
 class DeterministicDropoutFunction(torch.autograd.Function):
-
     @staticmethod
     def forward(ctx, x, keep_prob, mask):
         ctx.keep_prob = keep_prob
@@ -33,11 +32,15 @@ class DeterministicDropout(nn.Module):
         else:
             return x
 
-dropout_sample_input = lambda batch_size, hid_dim: \
-    (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))
-@register_expert_class('det_dropout', dropout_sample_input)
-class DeterministicDropoutNetwork(nn.Module):
 
+dropout_sample_input = lambda batch_size, hid_dim: (
+    torch.empty((batch_size, hid_dim)),
+    torch.randint(0, 1, (batch_size, hid_dim)),
+)
+
+
+@register_expert_class("det_dropout", dropout_sample_input)
+class DeterministicDropoutNetwork(nn.Module):
     def __init__(self, hid_dim, dropout_prob=0.2):
         super().__init__()
         self.linear_in = nn.Linear(hid_dim, 2 * hid_dim)

+ 25 - 14
hivemind/moe/server/runtime.py

@@ -41,10 +41,17 @@ class Runtime(threading.Thread):
 
     :param stats_report_interval: interval to collect and log statistics about runtime performance
     """
+
     SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
 
-    def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
-                 device: torch.device = None, stats_report_interval: Optional[int] = None):
+    def __init__(
+        self,
+        expert_backends: Dict[str, ExpertBackend],
+        prefetch_batches=64,
+        sender_threads: int = 1,
+        device: torch.device = None,
+        stats_report_interval: Optional[int] = None,
+    ):
         super().__init__()
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
@@ -73,7 +80,8 @@ class Runtime(threading.Thread):
                 logger.info("Started")
 
                 for pool, batch_index, batch in BackgroundGenerator(
-                        self.iterate_minibatches_from_pools(), self.prefetch_batches):
+                    self.iterate_minibatches_from_pools(), self.prefetch_batches
+                ):
                     logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
 
                     start = time()
@@ -92,7 +100,7 @@ class Runtime(threading.Thread):
                     self.shutdown()
 
     def shutdown(self):
-        """ Gracefully terminate a running runtime. """
+        """Gracefully terminate a running runtime."""
         logger.info("Shutting down")
         self.ready.clear()
 
@@ -137,7 +145,7 @@ class Runtime(threading.Thread):
                 yield pool, batch_index, batch_tensors
 
 
-BatchStats = NamedTuple('BatchStats', (('batch_size', int), ('processing_time', float)))
+BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
 
 
 class StatsReporter(threading.Thread):
@@ -155,23 +163,26 @@ class StatsReporter(threading.Thread):
                 pool_batch_stats[pool_uid].append(batch_stats)
 
             total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
-            logger.info(f'Processed {total_processed_batches} batches in last {self.report_interval} seconds:')
+            logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
             for pool_uid, pool_stats in pool_batch_stats.items():
                 total_batches = len(pool_stats)
                 total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
                 avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
                 total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
                 batches_to_time = total_batches / total_time
-                batch_performance = f'{batches_to_time:.2f} ' + ('batches/s' if batches_to_time > 1 else 's/batch')
+                batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
 
                 examples_to_time = total_examples / total_time
-                example_performance = f'{examples_to_time:.2f} ' + (
-                    'examples/s' if examples_to_time > 1 else 's/example')
-
-                logger.info(f'{pool_uid}: '
-                            f'{total_batches} batches ({batch_performance}), '
-                            f'{total_examples} examples ({example_performance}), '
-                            f'avg batch size {avg_batch_size:.2f}')
+                example_performance = f"{examples_to_time:.2f} " + (
+                    "examples/s" if examples_to_time > 1 else "s/example"
+                )
+
+                logger.info(
+                    f"{pool_uid}: "
+                    f"{total_batches} batches ({batch_performance}), "
+                    f"{total_examples} examples ({example_performance}), "
+                    f"avg batch size {avg_batch_size:.2f}"
+                )
 
     def report_stats(self, pool_uid, batch_size, processing_time):
         batch_stats = BatchStats(batch_size, processing_time)

+ 33 - 19
hivemind/moe/server/task_pool.py

@@ -22,7 +22,7 @@ Task = namedtuple("Task", ("future", "args"))
 
 
 class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
-    """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
+    """A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime"""
 
     def __init__(self, process_func: callable, daemon=True, **kwargs):
         super().__init__(daemon=daemon, **kwargs)
@@ -71,8 +71,18 @@ class TaskPool(TaskPoolBase):
     :param start: if True, start automatically at the end of __init__
     """
 
-    def __init__(self, process_func: callable, max_batch_size: int, name: str, min_batch_size=1,
-                 timeout=None, pool_size=None, prefetch_batches=1, daemon=True, start=False):
+    def __init__(
+        self,
+        process_func: callable,
+        max_batch_size: int,
+        name: str,
+        min_batch_size=1,
+        timeout=None,
+        pool_size=None,
+        prefetch_batches=1,
+        daemon=True,
+        start=False,
+    ):
         super().__init__(process_func, daemon=daemon, name=name)
         self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
         self.prefetch_batches = prefetch_batches
@@ -89,7 +99,7 @@ class TaskPool(TaskPoolBase):
             self.start()
 
     def submit_task(self, *args: torch.Tensor) -> Future:
-        """ Add task to this pool's queue, return Future for its output """
+        """Add task to this pool's queue, return Future for its output"""
         task = Task(MPFuture(), args)
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
@@ -100,7 +110,7 @@ class TaskPool(TaskPoolBase):
         return task.future
 
     def iterate_minibatches(self, *args, **kwargs):
-        """ Form minibatches by grouping one or more tasks together up to self.max_batch_size """
+        """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
         batch = []
         total_size = 0
 
@@ -132,22 +142,23 @@ class TaskPool(TaskPoolBase):
 
     def run(self, *args, **kwargs):
         torch.set_num_threads(1)
-        logger.info(f'{self.name} starting, pid={os.getpid()}')
+        logger.info(f"{self.name} starting, pid={os.getpid()}")
         pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
 
-        output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
-                                         name=f'{self.name}_output', daemon=True)
+        output_thread = threading.Thread(
+            target=self._pool_output_loop, args=[pending_batches], name=f"{self.name}_output", daemon=True
+        )
 
         try:
             output_thread.start()
             self._pool_input_loop(pending_batches, *args, **kwargs)
         except KeyboardInterrupt:
-            logger.debug('Caught KeyboardInterrupt, shutting down')
+            logger.debug("Caught KeyboardInterrupt, shutting down")
         finally:
             output_thread.join()
 
     def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
-        """ Infinite loop: aggregate tasks into batches and send them to runtime """
+        """Infinite loop: aggregate tasks into batches and send them to runtime"""
 
         prev_num_tasks = 0  # number of tasks currently in shared buffer
         batch_index = max(pending_batches.keys(), default=0)
@@ -157,7 +168,9 @@ class TaskPool(TaskPoolBase):
             # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
             # assumes that tasks are processed in the same order as they are created
             for skip_i in range(prev_num_tasks):
-                finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
+                finished_task_timestamp = (
+                    self.undispatched_task_timestamps.get()
+                )  # earlier timestamp = higher priority
                 if skip_i == prev_num_tasks - 1:
                     self.priority = finished_task_timestamp
 
@@ -168,8 +181,7 @@ class TaskPool(TaskPoolBase):
 
             logger.debug(f"{self.name}, batch  {batch_index}: aggregating inputs")
             # find or create shared arrays for current batch size
-            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
-                            range(len(batch_tasks[0].args))]
+            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
             batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
 
             logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
@@ -179,7 +191,7 @@ class TaskPool(TaskPoolBase):
             batch_index += 1
 
     def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
-        """ Infinite loop: receive results from runtime and dispatch them to task Futures """
+        """Infinite loop: receive results from runtime and dispatch them to task Futures"""
 
         while True:
             logger.debug(f"{self.name} waiting for results from runtime")
@@ -204,7 +216,7 @@ class TaskPool(TaskPoolBase):
         return not self.batch_receiver.poll()
 
     def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
-        """ receive next batch of numpy arrays """
+        """receive next batch of numpy arrays"""
         if not self.batch_receiver.poll(timeout):
             raise TimeoutError()
 
@@ -213,11 +225,13 @@ class TaskPool(TaskPoolBase):
         return batch_index, batch_inputs
 
     def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
-        """ send results for a processed batch, previously loaded through load_batch_to_runtime """
-        batch_outputs = [tensor.to(device='cpu').share_memory_().detach().requires_grad_(tensor.requires_grad)
-                         for tensor in batch_outputs]
+        """send results for a processed batch, previously loaded through load_batch_to_runtime"""
+        batch_outputs = [
+            tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
+            for tensor in batch_outputs
+        ]
         self.outputs_sender.send((batch_index, batch_outputs))
 
     def get_task_size(self, task: Task) -> int:
-        """ compute task processing complexity (used for batching); defaults to batch size """
+        """compute task processing complexity (used for batching); defaults to batch size"""
         return len(task.args[0]) if task.args else 1

+ 11 - 4
hivemind/optim/adaptive.py

@@ -21,7 +21,14 @@ class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
         super().__init__(opt, average_opt_statistics=average_opt_statistics, **kwargs)
 
     def _make_averager(self, average_opt_statistics, **kwargs):
-        return TrainingAverager(self.opt, dht=self.dht, average_parameters=True, average_gradients=False,
-                                average_opt_statistics=average_opt_statistics,
-                                prefix=f"{self.prefix}_averaging", allreduce_timeout=self.averaging_timeout,
-                                listen=not self.client_mode, **kwargs)
+        return TrainingAverager(
+            self.opt,
+            dht=self.dht,
+            average_parameters=True,
+            average_gradients=False,
+            average_opt_statistics=average_opt_statistics,
+            prefix=f"{self.prefix}_averaging",
+            allreduce_timeout=self.averaging_timeout,
+            listen=not self.client_mode,
+            **kwargs,
+        )

+ 6 - 3
hivemind/optim/base.py

@@ -4,7 +4,8 @@ from hivemind.dht import DHT
 
 
 class DecentralizedOptimizerBase(torch.optim.Optimizer):
-    """ A shared interface for all hivemind optimizers. Cooperates with DHT peers to train a shared model """
+    """A shared interface for all hivemind optimizers. Cooperates with DHT peers to train a shared model"""
+
     def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
         self.opt, self.dht = opt, dht
 
@@ -17,8 +18,10 @@ class DecentralizedOptimizerBase(torch.optim.Optimizer):
         return self.opt.param_groups
 
     def add_param_group(self, param_group: dict) -> None:
-        raise ValueError(f"{self.__class__.__name__} does not support calling add_param_group after creation."
-                         f"Please provide all parameter groups at init.")
+        raise ValueError(
+            f"{self.__class__.__name__} does not support calling add_param_group after creation."
+            f"Please provide all parameter groups at init."
+        )
 
     def state_dict(self) -> dict:
         return self.opt.state_dict()

+ 128 - 64
hivemind/optim/collaborative.py

@@ -18,7 +18,7 @@ from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.utils import Endpoint, get_dht_time, get_logger
 
 logger = get_logger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 
 
 @dataclass(frozen=False)
@@ -38,7 +38,7 @@ class CollaborationState:
     def register_step(self, local_step: int):
         self.optimizer_step = max(local_step, self.optimizer_step)
         self.samples_accumulated = 0
-        self.eta_next_step = float('inf')
+        self.eta_next_step = float("inf")
 
 
 class TrainingState(BaseModel):
@@ -97,26 +97,45 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
     """
 
-    def __init__(self, opt: torch.optim.Optimizer, *, dht: DHT, prefix: str, target_batch_size: int,
-                 batch_size_per_step: Optional[int] = None, scheduler: Optional[LRSchedulerBase] = None,
-                 min_refresh_period: float = 0.5, max_refresh_period: float = 30, default_refresh_period: float = 3,
-                 expected_drift_peers: float = 3, expected_drift_rate: float = 0.2, performance_ema_alpha: float = 0.1,
-                 metadata_expiration: float = 60.0, averaging_timeout: Optional[float] = None, step_tolerance: int = 1,
-                 reuse_grad_buffers: bool = False, accumulate_grads_on: Optional[torch.device] = None,
-                 client_mode: bool = False, verbose: bool = False, **kwargs):
+    def __init__(
+        self,
+        opt: torch.optim.Optimizer,
+        *,
+        dht: DHT,
+        prefix: str,
+        target_batch_size: int,
+        batch_size_per_step: Optional[int] = None,
+        scheduler: Optional[LRSchedulerBase] = None,
+        min_refresh_period: float = 0.5,
+        max_refresh_period: float = 30,
+        default_refresh_period: float = 3,
+        expected_drift_peers: float = 3,
+        expected_drift_rate: float = 0.2,
+        performance_ema_alpha: float = 0.1,
+        metadata_expiration: float = 60.0,
+        averaging_timeout: Optional[float] = None,
+        step_tolerance: int = 1,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = False,
+        verbose: bool = False,
+        **kwargs,
+    ):
         super().__init__(opt, dht)
 
         signature_validator = RSASignatureValidator()
         self._local_public_key = signature_validator.local_public_key
-        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix),
-                            signature_validator])
+        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
 
         if reuse_grad_buffers and accumulate_grads_on is not None:
             logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
         self.prefix, self.scheduler = prefix, scheduler
         self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
-        self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = \
-            min_refresh_period, max_refresh_period, default_refresh_period
+        self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = (
+            min_refresh_period,
+            max_refresh_period,
+            default_refresh_period,
+        )
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
         self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
         self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
@@ -135,14 +154,22 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.lock_local_progress, self.should_report_progress = Lock(), Event()
         self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
         self.progress_reporter.start()
-        self.collaboration_state_updater = Thread(target=self.check_collaboration_state_periodically, daemon=True,
-                                                  name=f"{self}.collaboration_state_updater")
+        self.collaboration_state_updater = Thread(
+            target=self.check_collaboration_state_periodically, daemon=True, name=f"{self}.collaboration_state_updater"
+        )
         self.collaboration_state_updater.start()
 
     def _make_averager(self, **kwargs):
-        return TrainingAverager(self.opt, dht=self.dht, average_parameters=True, average_gradients=True,
-                                prefix=f"{self.prefix}_averaging", allreduce_timeout=self.averaging_timeout,
-                                listen=not self.client_mode, **kwargs)
+        return TrainingAverager(
+            self.opt,
+            dht=self.dht,
+            average_parameters=True,
+            average_gradients=True,
+            prefix=f"{self.prefix}_averaging",
+            allreduce_timeout=self.averaging_timeout,
+            listen=not self.client_mode,
+            **kwargs,
+        )
 
     @property
     def local_step(self) -> int:
@@ -156,7 +183,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         return self.averager.is_alive()
 
     def load_state_from_peers(self, **kwargs):
-        """ Attempt to fetch the newest collaboration state from other peers """
+        """Attempt to fetch the newest collaboration state from other peers"""
         with self.lock_collaboration_state:
             self.averager.load_state_from_peers(**kwargs)
             self.local_samples_accumulated = self.local_steps_accumulated = 0
@@ -183,8 +210,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
 
         if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
-            logger.warning(f"Training step took {get_dht_time() - self.last_step_time}, "
-                           f"but metadata expired in {self.metadata_expiration} s.")
+            logger.warning(
+                f"Training step took {get_dht_time() - self.last_step_time}, "
+                f"but metadata expired in {self.metadata_expiration} s."
+            )
 
         self.accumulate_grads_(batch_size)
 
@@ -207,7 +236,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
         with self.performance_ema.pause(), self.lock_collaboration_state:
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1. / self.local_steps_accumulated)
+            self.apply_accumulated_grads_(scale_by=1.0 / self.local_steps_accumulated)
             current_step, group_info = self.averager.local_step, None
 
             if self.collaboration_state.num_peers > 1:
@@ -221,8 +250,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
             else:
-                logger.log(self.status_loglevel, f"Skipped averaging: collaboration consists of "
-                                                 f"{self.collaboration_state.num_peers} peer(s).")
+                logger.log(
+                    self.status_loglevel,
+                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
+                )
 
             self.opt.step()
             self.reset_accumulated_grads_()
@@ -246,8 +277,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel,
-                   f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
+        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
         self.collaboration_state = self.fetch_collaboration_state()
         self.collaboration_state_updated.set()
 
@@ -257,8 +287,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             try:
                 group_info = self.averager.step(timeout=self.averaging_timeout, **kwargs)
                 if group_info:
-                    logger.log(self.status_loglevel,
-                               f"Averaged tensors successfully with {len(group_info)} peers")
+                    logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
             except BaseException as e:
                 logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
 
@@ -271,9 +300,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         return group_info
 
     def _grad_buffers(self) -> Iterator[torch.Tensor]:
-        """ pytorch-internal gradient buffers """
+        """pytorch-internal gradient buffers"""
         for param_group in self.opt.param_groups:
-            for param in param_group['params']:
+            for param in param_group["params"]:
                 if param.grad is None:
                     yield torch.zeros_like(param)
                 else:
@@ -281,17 +310,19 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     @torch.no_grad()
     def accumulated_grads(self) -> Iterator[torch.Tensor]:
-        """ local gradient accumulators """
+        """local gradient accumulators"""
         if self.reuse_grad_buffers:
             yield from self._grad_buffers()
         elif self._grads is None:
             with torch.no_grad():
-                self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
+                self._grads = [
+                    torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
+                ]
         yield from self._grads
 
     @torch.no_grad()
     def accumulate_grads_(self, batch_size: int):
-        """ add current gradients to grad accumulators (if any) """
+        """add current gradients to grad accumulators (if any)"""
         if self.reuse_grad_buffers:
             return  # user is responsible for accumulating gradients in .grad buffers
         alpha = float(batch_size) / self.batch_size_per_step
@@ -316,7 +347,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 grad_buf.zero_()
 
     def report_training_progress(self):
-        """ Periodically publish metadata and the current number of samples accumulated towards the next step """
+        """Periodically publish metadata and the current number of samples accumulated towards the next step"""
         while self.is_alive():
             self.should_report_progress.wait()
             self.should_report_progress.clear()
@@ -328,12 +359,16 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                     samples_accumulated=self.local_samples_accumulated,
                     samples_per_second=self.performance_ema.samples_per_second,
                     time=current_time,
-                    client_mode=not self.averager.listen)
+                    client_mode=not self.averager.listen,
+                )
 
-            self.dht.store(key=self.training_progress_key, subkey=self._local_public_key,
-                           value=local_state_info.dict(),
-                           expiration_time=current_time + self.metadata_expiration,
-                           return_future=True)
+            self.dht.store(
+                key=self.training_progress_key,
+                subkey=self._local_public_key,
+                value=local_state_info.dict(),
+                expiration_time=current_time + self.metadata_expiration,
+                return_future=True,
+            )
 
     def check_collaboration_state_periodically(self):
         """
@@ -349,21 +384,30 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 self.collaboration_state = self.fetch_collaboration_state()
 
     def fetch_collaboration_state(self) -> CollaborationState:
-        """ Read performance statistics reported by peers, estimate progress towards next batch """
-        response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float('inf'))
+        """Read performance statistics reported by peers, estimate progress towards next batch"""
+        response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         current_time = get_dht_time()
 
         if not isinstance(response, dict) or len(response) == 0:
             logger.log(self.status_loglevel, f"Found no active peers: {response}")
-            local_eta_next_step = max(0, self.target_batch_size - self.local_steps_accumulated
-                                      ) / self.performance_ema.samples_per_second
-            return CollaborationState(self.local_step, self.local_samples_accumulated, self.target_batch_size,
-                                      num_peers=0, num_clients=0, eta_next_step=current_time + local_eta_next_step,
-                                      next_fetch_time=current_time + self.default_refresh_period)
-
-        valid_peer_states = [TrainingState.parse_obj(peer_state.value)
-                             for peer_state in response.values()
-                             if peer_state.value is not None]
+            local_eta_next_step = (
+                max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
+            )
+            return CollaborationState(
+                self.local_step,
+                self.local_samples_accumulated,
+                self.target_batch_size,
+                num_peers=0,
+                num_clients=0,
+                eta_next_step=current_time + local_eta_next_step,
+                next_fetch_time=current_time + self.default_refresh_period,
+            )
+
+        valid_peer_states = [
+            TrainingState.parse_obj(peer_state.value)
+            for peer_state in response.values()
+            if peer_state.value is not None
+        ]
 
         num_peers = len(valid_peer_states)
         num_clients = sum(state.client_mode for state in valid_peer_states)
@@ -378,8 +422,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             total_samples_per_second += state.samples_per_second
             if state.step == global_optimizer_step:
                 total_samples_accumulated += state.samples_accumulated
-                estimated_current_samples += (state.samples_accumulated +
-                                              max(0, current_time - state.time) * state.samples_per_second)
+                estimated_current_samples += (
+                    state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
+                )
             # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
             # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
 
@@ -387,20 +432,35 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
 
         expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
-        time_to_next_fetch = float(np.clip(a=estimated_time_to_next_step * num_peers / expected_max_peers,
-                                           a_min=self.min_refresh_period, a_max=self.max_refresh_period))
-        logger.log(self.status_loglevel, f"Collaboration accumulated {total_samples_accumulated} samples from "
-                                         f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
-                                         f"(refresh in {time_to_next_fetch:.2f}s.)")
+        time_to_next_fetch = float(
+            np.clip(
+                a=estimated_time_to_next_step * num_peers / expected_max_peers,
+                a_min=self.min_refresh_period,
+                a_max=self.max_refresh_period,
+            )
+        )
+        logger.log(
+            self.status_loglevel,
+            f"Collaboration accumulated {total_samples_accumulated} samples from "
+            f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
+            f"(refresh in {time_to_next_fetch:.2f}s.)",
+        )
         return CollaborationState(
-            global_optimizer_step, total_samples_accumulated, target_batch_size=self.target_batch_size,
-            num_peers=num_peers, num_clients=num_clients, eta_next_step=current_time + estimated_time_to_next_step,
-            next_fetch_time=current_time + time_to_next_fetch)
+            global_optimizer_step,
+            total_samples_accumulated,
+            target_batch_size=self.target_batch_size,
+            num_peers=num_peers,
+            num_clients=num_clients,
+            eta_next_step=current_time + estimated_time_to_next_step,
+            next_fetch_time=current_time + time_to_next_fetch,
+        )
 
     def zero_grad(self, *args, **kwargs):
         if self.reuse_grad_buffers:
-            raise ValueError(f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
-                             f"call zero_grad manually. Gradients will be refreshed internally.")
+            raise ValueError(
+                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
+                f"call zero_grad manually. Gradients will be refreshed internally."
+            )
         return self.opt.zero_grad(*args, **kwargs)
 
     def update_scheduler(self):
@@ -412,8 +472,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         logger.debug("Shutting down averager...")
         self.averager.shutdown()
         logger.debug("Sending goodbye to peers...")
-        self.dht.store(self.training_progress_key, subkey=self._local_public_key, value=None,
-                       expiration_time=get_dht_time() + self.metadata_expiration)
+        self.dht.store(
+            self.training_progress_key,
+            subkey=self._local_public_key,
+            value=None,
+            expiration_time=get_dht_time() + self.metadata_expiration,
+        )
         logger.debug(f"{self.__class__.__name__} is shut down.")
 
     def __del__(self):

+ 2 - 1
hivemind/optim/performance_ema.py

@@ -8,6 +8,7 @@ class PerformanceEMA:
     A running estimate of performance (operations/sec) using adjusted exponential moving average
     :param alpha: Smoothing factor in range [0, 1], [default: 0.1].
     """
+
     def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
         self.alpha, self.eps, self.num_updates = alpha, eps, 0
         self.ema_seconds_per_sample, self.samples_per_second = 0, eps
@@ -31,7 +32,7 @@ class PerformanceEMA:
 
     @contextmanager
     def pause(self):
-        """ While inside this context, EMA will not count the time passed towards the performance estimate """
+        """While inside this context, EMA will not count the time passed towards the performance estimate"""
         self.paused, was_paused = True, self.paused
         try:
             yield

+ 90 - 24
hivemind/optim/simple.py

@@ -34,25 +34,46 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
     """
 
-    def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
-                 average_parameters: bool, average_gradients: bool, average_opt_statistics: Sequence[str] = (),
-                 averaging_steps_period: int = 1, averaging_time_period: float = 0,
-                 timeout: Optional[float] = None, verbose: bool = False, **kwargs):
+    def __init__(
+        self,
+        opt: torch.optim.Optimizer,
+        dht: DHT,
+        *,
+        prefix: str,
+        target_group_size: int,
+        average_parameters: bool,
+        average_gradients: bool,
+        average_opt_statistics: Sequence[str] = (),
+        averaging_steps_period: int = 1,
+        averaging_time_period: float = 0,
+        timeout: Optional[float] = None,
+        verbose: bool = False,
+        **kwargs,
+    ):
         super().__init__(opt, dht)
         assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
         self.local_step, self.averaging_step_period = 0, averaging_steps_period
 
-        self.averager = TrainingAverager(opt, average_parameters=average_parameters,
-                                         average_gradients=average_gradients,
-                                         average_opt_statistics=average_opt_statistics,
-                                         dht=dht, start=True, prefix=prefix,
-                                         target_group_size=target_group_size, **kwargs)
+        self.averager = TrainingAverager(
+            opt,
+            average_parameters=average_parameters,
+            average_gradients=average_gradients,
+            average_opt_statistics=average_opt_statistics,
+            dht=dht,
+            start=True,
+            prefix=prefix,
+            target_group_size=target_group_size,
+            **kwargs,
+        )
         self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
 
         self.background_averaging_thread = Thread(
-            name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
+            name=f"{self.__class__.__name__}",
+            daemon=True,
+            target=self._average_parameters_in_background,
             args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
-            kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose))
+            kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose),
+        )
         self.background_averaging_thread.start()
 
     def step(self, *args, **kwargs):
@@ -78,9 +99,15 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
     @staticmethod
     @torch.no_grad()
     def _average_parameters_in_background(
-            lock_parameters: Lock, update_event: Event, stop_event: Event, averager: TrainingAverager,
-            averaging_period: float, verbose: bool, **kwargs):
-        """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """
+        lock_parameters: Lock,
+        update_event: Event,
+        stop_event: Event,
+        averager: TrainingAverager,
+        averaging_period: float,
+        verbose: bool,
+        **kwargs,
+    ):
+        """Iteratively find groups of peers, average parameters with these peers and update local model parameters."""
         while not stop_event.is_set():
             update_event.wait()
             update_event.clear()
@@ -121,11 +148,30 @@ class DecentralizedSGD(DecentralizedOptimizer):
      https://arxiv.org/abs/2103.03239
     """
 
-    def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int,
-                 momentum: float = 0, dampening: float = 0, weight_decay: float = 0, nesterov: bool = False, **kwargs):
+    def __init__(
+        self,
+        params,
+        lr: float,
+        *,
+        dht: DHT,
+        prefix: str,
+        target_group_size: int,
+        momentum: float = 0,
+        dampening: float = 0,
+        weight_decay: float = 0,
+        nesterov: bool = False,
+        **kwargs,
+    ):
         opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
-        super().__init__(opt, dht, prefix=prefix, target_group_size=target_group_size, average_parameters=True,
-                         average_gradients=False, **kwargs)
+        super().__init__(
+            opt,
+            dht,
+            prefix=prefix,
+            target_group_size=target_group_size,
+            average_parameters=True,
+            average_gradients=False,
+            **kwargs,
+        )
 
 
 class DecentralizedAdam(DecentralizedOptimizer):
@@ -142,12 +188,32 @@ class DecentralizedAdam(DecentralizedOptimizer):
     - [2] Toward Communication Efficient Adaptive Gradient Method - https://dl.acm.org/doi/abs/10.1145/3412815.3416891
     """
 
-    def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int, averaging_steps_period: int,
-                 betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0,
-                 amsgrad: bool = False, **kwargs):
+    def __init__(
+        self,
+        params,
+        lr: float,
+        *,
+        dht: DHT,
+        prefix: str,
+        target_group_size: int,
+        averaging_steps_period: int,
+        betas: Tuple[float, float] = (0.9, 0.999),
+        eps: float = 1e-8,
+        weight_decay: float = 0,
+        amsgrad: bool = False,
+        **kwargs,
+    ):
         opt = torch.optim.Adam(params, lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
         opt_statistics = ("max_exp_avg_sq",) if amsgrad else ("exp_avg_sq",)
 
-        super().__init__(opt, dht, prefix=prefix, target_group_size=target_group_size, average_parameters=True,
-                         average_gradients=False, average_opt_statistics=opt_statistics,
-                         averaging_steps_period=averaging_steps_period, **kwargs)
+        super().__init__(
+            opt,
+            dht,
+            prefix=prefix,
+            target_group_size=target_group_size,
+            average_parameters=True,
+            average_gradients=False,
+            average_opt_statistics=opt_statistics,
+            averaging_steps_period=averaging_steps_period,
+            **kwargs,
+        )

+ 106 - 75
hivemind/p2p/p2p_daemon.py

@@ -20,7 +20,7 @@ from hivemind.utils.logging import get_logger
 logger = get_logger(__name__)
 
 
-P2PD_FILENAME = 'p2pd'
+P2PD_FILENAME = "p2pd"
 
 
 @dataclass(frozen=True)
@@ -48,20 +48,20 @@ class P2P:
     """
 
     HEADER_LEN = 8
-    BYTEORDER = 'big'
+    BYTEORDER = "big"
     PB_HEADER_LEN = 1
-    RESULT_MESSAGE = b'\x00'
-    ERROR_MESSAGE = b'\x01'
+    RESULT_MESSAGE = b"\x00"
+    ERROR_MESSAGE = b"\x01"
     DHT_MODE_MAPPING = {
-        'dht': {'dht': 1},
-        'dht_server': {'dhtServer': 1},
-        'dht_client': {'dhtClient': 1},
+        "dht": {"dht": 1},
+        "dht_server": {"dhtServer": 1},
+        "dht_client": {"dhtClient": 1},
     }
     FORCE_REACHABILITY_MAPPING = {
-        'public': {'forceReachabilityPublic': 1},
-        'private': {'forceReachabilityPrivate': 1},
+        "public": {"forceReachabilityPublic": 1},
+        "private": {"forceReachabilityPrivate": 1},
     }
-    _UNIX_SOCKET_PREFIX = '/unix/tmp/hivemind-'
+    _UNIX_SOCKET_PREFIX = "/unix/tmp/hivemind-"
 
     def __init__(self):
         self.id = None
@@ -71,18 +71,28 @@ class P2P:
         self._server_stopped = asyncio.Event()
 
     @classmethod
-    async def create(cls,
-                     initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-                     use_ipfs: bool = False,
-                     host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ('/ip4/127.0.0.1/tcp/0',),
-                     announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
-                     quic: bool = True, tls: bool = True, conn_manager: bool = True,
-                     dht_mode: str = 'dht_server', force_reachability: Optional[str] = None,
-                     nat_port_map: bool = True, auto_nat: bool = True,
-                     use_relay: bool = True, use_relay_hop: bool = False,
-                     use_relay_discovery: bool = False, use_auto_relay: bool = False, relay_hop_limit: int = 0,
-                     quiet: bool = True,
-                     ping_n_attempts: int = 5, ping_delay: float = 0.4) -> 'P2P':
+    async def create(
+        cls,
+        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
+        use_ipfs: bool = False,
+        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
+        quic: bool = True,
+        tls: bool = True,
+        conn_manager: bool = True,
+        dht_mode: str = "dht_server",
+        force_reachability: Optional[str] = None,
+        nat_port_map: bool = True,
+        auto_nat: bool = True,
+        use_relay: bool = True,
+        use_relay_hop: bool = False,
+        use_relay_discovery: bool = False,
+        use_auto_relay: bool = False,
+        relay_hop_limit: int = 0,
+        quiet: bool = True,
+        ping_n_attempts: int = 5,
+        ping_delay: float = 0.4,
+    ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
         :param initial_peers: List of bootstrap peers
@@ -109,34 +119,46 @@ class P2P:
         :return: a wrapper for the p2p daemon
         """
 
-        assert not (initial_peers and use_ipfs), \
-            'User-defined initial_peers and use_ipfs=True are incompatible, please choose one option'
+        assert not (
+            initial_peers and use_ipfs
+        ), "User-defined initial_peers and use_ipfs=True are incompatible, please choose one option"
 
         self = cls()
         with path(cli, P2PD_FILENAME) as p:
             p2pd_path = p
 
         socket_uid = secrets.token_urlsafe(8)
-        self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f'p2pd-{socket_uid}.sock')
-        self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f'p2pclient-{socket_uid}.sock')
+        self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pd-{socket_uid}.sock")
+        self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
 
         need_bootstrap = bool(initial_peers) or use_ipfs
-        process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0})
+        process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})
         process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
-        for param, value in [('bootstrapPeers', initial_peers),
-                             ('hostAddrs', host_maddrs),
-                             ('announceAddrs', announce_maddrs)]:
+        for param, value in [
+            ("bootstrapPeers", initial_peers),
+            ("hostAddrs", host_maddrs),
+            ("announceAddrs", announce_maddrs),
+        ]:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
 
         proc_args = self._make_process_args(
             str(p2pd_path),
             listen=self._daemon_listen_maddr,
-            quic=quic, tls=tls, connManager=conn_manager,
-            natPortMap=nat_port_map, autonat=auto_nat,
-            relay=use_relay, relayHop=use_relay_hop, relayDiscovery=use_relay_discovery,
-            autoRelay=use_auto_relay, relayHopLimit=relay_hop_limit,
-            b=need_bootstrap, q=quiet, **process_kwargs)
+            quic=quic,
+            tls=tls,
+            connManager=conn_manager,
+            natPortMap=nat_port_map,
+            autonat=auto_nat,
+            relay=use_relay,
+            relayHop=use_relay_hop,
+            relayDiscovery=use_relay_discovery,
+            autoRelay=use_auto_relay,
+            relayHopLimit=relay_hop_limit,
+            b=need_bootstrap,
+            q=quiet,
+            **process_kwargs,
+        )
 
         self._child = Popen(args=proc_args, encoding="utf8")
         self._alive = True
@@ -158,15 +180,15 @@ class P2P:
                 break
             except Exception as e:
                 if try_number == ping_n_attempts - 1:
-                    logger.exception('Failed to ping p2pd that has just started')
+                    logger.exception("Failed to ping p2pd that has just started")
                     await self.shutdown()
                     raise
 
         if self._child.returncode is not None:
-            raise RuntimeError(f'The p2p daemon has died with return code {self._child.returncode}')
+            raise RuntimeError(f"The p2p daemon has died with return code {self._child.returncode}")
 
     @classmethod
-    async def replicate(cls, daemon_listen_maddr: Multiaddr) -> 'P2P':
+    async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """
         Connect to existing p2p daemon
         :param daemon_listen_maddr: multiaddr of the existing p2p daemon
@@ -181,7 +203,7 @@ class P2P:
 
         socket_uid = secrets.token_urlsafe(8)
         self._daemon_listen_maddr = daemon_listen_maddr
-        self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f'p2pclient-{socket_uid}.sock')
+        self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
 
         self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
 
@@ -190,7 +212,7 @@ class P2P:
 
     async def _ping_daemon(self) -> None:
         self.id, self._visible_maddrs = await self._client.identify()
-        logger.debug(f'Launched p2pd with id = {self.id}, host multiaddrs = {self._visible_maddrs}')
+        logger.debug(f"Launched p2pd with id = {self.id}, host multiaddrs = {self._visible_maddrs}")
 
     async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
@@ -205,7 +227,7 @@ class P2P:
         if not self._visible_maddrs:
             raise ValueError(f"No multiaddrs found for peer {self.id}")
 
-        p2p_maddr = Multiaddr(f'/p2p/{self.id.to_base58()}')
+        p2p_maddr = Multiaddr(f"/p2p/{self.id.to_base58()}")
         return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
 
     async def list_peers(self) -> List[PeerInfo]:
@@ -218,7 +240,7 @@ class P2P:
                 return
             await asyncio.sleep(delay)
 
-        raise RuntimeError('Not enough peers')
+        raise RuntimeError("Not enough peers")
 
     @property
     def daemon_listen_maddr(self) -> Multiaddr:
@@ -237,7 +259,7 @@ class P2P:
     @staticmethod
     async def send_protobuf(protobuf, out_proto_type: type, writer: asyncio.StreamWriter) -> None:
         if type(protobuf) != out_proto_type:
-            raise TypeError('Unary handler returned protobuf of wrong type.')
+            raise TypeError("Unary handler returned protobuf of wrong type.")
         if out_proto_type == p2pd_pb2.RPCError:
             await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
         else:
@@ -257,8 +279,9 @@ class P2P:
         return MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
 
     @staticmethod
-    async def receive_protobuf(in_proto_type: type, reader: asyncio.StreamReader) -> \
-            Tuple[Any, Optional[p2pd_pb2.RPCError]]:
+    async def receive_protobuf(
+        in_proto_type: type, reader: asyncio.StreamReader
+    ) -> Tuple[Any, Optional[p2pd_pb2.RPCError]]:
         msg_type = await P2P.receive_raw_data(reader)
         if msg_type == P2P.RESULT_MESSAGE:
             protobuf = in_proto_type()
@@ -269,12 +292,13 @@ class P2P:
             protobuf.ParseFromString(await P2P.receive_raw_data(reader))
             return None, protobuf
         else:
-            raise TypeError('Invalid Protobuf message type')
+            raise TypeError("Invalid Protobuf message type")
 
     @staticmethod
     def _handle_stream(handle: Callable[[bytes], bytes]):
         async def do_handle_stream(
-                stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
+            stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
+        ):
             try:
                 request = await P2P.receive_raw_data(reader)
             except asyncio.IncompleteReadError:
@@ -289,31 +313,39 @@ class P2P:
 
         return do_handle_stream
 
-    def _handle_unary_stream(self, handle: Callable[[Any, P2PContext], Any], handle_name: str,
-                             in_proto_type: type, out_proto_type: type):
+    def _handle_unary_stream(
+        self, handle: Callable[[Any, P2PContext], Any], handle_name: str, in_proto_type: type, out_proto_type: type
+    ):
         async def watchdog(reader: asyncio.StreamReader) -> None:
             await reader.read(n=1)
             raise P2PInterruptedError()
 
-        async def do_handle_unary_stream(stream_info: StreamInfo,
-                                         reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
+        async def do_handle_unary_stream(
+            stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
+        ) -> None:
             try:
                 try:
                     request, err = await P2P.receive_protobuf(in_proto_type, reader)
                 except asyncio.IncompleteReadError:
-                    logger.debug(f'Incomplete read while receiving request from peer in {handle_name}')
+                    logger.debug(f"Incomplete read while receiving request from peer in {handle_name}")
                     return
                 except google.protobuf.message.DecodeError as error:
-                    logger.debug(f'Failed to decode request protobuf '
-                                 f'of type {in_proto_type} in {handle_name}: {error}')
+                    logger.debug(
+                        f"Failed to decode request protobuf " f"of type {in_proto_type} in {handle_name}: {error}"
+                    )
                     return
                 if err is not None:
-                    logger.debug(f'Got an error instead of a request in {handle_name}: {err}')
-
-                context = P2PContext(handle_name=handle_name, local_id=self.id,
-                                     remote_id=stream_info.peer_id, remote_maddr=stream_info.addr)
-                done, pending = await asyncio.wait([watchdog(reader), handle(request, context)],
-                                                   return_when=asyncio.FIRST_COMPLETED)
+                    logger.debug(f"Got an error instead of a request in {handle_name}: {err}")
+
+                context = P2PContext(
+                    handle_name=handle_name,
+                    local_id=self.id,
+                    remote_id=stream_info.peer_id,
+                    remote_maddr=stream_info.addr,
+                )
+                done, pending = await asyncio.wait(
+                    [watchdog(reader), handle(request, context)], return_when=asyncio.FIRST_COMPLETED
+                )
                 try:
                     result = done.pop().result()
                     await P2P.send_protobuf(result, out_proto_type, writer)
@@ -354,12 +386,12 @@ class P2P:
             self._start_listening()
         await self._client.stream_handler(name, self._handle_stream(handle))
 
-    async def add_unary_handler(self, name: str, handle: Callable[[Any, P2PContext], Any],
-                                in_proto_type: type, out_proto_type: type) -> None:
+    async def add_unary_handler(
+        self, name: str, handle: Callable[[Any, P2PContext], Any], in_proto_type: type, out_proto_type: type
+    ) -> None:
         if self._listen_task is None:
             self._start_listening()
-        await self._client.stream_handler(
-            name, self._handle_unary_stream(handle, name, in_proto_type, out_proto_type))
+        await self._client.stream_handler(name, self._handle_unary_stream(handle, name, in_proto_type, out_proto_type))
 
     async def call_peer_handler(self, peer_id: PeerID, handler_name: str, input_data: bytes) -> bytes:
         stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
@@ -369,14 +401,15 @@ class P2P:
         finally:
             writer.close()
 
-    async def call_unary_handler(self, peer_id: PeerID, handler_name: str,
-                                 request_protobuf: Any, response_proto_type: type) -> Any:
+    async def call_unary_handler(
+        self, peer_id: PeerID, handler_name: str, request_protobuf: Any, response_proto_type: type
+    ) -> Any:
         stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
         try:
             await P2P.send_protobuf(request_protobuf, type(request_protobuf), writer)
             result, err = await P2P.receive_protobuf(response_proto_type, reader)
             if err is not None:
-                raise P2PHandlerError(f'Failed to call unary handler {handler_name} at {peer_id}: {err.message}')
+                raise P2PHandlerError(f"Failed to call unary handler {handler_name} at {peer_id}: {err.message}")
 
             return result
         finally:
@@ -398,21 +431,19 @@ class P2P:
         if self._child is not None and self._child.poll() is None:
             self._child.terminate()
             self._child.wait()
-            logger.debug(f'Terminated p2pd with id = {self.id}')
+            logger.debug(f"Terminated p2pd with id = {self.id}")
 
             with suppress(FileNotFoundError):
-                os.remove(self._daemon_listen_maddr['unix'])
+                os.remove(self._daemon_listen_maddr["unix"])
         with suppress(FileNotFoundError):
-            os.remove(self._client_listen_maddr['unix'])
+            os.remove(self._client_listen_maddr["unix"])
 
     @staticmethod
     def _make_process_args(*args, **kwargs) -> List[str]:
         proc_args = []
+        proc_args.extend(str(entry) for entry in args)
         proc_args.extend(
-            str(entry) for entry in args
-        )
-        proc_args.extend(
-            f'-{key}={P2P._convert_process_arg_type(value)}' if value is not None else f'-{key}'
+            f"-{key}={P2P._convert_process_arg_type(value)}" if value is not None else f"-{key}"
             for key, value in kwargs.items()
         )
         return proc_args
@@ -425,7 +456,7 @@ class P2P:
 
     @staticmethod
     def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
-        return ','.join(str(addr) for addr in maddrs)
+        return ",".join(str(addr) for addr in maddrs)
 
 
 class P2PInterruptedError(Exception):

+ 14 - 38
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -6,17 +6,12 @@ Author: Kevin Mai-Husan Chia
 
 import asyncio
 from contextlib import asynccontextmanager
-from typing import (AsyncIterator, Awaitable, Callable, Dict, Iterable,
-                    Sequence, Tuple)
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
 
 from multiaddr import Multiaddr, protocols
 
-from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo,
-                                                             StreamInfo)
-from hivemind.p2p.p2p_daemon_bindings.utils import (DispatchFailure,
-                                                    raise_if_failed,
-                                                    read_pbmsg_safe,
-                                                    write_pbmsg)
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import DispatchFailure, raise_if_failed, read_pbmsg_safe, write_pbmsg
 from hivemind.proto import p2pd_pb2 as p2pd_pb
 from hivemind.utils.logging import get_logger
 
@@ -27,9 +22,7 @@ SUPPORT_CONN_PROTOCOLS = (
     # protocols.P_IP6,
     protocols.P_UNIX,
 )
-SUPPORTED_PROTOS = (
-    protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS
-)
+SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
 logger = get_logger(__name__)
 
 
@@ -38,8 +31,7 @@ def parse_conn_protocol(maddr: Multiaddr) -> int:
     proto_cand = proto_codes.intersection(SUPPORT_CONN_PROTOCOLS)
     if len(proto_cand) != 1:
         raise ValueError(
-            f"connection protocol should be only one protocol out of {SUPPORTED_PROTOS}"
-            f", maddr={maddr}"
+            f"connection protocol should be only one protocol out of {SUPPORTED_PROTOS}" f", maddr={maddr}"
         )
     return tuple(proto_cand)[0]
 
@@ -60,16 +52,14 @@ class DaemonConnector:
             port = int(self.control_maddr.value_for_protocol(protocols.P_TCP))
             return await asyncio.open_connection(host, port)
         else:
-            raise ValueError(
-                f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}"
-            )
+            raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
 
 class ControlClient:
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
 
     def __init__(
-            self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+        self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
     ) -> None:
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
@@ -98,9 +88,7 @@ class ControlClient:
             port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
             server = await asyncio.start_server(self._handler, port=port, host=host)
         else:
-            raise ValueError(
-                f"Protocol not supported: {protocols.protocol_with_code(proto_code)}"
-            )
+            raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
 
         async with server:
             yield self
@@ -127,9 +115,7 @@ class ControlClient:
         reader, writer = await self.daemon_connector.open_connection()
 
         maddrs_bytes = [i.to_bytes() for i in maddrs]
-        connect_req = p2pd_pb.ConnectRequest(
-            peer=peer_id.to_bytes(), addrs=maddrs_bytes
-        )
+        connect_req = p2pd_pb.ConnectRequest(peer=peer_id.to_bytes(), addrs=maddrs_bytes)
         req = p2pd_pb.Request(type=p2pd_pb.Request.CONNECT, connect=connect_req)
         await write_pbmsg(writer, req)
 
@@ -152,9 +138,7 @@ class ControlClient:
 
     async def disconnect(self, peer_id: PeerID) -> None:
         disconnect_req = p2pd_pb.DisconnectRequest(peer=peer_id.to_bytes())
-        req = p2pd_pb.Request(
-            type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req
-        )
+        req = p2pd_pb.Request(type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req)
         reader, writer = await self.daemon_connector.open_connection()
         await write_pbmsg(writer, req)
         resp = p2pd_pb.Response()  # type: ignore
@@ -167,12 +151,8 @@ class ControlClient:
     ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
         reader, writer = await self.daemon_connector.open_connection()
 
-        stream_open_req = p2pd_pb.StreamOpenRequest(
-            peer=peer_id.to_bytes(), proto=list(protocols)
-        )
-        req = p2pd_pb.Request(
-            type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req
-        )
+        stream_open_req = p2pd_pb.StreamOpenRequest(peer=peer_id.to_bytes(), proto=list(protocols))
+        req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req)
         await write_pbmsg(writer, req)
 
         resp = p2pd_pb.Response()  # type: ignore
@@ -188,12 +168,8 @@ class ControlClient:
         reader, writer = await self.daemon_connector.open_connection()
 
         listen_path_maddr_bytes = self.listen_maddr.to_bytes()
-        stream_handler_req = p2pd_pb.StreamHandlerRequest(
-            addr=listen_path_maddr_bytes, proto=[proto]
-        )
-        req = p2pd_pb.Request(
-            type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req
-        )
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto])
+        req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
         await write_pbmsg(writer, req)
 
         resp = p2pd_pb.Response()  # type: ignore

+ 7 - 23
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -33,9 +33,7 @@ if ENABLE_INLINING:
         def digest(self) -> bytes:
             return self._digest
 
-    multihash.FuncReg.register(
-        IDENTITY_MULTIHASH_CODE, "identity", hash_new=IdentityHash
-    )
+    multihash.FuncReg.register(IDENTITY_MULTIHASH_CODE, "identity", hash_new=IdentityHash)
 
 
 class PeerID:
@@ -98,21 +96,15 @@ class StreamInfo:
         self.proto = proto
 
     def __repr__(self) -> str:
-        return (
-            f"<StreamInfo peer_id={self.peer_id} addr={self.addr} proto={self.proto}>"
-        )
+        return f"<StreamInfo peer_id={self.peer_id} addr={self.addr} proto={self.proto}>"
 
     def to_protobuf(self) -> p2pd_pb2.StreamInfo:
-        pb_msg = p2pd_pb2.StreamInfo(
-            peer=self.peer_id.to_bytes(), addr=self.addr.to_bytes(), proto=self.proto
-        )
+        pb_msg = p2pd_pb2.StreamInfo(peer=self.peer_id.to_bytes(), addr=self.addr.to_bytes(), proto=self.proto)
         return pb_msg
 
     @classmethod
     def from_protobuf(cls, pb_msg: p2pd_pb2.StreamInfo) -> "StreamInfo":
-        stream_info = cls(
-            peer_id=PeerID(pb_msg.peer), addr=Multiaddr(pb_msg.addr), proto=pb_msg.proto
-        )
+        stream_info = cls(peer_id=PeerID(pb_msg.peer), addr=Multiaddr(pb_msg.addr), proto=pb_msg.proto)
         return stream_info
 
 
@@ -122,11 +114,7 @@ class PeerInfo:
         self.addrs = list(addrs)
 
     def __eq__(self, other: Any) -> bool:
-        return (
-            isinstance(other, PeerInfo)
-            and self.peer_id == other.peer_id
-            and self.addrs == other.addrs
-        )
+        return isinstance(other, PeerInfo) and self.peer_id == other.peer_id and self.addrs == other.addrs
 
     @classmethod
     def from_protobuf(cls, peer_info_pb: p2pd_pb2.PeerInfo) -> "PeerInfo":
@@ -148,16 +136,12 @@ def info_from_p2p_addr(addr: Multiaddr) -> PeerInfo:
 
     parts = addr.split()
     if not parts:
-        raise InvalidAddrError(
-            f"`parts`={parts} should at least have a protocol `P_P2P`"
-        )
+        raise InvalidAddrError(f"`parts`={parts} should at least have a protocol `P_P2P`")
 
     p2p_part = parts[-1]
     last_protocol_code = p2p_part.protocols()[0].code
     if last_protocol_code != protocols.P_P2P:
-        raise InvalidAddrError(
-            f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`"
-        )
+        raise InvalidAddrError(f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`")
 
     # make sure the /p2p value parses as a peer.ID
     peer_id_str: str = p2p_part.value_for_protocol(protocols.P_P2P)

+ 4 - 11
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,23 +10,16 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 from multiaddr import Multiaddr
 
-from hivemind.p2p.p2p_daemon_bindings.control import (ControlClient,
-                                                      DaemonConnector,
-                                                      StreamHandler)
-from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo,
-                                                             StreamInfo)
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
 class Client:
     control: ControlClient
 
-    def __init__(
-        self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None
-    ) -> None:
+    def __init__(self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> None:
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        self.control = ControlClient(
-            daemon_connector=daemon_connector, listen_maddr=listen_maddr
-        )
+        self.control = ControlClient(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
 
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["Client"]:

+ 28 - 18
hivemind/p2p/servicer.py

@@ -46,32 +46,38 @@ class Servicer:
 
         self._rpc_handlers = []
         for method_name, method in self.__class__.__dict__.items():
-            if method_name.startswith('rpc_') and callable(method):
-                handle_name = f'{class_name}.{method_name}'
+            if method_name.startswith("rpc_") and callable(method):
+                handle_name = f"{class_name}.{method_name}"
 
                 hints = method.__annotations__
                 try:
-                    request_type = self._hint_to_type(hints['request'])
-                    response_type = self._hint_to_type(hints['return'])
+                    request_type = self._hint_to_type(hints["request"])
+                    response_type = self._hint_to_type(hints["return"])
                 except (KeyError, ValueError):
-                    raise ValueError(f'{handle_name} is expected to have type annotations like `dht_pb2.FindRequest` '
-                                     f'(a type from the hivemind.proto module) for the `request` parameter '
-                                     f'and the return value')
+                    raise ValueError(
+                        f"{handle_name} is expected to have type annotations like `dht_pb2.FindRequest` "
+                        f"(a type from the hivemind.proto module) for the `request` parameter "
+                        f"and the return value"
+                    )
 
                 self._rpc_handlers.append(RPCHandler(method_name, handle_name, request_type, response_type))
 
-        self._stub_type = type(f'{class_name}Stub', (StubBase,),
-                               {handler.method_name: self._make_rpc_caller(handler)
-                                for handler in self._rpc_handlers})
+        self._stub_type = type(
+            f"{class_name}Stub",
+            (StubBase,),
+            {handler.method_name: self._make_rpc_caller(handler) for handler in self._rpc_handlers},
+        )
 
     @staticmethod
     def _make_rpc_caller(handler: RPCHandler):
         # This method will be added to a new Stub type (a subclass of StubBase)
-        async def caller(self: StubBase, request: handler.request_type,
-                         timeout: Optional[float] = None) -> handler.response_type:
+        async def caller(
+            self: StubBase, request: handler.request_type, timeout: Optional[float] = None
+        ) -> handler.response_type:
             return await asyncio.wait_for(
                 self._p2p.call_unary_handler(self._peer, handler.handle_name, request, handler.response_type),
-                timeout=timeout)
+                timeout=timeout,
+            )
 
         caller.__name__ = handler.method_name
         return caller
@@ -79,8 +85,12 @@ class Servicer:
     async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
         servicer = self if wrapper is None else wrapper
         for handler in self._rpc_handlers:
-            await p2p.add_unary_handler(handler.handle_name, getattr(servicer, handler.method_name),
-                                        handler.request_type, handler.response_type)
+            await p2p.add_unary_handler(
+                handler.handle_name,
+                getattr(servicer, handler.method_name),
+                handler.request_type,
+                handler.response_type,
+            )
 
     def get_stub(self, p2p: P2P, peer: PeerID) -> StubBase:
         return self._stub_type(p2p, peer)
@@ -90,9 +100,9 @@ class Servicer:
         if isinstance(hint, type):
             return hint
 
-        module_name, proto_name = hint.split('.')
-        module = importlib.import_module('hivemind.proto.' + module_name)
+        module_name, proto_name = hint.split(".")
+        module = importlib.import_module("hivemind.proto." + module_name)
         result = getattr(module, proto_name)
         if not isinstance(result, type):
-            raise ValueError(f'`hivemind.proto.{hint}` is not a type')
+            raise ValueError(f"`hivemind.proto.{hint}` is not a type")
         return result

+ 2 - 2
hivemind/utils/__init__.py

@@ -6,6 +6,6 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
-from hivemind.utils.serializer import *
-from hivemind.utils.tensor_descr import *
+from hivemind.utils.serializer import SerializerBase, MSGPackSerializer
+from hivemind.utils.tensor_descr import TensorDescriptor, BatchTensorDescriptor
 from hivemind.utils.timed_storage import *

+ 14 - 10
hivemind/utils/asyncio.py

@@ -7,12 +7,12 @@ import uvloop
 from hivemind.utils.logging import get_logger
 
 
-T = TypeVar('T')
+T = TypeVar("T")
 logger = get_logger(__name__)
 
 
 def switch_to_uvloop() -> asyncio.AbstractEventLoop:
-    """ stop any running event loops; install uvloop; then create, set and return a new event loop """
+    """stop any running event loops; install uvloop; then create, set and return a new event loop"""
     try:
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
     except RuntimeError as error_no_event_loop:
@@ -24,18 +24,18 @@ def switch_to_uvloop() -> asyncio.AbstractEventLoop:
 
 
 async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
-    """ equivalent to next(iter) for asynchronous iterators. Modifies aiter in-place! """
+    """equivalent to next(iter) for asynchronous iterators. Modifies aiter in-place!"""
     return await aiter.__anext__()
 
 
 async def aiter(*args: T) -> AsyncIterator[T]:
-    """ create an asynchronous iterator from a sequence of values """
+    """create an asynchronous iterator from a sequence of values"""
     for arg in args:
         yield arg
 
 
 async def azip(*iterables: AsyncIterable[T]) -> AsyncIterator[Tuple[T, ...]]:
-    """ equivalent of zip for asynchronous iterables """
+    """equivalent of zip for asynchronous iterables"""
     iterators = [iterable.__aiter__() for iterable in iterables]
     while True:
         try:
@@ -45,14 +45,14 @@ async def azip(*iterables: AsyncIterable[T]) -> AsyncIterator[Tuple[T, ...]]:
 
 
 async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
-    """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
+    """equivalent to chain(iter1, iter2, ...) for asynchronous iterators."""
     for aiter in async_iters:
         async for elem in aiter:
             yield elem
 
 
 async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]]:
-    """ equivalent to enumerate(iter) for asynchronous iterators. """
+    """equivalent to enumerate(iter) for asynchronous iterators."""
     index = 0
     async for elem in aiterable:
         yield index, elem
@@ -69,9 +69,13 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
         return False
 
 
-async def amap_in_executor(func: Callable[..., T], *iterables: AsyncIterable, max_prefetch: Optional[int] = None,
-                           executor: Optional[ThreadPoolExecutor] = None) -> AsyncIterator[T]:
-    """ iterate from an async iterable in a background thread, yield results to async iterable """
+async def amap_in_executor(
+    func: Callable[..., T],
+    *iterables: AsyncIterable,
+    max_prefetch: Optional[int] = None,
+    executor: Optional[ThreadPoolExecutor] = None
+) -> AsyncIterator[T]:
+    """iterate from an async iterable in a background thread, yield results to async iterable"""
     loop = asyncio.get_event_loop()
     queue = asyncio.Queue(max_prefetch)
 

+ 24 - 18
hivemind/utils/auth.py

@@ -54,7 +54,7 @@ class TokenAuthorizerBase(AuthorizerBase):
     See https://github.com/learning-at-home/hivemind/issues/253
     """
 
-    def __init__(self, local_private_key: Optional[RSAPrivateKey]=None):
+    def __init__(self, local_private_key: Optional[RSAPrivateKey] = None):
         if local_private_key is None:
             local_private_key = RSAPrivateKey.process_wide()
         self._local_private_key = local_private_key
@@ -99,7 +99,7 @@ class TokenAuthorizerBase(AuthorizerBase):
         auth.time = get_dht_time()
         auth.nonce = secrets.token_bytes(8)
 
-        assert auth.signature == b''
+        assert auth.signature == b""
         auth.signature = self._local_private_key.sign(request.SerializeToString())
 
     _MAX_CLIENT_SERVICER_TIME_DIFF = timedelta(minutes=1)
@@ -109,31 +109,32 @@ class TokenAuthorizerBase(AuthorizerBase):
         auth = request.auth
 
         if not self.is_token_valid(auth.client_access_token):
-            logger.debug('Client failed to prove that it (still) has access to the network')
+            logger.debug("Client failed to prove that it (still) has access to the network")
             return False
 
         client_public_key = RSAPublicKey.from_bytes(auth.client_access_token.public_key)
         signature = auth.signature
-        auth.signature = b''
+        auth.signature = b""
         if not client_public_key.verify(request.SerializeToString(), signature):
-            logger.debug('Request has invalid signature')
+            logger.debug("Request has invalid signature")
             return False
 
         if auth.service_public_key and auth.service_public_key != self._local_public_key.to_bytes():
-            logger.debug('Request is generated for a peer with another public key')
+            logger.debug("Request is generated for a peer with another public key")
             return False
 
         with self._recent_nonces.freeze():
             current_time = get_dht_time()
             if abs(auth.time - current_time) > self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds():
-                logger.debug('Clocks are not synchronized or a previous request is replayed again')
+                logger.debug("Clocks are not synchronized or a previous request is replayed again")
                 return False
             if auth.nonce in self._recent_nonces:
-                logger.debug('Previous request is replayed again')
+                logger.debug("Previous request is replayed again")
                 return False
 
-        self._recent_nonces.store(auth.nonce, None,
-                                  current_time + self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds() * 3)
+        self._recent_nonces.store(
+            auth.nonce, None, current_time + self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds() * 3
+        )
         return True
 
     async def sign_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> None:
@@ -143,7 +144,7 @@ class TokenAuthorizerBase(AuthorizerBase):
         auth.service_access_token.CopyFrom(self._local_access_token)
         auth.nonce = request.auth.nonce
 
-        assert auth.signature == b''
+        assert auth.signature == b""
         auth.signature = self._local_private_key.sign(response.SerializeToString())
 
     async def validate_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> bool:
@@ -151,18 +152,18 @@ class TokenAuthorizerBase(AuthorizerBase):
         auth = response.auth
 
         if not self.is_token_valid(auth.service_access_token):
-            logger.debug('Service failed to prove that it (still) has access to the network')
+            logger.debug("Service failed to prove that it (still) has access to the network")
             return False
 
         service_public_key = RSAPublicKey.from_bytes(auth.service_access_token.public_key)
         signature = auth.signature
-        auth.signature = b''
+        auth.signature = b""
         if not service_public_key.verify(response.SerializeToString(), signature):
-            logger.debug('Response has invalid signature')
+            logger.debug("Response has invalid signature")
             return False
 
         if auth.nonce != request.auth.nonce:
-            logger.debug('Response is generated for another request')
+            logger.debug("Response is generated for another request")
             return False
 
         return True
@@ -174,15 +175,20 @@ class AuthRole(Enum):
 
 
 class AuthRPCWrapper:
-    def __init__(self, stub, role: AuthRole,
-                 authorizer: Optional[AuthorizerBase], service_public_key: Optional[RSAPublicKey]=None):
+    def __init__(
+        self,
+        stub,
+        role: AuthRole,
+        authorizer: Optional[AuthorizerBase],
+        service_public_key: Optional[RSAPublicKey] = None,
+    ):
         self._stub = stub
         self._role = role
         self._authorizer = authorizer
         self._service_public_key = service_public_key
 
     def __getattribute__(self, name: str):
-        if not name.startswith('rpc_'):
+        if not name.startswith("rpc_"):
             return object.__getattribute__(self, name)
 
         method = getattr(self._stub, name)

+ 25 - 20
hivemind/utils/compression.py

@@ -40,11 +40,11 @@ def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: in
 
 
 def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
-    """ Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel. """
+    """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
     if not array.data.c_contiguous and array.data.f_contiguous:
         array = array.T
     array = np.ascontiguousarray(array.reshape(-1))
-    quantiles = np.linspace(0., 1., num=n_quantiles, dtype=array.dtype)
+    quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
     chunk_size = _get_chunk_size(len(array), min_chunk_size)
     num_chunks = (len(array) - 1) // chunk_size + 1
     partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
@@ -60,7 +60,7 @@ def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size
 
 
 def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
-    """ Adjust chunk_size to minimize imbalance between chunk sizes """
+    """Adjust chunk_size to minimize imbalance between chunk sizes"""
     if min_chunk_size >= num_elements:
         return min_chunk_size
     leftover_elements = num_elements % min_chunk_size
@@ -78,9 +78,10 @@ def _uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
     return quant_weight, lookup
 
 
-def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE,
-                           allow_inplace=False) -> runtime_pb2.Tensor:
-    assert tensor.device == torch.device('cpu')
+def serialize_torch_tensor(
+    tensor: torch.Tensor, compression_type=CompressionType.NONE, allow_inplace=False
+) -> runtime_pb2.Tensor:
+    assert tensor.device == torch.device("cpu")
     if compression_type == CompressionType.MEANSTD_16BIT:
         assert tensor.dtype == torch.float32
 
@@ -93,14 +94,15 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
         tensor.div_(stds)
         tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
 
-        data = b''.join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
+        data = b"".join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
 
         proto = runtime_pb2.Tensor(
             compression=compression_type,
             buffer=data,
             size=tensor.shape,
-            dtype='compressed_float32',
-            requires_grad=tensor.requires_grad)
+            dtype="compressed_float32",
+            requires_grad=tensor.requires_grad,
+        )
     elif compression_type == CompressionType.FLOAT16:
         assert tensor.dtype == torch.float32
 
@@ -113,8 +115,9 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             compression=compression_type,
             buffer=data,
             size=tensor.shape,
-            dtype='clamped_float32',
-            requires_grad=tensor.requires_grad)
+            dtype="clamped_float32",
+            requires_grad=tensor.requires_grad,
+        )
     elif compression_type == CompressionType.NONE:
         array = tensor.numpy()
         proto = runtime_pb2.Tensor(
@@ -122,7 +125,8 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             buffer=array.tobytes(),
             size=array.shape,
             dtype=array.dtype.name,
-            requires_grad=tensor.requires_grad)
+            requires_grad=tensor.requires_grad,
+        )
     elif compression_type in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
         assert tensor.dtype == torch.float32
 
@@ -130,14 +134,15 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
         elif compression_type == CompressionType.UNIFORM_8BIT:
             quantized, lookup = _uint8_uniform_buckets_encode(tensor.detach(), UNIFORM_BUCKETS_STD_RANGE)
-        data = b''.join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
+        data = b"".join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
 
         proto = runtime_pb2.Tensor(
             compression=compression_type,
             buffer=data,
             size=tensor.shape,
-            dtype='compressed_float32',
-            requires_grad=tensor.requires_grad)
+            dtype="compressed_float32",
+            requires_grad=tensor.requires_grad,
+        )
     else:
         raise ValueError(f"Unknown compression type: {compression_type}")
 
@@ -145,7 +150,7 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
 
 
 def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype] = None):
-    """ Helper conversion function that handles edge case with scalar deserialization """
+    """Helper conversion function that handles edge case with scalar deserialization"""
     if size:
         return torch.as_tensor(array, dtype=dtype).view(*size)
     else:
@@ -162,12 +167,12 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
         stats_size[-1] = 1
         stats_count = np.prod(stats_size)
 
-        means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count: -NUM_BYTES_FLOAT32 * stats_count]
-        stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count:]
+        means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count : -NUM_BYTES_FLOAT32 * stats_count]
+        stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count :]
         means = construct_torch_tensor(np.frombuffer(means, dtype=np.float32), stats_size)
         stds = construct_torch_tensor(np.frombuffer(stds, dtype=np.float32), stats_size)
 
-        array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16)
+        array = np.frombuffer(serialized_tensor.buffer[: -8 * stats_count], dtype=np.float16)
         tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
 
     elif serialized_tensor.compression == CompressionType.FLOAT16:
@@ -194,7 +199,7 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
 
 
 def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
-    """ returns the number of bytes per value for a given tensor (excluding metadata) """
+    """returns the number of bytes per value for a given tensor (excluding metadata)"""
     if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
         return 1
     elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):

+ 6 - 4
hivemind/utils/crypto.py

@@ -63,10 +63,11 @@ class RSAPrivateKey(PrivateKey):
     def __getstate__(self):
         state = self.__dict__.copy()
         # Serializes the private key to make the class instances picklable
-        state['_private_key'] = self._private_key.private_bytes(
+        state["_private_key"] = self._private_key.private_bytes(
             encoding=serialization.Encoding.PEM,
             format=serialization.PrivateFormat.OpenSSH,
-            encryption_algorithm=serialization.NoEncryption())
+            encryption_algorithm=serialization.NoEncryption(),
+        )
         return state
 
     def __setstate__(self, state):
@@ -91,11 +92,12 @@ class RSAPublicKey(PublicKey):
 
     def to_bytes(self) -> bytes:
         return self._public_key.public_bytes(
-            encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
+            encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH
+        )
 
     @classmethod
     def from_bytes(cls, key: bytes) -> RSAPublicKey:
         key = serialization.load_ssh_public_key(key)
         if not isinstance(key, rsa.RSAPublicKey):
-            raise ValueError(f'Expected an RSA public key, got {key}')
+            raise ValueError(f"Expected an RSA public key, got {key}")
         return cls(key)

+ 54 - 30
hivemind/utils/grpc.py

@@ -20,12 +20,12 @@ logger = get_logger(__name__)
 Stub = TypeVar("Stub")
 
 GRPC_KEEPALIVE_OPTIONS = (
-    ('grpc.keepalive_time_ms', 60 * 1000),
-    ('grpc.keepalive_timeout_ms', 60 * 1000),
-    ('grpc.keepalive_permit_without_calls', True),
-    ('grpc.http2.max_pings_without_data', 0),
-    ('grpc.http2.min_time_between_pings_ms', 30 * 1000),
-    ('grpc.http2.min_ping_interval_without_data_ms', 10 * 1000),
+    ("grpc.keepalive_time_ms", 60 * 1000),
+    ("grpc.keepalive_timeout_ms", 60 * 1000),
+    ("grpc.keepalive_permit_without_calls", True),
+    ("grpc.http2.max_pings_without_data", 0),
+    ("grpc.http2.min_time_between_pings_ms", 30 * 1000),
+    ("grpc.http2.min_ping_interval_without_data_ms", 10 * 1000),
 )
 
 
@@ -44,6 +44,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
     Unlike TimedStorage, ChannelCache actively evicts stale channels even if the cache is not accessed
     Unlike grpc._simple_stubs.ChannelCache, this implementation supports aio and does not forcibly close active channels
     """
+
     MAXIMUM_CHANNELS = int(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM", 4096))
     EVICTION_PERIOD_SECONDS = float(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS", 10 * 60))
     logger.debug(f"Eviction period = {EVICTION_PERIOD_SECONDS}s, max channels = {MAXIMUM_CHANNELS}")
@@ -57,13 +58,13 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
         super().__init__(maxsize=self.MAXIMUM_CHANNELS)
         self._is_active = True
-        self._nearest_expiration_time = float('inf')
+        self._nearest_expiration_time = float("inf")
         self._eviction_thread = threading.Thread(target=self._evict_stale_channels_in_background, daemon=True)
         self._eviction_thread.start()
 
     @classmethod
     def get_singleton(cls):
-        """ Get or create the channel cache for the current process """
+        """Get or create the channel cache for the current process"""
         with cls._lock:
             if cls._singleton is None or cls._singleton_pid != os.getpid():
                 if cls._singleton is not None:
@@ -72,9 +73,16 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
             return cls._singleton
 
     @classmethod
-    def get_stub(cls, target: Endpoint, stub_type: Type[Stub], *, aio: bool, options: Tuple[Tuple[str, Any]] = (),
-                 channel_credentials: Optional[grpc.ChannelCredentials] = None,
-                 compression: Optional[grpc.Compression] = None) -> Stub:
+    def get_stub(
+        cls,
+        target: Endpoint,
+        stub_type: Type[Stub],
+        *,
+        aio: bool,
+        options: Tuple[Tuple[str, Any]] = (),
+        channel_credentials: Optional[grpc.ChannelCredentials] = None,
+        compression: Optional[grpc.Compression] = None,
+    ) -> Stub:
         """
         Create a grpc channel with given options or reuse pre-existing one
 
@@ -112,28 +120,37 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
             return stubs[stub_type]
 
     @classmethod
-    def _create_channel(cls, target: Endpoint, aio: bool, extra_options: Tuple[Tuple[str, Any], ...],
-                        channel_credentials: Optional[grpc.ChannelCredentials],
-                        compression: Optional[grpc.Compression]) -> Union[grpc.Channel, grpc.aio.Channel]:
+    def _create_channel(
+        cls,
+        target: Endpoint,
+        aio: bool,
+        extra_options: Tuple[Tuple[str, Any], ...],
+        channel_credentials: Optional[grpc.ChannelCredentials],
+        compression: Optional[grpc.Compression],
+    ) -> Union[grpc.Channel, grpc.aio.Channel]:
         namespace = grpc.aio if aio else grpc
 
         options = extra_options + GRPC_KEEPALIVE_OPTIONS
 
         if channel_credentials is None:
-            logger.debug(f"Creating insecure {namespace} channel with options '{options}' "
-                         f"and compression '{compression}'")
+            logger.debug(
+                f"Creating insecure {namespace} channel with options '{options}' " f"and compression '{compression}'"
+            )
             return namespace.insecure_channel(target, options=options, compression=compression)
         else:
-            logger.debug(f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
-                         f"options '{options}' and compression '{compression}'")
-            return namespace.secure_channel(target, credentials=channel_credentials,
-                                            options=options, compression=compression)
+            logger.debug(
+                f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
+                f"options '{options}' and compression '{compression}'"
+            )
+            return namespace.secure_channel(
+                target, credentials=channel_credentials, options=options, compression=compression
+            )
 
     def _evict_stale_channels_in_background(self):
         while self._is_active:
             now = get_dht_time()
             time_to_wait = max(0.0, self._nearest_expiration_time - now)
-            interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float('inf') else None)
+            interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float("inf") else None)
             if interrupted_early:
                 self._update_eviction_evt.clear()
                 continue
@@ -141,7 +158,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
             with self._lock:
                 self._remove_outdated()
                 _, entry = super().top()
-                self._nearest_expiration_time = entry.expiration_time if entry is not None else float('inf')
+                self._nearest_expiration_time = entry.expiration_time if entry is not None else float("inf")
 
     def _stop_background_thread(self):
         with self._lock:
@@ -161,20 +178,27 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
 STREAMING_CHUNK_SIZE_BYTES = 2 ** 16
 
 
-def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
-                        ) -> Iterator[runtime_pb2.Tensor]:
-    """ Split serialized_tensor into multiple chunks for gRPC streaming """
+def split_for_streaming(
+    serialized_tensor: runtime_pb2.Tensor,
+    chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
+) -> Iterator[runtime_pb2.Tensor]:
+    """Split serialized_tensor into multiple chunks for gRPC streaming"""
     buffer = memoryview(serialized_tensor.buffer)
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))
     yield runtime_pb2.Tensor(
-        compression=serialized_tensor.compression, buffer=buffer[:chunk_size_bytes].tobytes(), chunks=num_chunks,
-        size=serialized_tensor.size, dtype=serialized_tensor.dtype, requires_grad=serialized_tensor.requires_grad)
+        compression=serialized_tensor.compression,
+        buffer=buffer[:chunk_size_bytes].tobytes(),
+        chunks=num_chunks,
+        size=serialized_tensor.size,
+        dtype=serialized_tensor.dtype,
+        requires_grad=serialized_tensor.requires_grad,
+    )
     for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
-        yield runtime_pb2.Tensor(buffer=buffer[chunk_start: chunk_start + chunk_size_bytes].tobytes())
+        yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
 
 
 def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
-    """ Restore a result of split_into_chunks into a single serialized tensor """
+    """Restore a result of split_into_chunks into a single serialized tensor"""
     stream = iter(stream)
     first_chunk = next(stream)
     serialized_tensor = runtime_pb2.Tensor()
@@ -182,5 +206,5 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
     buffer_chunks = [first_chunk.buffer]
     for tensor_part in stream:
         buffer_chunks.append(tensor_part.buffer)
-    serialized_tensor.buffer = b''.join(buffer_chunks)
+    serialized_tensor.buffer = b"".join(buffer_chunks)
     return serialized_tensor

+ 2 - 1
hivemind/utils/limits.py

@@ -4,9 +4,10 @@ logger = get_logger(__name__)
 
 
 def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
-    """ Increase the maximum number of open files. On Linux, this allows spawning more processes/threads. """
+    """Increase the maximum number of open files. On Linux, this allows spawning more processes/threads."""
     try:
         import resource  # local import to avoid ImportError for Windows users
+
         soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
         new_soft = max(soft, new_soft)
         new_hard = max(hard, new_hard)

+ 8 - 5
hivemind/utils/logging.py

@@ -4,12 +4,15 @@ import os
 
 def get_logger(module_name: str) -> logging.Logger:
     # trim package name
-    name_without_prefix = '.'.join(module_name.split('.')[1:])
-    loglevel = os.getenv('LOGLEVEL', 'INFO')
+    name_without_prefix = ".".join(module_name.split(".")[1:])
+    loglevel = os.getenv("LOGLEVEL", "INFO")
 
-    logging.addLevelName(logging.WARNING, 'WARN')
-    formatter = logging.Formatter(fmt='[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}',
-                                  style='{', datefmt='%Y/%m/%d %H:%M:%S')
+    logging.addLevelName(logging.WARNING, "WARN")
+    formatter = logging.Formatter(
+        fmt="[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}",
+        style="{",
+        datefmt="%Y/%m/%d %H:%M:%S",
+    )
     handler = logging.StreamHandler()
     handler.setFormatter(formatter)
     logger = logging.getLogger(name_without_prefix)

+ 29 - 18
hivemind/utils/mpfuture.py

@@ -11,7 +11,7 @@ import uuid
 from enum import Enum, auto
 from typing import Generic, TypeVar, Dict, Optional, Any, Callable
 
-import torch    # used for py3.7-compatible shared memory
+import torch  # used for py3.7-compatible shared memory
 
 from hivemind.utils.logging import get_logger
 
@@ -19,7 +19,7 @@ from hivemind.utils.logging import get_logger
 logger = get_logger(__name__)
 
 # flavour types
-ResultType = TypeVar('ResultType')
+ResultType = TypeVar("ResultType")
 PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection
 ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED
 TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
@@ -57,6 +57,7 @@ class MPFuture(base.Future, Generic[ResultType]):
        - MPFuture is deterministic if only one process can call set_result/set_exception/set_running_or_notify_cancel
          and only the origin process can call result/exception/cancel.
     """
+
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
     _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
@@ -67,10 +68,11 @@ class MPFuture(base.Future, Generic[ResultType]):
     def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
         self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
-        self._state_cache:  Dict[State, State] = {}  # mapping from global to cached local future used that makes updates immediately
+        self._state_cache: Dict[State, State] = {}
+        # mapping from global to cached local future used that makes updates immediately
         # available on setter side; dictionary-based cache works because future can visit any state at most once
 
-        base.Future.__init__(self)   # parent init is deferred because it uses self._shared_state_code
+        base.Future.__init__(self)  # parent init is deferred because it uses self._shared_state_code
         self._state, self._result, self._exception = base.PENDING, None, None
         self._use_lock = use_lock
 
@@ -122,8 +124,9 @@ class MPFuture(base.Future, Generic[ResultType]):
 
         receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
         cls._active_pid, cls._active_futures = pid, {}
-        cls._pipe_waiter_thread = threading.Thread(target=cls._process_updates_in_background, args=[receiver_pipe],
-                                                   name=f'{__name__}.BACKEND', daemon=True)
+        cls._pipe_waiter_thread = threading.Thread(
+            target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
+        )
         cls._pipe_waiter_thread.start()
 
     @classmethod
@@ -148,7 +151,7 @@ class MPFuture(base.Future, Generic[ResultType]):
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
     def _send_update(self, update_type: UpdateType, payload: Any = None):
-        """ This method sends result, exception or cancel to the MPFuture origin. """
+        """This method sends result, exception or cancel to the MPFuture origin."""
         with MPFuture._update_lock if self._use_lock else nullcontext():
             self._sender_pipe.send((self._uid, update_type, payload))
 
@@ -190,7 +193,9 @@ class MPFuture(base.Future, Generic[ResultType]):
         elif self._state == base.CANCELLED:
             return False
         else:
-            raise InvalidStateError(f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})")
+            raise InvalidStateError(
+                f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
+            )
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
         if self._state not in TERMINAL_STATES:
@@ -240,22 +245,28 @@ class MPFuture(base.Future, Generic[ResultType]):
             raise asyncio.CancelledError()
 
     def __del__(self):
-        if getattr(self, '_origin_pid', None) == os.getpid():
+        if getattr(self, "_origin_pid", None) == os.getpid():
             MPFuture._active_futures.pop(self._uid, None)
-        if getattr(self, '_aio_event', None):
+        if getattr(self, "_aio_event", None):
             self._aio_event.set()
 
     def __getstate__(self):
-        return dict(_sender_pipe=self._sender_pipe, _shared_state_code=self._shared_state_code,
-                    _origin_pid=self._origin_pid, _uid=self._uid, _use_lock=self._use_lock,
-                    _result=self._result, _exception=self._exception)
+        return dict(
+            _sender_pipe=self._sender_pipe,
+            _shared_state_code=self._shared_state_code,
+            _origin_pid=self._origin_pid,
+            _uid=self._uid,
+            _use_lock=self._use_lock,
+            _result=self._result,
+            _exception=self._exception,
+        )
 
     def __setstate__(self, state):
-        self._sender_pipe = state['_sender_pipe']
-        self._shared_state_code = state['_shared_state_code']
-        self._origin_pid, self._uid = state['_origin_pid'], state['_uid']
-        self._result, self._exception = state['_result'], state['_exception']
-        self._use_lock = state['_use_lock']
+        self._sender_pipe = state["_sender_pipe"]
+        self._shared_state_code = state["_shared_state_code"]
+        self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
+        self._result, self._exception = state["_result"], state["_exception"]
+        self._use_lock = state["_use_lock"]
 
         self._waiters, self._done_callbacks = [], []
         self._condition = threading.Condition()

+ 10 - 17
hivemind/utils/nested.py

@@ -55,20 +55,11 @@ def nested_pack(flat, structure):
 
 def _nested_pack(flat_iter, structure):
     if is_namedtuple(structure):
-        return type(structure)(*[
-            _nested_pack(flat_iter, x)
-            for x in structure]
-                               )
+        return type(structure)(*[_nested_pack(flat_iter, x) for x in structure])
     elif isinstance(structure, (list, tuple)):
-        return type(structure)(
-            _nested_pack(flat_iter, x)
-            for x in structure
-        )
+        return type(structure)(_nested_pack(flat_iter, x) for x in structure)
     elif isinstance(structure, dict):
-        return {
-            k: _nested_pack(flat_iter, v)
-            for k, v in sorted(structure.items())
-        }
+        return {k: _nested_pack(flat_iter, v) for k, v in sorted(structure.items())}
     else:
         return next(flat_iter)
 
@@ -77,19 +68,21 @@ def is_namedtuple(x):
     """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 ."""
     t = type(x)
     b = t.__bases__
-    if len(b) != 1 or b[0] != tuple: return False
-    f = getattr(t, '_fields', None)
-    if not isinstance(f, tuple): return False
+    if len(b) != 1 or b[0] != tuple:
+        return False
+    f = getattr(t, "_fields", None)
+    if not isinstance(f, tuple):
+        return False
     return all(type(n) == str for n in f)
 
 
 def nested_map(fn, *t):
     # Check arguments.
     if not t:
-        raise ValueError('Expected 2+ arguments, got 1')
+        raise ValueError("Expected 2+ arguments, got 1")
     for i in range(1, len(t)):
         if not nested_compare(t[0], t[i]):
-            msg = 'Nested structure of %r and %r differs'
+            msg = "Nested structure of %r and %r differs"
             raise ValueError(msg % (t[0], t[i]))
 
     # Map.

+ 13 - 13
hivemind/utils/networking.py

@@ -8,43 +8,43 @@ from multiaddr import Multiaddr
 
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
-LOCALHOST = '127.0.0.1'
+LOCALHOST = "127.0.0.1"
 
 
 def get_port(endpoint: Endpoint) -> Optional[Port]:
-    """ get port or None if port is undefined """
+    """get port or None if port is undefined"""
     # TODO: find a standard way to get port, make sure it works in malformed ports
     try:
-        return int(endpoint[endpoint.rindex(':') + 1:], base=10)
+        return int(endpoint[endpoint.rindex(":") + 1 :], base=10)
     except ValueError:  # :* or not specified
         return None
 
 
 def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
-    assert endpoint.endswith(':*') or get_port(endpoint) is not None, endpoint
+    assert endpoint.endswith(":*") or get_port(endpoint) is not None, endpoint
     return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
 
 
 def strip_port(endpoint: Endpoint) -> Hostname:
-    """ Removes port from the end of endpoint. If port is not specified, does nothing """
-    maybe_port = endpoint[endpoint.rindex(':') + 1:]
-    return endpoint[:endpoint.rindex(':')] if maybe_port.isdigit() or maybe_port == '*' else endpoint
+    """Removes port from the end of endpoint. If port is not specified, does nothing"""
+    maybe_port = endpoint[endpoint.rindex(":") + 1 :]
+    return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
 
 
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
+    """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
     try:
         with closing(socket.socket(*params)) as sock:
-            sock.bind(('', 0))
+            sock.bind(("", 0))
             sock.setsockopt(*opt)
             return sock.getsockname()[1]
     except Exception as e:
         raise e
 
 
-def choose_ip_address(maddrs: Sequence[Multiaddr],
-                      prefer_global: bool = True,
-                      protocol_priority: Sequence[str] = ('ip4', 'ip6')) -> Hostname:
+def choose_ip_address(
+    maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
+) -> Hostname:
     """
     Currently, some components of hivemind are not converted to work over libp2p and use classical networking.
     To allow other peers reach a server when needed, these components announce a machine's IP address.
@@ -69,4 +69,4 @@ def choose_ip_address(maddrs: Sequence[Multiaddr],
                     if ip_address(value_for_protocol).is_global == need_global:
                         return value_for_protocol
 
-    raise ValueError(f'No IP address found among given multiaddrs: {maddrs}')
+    raise ValueError(f"No IP address found among given multiaddrs: {maddrs}")

+ 3 - 2
hivemind/utils/serializer.py

@@ -31,8 +31,9 @@ class MSGPackSerializer(SerializerBase):
         assert isinstance(type_code, int), "Please specify a (unique) int type code"
 
         def wrap(wrapped_type: type):
-            assert callable(getattr(wrapped_type, 'packb', None)) and callable(getattr(wrapped_type, 'unpackb', None)), \
-                f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
+            assert callable(getattr(wrapped_type, "packb", None)) and callable(
+                getattr(wrapped_type, "unpackb", None)
+            ), f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
             if type_code in cls._ext_type_codes:
                 logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting.")
             cls._ext_type_codes[type_code], cls._ext_types[wrapped_type] = wrapped_type, type_code

+ 16 - 10
hivemind/utils/tensor_descr.py

@@ -34,19 +34,20 @@ class TensorDescriptor(DescriptorBase):
 
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor):
-        return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad,
-                   safe_check_pinned(tensor))
+        return cls(
+            tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
+        )
 
     def make_empty(self, **kwargs):
         properties = asdict(self)
         properties.update(kwargs)
-        properties.pop('compression')
+        properties.pop("compression")
         return torch.empty(**properties)
 
 
 @dataclass(repr=True, frozen=True)
 class BatchTensorDescriptor(TensorDescriptor):
-    """ torch.Tensor with a variable 0-th dimension, used to describe batched data """
+    """torch.Tensor with a variable 0-th dimension, used to describe batched data"""
 
     def __init__(self, *instance_size, **kwargs):  # compatibility: allow initializing with *size
         if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)):
@@ -55,18 +56,23 @@ class BatchTensorDescriptor(TensorDescriptor):
 
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
-        return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
-                   device=tensor.device, requires_grad=tensor.requires_grad,
-                   pin_memory=safe_check_pinned(tensor),
-                   compression=compression if tensor.is_floating_point() else CompressionType.NONE)
+        return cls(
+            *tensor.shape[1:],
+            dtype=tensor.dtype,
+            layout=tensor.layout,
+            device=tensor.device,
+            requires_grad=tensor.requires_grad,
+            pin_memory=_safe_check_pinned(tensor),
+            compression=compression if tensor.is_floating_point() else CompressionType.NONE
+        )
 
     def make_empty(self, *batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
 
 
-def safe_check_pinned(tensor: torch.Tensor) -> bool:
-    """ Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error. """
+def _safe_check_pinned(tensor: torch.Tensor) -> bool:
+    """Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error."""
     try:
         return torch.cuda.is_available() and tensor.is_pinned()
     except RuntimeError:

+ 13 - 9
hivemind/utils/timed_storage.py

@@ -6,8 +6,8 @@ from contextlib import contextmanager
 from typing import TypeVar, Generic, Optional, Dict, List, Iterator, Tuple
 from dataclasses import dataclass
 
-KeyType = TypeVar('KeyType')
-ValueType = TypeVar('ValueType')
+KeyType = TypeVar("KeyType")
+ValueType = TypeVar("ValueType")
 get_dht_time = time.time  # a global (weakly synchronized) time
 MAX_DHT_TIME_DISCREPANCY_SECONDS = 3  # max allowed difference between get_dht_time for two DHT nodes
 DHTExpiration = float
@@ -46,7 +46,8 @@ class HeapEntry(Generic[KeyType]):
 
 
 class TimedStorage(Generic[KeyType, ValueType]):
-    """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
+    """A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time"""
+
     frozen = False  # can be set to True. If true, do not remove outdated elements
 
     def __init__(self, maxsize: Optional[int] = None):
@@ -56,8 +57,11 @@ class TimedStorage(Generic[KeyType, ValueType]):
         self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
 
     def _remove_outdated(self):
-        while not self.frozen and self.expiration_heap and (self.expiration_heap[ROOT].expiration_time < get_dht_time()
-                                                            or len(self.data) > self.maxsize):
+        while (
+            not self.frozen
+            and self.expiration_heap
+            and (self.expiration_heap[ROOT].expiration_time < get_dht_time() or len(self.data) > self.maxsize)
+        ):
             heap_entry = heapq.heappop(self.expiration_heap)
             if self.key_to_heap.get(heap_entry.key) == heap_entry:
                 del self.data[heap_entry.key], self.key_to_heap[heap_entry.key]
@@ -81,19 +85,19 @@ class TimedStorage(Generic[KeyType, ValueType]):
         return True
 
     def get(self, key: KeyType) -> Optional[ValueWithExpiration[ValueType]]:
-        """ Get a value corresponding to a key if that (key, value) pair was previously stored under this key. """
+        """Get a value corresponding to a key if that (key, value) pair was previously stored under this key."""
         self._remove_outdated()
         if key in self.data:
             return self.data[key]
         return None
 
     def items(self) -> Iterator[Tuple[KeyType, ValueWithExpiration[ValueType]]]:
-        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
+        """Iterate over (key, value, expiration_time) tuples stored in this storage"""
         self._remove_outdated()
         return ((key, value_and_expiration) for key, value_and_expiration in self.data.items())
 
     def top(self) -> Tuple[Optional[KeyType], Optional[ValueWithExpiration[ValueType]]]:
-        """ Return the entry with earliest expiration or None if there isn't any """
+        """Return the entry with earliest expiration or None if there isn't any"""
         self._remove_outdated()
         if self.data:
             # skip leftover "ghost" entries until first real entry
@@ -129,7 +133,7 @@ class TimedStorage(Generic[KeyType, ValueType]):
 
     @contextmanager
     def freeze(self):
-        """ Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency """
+        """Temporarily cease to ._remove_outdated() elements inside this context to ensure consistency"""
         prev_frozen, self.frozen = self.frozen, True
         try:
             yield self

+ 3 - 0
pyproject.toml

@@ -0,0 +1,3 @@
+[tool.black]
+line-length = 119
+required-version = "21.6b0"

+ 1 - 0
requirements-dev.txt

@@ -5,4 +5,5 @@ pytest-cov
 codecov
 tqdm
 scikit-learn
+black==21.6b0
 psutil

+ 71 - 62
setup.py

@@ -14,9 +14,9 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = 'v0.3.1'
-P2PD_CHECKSUM = '15292b880c6b31f5b3c36084b3acc17f'
-LIBP2P_TAR_URL = f'https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz'
+P2PD_VERSION = "v0.3.1"
+P2PD_CHECKSUM = "15292b880c6b31f5b3c36084b3acc17f"
+LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 
 here = os.path.abspath(os.path.dirname(__file__))
 
@@ -32,61 +32,68 @@ def md5(fname, chunk_size=4096):
 def proto_compile(output_path):
     import grpc_tools.protoc
 
-    cli_args = ['grpc_tools.protoc',
-                '--proto_path=hivemind/proto', f'--python_out={output_path}',
-                f'--grpc_python_out={output_path}'] + glob.glob('hivemind/proto/*.proto')
+    cli_args = [
+        "grpc_tools.protoc",
+        "--proto_path=hivemind/proto",
+        f"--python_out={output_path}",
+        f"--grpc_python_out={output_path}",
+    ] + glob.glob("hivemind/proto/*.proto")
 
     code = grpc_tools.protoc.main(cli_args)
     if code:  # hint: if you get this error in jupyter, run in console for richer error message
         raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}")
     # Make pb2 imports in generated scripts relative
-    for script in glob.iglob(f'{output_path}/*.py'):
-        with open(script, 'r+') as file:
+    for script in glob.iglob(f"{output_path}/*.py"):
+        with open(script, "r+") as file:
             code = file.read()
             file.seek(0)
-            file.write(re.sub(r'\n(import .+_pb2.*)', 'from . \\1', code))
+            file.write(re.sub(r"\n(import .+_pb2.*)", "from . \\1", code))
             file.truncate()
 
 
 def build_p2p_daemon():
-    result = subprocess.run("go version", capture_output=True, shell=True).stdout.decode('ascii', 'replace')
-    m = re.search(r'^go version go([\d.]+)', result)
+    result = subprocess.run("go version", capture_output=True, shell=True).stdout.decode("ascii", "replace")
+    m = re.search(r"^go version go([\d.]+)", result)
 
     if m is None:
-        raise FileNotFoundError('Could not find golang installation')
+        raise FileNotFoundError("Could not find golang installation")
     version = parse_version(m.group(1))
     if version < parse_version("1.13"):
-        raise EnvironmentError(f'Newer version of go required: must be >= 1.13, found {version}')
+        raise EnvironmentError(f"Newer version of go required: must be >= 1.13, found {version}")
 
     with tempfile.TemporaryDirectory() as tempdir:
-        dest = os.path.join(tempdir, 'libp2p-daemon.tar.gz')
+        dest = os.path.join(tempdir, "libp2p-daemon.tar.gz")
         urllib.request.urlretrieve(LIBP2P_TAR_URL, dest)
 
-        with tarfile.open(dest, 'r:gz') as tar:
+        with tarfile.open(dest, "r:gz") as tar:
             tar.extractall(tempdir)
 
-        result = subprocess.run(f'go build -o {shlex.quote(os.path.join(here, "hivemind", "hivemind_cli", "p2pd"))}',
-                                cwd=os.path.join(tempdir, f'go-libp2p-daemon-{P2PD_VERSION[1:]}', 'p2pd'), shell=True)
+        result = subprocess.run(
+            f'go build -o {shlex.quote(os.path.join(here, "hivemind", "hivemind_cli", "p2pd"))}',
+            cwd=os.path.join(tempdir, f"go-libp2p-daemon-{P2PD_VERSION[1:]}", "p2pd"),
+            shell=True,
+        )
 
         if result.returncode:
-            raise RuntimeError('Failed to build or install libp2p-daemon:'
-                               f' exited with status code: {result.returncode}')
+            raise RuntimeError(
+                "Failed to build or install libp2p-daemon:" f" exited with status code: {result.returncode}"
+            )
 
 
 def download_p2p_daemon():
-    install_path = os.path.join(here, 'hivemind', 'hivemind_cli')
-    binary_path = os.path.join(install_path, 'p2pd')
+    install_path = os.path.join(here, "hivemind", "hivemind_cli")
+    binary_path = os.path.join(install_path, "p2pd")
     if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
-        print('Downloading Peer to Peer Daemon')
-        url = f'https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd'
+        print("Downloading Peer to Peer Daemon")
+        url = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
         urllib.request.urlretrieve(url, binary_path)
         os.chmod(binary_path, 0o777)
         if md5(binary_path) != P2PD_CHECKSUM:
-            raise RuntimeError(f'Downloaded p2pd binary from {url} does not match with md5 checksum')
+            raise RuntimeError(f"Downloaded p2pd binary from {url} does not match with md5 checksum")
 
 
 class BuildPy(build_py):
-    user_options = build_py.user_options + [('buildgo', None, "Builds p2pd from source")]
+    user_options = build_py.user_options + [("buildgo", None, "Builds p2pd from source")]
 
     def initialize_options(self):
         super().initialize_options()
@@ -100,70 +107,72 @@ class BuildPy(build_py):
 
         super().run()
 
-        proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto'))
+        proto_compile(os.path.join(self.build_lib, "hivemind", "proto"))
 
 
 class Develop(develop):
     def run(self):
-        self.reinitialize_command('build_py', build_lib=here)
-        self.run_command('build_py')
+        self.reinitialize_command("build_py", build_lib=here)
+        self.run_command("build_py")
         super().run()
 
 
-with open('requirements.txt') as requirements_file:
+with open("requirements.txt") as requirements_file:
     install_requires = list(map(str, parse_requirements(requirements_file)))
 
 # loading version from setup.py
-with codecs.open(os.path.join(here, 'hivemind/__init__.py'), encoding='utf-8') as init_file:
+with codecs.open(os.path.join(here, "hivemind/__init__.py"), encoding="utf-8") as init_file:
     version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M)
     version_string = version_match.group(1)
 
 extras = {}
 
-with open('requirements-dev.txt') as dev_requirements_file:
-    extras['dev'] = list(map(str, parse_requirements(dev_requirements_file)))
+with open("requirements-dev.txt") as dev_requirements_file:
+    extras["dev"] = list(map(str, parse_requirements(dev_requirements_file)))
 
-with open('requirements-docs.txt') as docs_requirements_file:
-    extras['docs'] = list(map(str, parse_requirements(docs_requirements_file)))
+with open("requirements-docs.txt") as docs_requirements_file:
+    extras["docs"] = list(map(str, parse_requirements(docs_requirements_file)))
 
-extras['all'] = extras['dev'] + extras['docs']
+extras["all"] = extras["dev"] + extras["docs"]
 
 setup(
-    name='hivemind',
+    name="hivemind",
     version=version_string,
-    cmdclass={'build_py': BuildPy, 'develop': Develop},
-    description='Decentralized deep learning in PyTorch',
-    long_description='Decentralized deep learning in PyTorch. Built to train giant models on '
-                     'thousands of volunteers across the world.',
-    author='Learning@home & contributors',
-    author_email='mryabinin0@gmail.com',
+    cmdclass={"build_py": BuildPy, "develop": Develop},
+    description="Decentralized deep learning in PyTorch",
+    long_description="Decentralized deep learning in PyTorch. Built to train giant models on "
+    "thousands of volunteers across the world.",
+    author="Learning@home & contributors",
+    author_email="mryabinin0@gmail.com",
     url="https://github.com/learning-at-home/hivemind",
-    packages=find_packages(exclude=['tests']),
-    package_data={'hivemind': ['proto/*', 'hivemind_cli/*']},
+    packages=find_packages(exclude=["tests"]),
+    package_data={"hivemind": ["proto/*", "hivemind_cli/*"]},
     include_package_data=True,
-    license='MIT',
-    setup_requires=['grpcio-tools'],
+    license="MIT",
+    setup_requires=["grpcio-tools"],
     install_requires=install_requires,
     extras_require=extras,
     classifiers=[
-        'Development Status :: 4 - Beta',
-        'Intended Audience :: Developers',
-        'Intended Audience :: Science/Research',
-        'License :: OSI Approved :: MIT License',
-        'Programming Language :: Python :: 3',
-        'Programming Language :: Python :: 3.7',
-        'Programming Language :: Python :: 3.8',
-        'Programming Language :: Python :: 3.9',
-        'Topic :: Scientific/Engineering',
-        'Topic :: Scientific/Engineering :: Mathematics',
-        'Topic :: Scientific/Engineering :: Artificial Intelligence',
-        'Topic :: Software Development',
-        'Topic :: Software Development :: Libraries',
-        'Topic :: Software Development :: Libraries :: Python Modules',
+        "Development Status :: 4 - Beta",
+        "Intended Audience :: Developers",
+        "Intended Audience :: Science/Research",
+        "License :: OSI Approved :: MIT License",
+        "Programming Language :: Python :: 3",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
+        "Topic :: Scientific/Engineering",
+        "Topic :: Scientific/Engineering :: Mathematics",
+        "Topic :: Scientific/Engineering :: Artificial Intelligence",
+        "Topic :: Software Development",
+        "Topic :: Software Development :: Libraries",
+        "Topic :: Software Development :: Libraries :: Python Modules",
     ],
     entry_points={
-        'console_scripts': ['hivemind-server = hivemind.hivemind_cli.run_server:main', ]
+        "console_scripts": [
+            "hivemind-server = hivemind.hivemind_cli.run_server:main",
+        ]
     },
     # What does your project relate to?
-    keywords='pytorch, deep learning, machine learning, gpu, distributed computing, volunteer computing, dht',
+    keywords="pytorch, deep learning, machine learning, gpu, distributed computing, volunteer computing, dht",
 )

+ 2 - 2
tests/conftest.py

@@ -10,7 +10,7 @@ from hivemind.utils import get_logger
 logger = get_logger(__name__)
 
 
-@pytest.fixture(autouse=True, scope='session')
+@pytest.fixture(autouse=True, scope="session")
 def cleanup_children():
     yield
 
@@ -18,7 +18,7 @@ def cleanup_children():
 
     children = psutil.Process().children(recursive=True)
     if children:
-        logger.info(f'Cleaning up {len(children)} leftover child processes')
+        logger.info(f"Cleaning up {len(children)} leftover child processes")
         for child in children:
             with suppress(psutil.NoSuchProcess):
                 child.terminate()

+ 70 - 40
tests/test_allreduce.py

@@ -19,9 +19,17 @@ from hivemind.utils import deserialize_torch_tensor, ChannelCache
 @pytest.mark.asyncio
 async def test_partitioning():
     all_tensors = [
-        torch.randn(30_000, 128), torch.rand(128), torch.ones(1, 1, 1, 1, 1, 1, 8),
-        torch.ones(1, 0), torch.zeros(0), torch.zeros([]), torch.randn(65536),
-        torch.rand(512, 2048), torch.randn(1024, 1024).add(-9), torch.zeros(1020), torch.randn(4096)
+        torch.randn(30_000, 128),
+        torch.rand(128),
+        torch.ones(1, 1, 1, 1, 1, 1, 8),
+        torch.ones(1, 0),
+        torch.zeros(0),
+        torch.zeros([]),
+        torch.randn(65536),
+        torch.rand(512, 2048),
+        torch.randn(1024, 1024).add(-9),
+        torch.zeros(1020),
+        torch.randn(4096),
     ]
 
     # note: this test does _not_ use parameterization to reuse sampled tensors
@@ -46,8 +54,14 @@ async def test_partitioning():
                 await task
 
 
-@pytest.mark.parametrize("tensors", [[torch.zeros(0)], [torch.zeros(0), torch.zeros(0), torch.zeros(1)],
-                                     [torch.zeros(0), torch.zeros(999), torch.zeros(0), torch.zeros(0)]])
+@pytest.mark.parametrize(
+    "tensors",
+    [
+        [torch.zeros(0)],
+        [torch.zeros(0), torch.zeros(0), torch.zeros(1)],
+        [torch.zeros(0), torch.zeros(999), torch.zeros(0), torch.zeros(0)],
+    ],
+)
 @pytest.mark.parametrize("peer_fractions", [(0.33, 0.44, 0.23), (0.5, 0.5), (0.1, 0.0, 0.9), (1.0,), (0.1,) * 9])
 @pytest.mark.forked
 @pytest.mark.asyncio
@@ -66,9 +80,8 @@ async def test_partitioning_edge_cases(tensors: Sequence[torch.Tensor], peer_fra
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_partitioning_asynchronous():
-    """ ensure that tensor partitioning does not interfere with asynchronous code """
-    tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096),
-               torch.randn(4096, 1024), torch.randn(30_000, 1024)]
+    """ensure that tensor partitioning does not interfere with asynchronous code"""
+    tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096), torch.randn(4096, 1024), torch.randn(30_000, 1024)]
     peer_fractions = [0.4, 0.3, 0.2, 0.1]
 
     partition = TensorPartContainer(tensors, peer_fractions, compression_type=CompressionType.QUANTILE_8BIT)
@@ -109,8 +122,7 @@ async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float
     tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
     reducer = TensorPartReducer(tensor_part_shapes, num_senders)
 
-    local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)]
-                               for j in range(num_senders)]
+    local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)] for j in range(num_senders)]
 
     async def send_tensors(sender_index: int):
         local_tensors = local_tensors_by_sender[sender_index]
@@ -118,8 +130,9 @@ async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float
         pending_tasks = []
 
         for part_index in range(num_parts):
-            pending_tasks.append(asyncio.create_task(
-                reducer.accumulate_part(sender_index, part_index, local_tensors[part_index])))
+            pending_tasks.append(
+                asyncio.create_task(reducer.accumulate_part(sender_index, part_index, local_tensors[part_index]))
+            )
 
             if random.random() < synchronize_prob or part_index == num_parts - 1:
                 averaged_parts.extend(await asyncio.gather(*pending_tasks))
@@ -128,9 +141,10 @@ async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float
 
     averaged_tensors_by_peer = await asyncio.gather(*map(send_tensors, range(num_senders)))
 
-    reference = [sum(local_tensors_by_sender[sender_index][part_index]
-                     for sender_index in range(num_senders)) / num_senders
-                 for part_index in range(num_parts)]
+    reference = [
+        sum(local_tensors_by_sender[sender_index][part_index] for sender_index in range(num_senders)) / num_senders
+        for part_index in range(num_parts)
+    ]
 
     for averaged_tensors in averaged_tensors_by_peer:
         assert len(averaged_tensors) == len(reference)
@@ -139,7 +153,7 @@ async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float
 
 
 class AllreduceRunnerForTesting(AllReduceRunner):
-    """ a version of AllReduceRunner that was monkey-patched to accept custom endpoint names """
+    """a version of AllReduceRunner that was monkey-patched to accept custom endpoint names"""
 
     def __init__(self, *args, peer_endpoints, **kwargs):
         self.__peer_endpoints = peer_endpoints
@@ -147,34 +161,43 @@ class AllreduceRunnerForTesting(AllReduceRunner):
 
     def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
         return ChannelCache.get_stub(
-            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True
+        )
 
 
 NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 
 
-@pytest.mark.parametrize("peer_modes, averaging_weights, peer_fractions", [
-    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1)),
-    ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1)),
-    ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0)),
-    ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0)),
-    ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4)),
-    ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1)),
-    ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0)),
-    ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4)),
-])
-@pytest.mark.parametrize("part_size_bytes", [2 ** 20, 256, 19], )
+@pytest.mark.parametrize(
+    "peer_modes, averaging_weights, peer_fractions",
+    [
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1)),
+        ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1)),
+        ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0)),
+        ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0)),
+        ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4)),
+        ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1)),
+        ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0)),
+        ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4)),
+    ],
+)
+@pytest.mark.parametrize(
+    "part_size_bytes",
+    [2 ** 20, 256, 19],
+)
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
-    """ Run group allreduce protocol manually without grpc, see if the internal logic is working as intended """
+    """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
 
     peers = "alice", "bob", "carol", "colab"
 
-    tensors_by_peer = {peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
-                       for i, peer in enumerate(peers)}
+    tensors_by_peer = {
+        peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
+        for i, peer in enumerate(peers)
+    }
 
-    group_id = random.getrandbits(160).to_bytes(length=20, byteorder='big')
+    group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
     servers = []
     allreduce_protocols = []
@@ -183,9 +206,15 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     for peer in peers:
         server = grpc.aio.server()
         allreduce_protocol = AllreduceRunnerForTesting(
-            group_id=group_id, endpoint=peer, tensors=[x.clone() for x in tensors_by_peer[peer]],
-            ordered_group_endpoints=peers, peer_fractions=peer_fractions, modes=peer_modes,
-            weights=averaging_weights, peer_endpoints=peer_endpoints, part_size_bytes=part_size_bytes
+            group_id=group_id,
+            endpoint=peer,
+            tensors=[x.clone() for x in tensors_by_peer[peer]],
+            ordered_group_endpoints=peers,
+            peer_fractions=peer_fractions,
+            modes=peer_modes,
+            weights=averaging_weights,
+            peer_endpoints=peer_endpoints,
+            part_size_bytes=part_size_bytes,
         )
         averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(allreduce_protocol, server)
         peer_endpoints[peer] = f"127.0.0.1:{server.add_insecure_port('127.0.0.1:*')}"
@@ -199,9 +228,11 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
 
     await asyncio.gather(*map(_run_allreduce_inplace, allreduce_protocols))
 
-    reference_tensors = [sum(tensors_by_peer[peer][i] * averaging_weights[peer_index]
-                             for peer_index, peer in enumerate(peers)) / sum(averaging_weights)
-                         for i in range(len(tensors_by_peer[peers[0]]))]
+    reference_tensors = [
+        sum(tensors_by_peer[peer][i] * averaging_weights[peer_index] for peer_index, peer in enumerate(peers))
+        / sum(averaging_weights)
+        for i in range(len(tensors_by_peer[peers[0]]))
+    ]
 
     for peer_index, protocol in enumerate(allreduce_protocols):
         assert protocol._future.done()
@@ -211,8 +242,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             targets_for_peer = tensors_by_peer[peers[peer_index]]
         output_tensors = protocol.tensor_part_container.local_tensors
         assert len(output_tensors) == len(targets_for_peer)
-        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0)
-                   for our, ref in zip(output_tensors, targets_for_peer))
+        assert all(torch.allclose(our, ref, atol=1e-6, rtol=0) for our, ref in zip(output_tensors, targets_for_peer))
 
     for server in servers:
         await server.stop(grace=1)

+ 30 - 27
tests/test_auth.py

@@ -17,7 +17,7 @@ class MockAuthorizer(TokenAuthorizerBase):
     _authority_private_key = None
     _authority_public_key = None
 
-    def __init__(self, local_private_key: Optional[RSAPrivateKey], username: str='mock'):
+    def __init__(self, local_private_key: Optional[RSAPrivateKey], username: str = "mock"):
         super().__init__(local_private_key)
 
         self._username = username
@@ -29,29 +29,32 @@ class MockAuthorizer(TokenAuthorizerBase):
 
         self._authority_public_key = MockAuthorizer._authority_private_key.get_public_key()
 
-        token = AccessToken(username=self._username,
-                            public_key=self.local_public_key.to_bytes(),
-                            expiration_time=str(datetime.utcnow() + timedelta(minutes=1)))
+        token = AccessToken(
+            username=self._username,
+            public_key=self.local_public_key.to_bytes(),
+            expiration_time=str(datetime.utcnow() + timedelta(minutes=1)),
+        )
         token.signature = MockAuthorizer._authority_private_key.sign(self._token_to_bytes(token))
         return token
 
     def is_token_valid(self, access_token: AccessToken) -> bool:
         data = self._token_to_bytes(access_token)
         if not self._authority_public_key.verify(data, access_token.signature):
-            logger.exception('Access token has invalid signature')
+            logger.exception("Access token has invalid signature")
             return False
 
         try:
             expiration_time = datetime.fromisoformat(access_token.expiration_time)
         except ValueError:
             logger.exception(
-                f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
+                f"datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}"
+            )
             return False
         if expiration_time.tzinfo is not None:
-            logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
+            logger.exception(f"Expected to have no timezone for expiration time: {access_token.expiration_time}")
             return False
         if expiration_time < datetime.utcnow():
-            logger.exception('Access token has expired')
+            logger.exception("Access token has expired")
             return False
 
         return True
@@ -64,7 +67,7 @@ class MockAuthorizer(TokenAuthorizerBase):
 
     @staticmethod
     def _token_to_bytes(access_token: AccessToken) -> bytes:
-        return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
+        return f"{access_token.username} {access_token.public_key} {access_token.expiration_time}".encode()
 
 
 @pytest.mark.asyncio
@@ -73,12 +76,12 @@ async def test_valid_request_and_response():
     service_authorizer = MockAuthorizer(RSAPrivateKey())
 
     request = dht_pb2.PingRequest()
-    request.peer.node_id = b'ping'
+    request.peer.node_id = b"ping"
     await client_authorizer.sign_request(request, service_authorizer.local_public_key)
     assert await service_authorizer.validate_request(request)
 
     response = dht_pb2.PingResponse()
-    response.peer.node_id = b'pong'
+    response.peer.node_id = b"pong"
     await service_authorizer.sign_response(response, request)
     assert await client_authorizer.validate_response(response, request)
 
@@ -89,20 +92,20 @@ async def test_invalid_access_token():
     service_authorizer = MockAuthorizer(RSAPrivateKey())
 
     request = dht_pb2.PingRequest()
-    request.peer.node_id = b'ping'
+    request.peer.node_id = b"ping"
     await client_authorizer.sign_request(request, service_authorizer.local_public_key)
 
     # Break the access token signature
-    request.auth.client_access_token.signature = b'broken'
+    request.auth.client_access_token.signature = b"broken"
 
     assert not await service_authorizer.validate_request(request)
 
     response = dht_pb2.PingResponse()
-    response.peer.node_id = b'pong'
+    response.peer.node_id = b"pong"
     await service_authorizer.sign_response(response, request)
 
     # Break the access token signature
-    response.auth.service_access_token.signature = b'broken'
+    response.auth.service_access_token.signature = b"broken"
 
     assert not await client_authorizer.validate_response(response, request)
 
@@ -113,20 +116,20 @@ async def test_invalid_signatures():
     service_authorizer = MockAuthorizer(RSAPrivateKey())
 
     request = dht_pb2.PingRequest()
-    request.peer.node_id = b'true-ping'
+    request.peer.node_id = b"true-ping"
     await client_authorizer.sign_request(request, service_authorizer.local_public_key)
 
     # A man-in-the-middle attacker changes the request content
-    request.peer.node_id = b'fake-ping'
+    request.peer.node_id = b"fake-ping"
 
     assert not await service_authorizer.validate_request(request)
 
     response = dht_pb2.PingResponse()
-    response.peer.node_id = b'true-pong'
+    response.peer.node_id = b"true-pong"
     await service_authorizer.sign_response(response, request)
 
     # A man-in-the-middle attacker changes the response content
-    response.peer.node_id = b'fake-pong'
+    response.peer.node_id = b"fake-pong"
 
     assert not await client_authorizer.validate_response(response, request)
 
@@ -135,11 +138,11 @@ async def test_invalid_signatures():
 async def test_auth_rpc_wrapper():
     class Servicer:
         async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
-            assert request.peer.node_id == b'ping'
-            assert request.auth.client_access_token.username == 'alice'
+            assert request.peer.node_id == b"ping"
+            assert request.auth.client_access_token.username == "alice"
 
             response = dht_pb2.PingResponse()
-            response.peer.node_id = b'pong'
+            response.peer.node_id = b"pong"
             return response
 
     class Client:
@@ -149,13 +152,13 @@ async def test_auth_rpc_wrapper():
         async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
             return await self._servicer.rpc_increment(request)
 
-    servicer = AuthRPCWrapper(Servicer(), AuthRole.SERVICER, MockAuthorizer(RSAPrivateKey(), 'bob'))
-    client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), 'alice'))
+    servicer = AuthRPCWrapper(Servicer(), AuthRole.SERVICER, MockAuthorizer(RSAPrivateKey(), "bob"))
+    client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), "alice"))
 
     request = dht_pb2.PingRequest()
-    request.peer.node_id = b'ping'
+    request.peer.node_id = b"ping"
 
     response = await client.rpc_increment(request)
 
-    assert response.peer.node_id == b'pong'
-    assert response.auth.service_access_token.username == 'bob'
+    assert response.peer.node_id == b"pong"
+    assert response.auth.service_access_token.username == "bob"

+ 172 - 76
tests/test_averaging.py

@@ -15,30 +15,34 @@ from hivemind.proto.runtime_pb2 import CompressionType
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_key_manager():
-    key_manager = GroupKeyManager(hivemind.DHT(start=True), endpoint='localhvost',
-                                  prefix='test_averaging', initial_group_bits='10110',
-                                  target_group_size=2)
+    key_manager = GroupKeyManager(
+        hivemind.DHT(start=True),
+        endpoint="localhvost",
+        prefix="test_averaging",
+        initial_group_bits="10110",
+        target_group_size=2,
+    )
 
     t = hivemind.get_dht_time()
     key = key_manager.current_key
-    await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 60)
-    await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61)
+    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 60)
+    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61)
 
     q1 = await key_manager.get_averagers(key, only_active=True)
 
-    await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 66)
+    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 66)
     q2 = await key_manager.get_averagers(key, only_active=True)
 
-    await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61, looking_for_group=False)
+    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61, looking_for_group=False)
     q3 = await key_manager.get_averagers(key, only_active=True)
     q4 = await key_manager.get_averagers(key, only_active=False)
 
-    q5 = await key_manager.get_averagers('nonexistent_key.0b0101', only_active=False)
+    q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
 
-    assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
-    assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
-    assert len(q3) == 1 and ('localhvost', t + 66) in q3
-    assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
+    assert len(q1) == 2 and ("localhvost", t + 60) in q1 and ("localhvost2", t + 61) in q1
+    assert len(q2) == 2 and ("localhvost", t + 66) in q2 and ("localhvost2", t + 61) in q2
+    assert len(q3) == 1 and ("localhvost", t + 66) in q3
+    assert len(q4) == 2 and ("localhvost", t + 66) in q4 and ("localhvost2", t + 61) in q2
     assert len(q5) == 0
 
 
@@ -46,8 +50,11 @@ def _test_allreduce_once(n_clients, n_aux):
     dht = hivemind.DHT(start=True)
 
     n_peers = 4
-    modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (
-            n_peers - n_clients - n_aux)
+    modes = (
+        [AveragingMode.CLIENT] * n_clients
+        + [AveragingMode.AUX] * n_aux
+        + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
+    )
     random.shuffle(modes)
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
@@ -56,15 +63,26 @@ def _test_allreduce_once(n_clients, n_aux):
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
     peer_tensors = [tensors1, tensors2, tensors3, tensors4]
 
-    reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
-                     if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
+    reference = [
+        sum(tensors[i] for tensors, mode in zip(peer_tensors, modes) if mode != AveragingMode.AUX)
+        / max(1, n_peers - n_aux)
+        for i in range(len(tensors1))
+    ]
 
     averagers = [
-        hivemind.averaging.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                 prefix='mygroup', listen=mode != AveragingMode.CLIENT,
-                                                 listen_on='127.0.0.1:*',
-                                                 auxiliary=mode == AveragingMode.AUX, start=True)
-        for tensors, mode in zip(peer_tensors, modes)]
+        hivemind.averaging.DecentralizedAverager(
+            tensors,
+            dht=dht,
+            target_group_size=4,
+            averaging_expiration=15,
+            prefix="mygroup",
+            listen=mode != AveragingMode.CLIENT,
+            listen_on="127.0.0.1:*",
+            auxiliary=mode == AveragingMode.AUX,
+            start=True,
+        )
+        for tensors, mode in zip(peer_tensors, modes)
+    ]
 
     futures = []
     for averager in averagers:
@@ -111,12 +129,24 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
     averagers = [
-        hivemind.averaging.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                 prefix='mygroup', listen=listen, listen_on='127.0.0.1:*', start=True)
-        for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
+        hivemind.averaging.DecentralizedAverager(
+            tensors,
+            dht=dht,
+            target_group_size=4,
+            averaging_expiration=15,
+            prefix="mygroup",
+            listen=listen,
+            listen_on="127.0.0.1:*",
+            start=True,
+        )
+        for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)
+    ]
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
-    reference = [(tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2]
-                  + tensors4[i] * weights[3]) / sum(weights) for i in range(len(tensors1))]
+    reference = [
+        (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
+        / sum(weights)
+        for i in range(len(tensors1))
+    ]
 
     futures = []
     for averager, weight in zip(averagers, weights):
@@ -136,7 +166,7 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
 
 @pytest.mark.forked
 def test_allreduce_compression():
-    """ this test ensures that compression works correctly when multiple tensors have different compression types """
+    """this test ensures that compression works correctly when multiple tensors have different compression types"""
     dht = hivemind.DHT(start=True)
 
     tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
@@ -146,14 +176,24 @@ def test_allreduce_compression():
     FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
 
     for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
-        averager1 = hivemind.averaging.DecentralizedAverager([x.clone() for x in tensors1], dht=dht,
-                                                             compression_type=compression_type_pair,
-                                                             listen=False, target_group_size=2, prefix='mygroup',
-                                                             start=True)
-        averager2 = hivemind.averaging.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
-                                                             compression_type=compression_type_pair,
-                                                             target_group_size=2, prefix='mygroup',
-                                                             listen_on='127.0.0.1:*', start=True)
+        averager1 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors1],
+            dht=dht,
+            compression_type=compression_type_pair,
+            listen=False,
+            target_group_size=2,
+            prefix="mygroup",
+            start=True,
+        )
+        averager2 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors2],
+            dht=dht,
+            compression_type=compression_type_pair,
+            target_group_size=2,
+            prefix="mygroup",
+            listen_on="127.0.0.1:*",
+            start=True,
+        )
 
         for future in averager1.step(wait=False), averager2.step(wait=False):
             future.result()
@@ -192,10 +232,18 @@ def compute_mean_std(averagers, unbiased=True):
 @pytest.mark.forked
 def test_allreduce_grid():
     dht = hivemind.DHT(start=True)
-    averagers = [hivemind.averaging.DecentralizedAverager(
-        averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
-        prefix='mygroup', initial_group_bits=bin(i // 2)[2:].rjust(2, '0'), listen_on='127.0.0.1:*', start=True)
-        for i in range(8)]
+    averagers = [
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            target_group_size=2,
+            prefix="mygroup",
+            initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
+            listen_on="127.0.0.1:*",
+            start=True,
+        )
+        for i in range(8)
+    ]
 
     [means0], [stds0] = compute_mean_std(averagers)
     assert not torch.allclose(stds0, torch.zeros_like(stds0))
@@ -222,19 +270,29 @@ def test_allreduce_grid():
 @pytest.mark.forked
 def test_allgather():
     dht = hivemind.DHT(start=True)
-    averagers = [hivemind.averaging.DecentralizedAverager([torch.ones(1)], dht=dht, target_group_size=4,
-                                                          averaging_expiration=15, prefix='mygroup',
-                                                          initial_group_bits='000', listen_on='127.0.0.1:*', start=True)
-                 for _ in range(8)]
+    averagers = [
+        hivemind.averaging.DecentralizedAverager(
+            [torch.ones(1)],
+            dht=dht,
+            target_group_size=4,
+            averaging_expiration=15,
+            prefix="mygroup",
+            initial_group_bits="000",
+            listen_on="127.0.0.1:*",
+            start=True,
+        )
+        for _ in range(8)
+    ]
 
     futures = []
     for i, averager in enumerate(averagers):
-        futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo='bar')))
+        futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
 
     assert len(set(repr(sorted(future.result())) for future in futures)) == 2
 
-    reference_metadata = {averager.endpoint: dict(batch_size=123 + i, foo='bar')
-                          for i, averager in enumerate(averagers)}
+    reference_metadata = {
+        averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
+    }
     for future in futures:
         gathered = future.result()
 
@@ -249,8 +307,10 @@ def test_allgather():
 
 
 def get_cost(vector_size, partitions, throughputs):
-    return max((vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
-               for i in range(len(partitions)))
+    return max(
+        (vector_size - partitions[i] + (len(partitions) - 1) * partitions[i]) / max(throughputs[i], 1e-9)
+        for i in range(len(partitions))
+    )
 
 
 def check_optimality(vector_size, throughputs, ref_partitions):
@@ -292,11 +352,20 @@ def test_load_balancing():
 @pytest.mark.forked
 def test_too_few_peers():
     dht = hivemind.DHT(start=True)
-    averagers = [hivemind.averaging.DecentralizedAverager(
-        averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
-        averaging_expiration=1, request_timeout=0.5,
-        prefix='mygroup', initial_group_bits=bin(i)[2:].rjust(3, '0'), listen_on='127.0.0.1:*', start=True)
-        for i in range(4)]
+    averagers = [
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            target_group_size=2,
+            averaging_expiration=1,
+            request_timeout=0.5,
+            prefix="mygroup",
+            initial_group_bits=bin(i)[2:].rjust(3, "0"),
+            listen_on="127.0.0.1:*",
+            start=True,
+        )
+        for i in range(4)
+    ]
     step_futures = [averager.step(wait=False) for averager in averagers]
     for future in step_futures:
         assert len(future.result()) == 2
@@ -309,11 +378,20 @@ def test_too_few_peers():
 @pytest.mark.forked
 def test_overcrowded(num_peers=16):
     dht = hivemind.DHT(start=True)
-    averagers = [hivemind.averaging.DecentralizedAverager(
-        averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
-        averaging_expiration=1, request_timeout=0.5,
-        prefix='mygroup', initial_group_bits='', listen_on='127.0.0.1:*', start=True)
-        for _ in range(num_peers)]
+    averagers = [
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            target_group_size=2,
+            averaging_expiration=1,
+            request_timeout=0.5,
+            prefix="mygroup",
+            initial_group_bits="",
+            listen_on="127.0.0.1:*",
+            start=True,
+        )
+        for _ in range(num_peers)
+    ]
     for t in range(5):
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
@@ -342,15 +420,25 @@ def test_load_state_from_peers():
     dht_root = hivemind.DHT(start=True)
     initial_peers = dht_root.get_visible_maddrs()
     dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
-    averager1 = TestAverager([torch.randn(3), torch.rand(5)],
-                             dht=dht1, start=True,
-                             prefix='demo-run', target_group_size=2, listen_on='127.0.0.1:*')
+    averager1 = TestAverager(
+        [torch.randn(3), torch.rand(5)],
+        dht=dht1,
+        start=True,
+        prefix="demo-run",
+        target_group_size=2,
+        listen_on="127.0.0.1:*",
+    )
 
     dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
-    dht2.get('demo-run.all_averagers')
-    averager2 = TestAverager([torch.randn(3), torch.rand(5)],
-                             dht=dht2, start=True,
-                             prefix='demo-run', target_group_size=2, listen_on='127.0.0.1:*')
+    dht2.get("demo-run.all_averagers")
+    averager2 = TestAverager(
+        [torch.randn(3), torch.rand(5)],
+        dht=dht2,
+        start=True,
+        prefix="demo-run",
+        target_group_size=2,
+        listen_on="127.0.0.1:*",
+    )
 
     assert num_calls == 0
     got_metadata, got_tensors = averager2.load_state_from_peers()
@@ -358,7 +446,7 @@ def test_load_state_from_peers():
     assert got_metadata == super_metadata
     assert all(map(torch.allclose, got_tensors, super_tensors))
 
-    super_metadata['y'] = 123
+    super_metadata["y"] = 123
     super_tensors[1][2] = 9
     assert num_calls == 1
     assert got_metadata != super_metadata
@@ -379,10 +467,11 @@ def test_load_state_from_peers():
 @pytest.mark.forked
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
-    averager = hivemind.averaging.DecentralizedAverager([torch.randn(3)], dht=dht, start=True, prefix='test_prefix',
-                                                        target_group_size=2, listen_on='127.0.0.1:*')
-    averager.set_group_bits('00101011101010')
-    assert averager.get_group_bits() == '00101011101010'
+    averager = hivemind.averaging.DecentralizedAverager(
+        [torch.randn(3)], dht=dht, start=True, prefix="test_prefix", target_group_size=2, listen_on="127.0.0.1:*"
+    )
+    averager.set_group_bits("00101011101010")
+    assert averager.get_group_bits() == "00101011101010"
 
 
 @pytest.mark.forked
@@ -390,18 +479,25 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)
 
     dht = hivemind.DHT(start=True)
-    common_kwargs = {'dht': dht, 'start': True, 'listen_on': '127.0.0.1:*',
-                     'prefix': 'demo-run', 'target_group_size': 2}
+    common_kwargs = {
+        "dht": dht,
+        "start": True,
+        "listen_on": "127.0.0.1:*",
+        "prefix": "demo-run",
+        "target_group_size": 2,
+    }
 
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.averaging.TrainingAverager(opt1, average_gradients=True, average_parameters=True,
-                                                    average_opt_statistics=["exp_avg_sq"], **common_kwargs)
+    averager1 = hivemind.averaging.TrainingAverager(
+        opt1, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
+    )
 
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.averaging.TrainingAverager(opt2, average_gradients=True, average_parameters=True,
-                                                    average_opt_statistics=["exp_avg_sq"], **common_kwargs)
+    averager2 = hivemind.averaging.TrainingAverager(
+        opt2, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
+    )
     a = torch.ones(n_dims)
 
     for i in range(n_steps):

+ 26 - 13
tests/test_custom_experts.py

@@ -6,17 +6,22 @@ import torch
 from hivemind import RemoteExpert
 from hivemind.moe.server import background_server
 
-CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), 'test_utils', 'custom_networks.py')
+CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
 
 
 @pytest.mark.forked
 def test_custom_expert(hid_dim=16):
     with background_server(
-            expert_cls='perceptron', num_experts=2, device='cpu',
-            hidden_dim=hid_dim, num_handlers=2, no_dht=True,
-            custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _):
-        expert0 = RemoteExpert('expert.0', server_endpoint)
-        expert1 = RemoteExpert('expert.1', server_endpoint)
+        expert_cls="perceptron",
+        num_experts=2,
+        device="cpu",
+        hidden_dim=hid_dim,
+        num_handlers=2,
+        no_dht=True,
+        custom_module_path=CUSTOM_EXPERTS_PATH,
+    ) as (server_endpoint, _):
+        expert0 = RemoteExpert("expert.0", server_endpoint)
+        expert1 = RemoteExpert("expert.1", server_endpoint)
 
         for batch_size in (1, 4):
             batch = torch.randn(batch_size, hid_dim)
@@ -33,15 +38,23 @@ def test_custom_expert(hid_dim=16):
 @pytest.mark.forked
 def test_multihead_expert(hid_dim=16):
     with background_server(
-            expert_cls='multihead', num_experts=2, device='cpu',
-            hidden_dim=hid_dim, num_handlers=2, no_dht=True,
-            custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _):
-        expert0 = RemoteExpert('expert.0', server_endpoint)
-        expert1 = RemoteExpert('expert.1', server_endpoint)
+        expert_cls="multihead",
+        num_experts=2,
+        device="cpu",
+        hidden_dim=hid_dim,
+        num_handlers=2,
+        no_dht=True,
+        custom_module_path=CUSTOM_EXPERTS_PATH,
+    ) as (server_endpoint, _):
+        expert0 = RemoteExpert("expert.0", server_endpoint)
+        expert1 = RemoteExpert("expert.1", server_endpoint)
 
         for batch_size in (1, 4):
-            batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim),
-                     torch.randn(batch_size, 3 * hid_dim))
+            batch = (
+                torch.randn(batch_size, hid_dim),
+                torch.randn(batch_size, 2 * hid_dim),
+                torch.randn(batch_size, 3 * hid_dim),
+            )
 
             output0 = expert0(*batch)
             output1 = expert1(*batch)

+ 21 - 22
tests/test_dht.py

@@ -8,7 +8,6 @@ from multiaddr import Multiaddr
 import hivemind
 
 
-
 @pytest.mark.forked
 def test_get_store(n_peers=10):
     peers = [hivemind.DHT(start=True)]
@@ -16,34 +15,34 @@ def test_get_store(n_peers=10):
     peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
 
     node1, node2 = random.sample(peers, 2)
-    assert node1.store('key1', 'value1', expiration_time=hivemind.get_dht_time() + 30)
-    assert node1.get('key1').value == 'value1'
-    assert node2.get('key1').value == 'value1'
-    assert node2.get('key2') is None
+    assert node1.store("key1", "value1", expiration_time=hivemind.get_dht_time() + 30)
+    assert node1.get("key1").value == "value1"
+    assert node2.get("key1").value == "value1"
+    assert node2.get("key2") is None
 
-    future = node1.get('foo', return_future=True)
+    future = node1.get("foo", return_future=True)
     assert future.result() is None
 
-    future = node1.get('foo', return_future=True)
+    future = node1.get("foo", return_future=True)
     future.cancel()
 
-    assert node2.store('key1', 123, expiration_time=hivemind.get_dht_time() + 31)
-    assert node2.store('key2', 456, expiration_time=hivemind.get_dht_time() + 32)
-    assert node1.get('key1', latest=True).value == 123
-    assert node1.get('key2').value == 456
+    assert node2.store("key1", 123, expiration_time=hivemind.get_dht_time() + 31)
+    assert node2.store("key2", 456, expiration_time=hivemind.get_dht_time() + 32)
+    assert node1.get("key1", latest=True).value == 123
+    assert node1.get("key2").value == 456
 
-    assert node1.store('key2', subkey='subkey1', value=789, expiration_time=hivemind.get_dht_time() + 32)
-    assert node2.store('key2', subkey='subkey2', value='pew', expiration_time=hivemind.get_dht_time() + 32)
-    found_dict = node1.get('key2', latest=True).value
+    assert node1.store("key2", subkey="subkey1", value=789, expiration_time=hivemind.get_dht_time() + 32)
+    assert node2.store("key2", subkey="subkey2", value="pew", expiration_time=hivemind.get_dht_time() + 32)
+    found_dict = node1.get("key2", latest=True).value
     assert isinstance(found_dict, dict) and len(found_dict) == 2
-    assert found_dict['subkey1'].value == 789 and found_dict['subkey2'].value == 'pew'
+    assert found_dict["subkey1"].value == 789 and found_dict["subkey2"].value == "pew"
 
     for peer in peers:
         peer.shutdown()
 
 
 async def dummy_dht_coro(self, node):
-    return 'pew'
+    return "pew"
 
 
 async def dummy_dht_coro_error(self, node):
@@ -51,7 +50,7 @@ async def dummy_dht_coro_error(self, node):
 
 
 async def dummy_dht_coro_stateful(self, node):
-    self._x_dummy = getattr(self, '_x_dummy', 123) + 1
+    self._x_dummy = getattr(self, "_x_dummy", 123) + 1
     return self._x_dummy
 
 
@@ -69,7 +68,7 @@ async def dummy_dht_coro_for_cancel(self, node):
 @pytest.mark.forked
 def test_run_coroutine():
     dht = hivemind.DHT(start=True)
-    assert dht.run_coroutine(dummy_dht_coro) == 'pew'
+    assert dht.run_coroutine(dummy_dht_coro) == "pew"
 
     with pytest.raises(ValueError):
         res = dht.run_coroutine(dummy_dht_coro_error)
@@ -78,7 +77,7 @@ def test_run_coroutine():
     assert dht.run_coroutine(dummy_dht_coro_stateful) == 124
     assert dht.run_coroutine(dummy_dht_coro_stateful) == 125
     assert dht.run_coroutine(dummy_dht_coro_stateful) == 126
-    assert not hasattr(dht, '_x_dummy')
+    assert not hasattr(dht, "_x_dummy")
     assert bg_task.result() == 126 ** 2
 
     future = dht.run_coroutine(dummy_dht_coro_for_cancel, return_future=True)
@@ -96,14 +95,14 @@ async def test_dht_get_visible_maddrs():
 
     dht = hivemind.DHT(start=True)
 
-    assert any(str(maddr).startswith('/ip4/127.0.0.1') for maddr in dht.get_visible_maddrs())
+    assert any(str(maddr).startswith("/ip4/127.0.0.1") for maddr in dht.get_visible_maddrs())
     dht.shutdown()
 
     # test 2: announce_maddrs are the single visible multiaddrs if defined
 
-    dummy_endpoint = Multiaddr('/ip4/123.45.67.89/tcp/31337')
+    dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     dht = hivemind.DHT(p2p, start=True)
 
-    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f'/p2p/{p2p.id}')]
+    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.id}")]
     dht.shutdown()

+ 37 - 34
tests/test_dht_crypto.py

@@ -17,13 +17,10 @@ def test_rsa_signature_validator():
     sender_validator = RSASignatureValidator(RSAPrivateKey())
     mallory_validator = RSASignatureValidator(RSAPrivateKey())
 
-    plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
-                             expiration_time=get_dht_time() + 10)
+    plain_record = DHTRecord(key=b"key", subkey=b"subkey", value=b"value", expiration_time=get_dht_time() + 10)
     protected_records = [
-        dataclasses.replace(plain_record,
-                            key=plain_record.key + sender_validator.local_public_key),
-        dataclasses.replace(plain_record,
-                            subkey=plain_record.subkey + sender_validator.local_public_key),
+        dataclasses.replace(plain_record, key=plain_record.key + sender_validator.local_public_key),
+        dataclasses.replace(plain_record, subkey=plain_record.subkey + sender_validator.local_public_key),
     ]
 
     # test 1: Non-protected record (no signature added)
@@ -31,19 +28,21 @@ def test_rsa_signature_validator():
     assert receiver_validator.validate(plain_record)
 
     # test 2: Correct signatures
-    signed_records = [dataclasses.replace(record, value=sender_validator.sign_value(record))
-                      for record in protected_records]
+    signed_records = [
+        dataclasses.replace(record, value=sender_validator.sign_value(record)) for record in protected_records
+    ]
     for record in signed_records:
         assert receiver_validator.validate(record)
-        assert receiver_validator.strip_value(record) == b'value'
+        assert receiver_validator.strip_value(record) == b"value"
 
     # test 3: Invalid signatures
     signed_records = protected_records  # Without signature
-    signed_records += [dataclasses.replace(record,
-                                           value=record.value + b'[signature:INVALID_BYTES]')
-                       for record in protected_records]  # With invalid signature
-    signed_records += [dataclasses.replace(record, value=mallory_validator.sign_value(record))
-                       for record in protected_records]  # With someone else's signature
+    signed_records += [
+        dataclasses.replace(record, value=record.value + b"[signature:INVALID_BYTES]") for record in protected_records
+    ]  # With invalid signature
+    signed_records += [
+        dataclasses.replace(record, value=mallory_validator.sign_value(record)) for record in protected_records
+    ]  # With someone else's signature
     for record in signed_records:
         assert not receiver_validator.validate(record)
 
@@ -66,11 +65,15 @@ def test_validator_instance_is_picklable():
     # To check that the private key was pickled and unpickled correctly, we sign a record
     # with the original public key using the unpickled validator and then validate the signature
 
-    record = DHTRecord(key=b'key', subkey=b'subkey' + original_validator.local_public_key,
-                       value=b'value', expiration_time=get_dht_time() + 10)
+    record = DHTRecord(
+        key=b"key",
+        subkey=b"subkey" + original_validator.local_public_key,
+        value=b"value",
+        expiration_time=get_dht_time() + 10,
+    )
     signed_record = dataclasses.replace(record, value=unpickled_validator.sign_value(record))
 
-    assert b'[signature:' in signed_record.value
+    assert b"[signature:" in signed_record.value
     assert original_validator.validate(signed_record)
     assert unpickled_validator.validate(signed_record)
 
@@ -93,12 +96,13 @@ def test_signing_in_different_process():
     validator = RSASignatureValidator()
     parent_conn.send(validator)
 
-    record = DHTRecord(key=b'key', subkey=b'subkey' + validator.local_public_key,
-                       value=b'value', expiration_time=get_dht_time() + 10)
+    record = DHTRecord(
+        key=b"key", subkey=b"subkey" + validator.local_public_key, value=b"value", expiration_time=get_dht_time() + 10
+    )
     parent_conn.send(record)
 
     signed_record = parent_conn.recv()
-    assert b'[signature:' in signed_record.value
+    assert b"[signature:" in signed_record.value
     assert validator.validate(signed_record)
 
 
@@ -107,27 +111,26 @@ def test_signing_in_different_process():
 async def test_dhtnode_signatures():
     alice = await DHTNode.create(record_validator=RSASignatureValidator())
     initial_peers = await alice.get_visible_maddrs()
-    bob = await DHTNode.create(
-        record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers)
+    bob = await DHTNode.create(record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers)
     mallory = await DHTNode.create(
-        record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers)
+        record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers
+    )
 
-    key = b'key'
-    subkey = b'protected_subkey' + bob.protocol.record_validator.local_public_key
+    key = b"key"
+    subkey = b"protected_subkey" + bob.protocol.record_validator.local_public_key
 
-    assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey)
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
+    assert await bob.store(key, b"true_value", hivemind.get_dht_time() + 10, subkey=subkey)
+    assert (await alice.get(key, latest=True)).value[subkey].value == b"true_value"
 
-    store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    store_ok = await mallory.store(key, b"fake_value", hivemind.get_dht_time() + 10, subkey=subkey)
     assert not store_ok
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
+    assert (await alice.get(key, latest=True)).value[subkey].value == b"true_value"
 
-    assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey)
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
+    assert await bob.store(key, b"updated_true_value", hivemind.get_dht_time() + 10, subkey=subkey)
+    assert (await alice.get(key, latest=True)).value[subkey].value == b"updated_true_value"
 
     await bob.shutdown()  # Bob has shut down, now Mallory is the single peer of Alice
 
-    store_ok = await mallory.store(key, b'updated_fake_value',
-                                   hivemind.get_dht_time() + 10, subkey=subkey)
+    store_ok = await mallory.store(key, b"updated_fake_value", hivemind.get_dht_time() + 10, subkey=subkey)
     assert not store_ok
-    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
+    assert (await alice.get(key, latest=True)).value[subkey].value == b"updated_true_value"

+ 93 - 48
tests/test_dht_experts.py

@@ -25,17 +25,17 @@ def test_store_get_experts(n_peers=10):
     expert_uids = [f"my_expert.{i}" for i in range(50)]
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
-        declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
+        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], "localhost:1234")
 
-    found = get_experts(other_peer, random.sample(expert_uids, 5) + ['foo', 'bar'])
+    found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
     other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
-    declare_experts(other_peer, [other_expert], f'that_host:{other_port}')
-    first_notfound, first_found = get_experts(first_peer, ['foobar', other_expert])
+    declare_experts(other_peer, [other_expert], f"that_host:{other_port}")
+    first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
-    assert first_found.endpoint == f'that_host:{other_port}'
+    assert first_found.endpoint == f"that_host:{other_port}"
 
     # test graceful shutdown
     first_peer.shutdown()
@@ -43,28 +43,32 @@ def test_store_get_experts(n_peers=10):
     time.sleep(1.0)
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
-    assert all(declare_experts(remaining_peer1, ['new_expert.1'], 'dummy'))
-    assert get_experts(remaining_peer2, ['new_expert.1'])[0].endpoint == 'dummy'
+    assert all(declare_experts(remaining_peer1, ["new_expert.1"], "dummy"))
+    assert get_experts(remaining_peer2, ["new_expert.1"])[0].endpoint == "dummy"
 
 
 @pytest.mark.forked
-def test_beam_search(n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4,
-                     grid_dims=(32, 32, 32)):
+def test_beam_search(
+    n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
+):
     dht = [hivemind.DHT(start=True)]
     initial_peers = dht[0].get_visible_maddrs()
     dht += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
 
-    real_experts = sorted({
-        'expert.' + '.'.join([str(random.randint(0, dim - 1)) for dim in grid_dims])
-        for _ in range(total_experts)
-    })
+    real_experts = sorted(
+        {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
+    )
     for batch_start in range(0, len(real_experts), batch_size):
-        declare_experts(random.choice(dht), real_experts[batch_start: batch_start + batch_size], wait=True,
-                        endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}")
+        declare_experts(
+            random.choice(dht),
+            real_experts[batch_start : batch_start + batch_size],
+            wait=True,
+            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}",
+        )
 
     neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
     you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
-    beam_search = MoEBeamSearcher(you, 'expert.', grid_dims)
+    beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
 
     for i in range(10):
         topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
@@ -72,8 +76,9 @@ def test_beam_search(n_peers=20, total_experts=128, batch_size=32, beam_size=4,
         assert len(topk_experts) == beam_size
 
     for i in range(10):
-        batch_experts = beam_search.batch_find_best_experts([np.random.randn(batch_size, dim) for dim in grid_dims],
-                                                            beam_size=beam_size)
+        batch_experts = beam_search.batch_find_best_experts(
+            [np.random.randn(batch_size, dim) for dim in grid_dims], beam_size=beam_size
+        )
         assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
         assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
         assert all(len(experts) == beam_size for experts in batch_experts)
@@ -82,43 +87,57 @@ def test_beam_search(n_peers=20, total_experts=128, batch_size=32, beam_size=4,
 @pytest.mark.forked
 def test_dht_single_node():
     node = hivemind.DHT(start=True)
-    beam_search = MoEBeamSearcher(node, 'expert.', grid_size=(10,))
+    beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
 
-    assert all(declare_experts(node, ['expert.1', 'expert.2', 'expert.3'], f"{hivemind.LOCALHOST}:1337").values())
+    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], f"{hivemind.LOCALHOST}:1337").values())
     assert len(declare_experts(node, ["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(declare_experts(node, ['e.1.2.3', 'e.1.2.5', 'e.2.0'], f"{hivemind.LOCALHOST}:42")) == 7
+    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], f"{hivemind.LOCALHOST}:42")) == 7
 
-    for expert in get_experts(node, ['expert.3', 'expert.2']):
+    for expert in get_experts(node, ["expert.3", "expert.2"]):
         assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
 
-    assert all(declare_experts(node, ['expert.5', 'expert.2'], f"{hivemind.LOCALHOST}:1337").values())
-    found_experts = beam_search.find_best_experts([(0., 1., 2., 3., 4., 5., 6., 7., 8.)], beam_size=2)
-    assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ['expert.5', 'expert.3']
+    assert all(declare_experts(node, ["expert.5", "expert.2"], f"{hivemind.LOCALHOST}:1337").values())
+    found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
+    assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
 
-    successors = beam_search.get_active_successors(['e.1.2.', 'e.2.', 'e.4.5.'])
-    assert len(successors['e.1.2.']) == 2
-    assert successors['e.1.2.'][3] == UidEndpoint('e.1.2.3', f'{LOCALHOST}:42')
-    assert successors['e.1.2.'][5] == UidEndpoint('e.1.2.5', f'{LOCALHOST}:42')
-    assert len(successors['e.2.']) == 1 and successors['e.2.'][0] == UidEndpoint('e.2.0', f'{LOCALHOST}:42')
-    assert successors['e.4.5.'] == {}
+    successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
+    assert len(successors["e.1.2."]) == 2
+    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", f"{LOCALHOST}:42")
+    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", f"{LOCALHOST}:42")
+    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", f"{LOCALHOST}:42")
+    assert successors["e.4.5."] == {}
 
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
     assert len(initial_beam) == 3
-    assert initial_beam[0][:2] == (2.0, 'expert.1.')
-    assert initial_beam[1][:2] == (1.0, 'expert.2.')
-    assert initial_beam[2][:2] == (0.0, 'expert.3.')
+    assert initial_beam[0][:2] == (2.0, "expert.1.")
+    assert initial_beam[1][:2] == (1.0, "expert.2.")
+    assert initial_beam[2][:2] == (0.0, "expert.3.")
 
     with pytest.raises(AssertionError):
-        beam_search = MoEBeamSearcher(node, 'expert.1.ffn', (2, 2))
+        beam_search = MoEBeamSearcher(node, "expert.1.ffn", (2, 2))
 
     with pytest.raises(AssertionError):
-        beam_search.get_active_successors(['e.1.2.', 'e.2', 'e.4.5.'])
+        beam_search.get_active_successors(["e.1.2.", "e.2", "e.4.5."])
 
 
 def test_uid_patterns():
-    valid_experts = ["expert.1", "expert.0", "expert.0.0.1", "expert.1337", "ffn.12.34.56.78.90",
-                     "transformer.3.2.1.0", "transformer_encoder.2", "transformer::encoder.2", "T®@nsf0rmE®🤗.321",
-                     "🤗.321", "0.1.2", "00.1.2", "7070.3.2.1.0", "block2.1.23", "LAYER.1.0.1"]
+    valid_experts = [
+        "expert.1",
+        "expert.0",
+        "expert.0.0.1",
+        "expert.1337",
+        "ffn.12.34.56.78.90",
+        "transformer.3.2.1.0",
+        "transformer_encoder.2",
+        "transformer::encoder.2",
+        "T®@nsf0rmE®🤗.321",
+        "🤗.321",
+        "0.1.2",
+        "00.1.2",
+        "7070.3.2.1.0",
+        "block2.1.23",
+        "LAYER.1.0.1",
+    ]
     valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
     valid_prefixes.extend([f"{uid}." for uid in valid_experts])
     valid_prefixes.extend([split_uid(uid)[0] for uid in valid_experts])
@@ -127,10 +146,36 @@ def test_uid_patterns():
     for pfx in valid_prefixes:
         assert is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
 
-    invalid = ["", ".", "expert.-1", "xxx.a", "expert.1x", "expert_ffn.1.abc1", "some.123.01", "expert.123.01",
-               "e1", "e..1", "e", "e.1.2.3..4", "ffn.1..1", ".123", ".1.2.3.", ".expert", "transformer.encoder.2",
-               "T®@nsf0rmE®.🤗.321", "layer::123", "expert.0.1.2.suffix", "0.1.2.suffix", "expert.1 something",
-               "expert.1\n", "expert.1\n2", "expert.1 ", "expert.1\nexpert.2", "'expert.1'", '"expert.1"']
+    invalid = [
+        "",
+        ".",
+        "expert.-1",
+        "xxx.a",
+        "expert.1x",
+        "expert_ffn.1.abc1",
+        "some.123.01",
+        "expert.123.01",
+        "e1",
+        "e..1",
+        "e",
+        "e.1.2.3..4",
+        "ffn.1..1",
+        ".123",
+        ".1.2.3.",
+        ".expert",
+        "transformer.encoder.2",
+        "T®@nsf0rmE®.🤗.321",
+        "layer::123",
+        "expert.0.1.2.suffix",
+        "0.1.2.suffix",
+        "expert.1 something",
+        "expert.1\n",
+        "expert.1\n2",
+        "expert.1 ",
+        "expert.1\nexpert.2",
+        "'expert.1'",
+        '"expert.1"',
+    ]
     invalid_experts = invalid + valid_prefixes + ["0", "123456"]
     invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
     for uid in invalid_experts:
@@ -142,23 +187,23 @@ def test_uid_patterns():
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_negative_caching(n_peers=10):
-    dht_kwargs = {'cache_locally': False}
+    dht_kwargs = {"cache_locally": False}
 
     peers = [hivemind.DHT(start=True, **dht_kwargs)]
     initial_peers = peers[0].get_visible_maddrs()
     peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
 
     writer_peer = random.choice(peers)
-    assert all(declare_experts(writer_peer, ['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
+    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], "myaddr:1234").values())
 
     neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)
-    beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix='ffn.', grid_size=(10, 10, 10), negative_caching=True)
+    beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix="ffn.", grid_size=(10, 10, 10), negative_caching=True)
     # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
-    assert len(beam_search.get_initial_beam(scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
+    assert len(beam_search.get_initial_beam(scores=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], beam_size=3)) == 2
 
     node = await DHTNode.create(initial_peers=neighbors)
-    fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
+    fetched = await asyncio.gather(*(node.get(f"ffn.{i}.") for i in range(10)))
     for i in range(6):
         assert fetched[i] is not None, f"node should have cached ffn.{i}."
     for i in range(6, len(fetched)):

+ 121 - 86
tests/test_dht_node.py

@@ -24,18 +24,20 @@ logger = get_logger(__name__)
 
 
 def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
-    return list({PeerID.from_base58(maddr['p2p']) for maddr in maddrs})
+    return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
 
 
-def run_protocol_listener(dhtid: DHTID, maddr_conn: mp.connection.Connection,
-                          initial_peers: Sequence[Multiaddr]) -> None:
+def run_protocol_listener(
+    dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
+) -> None:
     loop = asyncio.get_event_loop()
 
     p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
     visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
 
-    protocol = loop.run_until_complete(DHTProtocol.create(
-        p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5))
+    protocol = loop.run_until_complete(
+        DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
+    )
 
     logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
 
@@ -53,8 +55,9 @@ def run_protocol_listener(dhtid: DHTID, maddr_conn: mp.connection.Connection,
     loop.run_forever()
 
 
-def launch_protocol_listener(initial_peers: Sequence[Multiaddr] = ()) -> \
-        Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
+def launch_protocol_listener(
+    initial_peers: Sequence[Multiaddr] = (),
+) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
     remote_conn, local_conn = mp.Pipe()
     dht_id = DHTID.generate()
     process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
@@ -76,55 +79,72 @@ def test_dht_protocol():
     loop = asyncio.get_event_loop()
     for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
         p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
-        protocol = loop.run_until_complete(DHTProtocol.create(
-            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
+        protocol = loop.run_until_complete(
+            DHTProtocol.create(
+                p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen
+            )
+        )
         logger.info(f"Self id={protocol.node_id}")
 
         assert loop.run_until_complete(protocol.call_ping(peer1_id)) == peer1_node_id
 
-        key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
-        store_ok = loop.run_until_complete(protocol.call_store(
-            peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
+        store_ok = loop.run_until_complete(
+            protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
         )
         assert all(store_ok), "DHT rejected a trivial store"
 
         # peer 1 must know about peer 2
         (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [key]))[key]
+            protocol.call_find(peer1_id, [key])
+        )[key]
         recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
         (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
-        assert recv_id == peer2_node_id and recv_peer_id == peer2_id, \
-            f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
+        assert (
+            recv_id == peer2_node_id and recv_peer_id == peer2_id
+        ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
 
-        assert recv_value == value and recv_expiration == expiration, \
-            f"call_find_value expected {value} (expires by {expiration}) " \
+        assert recv_value == value and recv_expiration == expiration, (
+            f"call_find_value expected {value} (expires by {expiration}) "
             f"but got {recv_value} (expires by {recv_expiration})"
+        )
 
         # peer 2 must know about peer 1, but not have a *random* nonexistent value
         dummy_key = DHTID.generate()
-        empty_item, nodes_found_2 = loop.run_until_complete(
-            protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
+        empty_item, nodes_found_2 = loop.run_until_complete(protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
         assert empty_item is None, "Non-existent keys shouldn't have values"
         (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
-        assert recv_id == peer1_node_id and recv_peer_id == peer1_id, \
-            f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
+        assert (
+            recv_id == peer1_node_id and recv_peer_id == peer1_id
+        ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
 
         # cause a non-response by querying a nonexistent peer
-        assert loop.run_until_complete(protocol.call_find(PeerID.from_base58('fakeid'), [key])) is None
+        assert loop.run_until_complete(protocol.call_find(PeerID.from_base58("fakeid"), [key])) is None
 
         # store/get a dictionary with sub-keys
-        nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
-        value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
-        assert loop.run_until_complete(protocol.call_store(
-            peer1_id, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
-            expiration_time=[expiration], subkeys=[subkey1])
+        nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
+        value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
+        assert loop.run_until_complete(
+            protocol.call_store(
+                peer1_id,
+                keys=[nested_key],
+                values=[hivemind.MSGPackSerializer.dumps(value1)],
+                expiration_time=[expiration],
+                subkeys=[subkey1],
+            )
         )
-        assert loop.run_until_complete(protocol.call_store(
-            peer1_id, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
-            expiration_time=[expiration + 5], subkeys=[subkey2])
+        assert loop.run_until_complete(
+            protocol.call_store(
+                peer1_id,
+                keys=[nested_key],
+                values=[hivemind.MSGPackSerializer.dumps(value2)],
+                expiration_time=[expiration + 5],
+                subkeys=[subkey2],
+            )
         )
         (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [nested_key]))[nested_key]
+            protocol.call_find(peer1_id, [nested_key])
+        )[nested_key]
         assert isinstance(recv_dict, DictionaryDHTValue)
         assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
@@ -139,31 +159,36 @@ def test_dht_protocol():
 
 @pytest.mark.forked
 def test_empty_table():
-    """ Test RPC methods with empty routing table """
+    """Test RPC methods with empty routing table"""
     peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
 
     loop = asyncio.get_event_loop()
     p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
-    protocol = loop.run_until_complete(DHTProtocol.create(
-        p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
+    protocol = loop.run_until_complete(
+        DHTProtocol.create(
+            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False
+        )
+    )
 
-    key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+    key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
 
-    empty_item, nodes_found = loop.run_until_complete(
-        protocol.call_find(peer_peer_id, [key]))[key]
+    empty_item, nodes_found = loop.run_until_complete(protocol.call_find(peer_peer_id, [key]))[key]
     assert empty_item is None and len(nodes_found) == 0
-    assert all(loop.run_until_complete(protocol.call_store(
-        peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-    )), "peer rejected store"
+    assert all(
+        loop.run_until_complete(
+            protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        )
+    ), "peer rejected store"
 
     (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-        protocol.call_find(peer_peer_id, [key]))[key]
+        protocol.call_find(peer_peer_id, [key])
+    )[key]
     recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
     assert len(nodes_found) == 0
     assert recv_value == value and recv_expiration == expiration
 
     assert loop.run_until_complete(protocol.call_ping(peer_peer_id)) == peer_id
-    assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58('fakeid'))) is None
+    assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58("fakeid"))) is None
     peer_proc.terminate()
 
 
@@ -176,8 +201,9 @@ def test_dht_node():
     # step B: run 51-st node in this process
     loop = asyncio.get_event_loop()
     initial_peers = random.choice(swarm_maddrs)
-    me = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, parallel_rpc=10,
-                                                cache_refresh_before_expiry=False))
+    me = loop.run_until_complete(
+        DHTNode.create(initial_peers=initial_peers, parallel_rpc=10, cache_refresh_before_expiry=False)
+    )
 
     # test 1: find self
     nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
@@ -201,7 +227,8 @@ def test_dht_node():
         k_nearest = random.randint(1, 10)
         exclude_self = random.random() > 0.5
         nearest = loop.run_until_complete(
-            me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
+            me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
+        )[query_id]
         nearest_nodes = list(nearest)  # keys from ordered dict
 
         assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
@@ -245,8 +272,11 @@ def test_dht_node():
     assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
 
     initial_peers = random.choice(swarm_maddrs)
-    that_guy = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, parallel_rpc=10,
-                                                      cache_refresh_before_expiry=False, cache_locally=False))
+    that_guy = loop.run_until_complete(
+        DHTNode.create(
+            initial_peers=initial_peers, parallel_rpc=10, cache_refresh_before_expiry=False, cache_locally=False
+        )
+    )
 
     for node in [me, that_guy]:
         val, expiration_time = loop.run_until_complete(node.get("mykey"))
@@ -256,8 +286,8 @@ def test_dht_node():
     assert loop.run_until_complete(detached_node.get("mykey")) is None
 
     # test 7: bulk store and bulk get
-    keys = 'foo', 'bar', 'baz', 'zzz'
-    values = 3, 2, 'batman', [1, 2, 3]
+    keys = "foo", "bar", "baz", "zzz"
+    values = 3, 2, "batman", [1, 2, 3]
     store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
     assert all(store_ok.values()), "failed to store one or more keys"
     response = loop.run_until_complete(me.get_many(keys[::-1]))
@@ -265,7 +295,7 @@ def test_dht_node():
         assert key in response and response[key][0] == value
 
     # test 8: store dictionaries as values (with sub-keys)
-    upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
+    upper_key, subkey1, subkey2, subkey3 = "ololo", "k1", "k2", "k3"
     now = get_dht_time()
     assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
     assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
@@ -302,17 +332,17 @@ async def test_dhtnode_replicas():
     peers = await launch_star_shaped_swarm(n_peers=20, num_replicas=num_replicas)
 
     you = random.choice(peers)
-    assert await you.store('key1', 'foo', get_dht_time() + 999)
+    assert await you.store("key1", "foo", get_dht_time() + 999)
 
     actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
     assert num_replicas == actual_key1_replicas
 
-    assert await you.store('key2', 'bar', get_dht_time() + 999)
+    assert await you.store("key2", "bar", get_dht_time() + 999)
     total_size = sum(len(peer.protocol.storage) for peer in peers)
     actual_key2_replicas = total_size - actual_key1_replicas
     assert num_replicas == actual_key2_replicas
 
-    assert await you.store('key2', 'baz', get_dht_time() + 1000)
+    assert await you.store("key2", "baz", get_dht_time() + 1000)
     assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
 
 
@@ -320,21 +350,25 @@ async def test_dhtnode_replicas():
 @pytest.mark.asyncio
 async def test_dhtnode_caching(T=0.05):
     node2 = await DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
-    node1 = await DHTNode.create(initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
-                                 cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
-    await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
-    await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
-    await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
-    await node1.get_many(['k', 'k2', 'k3', 'k4'])
+    node1 = await DHTNode.create(
+        initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
+        cache_refresh_before_expiry=5 * T,
+        listen=False,
+        reuse_get_requests=False,
+    )
+    await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)
+    await node2.store("k2", [654, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)
+    await node2.store("k3", [654, "value"], expiration_time=hivemind.get_dht_time() + 15 * T)
+    await node1.get_many(["k", "k2", "k3", "k4"])
     assert len(node1.protocol.cache) == 3
     assert len(node1.cache_refresh_queue) == 0
 
-    await node1.get_many(['k', 'k2', 'k3', 'k4'])
+    await node1.get_many(["k", "k2", "k3", "k4"])
     assert len(node1.cache_refresh_queue) == 3
 
-    await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
+    await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 12 * T)
     await asyncio.sleep(4 * T)
-    await node1.get('k')
+    await node1.get("k")
     await asyncio.sleep(1 * T)
 
     assert len(node1.protocol.cache) == 3
@@ -348,11 +382,11 @@ async def test_dhtnode_caching(T=0.05):
     await asyncio.sleep(5 * T)
     assert len(node1.cache_refresh_queue) == 0
 
-    await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
-    await node1.get('k')
+    await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 10 * T)
+    await node1.get("k")
     await asyncio.sleep(1 * T)
     assert len(node1.cache_refresh_queue) == 0
-    await node1.get('k')
+    await node1.get("k")
     await asyncio.sleep(1 * T)
     assert len(node1.cache_refresh_queue) == 1
 
@@ -368,28 +402,28 @@ async def test_dhtnode_reuse_get():
     peers = await launch_star_shaped_swarm(n_peers=10, parallel_rpc=256)
 
     await asyncio.gather(
-        random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
-        random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
+        random.choice(peers).store("k1", 123, hivemind.get_dht_time() + 999),
+        random.choice(peers).store("k2", 567, hivemind.get_dht_time() + 999),
     )
 
     you = random.choice(peers)
 
-    futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
-    assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
-    assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
+    futures1 = await you.get_many(["k1", "k2"], return_futures=True)
+    assert len(you.pending_get_requests[DHTID.generate("k1")]) == 1
+    assert len(you.pending_get_requests[DHTID.generate("k2")]) == 1
 
-    futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
-    assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
+    futures2 = await you.get_many(["k2", "k3"], return_futures=True)
+    assert len(you.pending_get_requests[DHTID.generate("k2")]) == 2
 
     await asyncio.gather(*futures1.values(), *futures2.values())
-    futures3 = await you.get_many(['k3'], return_futures=True)
-    assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
-    assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
-    assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
+    futures3 = await you.get_many(["k3"], return_futures=True)
+    assert len(you.pending_get_requests[DHTID.generate("k1")]) == 0
+    assert len(you.pending_get_requests[DHTID.generate("k2")]) == 0
+    assert len(you.pending_get_requests[DHTID.generate("k3")]) == 1
 
-    assert (await futures1['k1'])[0] == 123
-    assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
-    assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
+    assert (await futures1["k1"])[0] == 123
+    assert await futures1["k2"] == await futures2["k2"] and (await futures1["k2"])[0] == 567
+    assert await futures2["k3"] == await futures3["k3"] and (await futures3["k3"]) is None
 
 
 @pytest.mark.forked
@@ -397,19 +431,19 @@ async def test_dhtnode_reuse_get():
 async def test_dhtnode_blacklist():
     node1, node2, node3, node4 = await launch_star_shaped_swarm(n_peers=4, blacklist_time=999)
 
-    assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
+    assert await node2.store("abc", 123, expiration_time=hivemind.get_dht_time() + 99)
     assert len(node2.blacklist.ban_counter) == 0
 
     await asyncio.gather(node3.shutdown(), node4.shutdown())
 
-    assert await node2.store('def', 456, expiration_time=hivemind.get_dht_time() + 99)
+    assert await node2.store("def", 456, expiration_time=hivemind.get_dht_time() + 99)
 
     assert set(node2.blacklist.ban_counter.keys()) == {node3.peer_id, node4.peer_id}
 
-    assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
+    assert await node1.get("abc", latest=True)  # force node1 to crawl dht and discover unresponsive peers
     assert node3.peer_id in node1.blacklist
 
-    assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
+    assert await node1.get("abc", latest=True)  # force node1 to crawl dht and discover unresponsive peers
     assert node2.peer_id not in node1.blacklist
 
     await asyncio.gather(node1.shutdown(), node2.shutdown())
@@ -420,12 +454,13 @@ async def test_dhtnode_blacklist():
 async def test_dhtnode_edge_cases():
     peers = await launch_star_shaped_swarm(n_peers=4, parallel_rpc=4)
 
-    subkeys = [0, '', False, True, 'abyrvalg', 4555]
+    subkeys = [0, "", False, True, "abyrvalg", 4555]
     keys = subkeys + [()]
     values = subkeys + [[]]
     for key, subkey, value in product(keys, subkeys, values):
-        await random.choice(peers).store(key=key, subkey=subkey, value=value,
-                                         expiration_time=hivemind.get_dht_time() + 999),
+        await random.choice(peers).store(
+            key=key, subkey=subkey, value=value, expiration_time=hivemind.get_dht_time() + 999
+        ),
 
         stored = await random.choice(peers).get(key=key, latest=True)
         assert stored is not None

+ 46 - 52
tests/test_dht_schema.py

@@ -34,17 +34,16 @@ async def test_expecting_regular_value(dht_nodes_with_schema):
     alice, bob = dht_nodes_with_schema
 
     # Regular value (bytes) expected
-    assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
-    assert not await bob.store('experiment_name', 666, get_dht_time() + 10)
-    assert not await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10,
-                               subkey=b'subkey')
+    assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
+    assert not await bob.store("experiment_name", 666, get_dht_time() + 10)
+    assert not await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10, subkey=b"subkey")
 
     # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
-    assert not await bob.store('experiment_name', [], get_dht_time() + 10)
-    assert not await bob.store('experiment_name', [1, 2, 3], get_dht_time() + 10)
+    assert not await bob.store("experiment_name", [], get_dht_time() + 10)
+    assert not await bob.store("experiment_name", [1, 2, 3], get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
+        assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"
 
 
 @pytest.mark.forked
@@ -53,30 +52,28 @@ async def test_expecting_dictionary(dht_nodes_with_schema):
     alice, bob = dht_nodes_with_schema
 
     # Dictionary (bytes -> non-negative int) expected
-    assert await bob.store('n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
-    assert await bob.store('n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
-    assert not await bob.store('n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
-    assert not await bob.store('n_batches', 666, get_dht_time() + 10)
-    assert not await bob.store('n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
-    assert not await bob.store('n_batches', 666, get_dht_time() + 10, subkey=666)
+    assert await bob.store("n_batches", 777, get_dht_time() + 10, subkey=b"uid1")
+    assert await bob.store("n_batches", 778, get_dht_time() + 10, subkey=b"uid2")
+    assert not await bob.store("n_batches", -666, get_dht_time() + 10, subkey=b"uid3")
+    assert not await bob.store("n_batches", 666, get_dht_time() + 10)
+    assert not await bob.store("n_batches", b"not_integer", get_dht_time() + 10, subkey=b"uid1")
+    assert not await bob.store("n_batches", 666, get_dht_time() + 10, subkey=666)
 
     # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
-    assert not await bob.store('n_batches', {b'uid3': 779}, get_dht_time() + 10)
+    assert not await bob.store("n_batches", {b"uid3": 779}, get_dht_time() + 10)
 
     # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
-    assert not await bob.store('n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
-    assert not await bob.store('n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
-    assert not await bob.store('n_batches', [], get_dht_time() + 10)
-    assert not await bob.store('n_batches', [(b'uid3', 779)], get_dht_time() + 10)
+    assert not await bob.store("n_batches", 779.5, get_dht_time() + 10, subkey=b"uid3")
+    assert not await bob.store("n_batches", 779.0, get_dht_time() + 10, subkey=b"uid3")
+    assert not await bob.store("n_batches", [], get_dht_time() + 10)
+    assert not await bob.store("n_batches", [(b"uid3", 779)], get_dht_time() + 10)
 
     # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
-    assert not await bob.store('n_batches', '', get_dht_time() + 10)
+    assert not await bob.store("n_batches", "", get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        dictionary = (await peer.get('n_batches', latest=True)).value
-        assert (len(dictionary) == 2 and
-                dictionary[b'uid1'].value == 777 and
-                dictionary[b'uid2'].value == 778)
+        dictionary = (await peer.get("n_batches", latest=True)).value
+        assert len(dictionary) == 2 and dictionary[b"uid1"].value == 777 and dictionary[b"uid2"].value == 778
 
 
 @pytest.mark.forked
@@ -86,15 +83,12 @@ async def test_expecting_public_keys(dht_nodes_with_schema):
 
     # Subkeys expected to contain a public key
     # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
-    assert await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
-                           subkey=b'uid[owner:public-key]')
-    assert not await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
-                               subkey=b'uid-without-public-key')
+    assert await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid[owner:public-key]")
+    assert not await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid-without-public-key")
 
     for peer in [alice, bob]:
-        dictionary = (await peer.get('signed_data', latest=True)).value
-        assert (len(dictionary) == 1 and
-                dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
+        dictionary = (await peer.get("signed_data", latest=True)).value
+        assert len(dictionary) == 1 and dictionary[b"uid[owner:public-key]"].value == b"foo_bar"
 
 
 @pytest.mark.forked
@@ -113,13 +107,13 @@ async def test_keys_outside_schema(dht_nodes_with_schema):
         alice = await DHTNode.create(record_validator=validator)
         bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
 
-        store_ok = await bob.store('unknown_key', b'foo_bar', get_dht_time() + 10)
+        store_ok = await bob.store("unknown_key", b"foo_bar", get_dht_time() + 10)
         assert store_ok == allow_extra_keys
 
         for peer in [alice, bob]:
-            result = await peer.get('unknown_key', latest=True)
+            result = await peer.get("unknown_key", latest=True)
             if allow_extra_keys:
-                assert result.value == b'foo_bar'
+                assert result.value == b"foo_bar"
             else:
                 assert result is None
 
@@ -130,18 +124,18 @@ async def test_prefix():
     class Schema(BaseModel):
         field: StrictInt
 
-    validator = SchemaValidator(Schema, allow_extra_keys=False, prefix='prefix')
+    validator = SchemaValidator(Schema, allow_extra_keys=False, prefix="prefix")
 
     alice = await DHTNode.create(record_validator=validator)
     bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
 
-    assert await bob.store('prefix_field', 777, get_dht_time() + 10)
-    assert not await bob.store('prefix_field', 'string_value', get_dht_time() + 10)
-    assert not await bob.store('field', 777, get_dht_time() + 10)
+    assert await bob.store("prefix_field", 777, get_dht_time() + 10)
+    assert not await bob.store("prefix_field", "string_value", get_dht_time() + 10)
+    assert not await bob.store("field", 777, get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        assert (await peer.get('prefix_field', latest=True)).value == 777
-        assert (await peer.get('field', latest=True)) is None
+        assert (await peer.get("prefix_field", latest=True)).value == 777
+        assert (await peer.get("field", latest=True)) is None
 
     await asyncio.gather(alice.shutdown(), bob.shutdown())
 
@@ -171,21 +165,21 @@ async def test_merging_schema_validators(dht_nodes_with_schema):
         for peer in [alice, bob]:
             assert peer.protocol.record_validator.merge_with(new_validator)
 
-    assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
-    assert await bob.store('some_field', 777, get_dht_time() + 10)
-    assert not await bob.store('some_field', 'string_value', get_dht_time() + 10)
-    assert await bob.store('another_field', 42, get_dht_time() + 10)
-    assert await bob.store('another_field', 'string_value', get_dht_time() + 10)
+    assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
+    assert await bob.store("some_field", 777, get_dht_time() + 10)
+    assert not await bob.store("some_field", "string_value", get_dht_time() + 10)
+    assert await bob.store("another_field", 42, get_dht_time() + 10)
+    assert await bob.store("another_field", "string_value", get_dht_time() + 10)
 
     # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
-    assert await bob.store('unknown_key', 999, get_dht_time() + 10)
+    assert await bob.store("unknown_key", 999, get_dht_time() + 10)
 
     for peer in [alice, bob]:
-        assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
-        assert (await peer.get('some_field', latest=True)).value == 777
-        assert (await peer.get('another_field', latest=True)).value == 'string_value'
+        assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"
+        assert (await peer.get("some_field", latest=True)).value == 777
+        assert (await peer.get("another_field", latest=True)).value == "string_value"
 
-        assert (await peer.get('unknown_key', latest=True)).value == 999
+        assert (await peer.get("unknown_key", latest=True)).value == 999
 
 
 @pytest.mark.forked
@@ -196,9 +190,9 @@ def test_sending_validator_instance_between_processes():
     alice.add_validators([SchemaValidator(SampleSchema)])
     bob.add_validators([SchemaValidator(SampleSchema)])
 
-    assert bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
-    assert not bob.store('experiment_name', 777, get_dht_time() + 10)
-    assert alice.get('experiment_name', latest=True).value == b'foo_bar'
+    assert bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
+    assert not bob.store("experiment_name", 777, get_dht_time() + 10)
+    assert alice.get("experiment_name", latest=True).value == b"foo_bar"
 
     alice.shutdown()
     bob.shutdown()

+ 31 - 33
tests/test_dht_storage.py

@@ -9,7 +9,6 @@ def test_store():
     d = DHTLocalStorage()
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
     assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
-    print("Test store passed")
 
 
 def test_get_expired():
@@ -17,13 +16,11 @@ def test_get_expired():
     d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
     time.sleep(0.5)
     assert d.get(DHTID.generate("key")) is None, "Expired value must be deleted"
-    print("Test get expired passed")
 
 
 def test_get_empty():
     d = DHTLocalStorage()
     assert d.get(DHTID.generate(source="key")) is None, "DHTLocalStorage returned non-existent value"
-    print("Test get expired passed")
 
 
 def test_change_expiration_time():
@@ -33,7 +30,6 @@ def test_change_expiration_time():
     d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
     time.sleep(1)
     assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
-    print("Test change expiration time passed")
 
 
 def test_maxsize_cache():
@@ -57,7 +53,7 @@ def test_localstorage_top():
     d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3)
     assert d.top()[0] == DHTID.generate("key2") and d.top()[1].value == b"val2"
 
-    del d[DHTID.generate('key2')]
+    del d[DHTID.generate("key2")]
     assert d.top()[0] == DHTID.generate("key1") and d.top()[1].value == b"val1_new"
     d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5)
     d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6)  # key4 will push out key1 due to maxsize
@@ -69,32 +65,34 @@ def test_localstorage_nested():
     time = get_dht_time()
     d1 = DHTLocalStorage()
     d2 = DictionaryDHTValue()
-    d2.store('subkey1', b'value1', time + 2)
-    d2.store('subkey2', b'value2', time + 3)
-    d2.store('subkey3', b'value3', time + 1)
+    d2.store("subkey1", b"value1", time + 2)
+    d2.store("subkey2", b"value2", time + 3)
+    d2.store("subkey3", b"value3", time + 1)
 
     assert d2.latest_expiration_time == time + 3
     for subkey, (subvalue, subexpiration) in d2.items():
-        assert d1.store_subkey(DHTID.generate('foo'), subkey, subvalue, subexpiration)
-    assert d1.store(DHTID.generate('bar'), b'456', time + 2)
-    assert d1.get(DHTID.generate('foo'))[0].data == d2.data
-    assert d1.get(DHTID.generate('foo'))[1] == d2.latest_expiration_time
-    assert d1.get(DHTID.generate('foo'))[0].get('subkey1') == (b'value1', time + 2)
-    assert len(d1.get(DHTID.generate('foo'))[0]) == 3
-    assert d1.store_subkey(DHTID.generate('foo'), 'subkey4', b'value4', time + 4)
-    assert len(d1.get(DHTID.generate('foo'))[0]) == 4
-
-    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA', time + 1) is False  # prev has better expiration
-    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA', time + 3)  # new value has better expiration
-    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyB', b'valueB', time + 4)  # new value has better expiration
-    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA+', time + 5)  # overwrite subkeyA under key bar
-    assert all(subkey in d1.get(DHTID.generate('bar'))[0] for subkey in ('subkeyA', 'subkeyB'))
-    assert len(d1.get(DHTID.generate('bar'))[0]) == 2 and d1.get(DHTID.generate('bar'))[1] == time + 5
-
-    assert d1.store(DHTID.generate('foo'), b'nothing', time + 3.5) is False  # previous value has better expiration
-    assert d1.get(DHTID.generate('foo'))[0].get('subkey2') == (b'value2', time + 3)
-    assert d1.store(DHTID.generate('foo'), b'nothing', time + 5) is True  # new value has better expiraiton
-    assert d1.get(DHTID.generate('foo')) == (b'nothing', time + 5)  # value should be replaced
+        assert d1.store_subkey(DHTID.generate("foo"), subkey, subvalue, subexpiration)
+    assert d1.store(DHTID.generate("bar"), b"456", time + 2)
+    assert d1.get(DHTID.generate("foo"))[0].data == d2.data
+    assert d1.get(DHTID.generate("foo"))[1] == d2.latest_expiration_time
+    assert d1.get(DHTID.generate("foo"))[0].get("subkey1") == (b"value1", time + 2)
+    assert len(d1.get(DHTID.generate("foo"))[0]) == 3
+    assert d1.store_subkey(DHTID.generate("foo"), "subkey4", b"value4", time + 4)
+    assert len(d1.get(DHTID.generate("foo"))[0]) == 4
+
+    assert (
+        d1.store_subkey(DHTID.generate("bar"), "subkeyA", b"valueA", time + 1) is False
+    )  # prev has better expiration
+    assert d1.store_subkey(DHTID.generate("bar"), "subkeyA", b"valueA", time + 3)  # new value has better expiration
+    assert d1.store_subkey(DHTID.generate("bar"), "subkeyB", b"valueB", time + 4)  # new value has better expiration
+    assert d1.store_subkey(DHTID.generate("bar"), "subkeyA", b"valueA+", time + 5)  # overwrite subkeyA under key bar
+    assert all(subkey in d1.get(DHTID.generate("bar"))[0] for subkey in ("subkeyA", "subkeyB"))
+    assert len(d1.get(DHTID.generate("bar"))[0]) == 2 and d1.get(DHTID.generate("bar"))[1] == time + 5
+
+    assert d1.store(DHTID.generate("foo"), b"nothing", time + 3.5) is False  # previous value has better expiration
+    assert d1.get(DHTID.generate("foo"))[0].get("subkey2") == (b"value2", time + 3)
+    assert d1.store(DHTID.generate("foo"), b"nothing", time + 5) is True  # new value has better expiraiton
+    assert d1.get(DHTID.generate("foo")) == (b"nothing", time + 5)  # value should be replaced
 
 
 def test_localstorage_freeze():
@@ -120,13 +118,13 @@ def test_localstorage_serialize():
     d2 = DictionaryDHTValue()
 
     now = get_dht_time()
-    d1.store('key1', b'ololo', now + 1)
-    d2.store('key2', b'pysh', now + 1)
-    d2.store('key3', b'pyshpysh', now + 2)
+    d1.store("key1", b"ololo", now + 1)
+    d2.store("key2", b"pysh", now + 1)
+    d2.store("key3", b"pyshpysh", now + 2)
 
     data = MSGPackSerializer.dumps([d1, d2, 123321])
     assert isinstance(data, bytes)
     new_d1, new_d2, new_value = MSGPackSerializer.loads(data)
     assert isinstance(new_d1, DictionaryDHTValue) and isinstance(new_d2, DictionaryDHTValue) and new_value == 123321
-    assert 'key1' in new_d1 and len(new_d1) == 1
-    assert 'key1' not in new_d2 and len(new_d2) == 2 and new_d2.get('key3') == (b'pyshpysh', now + 2)
+    assert "key1" in new_d1 and len(new_d1) == 1
+    assert "key1" not in new_d2 and len(new_d2) == 2 and new_d2.get("key3") == (b"pyshpysh", now + 2)

+ 34 - 33
tests/test_dht_validation.py

@@ -24,41 +24,43 @@ class SchemaB(BaseModel):
 def validators_for_app():
     # Each application may add its own validator set
     return {
-        'A': [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)],
-        'B': [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()],
+        "A": [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)],
+        "B": [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()],
     }
 
 
 def test_composite_validator(validators_for_app):
-    validator = CompositeValidator(validators_for_app['A'])
-    assert ([type(item) for item in validator._validators] ==
-        [SchemaValidator, RSASignatureValidator])
+    validator = CompositeValidator(validators_for_app["A"])
+    assert [type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]
 
-    validator.extend(validators_for_app['B'])
-    assert ([type(item) for item in validator._validators] ==
-        [SchemaValidator, RSASignatureValidator])
+    validator.extend(validators_for_app["B"])
+    assert [type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]
     assert len(validator._validators[0]._schemas) == 2
 
-    local_public_key = validators_for_app['A'][0].local_public_key
-    record = DHTRecord(key=DHTID.generate(source='field_b').to_bytes(),
-                       subkey=DHTProtocol.serializer.dumps(local_public_key),
-                       value=DHTProtocol.serializer.dumps(777),
-                       expiration_time=hivemind.get_dht_time() + 10)
+    local_public_key = validators_for_app["A"][0].local_public_key
+    record = DHTRecord(
+        key=DHTID.generate(source="field_b").to_bytes(),
+        subkey=DHTProtocol.serializer.dumps(local_public_key),
+        value=DHTProtocol.serializer.dumps(777),
+        expiration_time=hivemind.get_dht_time() + 10,
+    )
 
     signed_record = dataclasses.replace(record, value=validator.sign_value(record))
     # Expect only one signature since two RSASignatureValidatos have been merged
-    assert signed_record.value.count(b'[signature:') == 1
+    assert signed_record.value.count(b"[signature:") == 1
     # Expect successful validation since the second SchemaValidator has been merged to the first
     assert validator.validate(signed_record)
     assert validator.strip_value(signed_record) == record.value
 
-    record = DHTRecord(key=DHTID.generate(source='unknown_key').to_bytes(),
-                       subkey=DHTProtocol.IS_REGULAR_VALUE,
-                       value=DHTProtocol.serializer.dumps(777),
-                       expiration_time=hivemind.get_dht_time() + 10)
+    record = DHTRecord(
+        key=DHTID.generate(source="unknown_key").to_bytes(),
+        subkey=DHTProtocol.IS_REGULAR_VALUE,
+        value=DHTProtocol.serializer.dumps(777),
+        expiration_time=hivemind.get_dht_time() + 10,
+    )
 
     signed_record = dataclasses.replace(record, value=validator.sign_value(record))
-    assert signed_record.value.count(b'[signature:') == 0
+    assert signed_record.value.count(b"[signature:") == 0
     # Expect failed validation since `unknown_key` is not a part of any schema
     assert not validator.validate(signed_record)
 
@@ -66,27 +68,26 @@ def test_composite_validator(validators_for_app):
 @pytest.mark.forked
 def test_dht_add_validators(validators_for_app):
     # One app may create a DHT with its validators
-    dht = hivemind.DHT(start=False, record_validators=validators_for_app['A'])
+    dht = hivemind.DHT(start=False, record_validators=validators_for_app["A"])
 
     # While the DHT process is not started, you can't send a command to append new validators
     with pytest.raises(RuntimeError):
-        dht.add_validators(validators_for_app['B'])
+        dht.add_validators(validators_for_app["B"])
     dht.run_in_background(await_ready=True)
 
     # After starting the process, other apps may add new validators to the existing DHT
-    dht.add_validators(validators_for_app['B'])
+    dht.add_validators(validators_for_app["B"])
 
-    assert dht.store('field_a', b'bytes_value', hivemind.get_dht_time() + 10)
-    assert dht.get('field_a', latest=True).value == b'bytes_value'
+    assert dht.store("field_a", b"bytes_value", hivemind.get_dht_time() + 10)
+    assert dht.get("field_a", latest=True).value == b"bytes_value"
 
-    assert not dht.store('field_a', 666, hivemind.get_dht_time() + 10)
-    assert dht.get('field_a', latest=True).value == b'bytes_value'
+    assert not dht.store("field_a", 666, hivemind.get_dht_time() + 10)
+    assert dht.get("field_a", latest=True).value == b"bytes_value"
 
-    local_public_key = validators_for_app['A'][0].local_public_key
-    assert dht.store('field_b', 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
-    dictionary = dht.get('field_b', latest=True).value
-    assert (len(dictionary) == 1 and
-            dictionary[local_public_key].value == 777)
+    local_public_key = validators_for_app["A"][0].local_public_key
+    assert dht.store("field_b", 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
+    dictionary = dht.get("field_b", latest=True).value
+    assert len(dictionary) == 1 and dictionary[local_public_key].value == 777
 
-    assert not dht.store('unknown_key', 666, hivemind.get_dht_time() + 10)
-    assert dht.get('unknown_key', latest=True) is None
+    assert not dht.store("unknown_key", 666, hivemind.get_dht_time() + 10)
+    assert dht.get("unknown_key", latest=True) is None

+ 18 - 13
tests/test_expert_backend.py

@@ -12,7 +12,7 @@ from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warm
 EXPERT_WEIGHT_UPDATES = 3
 BACKWARD_PASSES_BEFORE_SAVE = 2
 BACKWARD_PASSES_AFTER_SAVE = 2
-EXPERT_NAME = 'test_expert'
+EXPERT_NAME = "test_expert"
 PEAK_LR = 1.0
 
 
@@ -22,12 +22,17 @@ def example_experts():
     opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
 
     args_schema = (BatchTensorDescriptor(1),)
-    expert_backend = ExpertBackend(name=EXPERT_NAME, expert=expert, optimizer=opt,
-                                   scheduler=get_linear_schedule_with_warmup,
-                                   num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
-                                   num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
-                                   args_schema=args_schema, outputs_schema=BatchTensorDescriptor(1), max_batch_size=1,
-                                   )
+    expert_backend = ExpertBackend(
+        name=EXPERT_NAME,
+        expert=expert,
+        optimizer=opt,
+        scheduler=get_linear_schedule_with_warmup,
+        num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
+        num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+        args_schema=args_schema,
+        outputs_schema=BatchTensorDescriptor(1),
+        max_batch_size=1,
+    )
     experts = {EXPERT_NAME: expert_backend}
     yield experts
 
@@ -88,19 +93,19 @@ def test_lr_schedule(example_experts):
     with TemporaryDirectory() as tmpdir:
         tmp_path = Path(tmpdir)
 
-        assert optimizer.param_groups[0]['lr'] == 0.0
+        assert optimizer.param_groups[0]["lr"] == 0.0
 
         for i in range(BACKWARD_PASSES_BEFORE_SAVE):
-            assert optimizer.param_groups[0]['lr'] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
+            assert optimizer.param_groups[0]["lr"] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
             expert_backend.backward(batch, loss_grad)
 
-        assert optimizer.param_groups[0]['lr'] == PEAK_LR
+        assert optimizer.param_groups[0]["lr"] == PEAK_LR
         store_experts(example_experts, tmp_path)
 
         for i in range(BACKWARD_PASSES_AFTER_SAVE):
-            assert optimizer.param_groups[0]['lr'] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
+            assert optimizer.param_groups[0]["lr"] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
             expert_backend.backward(batch, loss_grad)
 
-        assert optimizer.param_groups[0]['lr'] == 0.0
+        assert optimizer.param_groups[0]["lr"] == 0.0
         load_experts(example_experts, tmp_path)
-        assert optimizer.param_groups[0]['lr'] == PEAK_LR
+        assert optimizer.param_groups[0]["lr"] == PEAK_LR

+ 115 - 62
tests/test_moe.py

@@ -11,14 +11,17 @@ from hivemind.moe.server import layers
 
 @pytest.mark.forked
 def test_moe():
-    all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
-                       for _ in range(10)]
-    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1,
-                           hidden_dim=16) as (server_endpoint, dht_maddrs):
+    all_expert_uids = [
+        f"ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
+    ]
+    with background_server(
+        expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
+    ) as (server_endpoint, dht_maddrs):
         dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
 
         dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix='ffn.')
+            in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn."
+        )
 
         for i in range(3):
             out = dmoe(torch.randn(10, 16))
@@ -27,15 +30,23 @@ def test_moe():
 
 @pytest.mark.forked
 def test_no_experts():
-    all_expert_uids = [f'expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
-                       for _ in range(10)]
-    with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='nop_delay', num_handlers=1,
-                           hidden_dim=16) as (server_endpoint, dht_maddrs):
+    all_expert_uids = [
+        f"expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
+    ]
+    with background_server(
+        expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
+    ) as (server_endpoint, dht_maddrs):
         dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
 
         dmoe = hivemind.RemoteSwitchMixtureOfExperts(
-            in_features=16, grid_size=(4, 4, 4), dht=dht, uid_prefix='expert.', forward_timeout=0.1,
-            backward_timeout=0.1, allow_zero_outputs=True)
+            in_features=16,
+            grid_size=(4, 4, 4),
+            dht=dht,
+            uid_prefix="expert.",
+            forward_timeout=0.1,
+            backward_timeout=0.1,
+            allow_zero_outputs=True,
+        )
 
         for i in range(3):
             out, balancing_loss = dmoe(torch.randn(10, 16))
@@ -53,24 +64,40 @@ def test_call_many(hidden_dim=16):
     allow_zero_outputs = False
     atol = 1e-5
 
-    with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
-                           optim_cls=None, no_dht=True) as (server_endpoint, _):
+    with background_server(
+        num_experts=5,
+        device="cpu",
+        expert_cls="ffn",
+        num_handlers=1,
+        hidden_dim=hidden_dim,
+        optim_cls=None,
+        no_dht=True,
+    ) as (server_endpoint, _):
         inputs = torch.randn(4, hidden_dim, requires_grad=True)
         inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]
-        e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
+        e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
+        e5 = hivemind.RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
 
         mask, expert_outputs = hivemind.moe.client.moe._RemoteCallMany.apply(
-            DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min, backward_k_min, timeout_after_k_min,
-            forward_timeout, backward_timeout, detect_anomalies, allow_zero_outputs, e1.info, inputs
+            DUMMY,
+            [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
+            k_min,
+            backward_k_min,
+            timeout_after_k_min,
+            forward_timeout,
+            backward_timeout,
+            detect_anomalies,
+            allow_zero_outputs,
+            e1.info,
+            inputs,
         )
         assert mask.shape == (4, 3)
         assert expert_outputs.shape == (4, 3, hidden_dim)
 
-        assert np.all(mask.data.numpy() == np.array([[True, True, True],
-                                                     [True, True, False],
-                                                     [True, False, True],
-                                                     [False, False, False]])), f"Incorrect mask, {mask}"
+        assert np.all(
+            mask.data.numpy()
+            == np.array([[True, True, True], [True, True, False], [True, False, True], [False, False, False]])
+        ), f"Incorrect mask, {mask}"
 
         reference_outputs = torch.zeros_like(expert_outputs)
         reference_outputs[0, 0] = e0(inputs_clone[0:1])
@@ -95,10 +122,17 @@ def test_call_many(hidden_dim=16):
 
 @pytest.mark.forked
 def test_remote_module_call(hidden_dim=16):
-    with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
-                           optim_cls=None, no_dht=True) as (server_endpoint, _):
-        real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
-        fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)
+    with background_server(
+        num_experts=1,
+        device="cpu",
+        expert_cls="ffn",
+        num_handlers=1,
+        hidden_dim=hidden_dim,
+        optim_cls=None,
+        no_dht=True,
+    ) as (server_endpoint, _):
+        real_expert = hivemind.RemoteExpert("expert.0", server_endpoint)
+        fake_expert = hivemind.RemoteExpert("oiasfjiasjf", server_endpoint)
 
         out1 = real_expert(torch.randn(1, hidden_dim))
         assert out1.shape == (1, hidden_dim)
@@ -118,27 +152,32 @@ def test_remote_module_call(hidden_dim=16):
 
 @pytest.mark.forked
 def test_beam_search_correctness():
-    all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
+    all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
     dht = hivemind.DHT(start=True)
-    assert all(declare_experts(dht, all_expert_uids, endpoint='fake-endpoint'))
+    assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
 
     dmoe = hivemind.RemoteMixtureOfExperts(
-        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn.')
+        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn."
+    )
 
     for i in range(25):
         input = torch.randn(32)
         grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
 
-        chosen_experts = dmoe.beam_search.find_best_experts([tensor.detach().numpy() for tensor in grid_scores],
-                                                            beam_size=dmoe.k_best)
-        chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
-                                                   [chosen_experts])[0]
+        chosen_experts = dmoe.beam_search.find_best_experts(
+            [tensor.detach().numpy() for tensor in grid_scores], beam_size=dmoe.k_best
+        )
+        chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores], [chosen_experts])[
+            0
+        ]
         our_best_scores = list(chosen_scores.cpu().detach().numpy())
 
         # reference: independently find :beam_size: best experts with exhaustive search
-        all_scores = dmoe.compute_expert_scores([dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-                                                [[hivemind.RemoteExpert(uid, '') for uid in all_expert_uids]])[0]
-        true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[:len(chosen_experts)]
+        all_scores = dmoe.compute_expert_scores(
+            [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
+            [[hivemind.RemoteExpert(uid, "") for uid in all_expert_uids]],
+        )[0]
+        true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
         assert np.allclose(true_best_scores, our_best_scores)
 
@@ -150,15 +189,22 @@ def test_determinism(hidden_dim=16):
     xx = torch.randn(32, hidden_dim, requires_grad=True)
     mask = torch.randint(0, 1, (32, hidden_dim))
 
-    with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1, hidden_dim=hidden_dim,
-                           optim_cls=None, no_dht=True) as (server_endpoint, _):
-        expert = hivemind.RemoteExpert(uid=f'expert.0', endpoint=server_endpoint)
+    with background_server(
+        num_experts=1,
+        device="cpu",
+        expert_cls="det_dropout",
+        num_handlers=1,
+        hidden_dim=hidden_dim,
+        optim_cls=None,
+        no_dht=True,
+    ) as (server_endpoint, _):
+        expert = hivemind.RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
 
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
 
-        grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
-        grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
+        (grad,) = torch.autograd.grad(out.sum(), xx, retain_graph=True)
+        (grad_rerun,) = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
 
     assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
     assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
@@ -169,14 +215,18 @@ def test_compute_expert_scores():
     try:
         dht = hivemind.DHT(start=True)
         moe = hivemind.moe.RemoteMixtureOfExperts(
-            dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
-            uid_prefix='expert.')
+            dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
+        )
         gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
         ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
-            [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}', endpoint="[::]:1337")
-             for expert_i in range(len(ii[batch_i]))]
+            [
+                hivemind.RemoteExpert(
+                    uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337"
+                )
+                for expert_i in range(len(ii[batch_i]))
+            ]
             for batch_i in range(len(ii))
         ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
         logits = moe.compute_expert_scores([gx, gy], batch_experts)
@@ -185,9 +235,9 @@ def test_compute_expert_scores():
 
         for batch_i in range(len(ii)):
             for expert_i in range(len(ii[batch_i])):
-                assert torch.allclose(logits[batch_i, expert_i],
-                                      gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]), \
-                    "compute_expert_scores returned incorrect score"
+                assert torch.allclose(
+                    logits[batch_i, expert_i], gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]
+                ), "compute_expert_scores returned incorrect score"
     finally:
         dht.shutdown()
 
@@ -198,15 +248,17 @@ def test_client_anomaly_detection():
 
     experts = {}
     for i in range(4):
-        expert = layers.name_to_block['ffn'](HID_DIM)
-        experts[f'expert.{i}'] = hivemind.ExpertBackend(name=f'expert.{i}',
-                                                        expert=expert, optimizer=torch.optim.Adam(expert.parameters()),
-                                                        args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
-                                                        outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
-                                                        max_batch_size=16,
-                                                        )
+        expert = layers.name_to_block["ffn"](HID_DIM)
+        experts[f"expert.{i}"] = hivemind.ExpertBackend(
+            name=f"expert.{i}",
+            expert=expert,
+            optimizer=torch.optim.Adam(expert.parameters()),
+            args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
+            outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
+            max_batch_size=16,
+        )
 
-    experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan')
+    experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
 
     dht = hivemind.DHT(start=True)
     server = hivemind.moe.Server(dht, experts, num_connection_handlers=1)
@@ -214,11 +266,12 @@ def test_client_anomaly_detection():
     try:
         server.ready.wait()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix='expert.',
-                                               detect_anomalies=True)
+        dmoe = hivemind.RemoteMixtureOfExperts(
+            in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
+        )
 
         input = torch.randn(1, 16)
-        input[0, 0] = float('nan')
+        input[0, 0] = float("nan")
 
         with pytest.raises(ValueError):
             dmoe(input)
@@ -226,15 +279,15 @@ def test_client_anomaly_detection():
         input[0, 0] = 0
         output = dmoe(input)
 
-        inf_loss = float('inf') * output.sum()
+        inf_loss = float("inf") * output.sum()
         with pytest.raises(ValueError):
             inf_loss.backward()
 
-        dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix='expert.',
-                                               detect_anomalies=True)
+        dmoe = hivemind.RemoteMixtureOfExperts(
+            in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
+        )
         output = dmoe(input)
         assert output.isfinite().all()
 
-
     finally:
         server.shutdown()

+ 34 - 42
tests/test_p2p_daemon.py

@@ -42,11 +42,12 @@ async def test_daemon_killed_on_del():
 
 
 @pytest.mark.parametrize(
-    'host_maddrs', [
-        [Multiaddr('/ip4/127.0.0.1/tcp/0')],
-        [Multiaddr('/ip4/127.0.0.1/udp/0/quic')],
-        [Multiaddr('/ip4/127.0.0.1/tcp/0'), Multiaddr('/ip4/127.0.0.1/udp/0/quic')],
-    ]
+    "host_maddrs",
+    [
+        [Multiaddr("/ip4/127.0.0.1/tcp/0")],
+        [Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
+        [Multiaddr("/ip4/127.0.0.1/tcp/0"), Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
+    ],
 )
 @pytest.mark.asyncio
 async def test_transports(host_maddrs: List[Multiaddr]):
@@ -118,16 +119,17 @@ def handle_add_torch_with_exc(args):
     try:
         return handle_add_torch(args)
     except Exception:
-        return b'something went wrong :('
+        return b"something went wrong :("
 
 
 @pytest.mark.parametrize(
-    'should_cancel,replicate', [
+    "should_cancel,replicate",
+    [
         (True, False),
         (True, True),
         (False, False),
         (False, True),
-    ]
+    ],
 )
 @pytest.mark.asyncio
 async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
@@ -141,13 +143,10 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
         except asyncio.CancelledError:
             nonlocal handler_cancelled
             handler_cancelled = True
-        return dht_pb2.PingResponse(
-            peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
-            available=True)
+        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
 
     server_pid = server_primary._child.pid
-    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
-                                   dht_pb2.PingResponse)
+    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
     assert is_process_running(server_pid)
 
     nodes = await bootstrap_from([server])
@@ -157,12 +156,8 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
 
-    ping_request = dht_pb2.PingRequest(
-        peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()),
-        validate=True)
-    expected_response = dht_pb2.PingResponse(
-        peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
-        available=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
+    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
 
     if should_cancel:
         stream_info, reader, writer = await client._client.stream_open(server.id, (handle_name,))
@@ -187,7 +182,7 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
 @pytest.mark.asyncio
 async def test_call_unary_handler_error(handle_name="handle"):
     async def error_handler(request, context):
-        raise ValueError('boom')
+        raise ValueError("boom")
 
     server = await P2P.create()
     server_pid = server._child.pid
@@ -200,13 +195,11 @@ async def test_call_unary_handler_error(handle_name="handle"):
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
 
-    ping_request = dht_pb2.PingRequest(
-        peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()),
-        validate=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
 
     with pytest.raises(P2PHandlerError) as excinfo:
         await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
-    assert 'boom' in str(excinfo.value)
+    assert "boom" in str(excinfo.value)
 
     await server.shutdown()
     await client.shutdown()
@@ -218,8 +211,8 @@ async def test_call_unary_handler_error(handle_name="handle"):
         pytest.param(10, 100, handle_square, id="square_integer"),
         pytest.param((1, 2), 3, handle_add, id="add_integers"),
         pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
-        pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda")
-    ]
+        pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda"),
+    ],
 )
 @pytest.mark.asyncio
 async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
@@ -303,10 +296,8 @@ async def test_call_peer_different_processes():
     "test_input,expected",
     [
         pytest.param(torch.tensor([2]), torch.tensor(4)),
-        pytest.param(
-            torch.tensor([[1.0, 2.0], [0.5, 0.1]]),
-            torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
-    ]
+        pytest.param(torch.tensor([[1.0, 2.0], [0.5, 0.1]]), torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
+    ],
 )
 @pytest.mark.asyncio
 async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
@@ -336,8 +327,9 @@ async def test_call_peer_torch_square(test_input, expected, handler_name="handle
         pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
         pytest.param(
             [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
-            torch.tensor([[1.2, 1.4], [1.6, 1.8]])),
-    ]
+            torch.tensor([[1.2, 1.4], [1.6, 1.8]]),
+        ),
+    ],
 )
 @pytest.mark.asyncio
 async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
@@ -367,7 +359,7 @@ async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
     [
         pytest.param(False, id="primary"),
         pytest.param(True, id="replica"),
-    ]
+    ],
 )
 @pytest.mark.asyncio
 async def test_call_peer_error(replicate, handler_name="handle"):
@@ -384,7 +376,7 @@ async def test_call_peer_error(replicate, handler_name="handle"):
     inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
     inp_msgp = MSGPackSerializer.dumps(inp)
     result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
-    assert result == b'something went wrong :('
+    assert result == b"something went wrong :("
 
     await server_primary.shutdown()
     await server.shutdown()
@@ -399,25 +391,25 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
 
     server_primary = await P2P.create()
     server_id = server_primary.id
-    await server_primary.add_stream_handler(handler_name, partial(handler, key=b'primary'))
+    await server_primary.add_stream_handler(handler_name, partial(handler, key=b"primary"))
 
     server_replica1 = await replicate_if_needed(server_primary, True)
-    await server_replica1.add_stream_handler(handler_name + '1', partial(handler, key=b'replica1'))
+    await server_replica1.add_stream_handler(handler_name + "1", partial(handler, key=b"replica1"))
 
     server_replica2 = await replicate_if_needed(server_primary, True)
-    await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2'))
+    await server_replica2.add_stream_handler(handler_name + "2", partial(handler, key=b"replica2"))
 
     nodes = await bootstrap_from([server_primary])
     client = await P2P.create(initial_peers=nodes)
     await client.wait_for_at_least_n_peers(1)
 
-    result = await client.call_peer_handler(server_id, handler_name, b'1')
+    result = await client.call_peer_handler(server_id, handler_name, b"1")
     assert result == b"primary"
 
-    result = await client.call_peer_handler(server_id, handler_name + '1', b'2')
+    result = await client.call_peer_handler(server_id, handler_name + "1", b"2")
     assert result == b"replica1"
 
-    result = await client.call_peer_handler(server_id, handler_name + '2', b'3')
+    result = await client.call_peer_handler(server_id, handler_name + "2", b"3")
     assert result == b"replica2"
 
     await server_replica1.shutdown()
@@ -425,9 +417,9 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
 
     # Primary does not handle replicas protocols
     with pytest.raises(Exception):
-        await client.call_peer_handler(server_id, handler_name + '1', b'')
+        await client.call_peer_handler(server_id, handler_name + "1", b"")
     with pytest.raises(Exception):
-        await client.call_peer_handler(server_id, handler_name + '2', b'')
+        await client.call_peer_handler(server_id, handler_name + "2", b"")
 
     await server_primary.shutdown()
     await client.shutdown()

+ 64 - 40
tests/test_p2p_daemon_bindings.py

@@ -8,8 +8,14 @@ from multiaddr import Multiaddr, protocols
 
 from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, parse_conn_protocol
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
-from hivemind.p2p.p2p_daemon_bindings.utils import (ControlFailure, raise_if_failed, read_pbmsg_safe,
-                                                    read_unsigned_varint, write_pbmsg, write_unsigned_varint)
+from hivemind.p2p.p2p_daemon_bindings.utils import (
+    ControlFailure,
+    raise_if_failed,
+    read_pbmsg_safe,
+    read_unsigned_varint,
+    write_pbmsg,
+    write_unsigned_varint,
+)
 from hivemind.proto import p2pd_pb2 as p2pd_pb
 from test_utils.p2p_daemon import make_p2pd_pair_ip4, connect_safe
 
@@ -198,9 +204,7 @@ def test_client_ctor_default_control_maddr():
 
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 def test_control_client_ctor_listen_maddr(listen_maddr_str):
-    c = ControlClient(
-        daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str)
-    )
+    c = ControlClient(daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str))
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
 
 
@@ -215,27 +219,39 @@ def test_control_client_ctor_default_listen_maddr():
         p2pd_pb.Response(
             type=p2pd_pb.Response.Type.OK,
             identify=p2pd_pb.IdentifyResponse(
-                id=PeerID.from_base58('QmT7WhTne9zBLfAgAJt9aiZ8jZ5BxJGowRubxsHYmnyzUd').to_bytes(),
-                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51126').to_bytes(),
-                       Multiaddr('/ip4/192.168.10.135/tcp/51126').to_bytes(),
-                       Multiaddr('/ip6/::1/tcp/51127').to_bytes()]
-            )).SerializeToString(),
+                id=PeerID.from_base58("QmT7WhTne9zBLfAgAJt9aiZ8jZ5BxJGowRubxsHYmnyzUd").to_bytes(),
+                addrs=[
+                    Multiaddr("/p2p-circuit").to_bytes(),
+                    Multiaddr("/ip4/127.0.0.1/tcp/51126").to_bytes(),
+                    Multiaddr("/ip4/192.168.10.135/tcp/51126").to_bytes(),
+                    Multiaddr("/ip6/::1/tcp/51127").to_bytes(),
+                ],
+            ),
+        ).SerializeToString(),
         p2pd_pb.Response(
             type=p2pd_pb.Response.Type.OK,
             identify=p2pd_pb.IdentifyResponse(
-                id=PeerID.from_base58('QmcQFt2MFfCZ9AxzUCNrk4k7TtMdZZvAAteaA6tHpBKdrk').to_bytes(),
-                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51493').to_bytes(),
-                       Multiaddr('/ip4/192.168.10.135/tcp/51493').to_bytes(),
-                       Multiaddr('/ip6/::1/tcp/51494').to_bytes()]
-            )).SerializeToString(),
+                id=PeerID.from_base58("QmcQFt2MFfCZ9AxzUCNrk4k7TtMdZZvAAteaA6tHpBKdrk").to_bytes(),
+                addrs=[
+                    Multiaddr("/p2p-circuit").to_bytes(),
+                    Multiaddr("/ip4/127.0.0.1/tcp/51493").to_bytes(),
+                    Multiaddr("/ip4/192.168.10.135/tcp/51493").to_bytes(),
+                    Multiaddr("/ip6/::1/tcp/51494").to_bytes(),
+                ],
+            ),
+        ).SerializeToString(),
         p2pd_pb.Response(
             type=p2pd_pb.Response.Type.OK,
             identify=p2pd_pb.IdentifyResponse(
-                id=PeerID.from_base58('QmbWqVVoz7v9LS9ZUQAhyyfdFJY3iU8ZrUY3XQozoTA5cc').to_bytes(),
-                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51552').to_bytes(),
-                       Multiaddr('/ip4/192.168.10.135/tcp/51552').to_bytes(),
-                       Multiaddr('/ip6/::1/tcp/51553').to_bytes()]
-            )).SerializeToString(),
+                id=PeerID.from_base58("QmbWqVVoz7v9LS9ZUQAhyyfdFJY3iU8ZrUY3XQozoTA5cc").to_bytes(),
+                addrs=[
+                    Multiaddr("/p2p-circuit").to_bytes(),
+                    Multiaddr("/ip4/127.0.0.1/tcp/51552").to_bytes(),
+                    Multiaddr("/ip4/192.168.10.135/tcp/51552").to_bytes(),
+                    Multiaddr("/ip6/::1/tcp/51553").to_bytes(),
+                ],
+            ),
+        ).SerializeToString(),
     ),
     # give test cases ids to prevent bytes from ruining the terminal
     ids=("pb example Response 0", "pb example Response 1", "pb example Response 2"),
@@ -262,37 +278,47 @@ async def test_read_pbmsg_safe_valid(msg_bytes):
                 dht=p2pd_pb.DHTResponse(
                     type=p2pd_pb.DHTResponse.Type.VALUE,
                     peer=p2pd_pb.PeerInfo(
-                        id=PeerID.from_base58('QmNaXUy78W9moQ9APCoKaTtPjLcEJPN9hRBCqErY7o2fQs').to_bytes(),
-                        addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56929').to_bytes(),
-                               Multiaddr('/ip4/192.168.10.135/tcp/56929').to_bytes(),
-                               Multiaddr('/ip6/::1/tcp/56930').to_bytes()]
-                    )
-                )
+                        id=PeerID.from_base58("QmNaXUy78W9moQ9APCoKaTtPjLcEJPN9hRBCqErY7o2fQs").to_bytes(),
+                        addrs=[
+                            Multiaddr("/p2p-circuit").to_bytes(),
+                            Multiaddr("/ip4/127.0.0.1/tcp/56929").to_bytes(),
+                            Multiaddr("/ip4/192.168.10.135/tcp/56929").to_bytes(),
+                            Multiaddr("/ip6/::1/tcp/56930").to_bytes(),
+                        ],
+                    ),
+                ),
             ),
         ),
         (p2pd_pb.Request, p2pd_pb.Request(type=p2pd_pb.Request.Type.LIST_PEERS)),
         (
             p2pd_pb.DHTRequest,
-            p2pd_pb.DHTRequest(type=p2pd_pb.DHTRequest.Type.FIND_PEER,
-                               peer=PeerID.from_base58('QmcgHMuEhqdLHDVeNjiCGU7Ds6E7xK3f4amgiwHNPKKn7R').to_bytes()),
+            p2pd_pb.DHTRequest(
+                type=p2pd_pb.DHTRequest.Type.FIND_PEER,
+                peer=PeerID.from_base58("QmcgHMuEhqdLHDVeNjiCGU7Ds6E7xK3f4amgiwHNPKKn7R").to_bytes(),
+            ),
         ),
         (
             p2pd_pb.DHTResponse,
             p2pd_pb.DHTResponse(
                 type=p2pd_pb.DHTResponse.Type.VALUE,
                 peer=p2pd_pb.PeerInfo(
-                    id=PeerID.from_base58('QmWP32GhEyXVQsLXFvV81eadDC8zQRZxZvJK359rXxLquk').to_bytes(),
-                    addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56897').to_bytes(),
-                           Multiaddr('/ip4/192.168.10.135/tcp/56897').to_bytes(),
-                           Multiaddr('/ip6/::1/tcp/56898').to_bytes()]
-                )
+                    id=PeerID.from_base58("QmWP32GhEyXVQsLXFvV81eadDC8zQRZxZvJK359rXxLquk").to_bytes(),
+                    addrs=[
+                        Multiaddr("/p2p-circuit").to_bytes(),
+                        Multiaddr("/ip4/127.0.0.1/tcp/56897").to_bytes(),
+                        Multiaddr("/ip4/192.168.10.135/tcp/56897").to_bytes(),
+                        Multiaddr("/ip6/::1/tcp/56898").to_bytes(),
+                    ],
+                ),
             ),
         ),
         (
             p2pd_pb.StreamInfo,
-            p2pd_pb.StreamInfo(peer=PeerID.from_base58('QmewLxB46MftfxQiunRgJo2W8nW4Lh5NLEkRohkHhJ4wW6').to_bytes(),
-                               addr=Multiaddr('/ip4/127.0.0.1/tcp/57029').to_bytes(),
-                               proto=b'protocol123'),
+            p2pd_pb.StreamInfo(
+                peer=PeerID.from_base58("QmewLxB46MftfxQiunRgJo2W8nW4Lh5NLEkRohkHhJ4wW6").to_bytes(),
+                addr=Multiaddr("/ip4/127.0.0.1/tcp/57029").to_bytes(),
+                proto=b"protocol123",
+            ),
         ),
     ),
     ids=(
@@ -305,7 +331,7 @@ async def test_read_pbmsg_safe_valid(msg_bytes):
 )
 @pytest.mark.asyncio
 async def test_write_pbmsg(pb_type, pb_msg):
-    msg_bytes = bytes(chr(pb_msg.ByteSize()), 'utf-8') + pb_msg.SerializeToString()
+    msg_bytes = bytes(chr(pb_msg.ByteSize()), "utf-8") + pb_msg.SerializeToString()
     pb_obj = pb_type()
 
     s_read = MockReaderWriter(msg_bytes)
@@ -441,9 +467,7 @@ async def test_client_stream_open_success(p2pcs):
     writer.close()
 
     # test case: open with multiple protocols
-    stream_info, reader, writer = await p2pcs[0].stream_open(
-        peer_id_1, (proto, "another_protocol")
-    )
+    stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto, "another_protocol"))
     assert stream_info.peer_id == peer_id_1
     assert stream_info.addr in maddrs_1
     assert stream_info.proto == "123"

+ 21 - 26
tests/test_routing.py

@@ -22,10 +22,7 @@ def test_ids_depth():
         ids = [random.randint(0, 4096) for i in range(random.randint(1, 256))]
         ours = DHTID.longest_common_prefix_length(*map(DHTID, ids))
 
-        ids_bitstr = [
-            "".join(bin(bite)[2:].rjust(8, '0') for bite in uid.to_bytes(20, 'big'))
-            for uid in ids
-        ]
+        ids_bitstr = ["".join(bin(bite)[2:].rjust(8, "0") for bite in uid.to_bytes(20, "big")) for uid in ids]
         reference = len(shared_prefix(*ids_bitstr))
         assert reference == ours, f"ours {ours} != reference {reference}, ids: {ids}"
 
@@ -37,11 +34,11 @@ def test_routing_table_basic():
 
     for phony_neighbor_port in random.sample(range(10000), 100):
         phony_id = DHTID.generate()
-        routing_table.add_or_update_node(phony_id, f'{LOCALHOST}:{phony_neighbor_port}')
+        routing_table.add_or_update_node(phony_id, f"{LOCALHOST}:{phony_neighbor_port}")
         assert phony_id in routing_table
-        assert f'{LOCALHOST}:{phony_neighbor_port}' in routing_table
-        assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
-        assert routing_table[f'{LOCALHOST}:{phony_neighbor_port}'] == phony_id
+        assert f"{LOCALHOST}:{phony_neighbor_port}" in routing_table
+        assert routing_table[phony_id] == f"{LOCALHOST}:{phony_neighbor_port}"
+        assert routing_table[f"{LOCALHOST}:{phony_neighbor_port}"] == phony_id
         added_nodes.append(phony_id)
 
     assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
@@ -66,40 +63,37 @@ def test_routing_table_basic():
 
 def test_routing_table_parameters():
     for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
-        (20,          5,      45,           65),
-        (50,          5,      35,           45),
-        (20,          10,     650,          800),
-        (20,          1,      7,            15),
+        (20, 5, 45, 65),
+        (50, 5, 35, 45),
+        (20, 10, 650, 800),
+        (20, 1, 7, 15),
     ]:
         node_id = DHTID.generate()
         routing_table = RoutingTable(node_id, bucket_size=bucket_size, depth_modulo=modulo)
         for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
-            routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
+            routing_table.add_or_update_node(DHTID.generate(), f"{LOCALHOST}:{phony_neighbor_port}")
         for bucket in routing_table.buckets:
             assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_peer_id) <= bucket.size
-        assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
-            f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}")
+        assert (
+            min_nbuckets <= len(routing_table.buckets) <= max_nbuckets
+        ), f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}"
 
 
 def test_routing_table_search():
-    for table_size, lower_active, upper_active in [
-        (10, 10, 10), (10_000, 800, 1100)
-    ]:
+    for table_size, lower_active, upper_active in [(10, 10, 10), (10_000, 800, 1100)]:
         node_id = DHTID.generate()
         routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
         num_added = 0
         total_nodes = 0
 
         for phony_neighbor_port in random.sample(range(1_000_000), table_size):
-            routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
+            routing_table.add_or_update_node(DHTID.generate(), f"{LOCALHOST}:{phony_neighbor_port}")
             new_total = sum(len(bucket.nodes_to_peer_id) for bucket in routing_table.buckets)
             num_added += new_total > total_nodes
             total_nodes = new_total
         num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
 
-        all_active_neighbors = list(chain(
-            *(bucket.nodes_to_peer_id.keys() for bucket in routing_table.buckets)
-        ))
+        all_active_neighbors = list(chain(*(bucket.nodes_to_peer_id.keys() for bucket in routing_table.buckets)))
         assert lower_active <= len(all_active_neighbors) <= upper_active
         assert len(all_active_neighbors) == num_added
         assert num_added + num_replacements == table_size
@@ -112,8 +106,7 @@ def test_routing_table_search():
             our_knn, our_peer_ids = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
             reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
             assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
-            assert all(our_peer_id == routing_table[our_node]
-                       for our_node, our_peer_id in zip(our_knn, our_peer_ids))
+            assert all(our_peer_id == routing_table[our_node] for our_node, our_peer_id in zip(our_knn, our_peer_ids))
 
         # queries from table
         for i in range(1000):
@@ -125,8 +118,10 @@ def test_routing_table_search():
             if query_id in reference_knn:
                 reference_knn.remove(query_id)
             assert len(our_knn) == len(reference_knn)
-            assert all(query_id.xor_distance(our) == query_id.xor_distance(ref)
-                       for our, ref in zip_longest(our_knn, reference_knn))
+            assert all(
+                query_id.xor_distance(our) == query_id.xor_distance(ref)
+                for our, ref in zip_longest(our_knn, reference_knn)
+            )
             assert routing_table.get_nearest_neighbors(query_id, k=k, exclude=None)[0][0] == query_id
 
 

+ 61 - 25
tests/test_training.py

@@ -16,13 +16,15 @@ from hivemind.optim import DecentralizedSGD, DecentralizedAdam
 @pytest.mark.forked
 def test_training(max_steps: int = 100, threshold: float = 0.9):
     dataset = load_digits(n_class=2)
-    X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
+    X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     SGD = partial(torch.optim.SGD, lr=0.05)
 
-    with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1,
-                           no_dht=True) as (server_endpoint, _):
-        expert1 = RemoteExpert('expert.0', server_endpoint)
-        expert2 = RemoteExpert('expert.1', server_endpoint)
+    with background_server(num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1, no_dht=True) as (
+        server_endpoint,
+        _,
+    ):
+        expert1 = RemoteExpert("expert.0", server_endpoint)
+        expert2 = RemoteExpert("expert.1", server_endpoint)
         model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
 
         opt = SGD(model.parameters(), lr=0.05)
@@ -44,15 +46,16 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
 @pytest.mark.forked
 def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
     dataset = load_digits(n_class=2)
-    X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
+    X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     SGD = partial(torch.optim.SGD, lr=0.05)
 
-    all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
-    with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1) \
-            as (server_endpoint, dht_maddrs):
+    all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
+    with background_server(
+        expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
+    ) as (server_endpoint, dht_maddrs):
         dht = DHT(start=True, initial_peers=dht_maddrs)
 
-        moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix='expert.', k_best=2)
+        moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
         model = nn.Sequential(moe, nn.Linear(64, 2))
 
         opt = SGD(model.parameters(), lr=0.05)
@@ -74,9 +77,15 @@ def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=
 class SwitchNetwork(nn.Module):
     def __init__(self, dht, in_features, num_classes, num_experts):
         super().__init__()
-        self.moe = RemoteSwitchMixtureOfExperts(in_features=in_features, grid_size=(num_experts,), dht=dht,
-                                                jitter_eps=0, uid_prefix='expert.', k_best=1,
-                                                k_min=1)
+        self.moe = RemoteSwitchMixtureOfExperts(
+            in_features=in_features,
+            grid_size=(num_experts,),
+            dht=dht,
+            jitter_eps=0,
+            uid_prefix="expert.",
+            k_best=1,
+            k_min=1,
+        )
         self.linear = nn.Linear(in_features, num_classes)
 
     def forward(self, x):
@@ -87,12 +96,13 @@ class SwitchNetwork(nn.Module):
 @pytest.mark.forked
 def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
     dataset = load_digits(n_class=2)
-    X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
+    X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     SGD = partial(torch.optim.SGD, lr=0.05)
 
-    all_expert_uids = [f'expert.{i}' for i in range(num_experts)]
-    with background_server(expert_uids=all_expert_uids, device='cpu', optim_cls=SGD, hidden_dim=64,
-                           num_handlers=1) as (server_endpoint, dht_maddrs):
+    all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
+    with background_server(
+        expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
+    ) as (server_endpoint, dht_maddrs):
         dht = DHT(start=True, initial_peers=dht_maddrs)
 
         model = SwitchNetwork(dht, 64, 2, num_experts)
@@ -119,12 +129,24 @@ def test_decentralized_optimizer_step():
     initial_peers = dht_root.get_visible_maddrs()
 
     param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedSGD([param1], lr=0.1, dht=DHT(initial_peers=initial_peers, start=True),
-                            prefix='foo', target_group_size=2, verbose=True)
+    opt1 = DecentralizedSGD(
+        [param1],
+        lr=0.1,
+        dht=DHT(initial_peers=initial_peers, start=True),
+        prefix="foo",
+        target_group_size=2,
+        verbose=True,
+    )
 
     param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedSGD([param2], lr=0.05, dht=DHT(initial_peers=initial_peers, start=True),
-                            prefix='foo', target_group_size=2, verbose=True)
+    opt2 = DecentralizedSGD(
+        [param2],
+        lr=0.05,
+        dht=DHT(initial_peers=initial_peers, start=True),
+        prefix="foo",
+        target_group_size=2,
+        verbose=True,
+    )
 
     assert not torch.allclose(param1, param2)
 
@@ -145,12 +167,26 @@ def test_decentralized_optimizer_averaging():
     initial_peers = dht_root.get_visible_maddrs()
 
     param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedAdam([param1], lr=0.1, averaging_steps_period=1, dht=DHT(initial_peers=initial_peers, start=True),
-                            prefix='foo', target_group_size=2, verbose=True)
+    opt1 = DecentralizedAdam(
+        [param1],
+        lr=0.1,
+        averaging_steps_period=1,
+        dht=DHT(initial_peers=initial_peers, start=True),
+        prefix="foo",
+        target_group_size=2,
+        verbose=True,
+    )
 
     param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedAdam([param2], lr=0.05, averaging_steps_period=1, dht=DHT(initial_peers=initial_peers, start=True),
-                            prefix='foo', target_group_size=2, verbose=True)
+    opt2 = DecentralizedAdam(
+        [param2],
+        lr=0.05,
+        averaging_steps_period=1,
+        dht=DHT(initial_peers=initial_peers, start=True),
+        prefix="foo",
+        target_group_size=2,
+        verbose=True,
+    )
 
     assert not torch.allclose(param1, param2)
 

+ 54 - 47
tests/test_util_modules.py

@@ -12,7 +12,7 @@ import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import MSGPackSerializer
+from hivemind.utils import MSGPackSerializer, ValueWithExpiration, HeapEntry, DHTExpiration
 from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
@@ -41,8 +41,8 @@ def test_mpfuture_result():
     with pytest.raises(concurrent.futures.TimeoutError):
         future.result(timeout=1e-3)
 
-    future.set_result(['abacaba', 123])
-    assert future.result() == ['abacaba', 123]
+    future.set_result(["abacaba", 123])
+    assert future.result() == ["abacaba", 123]
 
 
 @pytest.mark.forked
@@ -135,12 +135,12 @@ async def test_await_mpfuture():
     async def wait_and_assign_async():
         assert f2.set_running_or_notify_cancel() is True
         await asyncio.sleep(0.1)
-        f1.set_result((123, 'ololo'))
-        f2.set_result((456, 'pyshpysh'))
+        f1.set_result((123, "ololo"))
+        f2.set_result((456, "pyshpysh"))
 
     asyncio.create_task(wait_and_assign_async())
 
-    assert (await asyncio.gather(f1, f2)) == [(123, 'ololo'), (456, 'pyshpysh')]
+    assert (await asyncio.gather(f1, f2)) == [(123, "ololo"), (456, "pyshpysh")]
 
     # await result from separate processes
     f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
@@ -149,12 +149,12 @@ async def test_await_mpfuture():
         time.sleep(0.1 * random.random())
         future.set_result(value)
 
-    p1 = mp.Process(target=wait_and_assign, args=(f1, 'abc'))
-    p2 = mp.Process(target=wait_and_assign, args=(f2, 'def'))
+    p1 = mp.Process(target=wait_and_assign, args=(f1, "abc"))
+    p2 = mp.Process(target=wait_and_assign, args=(f2, "def"))
     for p in p1, p2:
         p.start()
 
-    assert (await asyncio.gather(f1, f2)) == ['abc', 'def']
+    assert (await asyncio.gather(f1, f2)) == ["abc", "def"]
     for p in p1, p2:
         p.join()
 
@@ -183,7 +183,7 @@ async def test_await_mpfuture():
         time.sleep(0.01)
         f2.set_result(123456)
         time.sleep(0.1)
-        f1.set_exception(ValueError('we messed up'))
+        f1.set_exception(ValueError("we messed up"))
 
     p = mp.Process(target=wait_and_raise)
     p.start()
@@ -202,9 +202,9 @@ def test_mpfuture_bidirectional():
 
     def _future_creator():
         future_from_fork = hivemind.MPFuture()
-        future_from_main.set_result(('abc', future_from_fork))
+        future_from_main.set_result(("abc", future_from_fork))
 
-        if future_from_fork.result() == ['we', 'need', 'to', 'go', 'deeper']:
+        if future_from_fork.result() == ["we", "need", "to", "go", "deeper"]:
             evt.set()
 
     p = mp.Process(target=_future_creator)
@@ -212,7 +212,7 @@ def test_mpfuture_bidirectional():
 
     out = future_from_main.result()
     assert isinstance(out[1], hivemind.MPFuture)
-    out[1].set_result(['we', 'need', 'to', 'go', 'deeper'])
+    out[1].set_result(["we", "need", "to", "go", "deeper"])
 
     p.join()
     assert evt.is_set()
@@ -240,7 +240,9 @@ def test_mpfuture_done_callback():
         future2.cancel()  # trigger future2 callback from the same process
 
         events[0].wait()
-        future1.add_done_callback(lambda future: events[4].set())  # schedule callback after future1 is already finished
+        future1.add_done_callback(
+            lambda future: events[4].set()
+        )  # schedule callback after future1 is already finished
 
     p = mp.Process(target=_future_creator)
     p.start()
@@ -331,31 +333,38 @@ async def test_channel_cache():
     hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
     hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
 
-    c1 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
-    c2 = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
-    c3 = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
-    c3_again = hivemind.ChannelCache.get_stub('localhost:1338', DHTStub, aio=False)
-    c1_again = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
-    c4 = hivemind.ChannelCache.get_stub('localhost:1339', DHTStub, aio=True)
-    c2_anew = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=True)
-    c1_yetagain = hivemind.ChannelCache.get_stub('localhost:1337', DHTStub, aio=False)
+    c1 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
+    c2 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
+    c3 = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
+    c3_again = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
+    c1_again = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
+    c4 = hivemind.ChannelCache.get_stub("localhost:1339", DHTStub, aio=True)
+    c2_anew = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
+    c1_yetagain = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
 
     await asyncio.sleep(0.2)
-    c1_anew = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
-    c1_anew_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=DHTStub)
-    c1_otherstub = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False, stub_type=ConnectionHandlerStub)
+    c1_anew = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
+    c1_anew_again = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
+    c1_otherstub = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub)
     await asyncio.sleep(0.05)
-    c1_otherstub_again = hivemind.ChannelCache.get_stub(target='localhost:1337', aio=False,
-                                                        stub_type=ConnectionHandlerStub)
+    c1_otherstub_again = hivemind.ChannelCache.get_stub(
+        target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub
+    )
     all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
 
     assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
     assert isinstance(all_channels[-1], ConnectionHandlerStub)
-    assert 'aio' in repr(c2.rpc_find)
-    assert 'aio' not in repr(c1.rpc_find)
-
-    duplicates = {(c1, c1_again), (c1, c1_yetagain), (c1_again, c1_yetagain), (c3, c3_again),
-                  (c1_anew, c1_anew_again), (c1_otherstub, c1_otherstub_again)}
+    assert "aio" in repr(c2.rpc_find)
+    assert "aio" not in repr(c1.rpc_find)
+
+    duplicates = {
+        (c1, c1_again),
+        (c1, c1_yetagain),
+        (c1_again, c1_yetagain),
+        (c3, c3_again),
+        (c1_anew, c1_anew_again),
+        (c1_otherstub, c1_otherstub_again),
+    }
     for i in range(len(all_channels)):
         for j in range(i + 1, len(all_channels)):
             ci, cj = all_channels[i], all_channels[j]
@@ -386,7 +395,7 @@ def test_serialize_tensor():
     restored = hivemind.combine_from_streaming(chunks)
     assert torch.allclose(deserialize_torch_tensor(restored), tensor)
 
-    scalar = torch.tensor(1.)
+    scalar = torch.tensor(1.0)
     serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE)
     assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
 
@@ -397,9 +406,9 @@ def test_serialize_tensor():
 def test_serialize_tuple():
     test_pairs = (
         ((1, 2, 3), [1, 2, 3]),
-        (('1', False, 0), ['1', False, 0]),
-        (('1', False, 0), ('1', 0, 0)),
-        (('1', b'qq', (2, 5, '0')), ['1', b'qq', (2, 5, '0')]),
+        (("1", False, 0), ["1", False, 0]),
+        (("1", False, 0), ("1", 0, 0)),
+        (("1", b"qq", (2, 5, "0")), ["1", b"qq", (2, 5, "0")]),
     )
 
     for first, second in test_pairs:
@@ -443,8 +452,6 @@ def test_split_parts():
 
 
 def test_generic_data_classes():
-    from hivemind.utils import ValueWithExpiration, HeapEntry, DHTExpiration
-
     value_with_exp = ValueWithExpiration(value="string_value", expiration_time=DHTExpiration(10))
     assert value_with_exp.value == "string_value" and value_with_exp.expiration_time == DHTExpiration(10)
 
@@ -458,7 +465,7 @@ def test_generic_data_classes():
 
 @pytest.mark.asyncio
 async def test_asyncio_utils():
-    res = [i async for i, item in aenumerate(aiter('a', 'b', 'c'))]
+    res = [i async for i, item in aenumerate(aiter("a", "b", "c"))]
     assert res == list(range(len(res)))
 
     num_steps = 0
@@ -471,20 +478,20 @@ async def test_asyncio_utils():
     ref = list(map(max, range(7), range(-50, 50, 10)))
     assert ours == ref
 
-    ours = [row async for row in azip(aiter('a', 'b', 'c'), aiter(1, 2, 3))]
-    ref = list(zip(['a', 'b', 'c'], [1, 2, 3]))
+    ours = [row async for row in azip(aiter("a", "b", "c"), aiter(1, 2, 3))]
+    ref = list(zip(["a", "b", "c"], [1, 2, 3]))
     assert ours == ref
 
     async def _aiterate():
-        yield 'foo'
-        yield 'bar'
-        yield 'baz'
+        yield "foo"
+        yield "bar"
+        yield "baz"
 
     iterator = _aiterate()
-    assert (await anext(iterator)) == 'foo'
+    assert (await anext(iterator)) == "foo"
     tail = [item async for item in iterator]
-    assert tail == ['bar', 'baz']
+    assert tail == ["bar", "baz"]
     with pytest.raises(StopAsyncIteration):
         await anext(iterator)
 
-    assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ['foo', 'bar', 'baz'] + list(range(5))
+    assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))

+ 7 - 6
tests/test_utils/custom_networks.py

@@ -7,7 +7,7 @@ from hivemind.moe import register_expert_class
 sample_input = lambda batch_size, hidden_dim: torch.empty((batch_size, hidden_dim))
 
 
-@register_expert_class('perceptron', sample_input)
+@register_expert_class("perceptron", sample_input)
 class MultilayerPerceptron(nn.Module):
     def __init__(self, hidden_dim, num_classes=10):
         super().__init__()
@@ -22,13 +22,14 @@ class MultilayerPerceptron(nn.Module):
         return x
 
 
-multihead_sample_input = lambda batch_size, hidden_dim: \
-    (torch.empty((batch_size, hidden_dim)),
-     torch.empty((batch_size, 2 * hidden_dim)),
-     torch.empty((batch_size, 3 * hidden_dim)),)
+multihead_sample_input = lambda batch_size, hidden_dim: (
+    torch.empty((batch_size, hidden_dim)),
+    torch.empty((batch_size, 2 * hidden_dim)),
+    torch.empty((batch_size, 3 * hidden_dim)),
+)
 
 
-@register_expert_class('multihead', multihead_sample_input)
+@register_expert_class("multihead", multihead_sample_input)
 class MultiheadNetwork(nn.Module):
     def __init__(self, hidden_dim, num_classes=10):
         super().__init__()

+ 7 - 6
tests/test_utils/dht_swarms.py

@@ -30,10 +30,12 @@ def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
     loop.run_forever()
 
 
-def launch_swarm_in_separate_processes(n_peers: int, n_sequential_peers: int) -> \
-        Tuple[List[mp.Process], Dict[PeerID, DHTID], List[List[Multiaddr]]]:
-    assert n_sequential_peers < n_peers, \
-        'Parameters imply that first n_sequential_peers of n_peers will be run sequentially'
+def launch_swarm_in_separate_processes(
+    n_peers: int, n_sequential_peers: int
+) -> Tuple[List[mp.Process], Dict[PeerID, DHTID], List[List[Multiaddr]]]:
+    assert (
+        n_sequential_peers < n_peers
+    ), "Parameters imply that first n_sequential_peers of n_peers will be run sequentially"
 
     processes = []
     dht = {}
@@ -82,6 +84,5 @@ def launch_swarm_in_separate_processes(n_peers: int, n_sequential_peers: int) ->
 async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
     nodes = [await DHTNode.create(**kwargs)]
     initial_peers = await nodes[0].get_visible_maddrs()
-    nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs)
-                                    for _ in range(n_peers - 1)])
+    nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs) for _ in range(n_peers - 1)])
     return nodes

+ 22 - 32
tests/test_utils/p2p_daemon.py

@@ -41,9 +41,7 @@ class Daemon:
     f_log = None
     closed = None
 
-    def __init__(
-            self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub
-    ):
+    def __init__(self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub):
         self.control_maddr = control_maddr
         self.enable_control = enable_control
         self.enable_connmgr = enable_connmgr
@@ -67,9 +65,7 @@ class Daemon:
             cmd_list += ["-dht=true"]
         if self.enable_pubsub:
             cmd_list += ["-pubsub=true", "-pubsubRouter=gossipsub"]
-        self.proc_daemon = subprocess.Popen(
-            cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0
-        )
+        self.proc_daemon = subprocess.Popen(cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0)
 
     async def wait_until_ready(self):
         lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
@@ -108,9 +104,7 @@ class ConnectionFailure(Exception):
 
 
 @asynccontextmanager
-async def make_p2pd_pair_unix(
-        enable_control, enable_connmgr, enable_dht, enable_pubsub
-):
+async def make_p2pd_pair_unix(enable_control, enable_connmgr, enable_dht, enable_pubsub):
     name = str(uuid.uuid4())[:8]
     control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
     listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
@@ -124,12 +118,12 @@ async def make_p2pd_pair_unix(
     except FileNotFoundError:
         pass
     async with _make_p2pd_pair(
-            control_maddr=control_maddr,
-            listen_maddr=listen_maddr,
-            enable_control=enable_control,
-            enable_connmgr=enable_connmgr,
-            enable_dht=enable_dht,
-            enable_pubsub=enable_pubsub,
+        control_maddr=control_maddr,
+        listen_maddr=listen_maddr,
+        enable_control=enable_control,
+        enable_connmgr=enable_connmgr,
+        enable_dht=enable_dht,
+        enable_pubsub=enable_pubsub,
     ) as pair:
         yield pair
 
@@ -139,24 +133,24 @@ async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_
     control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
     listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
     async with _make_p2pd_pair(
-            control_maddr=control_maddr,
-            listen_maddr=listen_maddr,
-            enable_control=enable_control,
-            enable_connmgr=enable_connmgr,
-            enable_dht=enable_dht,
-            enable_pubsub=enable_pubsub,
+        control_maddr=control_maddr,
+        listen_maddr=listen_maddr,
+        enable_control=enable_control,
+        enable_connmgr=enable_connmgr,
+        enable_dht=enable_dht,
+        enable_pubsub=enable_pubsub,
     ) as pair:
         yield pair
 
 
 @asynccontextmanager
 async def _make_p2pd_pair(
-        control_maddr,
-        listen_maddr,
-        enable_control,
-        enable_connmgr,
-        enable_dht,
-        enable_pubsub,
+    control_maddr,
+    listen_maddr,
+    enable_control,
+    enable_connmgr,
+    enable_dht,
+    enable_pubsub,
 ):
     p2pd = Daemon(
         control_maddr=control_maddr,
@@ -187,8 +181,4 @@ async def _check_connection(p2pd_tuple_0, p2pd_tuple_1):
 async def connect_safe(p2pd_tuple_0, p2pd_tuple_1):
     peer_id_1, maddrs_1 = await p2pd_tuple_1.identify()
     await p2pd_tuple_0.connect(peer_id_1, maddrs_1)
-    await try_until_success(
-        functools.partial(
-            _check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1
-        )
-    )
+    await try_until_success(functools.partial(_check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1))