justheuristic 4 лет назад
Родитель
Сommit
0aca1d9f8b

+ 3 - 1
.github/workflows/run-tests.yml

@@ -10,7 +10,9 @@ jobs:
     strategy:
       matrix:
         python-version: [ 3.7, 3.8, 3.9 ]
-    timeout-minutes: 10
+        attempt: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ]
+      fail-fast: false
+    timeout-minutes: 15
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python

+ 0 - 65
tests/test_custom_experts.py

@@ -1,65 +0,0 @@
-import os
-
-import pytest
-import torch
-
-from hivemind import RemoteExpert
-from hivemind.moe.server import background_server
-
-CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
-
-
-@pytest.mark.forked
-def test_custom_expert(hid_dim=16):
-    with background_server(
-        expert_cls="perceptron",
-        num_experts=2,
-        device="cpu",
-        hidden_dim=hid_dim,
-        num_handlers=2,
-        no_dht=True,
-        custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
-
-        for batch_size in (1, 4):
-            batch = torch.randn(batch_size, hid_dim)
-
-            output0 = expert0(batch)
-            output1 = expert1(batch)
-
-            loss = output0.sum()
-            loss.backward()
-            loss = output1.sum()
-            loss.backward()
-
-
-@pytest.mark.forked
-def test_multihead_expert(hid_dim=16):
-    with background_server(
-        expert_cls="multihead",
-        num_experts=2,
-        device="cpu",
-        hidden_dim=hid_dim,
-        num_handlers=2,
-        no_dht=True,
-        custom_module_path=CUSTOM_EXPERTS_PATH,
-    ) as (server_endpoint, _):
-        expert0 = RemoteExpert("expert.0", server_endpoint)
-        expert1 = RemoteExpert("expert.1", server_endpoint)
-
-        for batch_size in (1, 4):
-            batch = (
-                torch.randn(batch_size, hid_dim),
-                torch.randn(batch_size, 2 * hid_dim),
-                torch.randn(batch_size, 3 * hid_dim),
-            )
-
-            output0 = expert0(*batch)
-            output1 = expert1(*batch)
-
-            loss = output0.sum()
-            loss.backward()
-            loss = output1.sum()
-            loss.backward()

+ 0 - 107
tests/test_dht.py

@@ -1,107 +0,0 @@
-import asyncio
-import random
-import time
-
-import pytest
-from multiaddr import Multiaddr
-
-import hivemind
-from test_utils.dht_swarms import launch_dht_instances
-
-
-@pytest.mark.forked
-def test_get_store(n_peers=10):
-    peers = launch_dht_instances(n_peers)
-
-    node1, node2 = random.sample(peers, 2)
-    assert node1.store("key1", "value1", expiration_time=hivemind.get_dht_time() + 30)
-    assert node1.get("key1").value == "value1"
-    assert node2.get("key1").value == "value1"
-    assert node2.get("key2") is None
-
-    future = node1.get("foo", return_future=True)
-    assert future.result() is None
-
-    future = node1.get("foo", return_future=True)
-    future.cancel()
-
-    assert node2.store("key1", 123, expiration_time=hivemind.get_dht_time() + 31)
-    assert node2.store("key2", 456, expiration_time=hivemind.get_dht_time() + 32)
-    assert node1.get("key1", latest=True).value == 123
-    assert node1.get("key2").value == 456
-
-    assert node1.store("key2", subkey="subkey1", value=789, expiration_time=hivemind.get_dht_time() + 32)
-    assert node2.store("key2", subkey="subkey2", value="pew", expiration_time=hivemind.get_dht_time() + 32)
-    found_dict = node1.get("key2", latest=True).value
-    assert isinstance(found_dict, dict) and len(found_dict) == 2
-    assert found_dict["subkey1"].value == 789 and found_dict["subkey2"].value == "pew"
-
-    for peer in peers:
-        peer.shutdown()
-
-
-async def dummy_dht_coro(self, node):
-    return "pew"
-
-
-async def dummy_dht_coro_error(self, node):
-    raise ValueError("Oops, i did it again...")
-
-
-async def dummy_dht_coro_stateful(self, node):
-    self._x_dummy = getattr(self, "_x_dummy", 123) + 1
-    return self._x_dummy
-
-
-async def dummy_dht_coro_long(self, node):
-    await asyncio.sleep(0.25)
-    return self._x_dummy ** 2
-
-
-async def dummy_dht_coro_for_cancel(self, node):
-    self._x_dummy = -100
-    await asyncio.sleep(0.5)
-    self._x_dummy = 999
-
-
-@pytest.mark.forked
-def test_run_coroutine():
-    dht = hivemind.DHT(start=True)
-    assert dht.run_coroutine(dummy_dht_coro) == "pew"
-
-    with pytest.raises(ValueError):
-        res = dht.run_coroutine(dummy_dht_coro_error)
-
-    bg_task = dht.run_coroutine(dummy_dht_coro_long, return_future=True)
-    assert dht.run_coroutine(dummy_dht_coro_stateful) == 124
-    assert dht.run_coroutine(dummy_dht_coro_stateful) == 125
-    assert dht.run_coroutine(dummy_dht_coro_stateful) == 126
-    assert not hasattr(dht, "_x_dummy")
-    assert bg_task.result() == 126 ** 2
-
-    future = dht.run_coroutine(dummy_dht_coro_for_cancel, return_future=True)
-    time.sleep(0.25)
-    future.cancel()
-    assert dht.run_coroutine(dummy_dht_coro_stateful) == -99
-
-    dht.shutdown()
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dht_get_visible_maddrs():
-    # test 1: IPv4 localhost multiaddr is visible by default
-
-    dht = hivemind.DHT(start=True)
-
-    assert any(str(maddr).startswith("/ip4/127.0.0.1") for maddr in dht.get_visible_maddrs())
-    dht.shutdown()
-
-    # test 2: announce_maddrs are the single visible multiaddrs if defined
-
-    dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
-    p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
-    dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
-
-    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.id}")]
-    dht.shutdown()

+ 0 - 136
tests/test_dht_crypto.py

@@ -1,136 +0,0 @@
-import dataclasses
-import pickle
-import multiprocessing as mp
-
-import pytest
-
-import hivemind
-from hivemind.utils.timed_storage import get_dht_time
-from hivemind.dht.crypto import RSASignatureValidator
-from hivemind.dht.node import DHTNode
-from hivemind.dht.validation import DHTRecord
-from hivemind.utils.crypto import RSAPrivateKey
-
-
-def test_rsa_signature_validator():
-    receiver_validator = RSASignatureValidator()
-    sender_validator = RSASignatureValidator(RSAPrivateKey())
-    mallory_validator = RSASignatureValidator(RSAPrivateKey())
-
-    plain_record = DHTRecord(key=b"key", subkey=b"subkey", value=b"value", expiration_time=get_dht_time() + 10)
-    protected_records = [
-        dataclasses.replace(plain_record, key=plain_record.key + sender_validator.local_public_key),
-        dataclasses.replace(plain_record, subkey=plain_record.subkey + sender_validator.local_public_key),
-    ]
-
-    # test 1: Non-protected record (no signature added)
-    assert sender_validator.sign_value(plain_record) == plain_record.value
-    assert receiver_validator.validate(plain_record)
-
-    # test 2: Correct signatures
-    signed_records = [
-        dataclasses.replace(record, value=sender_validator.sign_value(record)) for record in protected_records
-    ]
-    for record in signed_records:
-        assert receiver_validator.validate(record)
-        assert receiver_validator.strip_value(record) == b"value"
-
-    # test 3: Invalid signatures
-    signed_records = protected_records  # Without signature
-    signed_records += [
-        dataclasses.replace(record, value=record.value + b"[signature:INVALID_BYTES]") for record in protected_records
-    ]  # With invalid signature
-    signed_records += [
-        dataclasses.replace(record, value=mallory_validator.sign_value(record)) for record in protected_records
-    ]  # With someone else's signature
-    for record in signed_records:
-        assert not receiver_validator.validate(record)
-
-
-def test_cached_key():
-    first_validator = RSASignatureValidator()
-    second_validator = RSASignatureValidator()
-    assert first_validator.local_public_key == second_validator.local_public_key
-
-    third_validator = RSASignatureValidator(RSAPrivateKey())
-    assert first_validator.local_public_key != third_validator.local_public_key
-
-
-def test_validator_instance_is_picklable():
-    # Needs to be picklable because the validator instance may be sent between processes
-
-    original_validator = RSASignatureValidator()
-    unpickled_validator = pickle.loads(pickle.dumps(original_validator))
-
-    # To check that the private key was pickled and unpickled correctly, we sign a record
-    # with the original public key using the unpickled validator and then validate the signature
-
-    record = DHTRecord(
-        key=b"key",
-        subkey=b"subkey" + original_validator.local_public_key,
-        value=b"value",
-        expiration_time=get_dht_time() + 10,
-    )
-    signed_record = dataclasses.replace(record, value=unpickled_validator.sign_value(record))
-
-    assert b"[signature:" in signed_record.value
-    assert original_validator.validate(signed_record)
-    assert unpickled_validator.validate(signed_record)
-
-
-def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
-    validator = conn.recv()
-    record = conn.recv()
-
-    record = dataclasses.replace(record, value=validator.sign_value(record))
-
-    conn.send(record)
-    return record
-
-
-def test_signing_in_different_process():
-    parent_conn, child_conn = mp.Pipe()
-    process = mp.Process(target=get_signed_record, args=[child_conn])
-    process.start()
-
-    validator = RSASignatureValidator()
-    parent_conn.send(validator)
-
-    record = DHTRecord(
-        key=b"key", subkey=b"subkey" + validator.local_public_key, value=b"value", expiration_time=get_dht_time() + 10
-    )
-    parent_conn.send(record)
-
-    signed_record = parent_conn.recv()
-    assert b"[signature:" in signed_record.value
-    assert validator.validate(signed_record)
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dhtnode_signatures():
-    alice = await DHTNode.create(record_validator=RSASignatureValidator())
-    initial_peers = await alice.get_visible_maddrs()
-    bob = await DHTNode.create(record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers)
-    mallory = await DHTNode.create(
-        record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=initial_peers
-    )
-
-    key = b"key"
-    subkey = b"protected_subkey" + bob.protocol.record_validator.local_public_key
-
-    assert await bob.store(key, b"true_value", hivemind.get_dht_time() + 10, subkey=subkey)
-    assert (await alice.get(key, latest=True)).value[subkey].value == b"true_value"
-
-    store_ok = await mallory.store(key, b"fake_value", hivemind.get_dht_time() + 10, subkey=subkey)
-    assert not store_ok
-    assert (await alice.get(key, latest=True)).value[subkey].value == b"true_value"
-
-    assert await bob.store(key, b"updated_true_value", hivemind.get_dht_time() + 10, subkey=subkey)
-    assert (await alice.get(key, latest=True)).value[subkey].value == b"updated_true_value"
-
-    await bob.shutdown()  # Bob has shut down, now Mallory is the single peer of Alice
-
-    store_ok = await mallory.store(key, b"updated_fake_value", hivemind.get_dht_time() + 10, subkey=subkey)
-    assert not store_ok
-    assert (await alice.get(key, latest=True)).value[subkey].value == b"updated_true_value"

+ 0 - 215
tests/test_dht_experts.py

@@ -1,215 +0,0 @@
-import asyncio
-import random
-import time
-
-import numpy as np
-import pytest
-
-import hivemind
-from hivemind.dht import DHTNode
-from hivemind import LOCALHOST
-from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.server import declare_experts, get_experts
-from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
-
-
-@pytest.mark.forked
-def test_store_get_experts(n_peers=10):
-    peers = [hivemind.DHT(start=True)]
-    initial_peers = peers[0].get_visible_maddrs()
-    peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
-
-    first_peer = random.choice(peers)
-    other_peer = random.choice(peers)
-
-    expert_uids = [f"my_expert.{i}" for i in range(50)]
-    batch_size = 10
-    for batch_start in range(0, len(expert_uids), batch_size):
-        declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], "localhost:1234")
-
-    found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"])
-    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
-    assert all(res is None for res in found[-2:]), "Found non-existing experts"
-
-    other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
-    declare_experts(other_peer, [other_expert], f"that_host:{other_port}")
-    first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
-    assert isinstance(first_found, hivemind.RemoteExpert)
-    assert first_found.endpoint == f"that_host:{other_port}"
-
-    # test graceful shutdown
-    first_peer.shutdown()
-    other_peer.shutdown()
-    time.sleep(1.0)
-    remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
-    remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
-    assert all(declare_experts(remaining_peer1, ["new_expert.1"], "dummy"))
-    assert get_experts(remaining_peer2, ["new_expert.1"])[0].endpoint == "dummy"
-
-
-@pytest.mark.forked
-def test_beam_search(
-    n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
-):
-    dht = [hivemind.DHT(start=True)]
-    initial_peers = dht[0].get_visible_maddrs()
-    dht += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
-
-    real_experts = sorted(
-        {"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
-    )
-    for batch_start in range(0, len(real_experts), batch_size):
-        declare_experts(
-            random.choice(dht),
-            real_experts[batch_start : batch_start + batch_size],
-            wait=True,
-            endpoint=f"host{batch_start // batch_size}:{random.randint(0, 65536)}",
-        )
-
-    neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
-    you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
-    beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
-
-    for i in range(10):
-        topk_experts = beam_search.find_best_experts([np.random.randn(dim) for dim in grid_dims], beam_size)
-        assert all(isinstance(e, hivemind.RemoteExpert) for e in topk_experts)
-        assert len(topk_experts) == beam_size
-
-    for i in range(10):
-        batch_experts = beam_search.batch_find_best_experts(
-            [np.random.randn(batch_size, dim) for dim in grid_dims], beam_size=beam_size
-        )
-        assert isinstance(batch_experts, list) and len(batch_experts) == batch_size
-        assert all(isinstance(e, hivemind.RemoteExpert) for experts in batch_experts for e in experts)
-        assert all(len(experts) == beam_size for experts in batch_experts)
-
-
-@pytest.mark.forked
-def test_dht_single_node():
-    node = hivemind.DHT(start=True)
-    beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,))
-
-    assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], f"{hivemind.LOCALHOST}:1337").values())
-    assert len(declare_experts(node, ["ffn.1", "ffn.2"], endpoint="that_place")) == 4
-    assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], f"{hivemind.LOCALHOST}:42")) == 7
-
-    for expert in get_experts(node, ["expert.3", "expert.2"]):
-        assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
-
-    assert all(declare_experts(node, ["expert.5", "expert.2"], f"{hivemind.LOCALHOST}:1337").values())
-    found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
-    assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"]
-
-    successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
-    assert len(successors["e.1.2."]) == 2
-    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", f"{LOCALHOST}:42")
-    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", f"{LOCALHOST}:42")
-    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", f"{LOCALHOST}:42")
-    assert successors["e.4.5."] == {}
-
-    initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
-    assert len(initial_beam) == 3
-    assert initial_beam[0][:2] == (2.0, "expert.1.")
-    assert initial_beam[1][:2] == (1.0, "expert.2.")
-    assert initial_beam[2][:2] == (0.0, "expert.3.")
-
-    with pytest.raises(AssertionError):
-        beam_search = MoEBeamSearcher(node, "expert.1.ffn", (2, 2))
-
-    with pytest.raises(AssertionError):
-        beam_search.get_active_successors(["e.1.2.", "e.2", "e.4.5."])
-
-
-def test_uid_patterns():
-    valid_experts = [
-        "expert.1",
-        "expert.0",
-        "expert.0.0.1",
-        "expert.1337",
-        "ffn.12.34.56.78.90",
-        "transformer.3.2.1.0",
-        "transformer_encoder.2",
-        "transformer::encoder.2",
-        "T®@nsf0rmE®🤗.321",
-        "🤗.321",
-        "0.1.2",
-        "00.1.2",
-        "7070.3.2.1.0",
-        "block2.1.23",
-        "LAYER.1.0.1",
-    ]
-    valid_prefixes = ["expert.", "e.1.", "e.2.", "e.1.2.3.", "ololo.123.456.789.10."]
-    valid_prefixes.extend([f"{uid}." for uid in valid_experts])
-    valid_prefixes.extend([split_uid(uid)[0] for uid in valid_experts])
-    for uid in valid_experts:
-        assert is_valid_uid(uid), f"UID {uid} is valid, but was perceived as invalid"
-    for pfx in valid_prefixes:
-        assert is_valid_prefix(pfx), f"Prefix {pfx} is valid, but was perceived as invalid"
-
-    invalid = [
-        "",
-        ".",
-        "expert.-1",
-        "xxx.a",
-        "expert.1x",
-        "expert_ffn.1.abc1",
-        "some.123.01",
-        "expert.123.01",
-        "e1",
-        "e..1",
-        "e",
-        "e.1.2.3..4",
-        "ffn.1..1",
-        ".123",
-        ".1.2.3.",
-        ".expert",
-        "transformer.encoder.2",
-        "T®@nsf0rmE®.🤗.321",
-        "layer::123",
-        "expert.0.1.2.suffix",
-        "0.1.2.suffix",
-        "expert.1 something",
-        "expert.1\n",
-        "expert.1\n2",
-        "expert.1 ",
-        "expert.1\nexpert.2",
-        "'expert.1'",
-        '"expert.1"',
-    ]
-    invalid_experts = invalid + valid_prefixes + ["0", "123456"]
-    invalid_prefixes = invalid + valid_experts + ["expert", ".🤗", ".expert"]
-    for uid in invalid_experts:
-        assert not is_valid_uid(uid), f"UID {uid} is not valid, but was perceived as valid"
-    for pfx in invalid_prefixes:
-        assert not is_valid_prefix(pfx), f"Prefix {pfx} is not valid, but was perceived as valid"
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_negative_caching(n_peers=10):
-    dht_kwargs = {"cache_locally": False}
-
-    peers = [hivemind.DHT(start=True, **dht_kwargs)]
-    initial_peers = peers[0].get_visible_maddrs()
-    peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)]
-
-    writer_peer = random.choice(peers)
-    assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], "myaddr:1234").values())
-
-    neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], [])
-    neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs)
-    beam_search = MoEBeamSearcher(neg_caching_peer, uid_prefix="ffn.", grid_size=(10, 10, 10), negative_caching=True)
-    # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
-    assert len(beam_search.get_initial_beam(scores=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], beam_size=3)) == 2
-
-    node = await DHTNode.create(initial_peers=neighbors)
-    fetched = await asyncio.gather(*(node.get(f"ffn.{i}.") for i in range(10)))
-    for i in range(6):
-        assert fetched[i] is not None, f"node should have cached ffn.{i}."
-    for i in range(6, len(fetched)):
-        assert fetched[i] is None, f"node shouldn't have cached ffn.{i}."
-
-    await node.shutdown()
-    neg_caching_peer.shutdown()
-    for peer in peers:
-        peer.shutdown()

+ 0 - 471
tests/test_dht_node.py

@@ -1,471 +0,0 @@
-import asyncio
-import heapq
-import multiprocessing as mp
-import random
-import signal
-from itertools import product
-from typing import List, Sequence, Tuple
-
-import numpy as np
-import pytest
-from multiaddr import Multiaddr
-
-import hivemind
-from hivemind import get_dht_time
-from hivemind.dht.node import DHTID, DHTNode
-from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.storage import DictionaryDHTValue
-from hivemind.p2p import P2P, PeerID
-from hivemind.utils.logging import get_logger
-from test_utils.dht_swarms import launch_swarm_in_separate_processes, launch_star_shaped_swarm
-
-
-logger = get_logger(__name__)
-
-
-def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
-    return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
-
-
-def run_protocol_listener(
-    dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
-) -> None:
-    loop = asyncio.get_event_loop()
-
-    p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
-    visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
-
-    protocol = loop.run_until_complete(
-        DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
-    )
-
-    logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
-
-    for peer_id in maddrs_to_peer_ids(initial_peers):
-        loop.run_until_complete(protocol.call_ping(peer_id))
-
-    maddr_conn.send((p2p.id, visible_maddrs))
-
-    async def shutdown():
-        await p2p.shutdown()
-        logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
-        loop.stop()
-
-    loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
-    loop.run_forever()
-
-
-def launch_protocol_listener(
-    initial_peers: Sequence[Multiaddr] = (),
-) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
-    remote_conn, local_conn = mp.Pipe()
-    dht_id = DHTID.generate()
-    process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
-    process.start()
-    peer_id, visible_maddrs = local_conn.recv()
-
-    return dht_id, process, peer_id, visible_maddrs
-
-
-# note: we run network-related tests in a separate process to re-initialize all global states from scratch
-# this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
-
-
-@pytest.mark.forked
-def test_dht_protocol():
-    peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
-    peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
-
-    loop = asyncio.get_event_loop()
-    for client_mode in [True, False]:  # note: order matters, this test assumes that first run uses client mode
-        peer_id = DHTID.generate()
-        p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
-        protocol = loop.run_until_complete(
-            DHTProtocol.create(
-                p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
-            )
-        )
-        logger.info(f"Self id={protocol.node_id}")
-
-        assert loop.run_until_complete(protocol.call_ping(peer1_id)) == peer1_node_id
-
-        key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
-        store_ok = loop.run_until_complete(
-            protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )
-        assert all(store_ok), "DHT rejected a trivial store"
-
-        # peer 1 must know about peer 2
-        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [key])
-        )[key]
-        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-        (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
-        assert (
-            recv_id == peer2_node_id and recv_peer_id == peer2_id
-        ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
-
-        assert recv_value == value and recv_expiration == expiration, (
-            f"call_find_value expected {value} (expires by {expiration}) "
-            f"but got {recv_value} (expires by {recv_expiration})"
-        )
-
-        # peer 2 must know about peer 1, but not have a *random* nonexistent value
-        dummy_key = DHTID.generate()
-        empty_item, nodes_found_2 = loop.run_until_complete(protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
-        assert empty_item is None, "Non-existent keys shouldn't have values"
-        (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
-        assert (
-            recv_id == peer1_node_id and recv_peer_id == peer1_id
-        ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
-
-        # cause a non-response by querying a nonexistent peer
-        assert loop.run_until_complete(protocol.call_find(PeerID.from_base58("fakeid"), [key])) is None
-
-        # store/get a dictionary with sub-keys
-        nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
-        value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
-        assert loop.run_until_complete(
-            protocol.call_store(
-                peer1_id,
-                keys=[nested_key],
-                values=[hivemind.MSGPackSerializer.dumps(value1)],
-                expiration_time=[expiration],
-                subkeys=[subkey1],
-            )
-        )
-        assert loop.run_until_complete(
-            protocol.call_store(
-                peer1_id,
-                keys=[nested_key],
-                values=[hivemind.MSGPackSerializer.dumps(value2)],
-                expiration_time=[expiration + 5],
-                subkeys=[subkey2],
-            )
-        )
-        (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [nested_key])
-        )[nested_key]
-        assert isinstance(recv_dict, DictionaryDHTValue)
-        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
-        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
-        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
-
-        if not client_mode:
-            loop.run_until_complete(p2p.shutdown())
-
-    peer1_proc.terminate()
-    peer2_proc.terminate()
-
-
-@pytest.mark.forked
-def test_empty_table():
-    """Test RPC methods with empty routing table"""
-    peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
-
-    loop = asyncio.get_event_loop()
-    p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
-    protocol = loop.run_until_complete(
-        DHTProtocol.create(
-            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
-        )
-    )
-
-    key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
-
-    empty_item, nodes_found = loop.run_until_complete(protocol.call_find(peer_peer_id, [key]))[key]
-    assert empty_item is None and len(nodes_found) == 0
-    assert all(
-        loop.run_until_complete(
-            protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )
-    ), "peer rejected store"
-
-    (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-        protocol.call_find(peer_peer_id, [key])
-    )[key]
-    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-    assert len(nodes_found) == 0
-    assert recv_value == value and recv_expiration == expiration
-
-    assert loop.run_until_complete(protocol.call_ping(peer_peer_id)) == peer_id
-    assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58("fakeid"))) is None
-    peer_proc.terminate()
-
-
-@pytest.mark.forked
-def test_dht_node():
-    # step A: create a swarm of 50 dht nodes in separate processes
-    #         (first 5 created sequentially, others created in parallel)
-    processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(n_peers=50, n_sequential_peers=5)
-
-    # step B: run 51-st node in this process
-    loop = asyncio.get_event_loop()
-    initial_peers = random.choice(swarm_maddrs)
-    me = loop.run_until_complete(
-        DHTNode.create(initial_peers=initial_peers, parallel_rpc=10, cache_refresh_before_expiry=False)
-    )
-
-    # test 1: find self
-    nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
-    assert len(nearest) == 1 and nearest[me.node_id] == me.peer_id
-
-    # test 2: find others
-    for _ in range(10):
-        ref_peer_id, query_id = random.choice(list(dht.items()))
-        nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
-        assert len(nearest) == 1
-        found_node_id, found_peer_id = next(iter(nearest.items()))
-        assert found_node_id == query_id and found_peer_id == ref_peer_id
-
-    # test 3: find neighbors to random nodes
-    accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
-    jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
-    all_node_ids = list(dht.values())
-
-    for _ in range(10):
-        query_id = DHTID.generate()
-        k_nearest = random.randint(1, 10)
-        exclude_self = random.random() > 0.5
-        nearest = loop.run_until_complete(
-            me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
-        )[query_id]
-        nearest_nodes = list(nearest)  # keys from ordered dict
-
-        assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
-        assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
-        assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
-
-        ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
-        if exclude_self and me.node_id in ref_nearest:
-            ref_nearest.remove(me.node_id)
-        if len(ref_nearest) > k_nearest:
-            ref_nearest.pop()
-
-        accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
-        accuracy_denominator += 1
-
-        jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
-        jaccard_denominator += k_nearest
-
-    accuracy = accuracy_numerator / accuracy_denominator
-    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 98-100%
-    jaccard_index = jaccard_numerator / jaccard_denominator
-    logger.debug(f"Jaccard index (intersection over union): {jaccard_index}")  # should be 95-100%
-    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-    assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-
-    # test 4: find all nodes
-    dummy = DHTID.generate()
-    nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
-    assert len(nearest) == len(dht) + 1
-    assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
-
-    # test 5: node without peers
-    detached_node = loop.run_until_complete(DHTNode.create())
-    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
-    assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.peer_id
-    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
-    assert len(nearest) == 0
-
-    # test 6: store and get value
-    true_time = get_dht_time() + 1200
-    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
-
-    initial_peers = random.choice(swarm_maddrs)
-    that_guy = loop.run_until_complete(
-        DHTNode.create(
-            initial_peers=initial_peers, parallel_rpc=10, cache_refresh_before_expiry=False, cache_locally=False
-        )
-    )
-
-    for node in [me, that_guy]:
-        val, expiration_time = loop.run_until_complete(node.get("mykey"))
-        assert val == ["Value", 10], "Wrong value"
-        assert expiration_time == true_time, f"Wrong time"
-
-    assert loop.run_until_complete(detached_node.get("mykey")) is None
-
-    # test 7: bulk store and bulk get
-    keys = "foo", "bar", "baz", "zzz"
-    values = 3, 2, "batman", [1, 2, 3]
-    store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
-    assert all(store_ok.values()), "failed to store one or more keys"
-    response = loop.run_until_complete(me.get_many(keys[::-1]))
-    for key, value in zip(keys, values):
-        assert key in response and response[key][0] == value
-
-    # test 8: store dictionaries as values (with sub-keys)
-    upper_key, subkey1, subkey2, subkey3 = "ololo", "k1", "k2", "k3"
-    now = get_dht_time()
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
-    for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key))
-        assert isinstance(value, dict) and time == now + 20
-        assert value[subkey1] == (123, now + 10)
-        assert value[subkey2] == (456, now + 20)
-        assert len(value) == 2
-
-    assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
-    loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh
-
-    for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key))
-        assert isinstance(value, dict) and time == now + 50, (value, time)
-        assert value[subkey1] == (123, now + 10)
-        assert value[subkey2] == (567, now + 30)
-        assert value[subkey3] == (890, now + 50)
-        assert len(value) == 3
-
-    for proc in processes:
-        proc.terminate()
-    # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
-    loop.run_until_complete(asyncio.wait([node.shutdown() for node in [me, detached_node, that_guy]]))
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dhtnode_replicas():
-    num_replicas = random.randint(1, 20)
-    peers = await launch_star_shaped_swarm(n_peers=20, num_replicas=num_replicas)
-
-    you = random.choice(peers)
-    assert await you.store("key1", "foo", get_dht_time() + 999)
-
-    actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
-    assert num_replicas == actual_key1_replicas
-
-    assert await you.store("key2", "bar", get_dht_time() + 999)
-    total_size = sum(len(peer.protocol.storage) for peer in peers)
-    actual_key2_replicas = total_size - actual_key1_replicas
-    assert num_replicas == actual_key2_replicas
-
-    assert await you.store("key2", "baz", get_dht_time() + 1000)
-    assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dhtnode_caching(T=0.05):
-    node2 = await DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
-    node1 = await DHTNode.create(
-        initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
-        cache_refresh_before_expiry=5 * T,
-        client_mode=True,
-        reuse_get_requests=False,
-    )
-    await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)
-    await node2.store("k2", [654, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)
-    await node2.store("k3", [654, "value"], expiration_time=hivemind.get_dht_time() + 15 * T)
-    await node1.get_many(["k", "k2", "k3", "k4"])
-    assert len(node1.protocol.cache) == 3
-    assert len(node1.cache_refresh_queue) == 0
-
-    await node1.get_many(["k", "k2", "k3", "k4"])
-    assert len(node1.cache_refresh_queue) == 3
-
-    await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 12 * T)
-    await asyncio.sleep(4 * T)
-    await node1.get("k")
-    await asyncio.sleep(1 * T)
-
-    assert len(node1.protocol.cache) == 3
-    assert len(node1.cache_refresh_queue) == 2
-    await asyncio.sleep(3 * T)
-
-    assert len(node1.cache_refresh_queue) == 1
-
-    await asyncio.sleep(5 * T)
-    assert len(node1.cache_refresh_queue) == 0
-    await asyncio.sleep(5 * T)
-    assert len(node1.cache_refresh_queue) == 0
-
-    await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 10 * T)
-    await node1.get("k")
-    await asyncio.sleep(1 * T)
-    assert len(node1.cache_refresh_queue) == 0
-    await node1.get("k")
-    await asyncio.sleep(1 * T)
-    assert len(node1.cache_refresh_queue) == 1
-
-    await asyncio.sleep(5 * T)
-    assert len(node1.cache_refresh_queue) == 0
-
-    await asyncio.gather(node1.shutdown(), node2.shutdown())
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dhtnode_reuse_get():
-    peers = await launch_star_shaped_swarm(n_peers=10, parallel_rpc=256)
-
-    await asyncio.gather(
-        random.choice(peers).store("k1", 123, hivemind.get_dht_time() + 999),
-        random.choice(peers).store("k2", 567, hivemind.get_dht_time() + 999),
-    )
-
-    you = random.choice(peers)
-
-    futures1 = await you.get_many(["k1", "k2"], return_futures=True)
-    assert len(you.pending_get_requests[DHTID.generate("k1")]) == 1
-    assert len(you.pending_get_requests[DHTID.generate("k2")]) == 1
-
-    futures2 = await you.get_many(["k2", "k3"], return_futures=True)
-    assert len(you.pending_get_requests[DHTID.generate("k2")]) == 2
-
-    await asyncio.gather(*futures1.values(), *futures2.values())
-    futures3 = await you.get_many(["k3"], return_futures=True)
-    assert len(you.pending_get_requests[DHTID.generate("k1")]) == 0
-    assert len(you.pending_get_requests[DHTID.generate("k2")]) == 0
-    assert len(you.pending_get_requests[DHTID.generate("k3")]) == 1
-
-    assert (await futures1["k1"])[0] == 123
-    assert await futures1["k2"] == await futures2["k2"] and (await futures1["k2"])[0] == 567
-    assert await futures2["k3"] == await futures3["k3"] and (await futures3["k3"]) is None
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dhtnode_blacklist():
-    node1, node2, node3, node4 = await launch_star_shaped_swarm(n_peers=4, blacklist_time=999)
-
-    assert await node2.store("abc", 123, expiration_time=hivemind.get_dht_time() + 99)
-    assert len(node2.blacklist.ban_counter) == 0
-
-    await asyncio.gather(node3.shutdown(), node4.shutdown())
-
-    assert await node2.store("def", 456, expiration_time=hivemind.get_dht_time() + 99)
-
-    assert set(node2.blacklist.ban_counter.keys()) == {node3.peer_id, node4.peer_id}
-
-    assert await node1.get("abc", latest=True)  # force node1 to crawl dht and discover unresponsive peers
-    assert node3.peer_id in node1.blacklist
-
-    assert await node1.get("abc", latest=True)  # force node1 to crawl dht and discover unresponsive peers
-    assert node2.peer_id not in node1.blacklist
-
-    await asyncio.gather(node1.shutdown(), node2.shutdown())
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_dhtnode_edge_cases():
-    peers = await launch_star_shaped_swarm(n_peers=4, parallel_rpc=4)
-
-    subkeys = [0, "", False, True, "abyrvalg", 4555]
-    keys = subkeys + [()]
-    values = subkeys + [[]]
-    for key, subkey, value in product(keys, subkeys, values):
-        await random.choice(peers).store(
-            key=key, subkey=subkey, value=value, expiration_time=hivemind.get_dht_time() + 999
-        ),
-
-        stored = await random.choice(peers).get(key=key, latest=True)
-        assert stored is not None
-        assert subkey in stored.value
-        assert stored.value[subkey].value == value
-
-    await asyncio.wait([node.shutdown() for node in peers])

+ 0 - 198
tests/test_dht_schema.py

@@ -1,198 +0,0 @@
-import asyncio
-from typing import Dict
-
-import pytest
-from pydantic import BaseModel, StrictInt, conint
-
-import hivemind
-from hivemind.dht.node import DHTNode
-from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.dht.validation import DHTRecord, RecordValidatorBase
-from hivemind.utils.timed_storage import get_dht_time
-
-
-class SampleSchema(BaseModel):
-    experiment_name: bytes
-    n_batches: Dict[bytes, conint(ge=0, strict=True)]
-    signed_data: Dict[BytesWithPublicKey, bytes]
-
-
-@pytest.fixture
-async def dht_nodes_with_schema():
-    validator = SchemaValidator(SampleSchema)
-
-    alice = await DHTNode.create(record_validator=validator)
-    bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
-    yield alice, bob
-
-    await asyncio.gather(alice.shutdown(), bob.shutdown())
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_expecting_regular_value(dht_nodes_with_schema):
-    alice, bob = dht_nodes_with_schema
-
-    # Regular value (bytes) expected
-    assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
-    assert not await bob.store("experiment_name", 666, get_dht_time() + 10)
-    assert not await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10, subkey=b"subkey")
-
-    # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
-    assert not await bob.store("experiment_name", [], get_dht_time() + 10)
-    assert not await bob.store("experiment_name", [1, 2, 3], get_dht_time() + 10)
-
-    for peer in [alice, bob]:
-        assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_expecting_dictionary(dht_nodes_with_schema):
-    alice, bob = dht_nodes_with_schema
-
-    # Dictionary (bytes -> non-negative int) expected
-    assert await bob.store("n_batches", 777, get_dht_time() + 10, subkey=b"uid1")
-    assert await bob.store("n_batches", 778, get_dht_time() + 10, subkey=b"uid2")
-    assert not await bob.store("n_batches", -666, get_dht_time() + 10, subkey=b"uid3")
-    assert not await bob.store("n_batches", 666, get_dht_time() + 10)
-    assert not await bob.store("n_batches", b"not_integer", get_dht_time() + 10, subkey=b"uid1")
-    assert not await bob.store("n_batches", 666, get_dht_time() + 10, subkey=666)
-
-    # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
-    assert not await bob.store("n_batches", {b"uid3": 779}, get_dht_time() + 10)
-
-    # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
-    assert not await bob.store("n_batches", 779.5, get_dht_time() + 10, subkey=b"uid3")
-    assert not await bob.store("n_batches", 779.0, get_dht_time() + 10, subkey=b"uid3")
-    assert not await bob.store("n_batches", [], get_dht_time() + 10)
-    assert not await bob.store("n_batches", [(b"uid3", 779)], get_dht_time() + 10)
-
-    # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
-    assert not await bob.store("n_batches", "", get_dht_time() + 10)
-
-    for peer in [alice, bob]:
-        dictionary = (await peer.get("n_batches", latest=True)).value
-        assert len(dictionary) == 2 and dictionary[b"uid1"].value == 777 and dictionary[b"uid2"].value == 778
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_expecting_public_keys(dht_nodes_with_schema):
-    alice, bob = dht_nodes_with_schema
-
-    # Subkeys expected to contain a public key
-    # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
-    assert await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid[owner:public-key]")
-    assert not await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid-without-public-key")
-
-    for peer in [alice, bob]:
-        dictionary = (await peer.get("signed_data", latest=True)).value
-        assert len(dictionary) == 1 and dictionary[b"uid[owner:public-key]"].value == b"foo_bar"
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_keys_outside_schema(dht_nodes_with_schema):
-    class Schema(BaseModel):
-        some_field: StrictInt
-
-    class MergedSchema(BaseModel):
-        another_field: StrictInt
-
-    for allow_extra_keys in [False, True]:
-        validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
-        assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))
-
-        alice = await DHTNode.create(record_validator=validator)
-        bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
-
-        store_ok = await bob.store("unknown_key", b"foo_bar", get_dht_time() + 10)
-        assert store_ok == allow_extra_keys
-
-        for peer in [alice, bob]:
-            result = await peer.get("unknown_key", latest=True)
-            if allow_extra_keys:
-                assert result.value == b"foo_bar"
-            else:
-                assert result is None
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_prefix():
-    class Schema(BaseModel):
-        field: StrictInt
-
-    validator = SchemaValidator(Schema, allow_extra_keys=False, prefix="prefix")
-
-    alice = await DHTNode.create(record_validator=validator)
-    bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
-
-    assert await bob.store("prefix_field", 777, get_dht_time() + 10)
-    assert not await bob.store("prefix_field", "string_value", get_dht_time() + 10)
-    assert not await bob.store("field", 777, get_dht_time() + 10)
-
-    for peer in [alice, bob]:
-        assert (await peer.get("prefix_field", latest=True)).value == 777
-        assert (await peer.get("field", latest=True)) is None
-
-    await asyncio.gather(alice.shutdown(), bob.shutdown())
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_merging_schema_validators(dht_nodes_with_schema):
-    alice, bob = dht_nodes_with_schema
-
-    class TrivialValidator(RecordValidatorBase):
-        def validate(self, record: DHTRecord) -> bool:
-            return True
-
-    second_validator = TrivialValidator()
-    # Can't merge with the validator of the different type
-    assert not alice.protocol.record_validator.merge_with(second_validator)
-
-    class SecondSchema(BaseModel):
-        some_field: StrictInt
-        another_field: str
-
-    class ThirdSchema(BaseModel):
-        another_field: StrictInt  # Allow it to be a StrictInt as well
-
-    for schema in [SecondSchema, ThirdSchema]:
-        new_validator = SchemaValidator(schema, allow_extra_keys=False)
-        for peer in [alice, bob]:
-            assert peer.protocol.record_validator.merge_with(new_validator)
-
-    assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
-    assert await bob.store("some_field", 777, get_dht_time() + 10)
-    assert not await bob.store("some_field", "string_value", get_dht_time() + 10)
-    assert await bob.store("another_field", 42, get_dht_time() + 10)
-    assert await bob.store("another_field", "string_value", get_dht_time() + 10)
-
-    # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
-    assert await bob.store("unknown_key", 999, get_dht_time() + 10)
-
-    for peer in [alice, bob]:
-        assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"
-        assert (await peer.get("some_field", latest=True)).value == 777
-        assert (await peer.get("another_field", latest=True)).value == "string_value"
-
-        assert (await peer.get("unknown_key", latest=True)).value == 999
-
-
-@pytest.mark.forked
-def test_sending_validator_instance_between_processes():
-    alice = hivemind.DHT(start=True)
-    bob = hivemind.DHT(start=True, initial_peers=alice.get_visible_maddrs())
-
-    alice.add_validators([SchemaValidator(SampleSchema)])
-    bob.add_validators([SchemaValidator(SampleSchema)])
-
-    assert bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
-    assert not bob.store("experiment_name", 777, get_dht_time() + 10)
-    assert alice.get("experiment_name", latest=True).value == b"foo_bar"
-
-    alice.shutdown()
-    bob.shutdown()

+ 0 - 130
tests/test_dht_storage.py

@@ -1,130 +0,0 @@
-import time
-
-from hivemind.utils.timed_storage import get_dht_time
-from hivemind.dht.storage import DHTLocalStorage, DHTID, DictionaryDHTValue
-from hivemind.utils.serializer import MSGPackSerializer
-
-
-def test_store():
-    d = DHTLocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
-    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
-
-
-def test_get_expired():
-    d = DHTLocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
-    time.sleep(0.5)
-    assert d.get(DHTID.generate("key")) is None, "Expired value must be deleted"
-
-
-def test_get_empty():
-    d = DHTLocalStorage()
-    assert d.get(DHTID.generate(source="key")) is None, "DHTLocalStorage returned non-existent value"
-
-
-def test_change_expiration_time():
-    d = DHTLocalStorage()
-    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
-    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
-    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
-    time.sleep(1)
-    assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
-
-
-def test_maxsize_cache():
-    d = DHTLocalStorage(maxsize=2)
-    d.store(DHTID.generate("key1a"), b"val1a", get_dht_time() + 1)
-    d.store(DHTID.generate("key1b"), b"val1b", get_dht_time() + 1)
-    d.store(DHTID.generate("key1a"), b"val1a2", get_dht_time() + 2)
-    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
-    assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
-    assert d.get(DHTID.generate("key1a"))[0] == b"val1a2", "Value with bigger exp. time must be kept"
-    assert d.get(DHTID.generate("key1b")) is None, "Value with less exp time, must be deleted"
-
-
-def test_localstorage_top():
-    d = DHTLocalStorage(maxsize=3)
-    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
-    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
-    d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
-    assert d.top()[0] == DHTID.generate("key1") and d.top()[1].value == b"val1"
-
-    d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3)
-    assert d.top()[0] == DHTID.generate("key2") and d.top()[1].value == b"val2"
-
-    del d[DHTID.generate("key2")]
-    assert d.top()[0] == DHTID.generate("key1") and d.top()[1].value == b"val1_new"
-    d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5)
-    d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6)  # key4 will push out key1 due to maxsize
-
-    assert d.top()[0] == DHTID.generate("key3") and d.top()[1].value == b"val3"
-
-
-def test_localstorage_nested():
-    time = get_dht_time()
-    d1 = DHTLocalStorage()
-    d2 = DictionaryDHTValue()
-    d2.store("subkey1", b"value1", time + 2)
-    d2.store("subkey2", b"value2", time + 3)
-    d2.store("subkey3", b"value3", time + 1)
-
-    assert d2.latest_expiration_time == time + 3
-    for subkey, (subvalue, subexpiration) in d2.items():
-        assert d1.store_subkey(DHTID.generate("foo"), subkey, subvalue, subexpiration)
-    assert d1.store(DHTID.generate("bar"), b"456", time + 2)
-    assert d1.get(DHTID.generate("foo"))[0].data == d2.data
-    assert d1.get(DHTID.generate("foo"))[1] == d2.latest_expiration_time
-    assert d1.get(DHTID.generate("foo"))[0].get("subkey1") == (b"value1", time + 2)
-    assert len(d1.get(DHTID.generate("foo"))[0]) == 3
-    assert d1.store_subkey(DHTID.generate("foo"), "subkey4", b"value4", time + 4)
-    assert len(d1.get(DHTID.generate("foo"))[0]) == 4
-
-    assert (
-        d1.store_subkey(DHTID.generate("bar"), "subkeyA", b"valueA", time + 1) is False
-    )  # prev has better expiration
-    assert d1.store_subkey(DHTID.generate("bar"), "subkeyA", b"valueA", time + 3)  # new value has better expiration
-    assert d1.store_subkey(DHTID.generate("bar"), "subkeyB", b"valueB", time + 4)  # new value has better expiration
-    assert d1.store_subkey(DHTID.generate("bar"), "subkeyA", b"valueA+", time + 5)  # overwrite subkeyA under key bar
-    assert all(subkey in d1.get(DHTID.generate("bar"))[0] for subkey in ("subkeyA", "subkeyB"))
-    assert len(d1.get(DHTID.generate("bar"))[0]) == 2 and d1.get(DHTID.generate("bar"))[1] == time + 5
-
-    assert d1.store(DHTID.generate("foo"), b"nothing", time + 3.5) is False  # previous value has better expiration
-    assert d1.get(DHTID.generate("foo"))[0].get("subkey2") == (b"value2", time + 3)
-    assert d1.store(DHTID.generate("foo"), b"nothing", time + 5) is True  # new value has better expiraiton
-    assert d1.get(DHTID.generate("foo")) == (b"nothing", time + 5)  # value should be replaced
-
-
-def test_localstorage_freeze():
-    d = DHTLocalStorage(maxsize=2)
-
-    with d.freeze():
-        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01)
-        assert DHTID.generate("key1") in d
-        time.sleep(0.03)
-        assert DHTID.generate("key1") in d
-    assert DHTID.generate("key1") not in d
-
-    with d.freeze():
-        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
-        d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
-        d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 3)  # key3 will push key1 out due to maxsize
-        assert DHTID.generate("key1") in d
-    assert DHTID.generate("key1") not in d
-
-
-def test_localstorage_serialize():
-    d1 = DictionaryDHTValue()
-    d2 = DictionaryDHTValue()
-
-    now = get_dht_time()
-    d1.store("key1", b"ololo", now + 1)
-    d2.store("key2", b"pysh", now + 1)
-    d2.store("key3", b"pyshpysh", now + 2)
-
-    data = MSGPackSerializer.dumps([d1, d2, 123321])
-    assert isinstance(data, bytes)
-    new_d1, new_d2, new_value = MSGPackSerializer.loads(data)
-    assert isinstance(new_d1, DictionaryDHTValue) and isinstance(new_d2, DictionaryDHTValue) and new_value == 123321
-    assert "key1" in new_d1 and len(new_d1) == 1
-    assert "key1" not in new_d2 and len(new_d2) == 2 and new_d2.get("key3") == (b"pyshpysh", now + 2)

+ 0 - 93
tests/test_dht_validation.py

@@ -1,93 +0,0 @@
-import dataclasses
-from typing import Dict
-
-import pytest
-from pydantic import BaseModel, StrictInt
-
-import hivemind
-from hivemind.dht.crypto import RSASignatureValidator
-from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.routing import DHTID
-from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.dht.validation import DHTRecord, CompositeValidator
-
-
-class SchemaA(BaseModel):
-    field_a: bytes
-
-
-class SchemaB(BaseModel):
-    field_b: Dict[BytesWithPublicKey, StrictInt]
-
-
-@pytest.fixture
-def validators_for_app():
-    # Each application may add its own validator set
-    return {
-        "A": [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)],
-        "B": [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()],
-    }
-
-
-def test_composite_validator(validators_for_app):
-    validator = CompositeValidator(validators_for_app["A"])
-    assert [type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]
-
-    validator.extend(validators_for_app["B"])
-    assert [type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]
-    assert len(validator._validators[0]._schemas) == 2
-
-    local_public_key = validators_for_app["A"][0].local_public_key
-    record = DHTRecord(
-        key=DHTID.generate(source="field_b").to_bytes(),
-        subkey=DHTProtocol.serializer.dumps(local_public_key),
-        value=DHTProtocol.serializer.dumps(777),
-        expiration_time=hivemind.get_dht_time() + 10,
-    )
-
-    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
-    # Expect only one signature since two RSASignatureValidatos have been merged
-    assert signed_record.value.count(b"[signature:") == 1
-    # Expect successful validation since the second SchemaValidator has been merged to the first
-    assert validator.validate(signed_record)
-    assert validator.strip_value(signed_record) == record.value
-
-    record = DHTRecord(
-        key=DHTID.generate(source="unknown_key").to_bytes(),
-        subkey=DHTProtocol.IS_REGULAR_VALUE,
-        value=DHTProtocol.serializer.dumps(777),
-        expiration_time=hivemind.get_dht_time() + 10,
-    )
-
-    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
-    assert signed_record.value.count(b"[signature:") == 0
-    # Expect failed validation since `unknown_key` is not a part of any schema
-    assert not validator.validate(signed_record)
-
-
-@pytest.mark.forked
-def test_dht_add_validators(validators_for_app):
-    # One app may create a DHT with its validators
-    dht = hivemind.DHT(start=False, record_validators=validators_for_app["A"])
-
-    # While the DHT process is not started, you can't send a command to append new validators
-    with pytest.raises(RuntimeError):
-        dht.add_validators(validators_for_app["B"])
-    dht.run_in_background(await_ready=True)
-
-    # After starting the process, other apps may add new validators to the existing DHT
-    dht.add_validators(validators_for_app["B"])
-
-    assert dht.store("field_a", b"bytes_value", hivemind.get_dht_time() + 10)
-    assert dht.get("field_a", latest=True).value == b"bytes_value"
-
-    assert not dht.store("field_a", 666, hivemind.get_dht_time() + 10)
-    assert dht.get("field_a", latest=True).value == b"bytes_value"
-
-    local_public_key = validators_for_app["A"][0].local_public_key
-    assert dht.store("field_b", 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
-    dictionary = dht.get("field_b", latest=True).value
-    assert len(dictionary) == 1 and dictionary[local_public_key].value == 777
-
-    assert not dht.store("unknown_key", 666, hivemind.get_dht_time() + 10)
-    assert dht.get("unknown_key", latest=True) is None

+ 0 - 111
tests/test_expert_backend.py

@@ -1,111 +0,0 @@
-from pathlib import Path
-from tempfile import TemporaryDirectory
-
-import pytest
-import torch
-from torch.nn import Linear
-
-from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.moe.server.checkpoints import store_experts, load_experts
-from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
-
-EXPERT_WEIGHT_UPDATES = 3
-BACKWARD_PASSES_BEFORE_SAVE = 2
-BACKWARD_PASSES_AFTER_SAVE = 2
-EXPERT_NAME = "test_expert"
-PEAK_LR = 1.0
-
-
-@pytest.fixture
-def example_experts():
-    expert = Linear(1, 1)
-    opt = torch.optim.SGD(expert.parameters(), PEAK_LR)
-
-    args_schema = (BatchTensorDescriptor(1),)
-    expert_backend = ExpertBackend(
-        name=EXPERT_NAME,
-        expert=expert,
-        optimizer=opt,
-        scheduler=get_linear_schedule_with_warmup,
-        num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
-        num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
-        args_schema=args_schema,
-        outputs_schema=BatchTensorDescriptor(1),
-        max_batch_size=1,
-    )
-    experts = {EXPERT_NAME: expert_backend}
-    yield experts
-
-
-@pytest.mark.forked
-def test_save_load_checkpoints(example_experts):
-    expert = example_experts[EXPERT_NAME].expert
-
-    with TemporaryDirectory() as tmpdir:
-        tmp_path = Path(tmpdir)
-
-        for i in range(1, EXPERT_WEIGHT_UPDATES + 1):
-            expert.weight.data[0] = i
-            store_experts(example_experts, tmp_path)
-
-        checkpoints_dir = tmp_path / EXPERT_NAME
-
-        assert checkpoints_dir.exists()
-        # include checkpoint_last.pt
-        assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1
-
-        expert.weight.data[0] = 0
-
-        load_experts(example_experts, tmp_path)
-        assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES
-
-
-@pytest.mark.forked
-def test_restore_update_count(example_experts):
-    expert_backend = example_experts[EXPERT_NAME]
-
-    batch = torch.randn(1, 1)
-    loss_grad = torch.randn(1, 1)
-
-    with TemporaryDirectory() as tmpdir:
-        tmp_path = Path(tmpdir)
-
-        for _ in range(BACKWARD_PASSES_BEFORE_SAVE):
-            expert_backend.backward(batch, loss_grad)
-
-        store_experts(example_experts, tmp_path)
-
-        for _ in range(BACKWARD_PASSES_AFTER_SAVE):
-            expert_backend.backward(batch, loss_grad)
-
-        load_experts(example_experts, tmp_path)
-        assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE
-
-
-@pytest.mark.forked
-def test_lr_schedule(example_experts):
-    expert_backend = example_experts[EXPERT_NAME]
-    optimizer = expert_backend.optimizer
-
-    batch = torch.randn(1, 1)
-    loss_grad = torch.randn(1, 1)
-
-    with TemporaryDirectory() as tmpdir:
-        tmp_path = Path(tmpdir)
-
-        assert optimizer.param_groups[0]["lr"] == 0.0
-
-        for i in range(BACKWARD_PASSES_BEFORE_SAVE):
-            assert optimizer.param_groups[0]["lr"] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE
-            expert_backend.backward(batch, loss_grad)
-
-        assert optimizer.param_groups[0]["lr"] == PEAK_LR
-        store_experts(example_experts, tmp_path)
-
-        for i in range(BACKWARD_PASSES_AFTER_SAVE):
-            assert optimizer.param_groups[0]["lr"] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE))
-            expert_backend.backward(batch, loss_grad)
-
-        assert optimizer.param_groups[0]["lr"] == 0.0
-        load_experts(example_experts, tmp_path)
-        assert optimizer.param_groups[0]["lr"] == PEAK_LR

+ 0 - 293
tests/test_moe.py

@@ -1,293 +0,0 @@
-import grpc
-import numpy as np
-import pytest
-import torch
-
-import hivemind
-from hivemind.moe.server import background_server, declare_experts
-from hivemind.moe.client.expert import DUMMY
-from hivemind.moe.server import layers
-
-
-@pytest.mark.forked
-def test_moe():
-    all_expert_uids = [
-        f"ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
-    ]
-    with background_server(
-        expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
-
-        dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn."
-        )
-
-        for i in range(3):
-            out = dmoe(torch.randn(10, 16))
-            out.sum().backward()
-
-
-@pytest.mark.forked
-def test_no_experts():
-    all_expert_uids = [
-        f"expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
-    ]
-    with background_server(
-        expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
-    ) as (server_endpoint, dht_maddrs):
-        dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
-
-        dmoe = hivemind.RemoteSwitchMixtureOfExperts(
-            in_features=16,
-            grid_size=(4, 4, 4),
-            dht=dht,
-            uid_prefix="expert.",
-            forward_timeout=0.1,
-            backward_timeout=0.1,
-            allow_zero_outputs=True,
-        )
-
-        for i in range(3):
-            out, balancing_loss = dmoe(torch.randn(10, 16))
-            out.sum().backward()
-
-
-@pytest.mark.forked
-def test_call_many(hidden_dim=16):
-    k_min = 1
-    timeout_after_k_min = None
-    backward_k_min = 1
-    forward_timeout = None
-    backward_timeout = None
-    detect_anomalies = False
-    allow_zero_outputs = False
-    atol = 1e-5
-
-    with background_server(
-        num_experts=5,
-        device="cpu",
-        expert_cls="ffn",
-        num_handlers=1,
-        hidden_dim=hidden_dim,
-        optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        inputs = torch.randn(4, hidden_dim, requires_grad=True)
-        inputs_clone = inputs.clone().detach().requires_grad_(True)
-        e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
-        e5 = hivemind.RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
-
-        mask, expert_outputs = hivemind.moe.client.moe._RemoteCallMany.apply(
-            DUMMY,
-            [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
-            k_min,
-            backward_k_min,
-            timeout_after_k_min,
-            forward_timeout,
-            backward_timeout,
-            detect_anomalies,
-            allow_zero_outputs,
-            e1.info,
-            inputs,
-        )
-        assert mask.shape == (4, 3)
-        assert expert_outputs.shape == (4, 3, hidden_dim)
-
-        assert np.all(
-            mask.data.numpy()
-            == np.array([[True, True, True], [True, True, False], [True, False, True], [False, False, False]])
-        ), f"Incorrect mask, {mask}"
-
-        reference_outputs = torch.zeros_like(expert_outputs)
-        reference_outputs[0, 0] = e0(inputs_clone[0:1])
-        reference_outputs[0, 1] = e1(inputs_clone[0:1])
-        reference_outputs[0, 2] = e2(inputs_clone[0:1])
-        reference_outputs[1, 0] = e2(inputs_clone[1:2])
-        reference_outputs[1, 1] = e4(inputs_clone[1:2])
-        reference_outputs[2, 0] = e1(inputs_clone[2:3])
-        reference_outputs[2, 2] = e3(inputs_clone[2:3])
-
-        assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0)
-        proj = torch.randn(4, hidden_dim)
-        loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
-        loss.backward()
-        our_grad = inputs.grad.data.cpu().clone()
-
-        reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
-        reference_loss.backward()
-        reference_grad = inputs_clone.grad.data.cpu().clone()
-        assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
-
-
-@pytest.mark.forked
-def test_remote_module_call(hidden_dim=16):
-    with background_server(
-        num_experts=1,
-        device="cpu",
-        expert_cls="ffn",
-        num_handlers=1,
-        hidden_dim=hidden_dim,
-        optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        real_expert = hivemind.RemoteExpert("expert.0", server_endpoint)
-        fake_expert = hivemind.RemoteExpert("oiasfjiasjf", server_endpoint)
-
-        out1 = real_expert(torch.randn(1, hidden_dim))
-        assert out1.shape == (1, hidden_dim)
-        dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
-        out3 = real_expert(dummy_x)
-        assert out3.shape == (3, hidden_dim)
-        out3_again = real_expert(dummy_x[1:])
-        assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
-        out3_again.norm().backward()
-        assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
-
-        with pytest.raises(grpc.RpcError):
-            real_expert(torch.randn(3, 11))
-        with pytest.raises(grpc.RpcError):
-            fake_expert(dummy_x)
-
-
-@pytest.mark.forked
-def test_beam_search_correctness():
-    all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
-    dht = hivemind.DHT(start=True)
-    assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
-
-    dmoe = hivemind.RemoteMixtureOfExperts(
-        in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn."
-    )
-
-    for i in range(25):
-        input = torch.randn(32)
-        grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
-
-        chosen_experts = dmoe.beam_search.find_best_experts(
-            [tensor.detach().numpy() for tensor in grid_scores], beam_size=dmoe.k_best
-        )
-        chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores], [chosen_experts])[
-            0
-        ]
-        our_best_scores = list(chosen_scores.cpu().detach().numpy())
-
-        # reference: independently find :beam_size: best experts with exhaustive search
-        all_scores = dmoe.compute_expert_scores(
-            [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[hivemind.RemoteExpert(uid, "") for uid in all_expert_uids]],
-        )[0]
-        true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
-
-        assert np.allclose(true_best_scores, our_best_scores)
-
-
-@pytest.mark.forked
-def test_determinism(hidden_dim=16):
-    atol = 1e-5
-
-    xx = torch.randn(32, hidden_dim, requires_grad=True)
-    mask = torch.randint(0, 1, (32, hidden_dim))
-
-    with background_server(
-        num_experts=1,
-        device="cpu",
-        expert_cls="det_dropout",
-        num_handlers=1,
-        hidden_dim=hidden_dim,
-        optim_cls=None,
-        no_dht=True,
-    ) as (server_endpoint, _):
-        expert = hivemind.RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
-
-        out = expert(xx, mask)
-        out_rerun = expert(xx, mask)
-
-        (grad,) = torch.autograd.grad(out.sum(), xx, retain_graph=True)
-        (grad_rerun,) = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
-
-    assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
-    assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
-
-
-@pytest.mark.forked
-def test_compute_expert_scores():
-    try:
-        dht = hivemind.DHT(start=True)
-        moe = hivemind.moe.RemoteMixtureOfExperts(
-            dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
-        )
-        gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
-        ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
-        jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
-        batch_experts = [
-            [
-                hivemind.RemoteExpert(
-                    uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337"
-                )
-                for expert_i in range(len(ii[batch_i]))
-            ]
-            for batch_i in range(len(ii))
-        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
-        logits = moe.compute_expert_scores([gx, gy], batch_experts)
-        torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
-        assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
-
-        for batch_i in range(len(ii)):
-            for expert_i in range(len(ii[batch_i])):
-                assert torch.allclose(
-                    logits[batch_i, expert_i], gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]
-                ), "compute_expert_scores returned incorrect score"
-    finally:
-        dht.shutdown()
-
-
-@pytest.mark.forked
-def test_client_anomaly_detection():
-    HID_DIM = 16
-
-    experts = {}
-    for i in range(4):
-        expert = layers.name_to_block["ffn"](HID_DIM)
-        experts[f"expert.{i}"] = hivemind.ExpertBackend(
-            name=f"expert.{i}",
-            expert=expert,
-            optimizer=torch.optim.Adam(expert.parameters()),
-            args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
-            outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
-            max_batch_size=16,
-        )
-
-    experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
-
-    dht = hivemind.DHT(start=True)
-    server = hivemind.moe.Server(dht, experts, num_connection_handlers=1)
-    server.start()
-    try:
-        server.ready.wait()
-
-        dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
-        )
-
-        input = torch.randn(1, 16)
-        input[0, 0] = float("nan")
-
-        with pytest.raises(ValueError):
-            dmoe(input)
-
-        input[0, 0] = 0
-        output = dmoe(input)
-
-        inf_loss = float("inf") * output.sum()
-        with pytest.raises(ValueError):
-            inf_loss.backward()
-
-        dmoe = hivemind.RemoteMixtureOfExperts(
-            in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
-        )
-        output = dmoe(input)
-        assert output.isfinite().all()
-
-    finally:
-        server.shutdown()

+ 0 - 340
tests/test_p2p_daemon.py

@@ -1,340 +0,0 @@
-import asyncio
-import multiprocessing as mp
-import subprocess
-from contextlib import closing
-from functools import partial
-from typing import List
-
-import numpy as np
-import pytest
-from multiaddr import Multiaddr
-
-from hivemind.p2p import P2P, P2PHandlerError
-from hivemind.proto import dht_pb2
-from hivemind.utils.serializer import MSGPackSerializer
-
-
-def is_process_running(pid: int) -> bool:
-    return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
-
-
-async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
-    return await P2P.replicate(p2p.daemon_listen_maddr) if replicate else p2p
-
-
-@pytest.mark.asyncio
-async def test_daemon_killed_on_del():
-    p2p_daemon = await P2P.create()
-
-    child_pid = p2p_daemon._child.pid
-    assert is_process_running(child_pid)
-
-    await p2p_daemon.shutdown()
-    assert not is_process_running(child_pid)
-
-
-@pytest.mark.parametrize(
-    "host_maddrs",
-    [
-        [Multiaddr("/ip4/127.0.0.1/tcp/0")],
-        [Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
-        [Multiaddr("/ip4/127.0.0.1/tcp/0"), Multiaddr("/ip4/127.0.0.1/udp/0/quic")],
-    ],
-)
-@pytest.mark.asyncio
-async def test_transports(host_maddrs: List[Multiaddr]):
-    server = await P2P.create(quic=True, host_maddrs=host_maddrs)
-    peers = await server.list_peers()
-    assert len(peers) == 0
-
-    client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
-    await client.wait_for_at_least_n_peers(1)
-
-    peers = await client.list_peers()
-    assert len(peers) == 1
-    peers = await server.list_peers()
-    assert len(peers) == 1
-
-
-@pytest.mark.asyncio
-async def test_daemon_replica_does_not_affect_primary():
-    p2p_daemon = await P2P.create()
-    p2p_replica = await P2P.replicate(p2p_daemon.daemon_listen_maddr)
-
-    child_pid = p2p_daemon._child.pid
-    assert is_process_running(child_pid)
-
-    await p2p_replica.shutdown()
-    assert is_process_running(child_pid)
-
-    await p2p_daemon.shutdown()
-    assert not is_process_running(child_pid)
-
-
-@pytest.mark.parametrize(
-    "should_cancel,replicate",
-    [
-        (True, False),
-        (True, True),
-        (False, False),
-        (False, True),
-    ],
-)
-@pytest.mark.asyncio
-async def test_call_protobuf_handler(should_cancel, replicate, handle_name="handle"):
-    handler_cancelled = False
-    server_primary = await P2P.create()
-    server = await replicate_if_needed(server_primary, replicate)
-
-    async def ping_handler(request, context):
-        try:
-            await asyncio.sleep(2)
-        except asyncio.CancelledError:
-            nonlocal handler_cancelled
-            handler_cancelled = True
-        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
-
-    server_pid = server_primary._child.pid
-    await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
-    assert is_process_running(server_pid)
-
-    client_primary = await P2P.create(initial_peers=await server.get_visible_maddrs())
-    client = await replicate_if_needed(client_primary, replicate)
-    client_pid = client_primary._child.pid
-    assert is_process_running(client_pid)
-    await client.wait_for_at_least_n_peers(1)
-
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
-    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
-
-    if should_cancel:
-        call_task = asyncio.create_task(
-            client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
-        )
-        await asyncio.sleep(0.25)
-
-        call_task.cancel()
-
-        await asyncio.sleep(0.25)
-        assert handler_cancelled
-    else:
-        actual_response = await client.call_protobuf_handler(
-            server.id, handle_name, ping_request, dht_pb2.PingResponse
-        )
-        assert actual_response == expected_response
-        assert not handler_cancelled
-
-    await server.shutdown()
-    await server_primary.shutdown()
-    assert not is_process_running(server_pid)
-
-    await client_primary.shutdown()
-    assert not is_process_running(client_pid)
-
-
-@pytest.mark.asyncio
-async def test_call_protobuf_handler_error(handle_name="handle"):
-    async def error_handler(request, context):
-        raise ValueError("boom")
-
-    server = await P2P.create()
-    server_pid = server._child.pid
-    await server.add_protobuf_handler(handle_name, error_handler, dht_pb2.PingRequest)
-    assert is_process_running(server_pid)
-
-    client = await P2P.create(initial_peers=await server.get_visible_maddrs())
-    client_pid = client._child.pid
-    assert is_process_running(client_pid)
-    await client.wait_for_at_least_n_peers(1)
-
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
-
-    with pytest.raises(P2PHandlerError) as excinfo:
-        await client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
-    assert "boom" in str(excinfo.value)
-
-    await server.shutdown()
-    await client.shutdown()
-
-
-async def handle_square_stream(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
-    with closing(writer):
-        while True:
-            try:
-                x = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
-            except asyncio.IncompleteReadError:
-                break
-
-            result = x ** 2
-
-            await P2P.send_raw_data(MSGPackSerializer.dumps(result), writer)
-
-
-async def validate_square_stream(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
-    with closing(writer):
-        for _ in range(10):
-            x = np.random.randint(100)
-
-            await P2P.send_raw_data(MSGPackSerializer.dumps(x), writer)
-            result = MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
-
-            assert result == x ** 2
-
-
-@pytest.mark.asyncio
-async def test_call_peer_single_process():
-    server = await P2P.create()
-    server_pid = server._child.pid
-    assert is_process_running(server_pid)
-
-    handler_name = "square"
-    await server.add_binary_stream_handler(handler_name, handle_square_stream)
-
-    client = await P2P.create(initial_peers=await server.get_visible_maddrs())
-    client_pid = client._child.pid
-    assert is_process_running(client_pid)
-
-    await client.wait_for_at_least_n_peers(1)
-
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
-    await validate_square_stream(reader, writer)
-
-    await server.shutdown()
-    assert not is_process_running(server_pid)
-
-    await client.shutdown()
-    assert not is_process_running(client_pid)
-
-
-async def run_server(handler_name, server_side, response_received):
-    server = await P2P.create()
-    server_pid = server._child.pid
-    assert is_process_running(server_pid)
-
-    await server.add_binary_stream_handler(handler_name, handle_square_stream)
-
-    server_side.send(server.id)
-    server_side.send(await server.get_visible_maddrs())
-    while response_received.value == 0:
-        await asyncio.sleep(0.5)
-
-    await server.shutdown()
-    assert not is_process_running(server_pid)
-
-
-def server_target(handler_name, server_side, response_received):
-    asyncio.run(run_server(handler_name, server_side, response_received))
-
-
-@pytest.mark.asyncio
-async def test_call_peer_different_processes():
-    handler_name = "square"
-
-    server_side, client_side = mp.Pipe()
-    response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
-    response_received.value = 0
-
-    proc = mp.Process(target=server_target, args=(handler_name, server_side, response_received))
-    proc.start()
-
-    peer_id = client_side.recv()
-    peer_maddrs = client_side.recv()
-
-    client = await P2P.create(initial_peers=peer_maddrs)
-    client_pid = client._child.pid
-    assert is_process_running(client_pid)
-
-    await client.wait_for_at_least_n_peers(1)
-
-    _, reader, writer = await client.call_binary_stream_handler(peer_id, handler_name)
-    await validate_square_stream(reader, writer)
-
-    response_received.value = 1
-
-    await client.shutdown()
-    assert not is_process_running(client_pid)
-
-    proc.join()
-    assert proc.exitcode == 0
-
-
-@pytest.mark.asyncio
-async def test_error_closes_connection():
-    async def handle_raising_error(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
-        with closing(writer):
-            command = await P2P.receive_raw_data(reader)
-            if command == b"raise_error":
-                raise Exception("The handler has failed")
-            else:
-                await P2P.send_raw_data(b"okay", writer)
-
-    server = await P2P.create()
-    server_pid = server._child.pid
-    assert is_process_running(server_pid)
-
-    handler_name = "handler"
-    await server.add_binary_stream_handler(handler_name, handle_raising_error)
-
-    client = await P2P.create(initial_peers=await server.get_visible_maddrs())
-    client_pid = client._child.pid
-    assert is_process_running(client_pid)
-
-    await client.wait_for_at_least_n_peers(1)
-
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
-    with closing(writer):
-        await P2P.send_raw_data(b"raise_error", writer)
-        with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
-            await P2P.receive_raw_data(reader)
-
-    # Despite the handler raised an exception, the server did not crash and ready for next requests
-    assert is_process_running(server_pid)
-
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
-    with closing(writer):
-        await P2P.send_raw_data(b"behave_normally", writer)
-        assert await P2P.receive_raw_data(reader) == b"okay"
-
-    await server.shutdown()
-    assert not is_process_running(server_pid)
-
-    await client.shutdown()
-    assert not is_process_running(client_pid)
-
-
-@pytest.mark.asyncio
-async def test_handlers_on_different_replicas():
-    async def handler(_, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, key: str) -> None:
-        with closing(writer):
-            await P2P.send_raw_data(key, writer)
-
-    server_primary = await P2P.create()
-    server_id = server_primary.id
-    await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
-
-    server_replica1 = await replicate_if_needed(server_primary, True)
-    await server_replica1.add_binary_stream_handler("handle1", partial(handler, key=b"replica1"))
-
-    server_replica2 = await replicate_if_needed(server_primary, True)
-    await server_replica2.add_binary_stream_handler("handle2", partial(handler, key=b"replica2"))
-
-    client = await P2P.create(initial_peers=await server_primary.get_visible_maddrs())
-    await client.wait_for_at_least_n_peers(1)
-
-    for name, expected_key in [("handle_primary", b"primary"), ("handle1", b"replica1"), ("handle2", b"replica2")]:
-        _, reader, writer = await client.call_binary_stream_handler(server_id, name)
-        with closing(writer):
-            assert await P2P.receive_raw_data(reader) == expected_key
-
-    await server_replica1.shutdown()
-    await server_replica2.shutdown()
-
-    # Primary does not handle replicas protocols after their shutdown
-
-    for name in ["handle1", "handle2"]:
-        _, reader, writer = await client.call_binary_stream_handler(server_id, name)
-        with pytest.raises(asyncio.IncompleteReadError), closing(writer):
-            await P2P.receive_raw_data(reader)
-
-    await server_primary.shutdown()
-    await client.shutdown()

+ 0 - 583
tests/test_p2p_daemon_bindings.py

@@ -1,583 +0,0 @@
-import asyncio
-import io
-from contextlib import AsyncExitStack
-
-import pytest
-from google.protobuf.message import EncodeError
-from multiaddr import Multiaddr, protocols
-
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, parse_conn_protocol
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
-from hivemind.p2p.p2p_daemon_bindings.utils import (
-    ControlFailure,
-    raise_if_failed,
-    read_pbmsg_safe,
-    read_unsigned_varint,
-    write_pbmsg,
-    write_unsigned_varint,
-)
-from hivemind.proto import p2pd_pb2 as p2pd_pb
-from test_utils.p2p_daemon import make_p2pd_pair_ip4, connect_safe
-
-
-def test_raise_if_failed_raises():
-    resp = p2pd_pb.Response()
-    resp.type = p2pd_pb.Response.ERROR
-    with pytest.raises(ControlFailure):
-        raise_if_failed(resp)
-
-
-def test_raise_if_failed_not_raises():
-    resp = p2pd_pb.Response()
-    resp.type = p2pd_pb.Response.OK
-    raise_if_failed(resp)
-
-
-PAIRS_INT_SERIALIZED_VALID = (
-    (0, b"\x00"),
-    (1, b"\x01"),
-    (128, b"\x80\x01"),
-    (2 ** 32, b"\x80\x80\x80\x80\x10"),
-    (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
-)
-
-PAIRS_INT_SERIALIZED_OVERFLOW = (
-    (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
-    (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
-    (
-        2 ** 128,
-        b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04",
-    ),
-)
-
-PEER_ID_STRING = "QmS5QmciTXXnCUCyxud5eWFenUMAmvAWSDa1c7dvdXRMZ7"
-PEER_ID_BYTES = b'\x12 7\x87F.[\xb5\xb1o\xe5*\xc7\xb9\xbb\x11:"Z|j2\x8ad\x1b\xa6\xe5<Ip\xfe\xb4\xf5v'
-PEER_ID = PeerID(PEER_ID_BYTES)
-MADDR = Multiaddr("/unix/123")
-NUM_P2PDS = 4
-PEER_ID_RANDOM = PeerID.from_base58("QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNK1")
-ENABLE_CONTROL = True
-ENABLE_CONNMGR = False
-ENABLE_DHT = False
-ENABLE_PUBSUB = False
-FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_ip4
-
-
-class MockReader(io.BytesIO):
-    async def readexactly(self, n):
-        await asyncio.sleep(0)
-        return self.read(n)
-
-
-class MockWriter(io.BytesIO):
-    pass
-
-
-class MockReaderWriter(MockReader, MockWriter):
-    pass
-
-
-@pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
-@pytest.mark.asyncio
-async def test_write_unsigned_varint(integer, serialized_integer):
-    s = MockWriter()
-    await write_unsigned_varint(s, integer)
-    assert s.getvalue() == serialized_integer
-
-
-@pytest.mark.parametrize("integer", tuple(i[0] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
-@pytest.mark.asyncio
-async def test_write_unsigned_varint_overflow(integer):
-    s = MockWriter()
-    with pytest.raises(ValueError):
-        await write_unsigned_varint(s, integer)
-
-
-@pytest.mark.parametrize("integer", (-1, -(2 ** 32), -(2 ** 64), -(2 ** 128)))
-@pytest.mark.asyncio
-async def test_write_unsigned_varint_negative(integer):
-    s = MockWriter()
-    with pytest.raises(ValueError):
-        await write_unsigned_varint(s, integer)
-
-
-@pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
-@pytest.mark.asyncio
-async def test_read_unsigned_varint(integer, serialized_integer):
-    s = MockReader(serialized_integer)
-    result = await read_unsigned_varint(s)
-    assert result == integer
-
-
-@pytest.mark.parametrize("serialized_integer", tuple(i[1] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
-@pytest.mark.asyncio
-async def test_read_unsigned_varint_overflow(serialized_integer):
-    s = MockReader(serialized_integer)
-    with pytest.raises(ValueError):
-        await read_unsigned_varint(s)
-
-
-@pytest.mark.parametrize("max_bits", (2, 31, 32, 63, 64, 127, 128))
-@pytest.mark.asyncio
-async def test_read_write_unsigned_varint_max_bits_edge(max_bits):
-    """
-    Test edge cases with different `max_bits`
-    """
-    for i in range(-3, 0):
-        integer = i + (2 ** max_bits)
-        s = MockReaderWriter()
-        await write_unsigned_varint(s, integer, max_bits=max_bits)
-        s.seek(0, 0)
-        result = await read_unsigned_varint(s, max_bits=max_bits)
-        assert integer == result
-
-
-def test_peer_id():
-    assert PEER_ID.to_bytes() == PEER_ID_BYTES
-    assert PEER_ID.to_string() == PEER_ID_STRING
-
-    peer_id_2 = PeerID.from_base58(PEER_ID_STRING)
-    assert peer_id_2.to_bytes() == PEER_ID_BYTES
-    assert peer_id_2.to_string() == PEER_ID_STRING
-    assert PEER_ID == peer_id_2
-    peer_id_3 = PeerID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
-    assert PEER_ID != peer_id_3
-
-
-def test_stream_info():
-    proto = "123"
-    si = StreamInfo(PEER_ID, MADDR, proto)
-    assert si.peer_id == PEER_ID
-    assert si.addr == MADDR
-    assert si.proto == proto
-    pb_si = si.to_protobuf()
-    assert pb_si.peer == PEER_ID.to_bytes()
-    assert pb_si.addr == MADDR.to_bytes()
-    assert pb_si.proto == si.proto
-    si_1 = StreamInfo.from_protobuf(pb_si)
-    assert si_1.peer_id == PEER_ID
-    assert si_1.addr == MADDR
-    assert si_1.proto == proto
-
-
-def test_peer_info():
-    pi = PeerInfo(PEER_ID, [MADDR])
-    assert pi.peer_id == PEER_ID
-    assert pi.addrs == [MADDR]
-    pi_pb = p2pd_pb.PeerInfo(id=PEER_ID.to_bytes(), addrs=[MADDR.to_bytes()])
-    pi_1 = PeerInfo.from_protobuf(pi_pb)
-    assert pi.peer_id == pi_1.peer_id
-    assert pi.addrs == pi_1.addrs
-
-
-@pytest.mark.parametrize(
-    "maddr_str, expected_proto",
-    (("/unix/123", protocols.P_UNIX), ("/ip4/127.0.0.1/tcp/7777", protocols.P_IP4)),
-)
-def test_parse_conn_protocol_valid(maddr_str, expected_proto):
-    assert parse_conn_protocol(Multiaddr(maddr_str)) == expected_proto
-
-
-@pytest.mark.parametrize(
-    "maddr_str",
-    (
-        "/p2p/QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP",
-        "/onion/timaq4ygg2iegci7:1234",
-    ),
-)
-def test_parse_conn_protocol_invalid(maddr_str):
-    maddr = Multiaddr(maddr_str)
-    with pytest.raises(ValueError):
-        parse_conn_protocol(maddr)
-
-
-@pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_client_ctor_control_maddr(control_maddr_str):
-    c = DaemonConnector(Multiaddr(control_maddr_str))
-    assert c.control_maddr == Multiaddr(control_maddr_str)
-
-
-def test_client_ctor_default_control_maddr():
-    c = DaemonConnector()
-    assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
-
-
-@pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_control_client_ctor_listen_maddr(listen_maddr_str):
-    c = ControlClient(daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str))
-    assert c.listen_maddr == Multiaddr(listen_maddr_str)
-
-
-def test_control_client_ctor_default_listen_maddr():
-    c = ControlClient(daemon_connector=DaemonConnector())
-    assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
-
-
-@pytest.mark.parametrize(
-    "msg_bytes",
-    (
-        p2pd_pb.Response(
-            type=p2pd_pb.Response.Type.OK,
-            identify=p2pd_pb.IdentifyResponse(
-                id=PeerID.from_base58("QmT7WhTne9zBLfAgAJt9aiZ8jZ5BxJGowRubxsHYmnyzUd").to_bytes(),
-                addrs=[
-                    Multiaddr("/p2p-circuit").to_bytes(),
-                    Multiaddr("/ip4/127.0.0.1/tcp/51126").to_bytes(),
-                    Multiaddr("/ip4/192.168.10.135/tcp/51126").to_bytes(),
-                    Multiaddr("/ip6/::1/tcp/51127").to_bytes(),
-                ],
-            ),
-        ).SerializeToString(),
-        p2pd_pb.Response(
-            type=p2pd_pb.Response.Type.OK,
-            identify=p2pd_pb.IdentifyResponse(
-                id=PeerID.from_base58("QmcQFt2MFfCZ9AxzUCNrk4k7TtMdZZvAAteaA6tHpBKdrk").to_bytes(),
-                addrs=[
-                    Multiaddr("/p2p-circuit").to_bytes(),
-                    Multiaddr("/ip4/127.0.0.1/tcp/51493").to_bytes(),
-                    Multiaddr("/ip4/192.168.10.135/tcp/51493").to_bytes(),
-                    Multiaddr("/ip6/::1/tcp/51494").to_bytes(),
-                ],
-            ),
-        ).SerializeToString(),
-        p2pd_pb.Response(
-            type=p2pd_pb.Response.Type.OK,
-            identify=p2pd_pb.IdentifyResponse(
-                id=PeerID.from_base58("QmbWqVVoz7v9LS9ZUQAhyyfdFJY3iU8ZrUY3XQozoTA5cc").to_bytes(),
-                addrs=[
-                    Multiaddr("/p2p-circuit").to_bytes(),
-                    Multiaddr("/ip4/127.0.0.1/tcp/51552").to_bytes(),
-                    Multiaddr("/ip4/192.168.10.135/tcp/51552").to_bytes(),
-                    Multiaddr("/ip6/::1/tcp/51553").to_bytes(),
-                ],
-            ),
-        ).SerializeToString(),
-    ),
-    # give test cases ids to prevent bytes from ruining the terminal
-    ids=("pb example Response 0", "pb example Response 1", "pb example Response 2"),
-)
-@pytest.mark.asyncio
-async def test_read_pbmsg_safe_valid(msg_bytes):
-    s = MockReaderWriter()
-    await write_unsigned_varint(s, len(msg_bytes))
-    s.write(msg_bytes)
-    # reset the offset back to the beginning
-    s.seek(0, 0)
-    pb_msg = p2pd_pb.Response()
-    await read_pbmsg_safe(s, pb_msg)
-    assert pb_msg.SerializeToString() == msg_bytes
-
-
-@pytest.mark.parametrize(
-    "pb_type, pb_msg",
-    (
-        (
-            p2pd_pb.Response,
-            p2pd_pb.Response(
-                type=p2pd_pb.Response.Type.OK,
-                dht=p2pd_pb.DHTResponse(
-                    type=p2pd_pb.DHTResponse.Type.VALUE,
-                    peer=p2pd_pb.PeerInfo(
-                        id=PeerID.from_base58("QmNaXUy78W9moQ9APCoKaTtPjLcEJPN9hRBCqErY7o2fQs").to_bytes(),
-                        addrs=[
-                            Multiaddr("/p2p-circuit").to_bytes(),
-                            Multiaddr("/ip4/127.0.0.1/tcp/56929").to_bytes(),
-                            Multiaddr("/ip4/192.168.10.135/tcp/56929").to_bytes(),
-                            Multiaddr("/ip6/::1/tcp/56930").to_bytes(),
-                        ],
-                    ),
-                ),
-            ),
-        ),
-        (p2pd_pb.Request, p2pd_pb.Request(type=p2pd_pb.Request.Type.LIST_PEERS)),
-        (
-            p2pd_pb.DHTRequest,
-            p2pd_pb.DHTRequest(
-                type=p2pd_pb.DHTRequest.Type.FIND_PEER,
-                peer=PeerID.from_base58("QmcgHMuEhqdLHDVeNjiCGU7Ds6E7xK3f4amgiwHNPKKn7R").to_bytes(),
-            ),
-        ),
-        (
-            p2pd_pb.DHTResponse,
-            p2pd_pb.DHTResponse(
-                type=p2pd_pb.DHTResponse.Type.VALUE,
-                peer=p2pd_pb.PeerInfo(
-                    id=PeerID.from_base58("QmWP32GhEyXVQsLXFvV81eadDC8zQRZxZvJK359rXxLquk").to_bytes(),
-                    addrs=[
-                        Multiaddr("/p2p-circuit").to_bytes(),
-                        Multiaddr("/ip4/127.0.0.1/tcp/56897").to_bytes(),
-                        Multiaddr("/ip4/192.168.10.135/tcp/56897").to_bytes(),
-                        Multiaddr("/ip6/::1/tcp/56898").to_bytes(),
-                    ],
-                ),
-            ),
-        ),
-        (
-            p2pd_pb.StreamInfo,
-            p2pd_pb.StreamInfo(
-                peer=PeerID.from_base58("QmewLxB46MftfxQiunRgJo2W8nW4Lh5NLEkRohkHhJ4wW6").to_bytes(),
-                addr=Multiaddr("/ip4/127.0.0.1/tcp/57029").to_bytes(),
-                proto=b"protocol123",
-            ),
-        ),
-    ),
-    ids=(
-        "pb example Response",
-        "pb example Request",
-        "pb example DHTRequest",
-        "pb example DHTResponse",
-        "pb example StreamInfo",
-    ),
-)
-@pytest.mark.asyncio
-async def test_write_pbmsg(pb_type, pb_msg):
-    msg_bytes = bytes(chr(pb_msg.ByteSize()), "utf-8") + pb_msg.SerializeToString()
-    pb_obj = pb_type()
-
-    s_read = MockReaderWriter(msg_bytes)
-    await read_pbmsg_safe(s_read, pb_obj)
-    s_write = MockReaderWriter()
-    await write_pbmsg(s_write, pb_obj)
-    assert msg_bytes == s_write.getvalue()
-
-
-@pytest.mark.parametrize(
-    "pb_msg",
-    (
-        p2pd_pb.Response(),
-        p2pd_pb.Request(),
-        p2pd_pb.DHTRequest(),
-        p2pd_pb.DHTResponse(),
-        p2pd_pb.StreamInfo(),
-    ),
-)
-@pytest.mark.asyncio
-async def test_write_pbmsg_missing_fields(pb_msg):
-    with pytest.raises(EncodeError):
-        await write_pbmsg(MockReaderWriter(), pb_msg)
-
-
-@pytest.fixture
-async def p2pcs():
-    # TODO: Change back to gather style
-    async with AsyncExitStack() as stack:
-        p2pd_tuples = [
-            await stack.enter_async_context(
-                FUNC_MAKE_P2PD_PAIR(
-                    enable_control=ENABLE_CONTROL,
-                    enable_connmgr=ENABLE_CONNMGR,
-                    enable_dht=ENABLE_DHT,
-                    enable_pubsub=ENABLE_PUBSUB,
-                )
-            )
-            for _ in range(NUM_P2PDS)
-        ]
-        yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
-
-
-@pytest.mark.asyncio
-async def test_client_identify_unix_socket(p2pcs):
-    await p2pcs[0].identify()
-
-
-@pytest.mark.asyncio
-async def test_client_identify(p2pcs):
-    await p2pcs[0].identify()
-
-
-@pytest.mark.asyncio
-async def test_client_connect_success(p2pcs):
-    peer_id_0, maddrs_0 = await p2pcs[0].identify()
-    peer_id_1, maddrs_1 = await p2pcs[1].identify()
-    await p2pcs[0].connect(peer_id_1, maddrs_1)
-    # test case: repeated connections
-    await p2pcs[1].connect(peer_id_0, maddrs_0)
-
-
-@pytest.mark.asyncio
-async def test_client_connect_failure(p2pcs):
-    peer_id_1, maddrs_1 = await p2pcs[1].identify()
-    await p2pcs[0].identify()
-    # test case: `peer_id` mismatches
-    with pytest.raises(ControlFailure):
-        await p2pcs[0].connect(PEER_ID_RANDOM, maddrs_1)
-    # test case: empty maddrs
-    with pytest.raises(ControlFailure):
-        await p2pcs[0].connect(peer_id_1, [])
-    # test case: wrong maddrs
-    with pytest.raises(ControlFailure):
-        await p2pcs[0].connect(peer_id_1, [Multiaddr("/ip4/127.0.0.1/udp/0")])
-
-
-@pytest.mark.asyncio
-async def test_connect_safe(p2pcs):
-    await connect_safe(p2pcs[0], p2pcs[1])
-
-
-@pytest.mark.asyncio
-async def test_client_list_peers(p2pcs):
-    # test case: no peers
-    assert len(await p2pcs[0].list_peers()) == 0
-    # test case: 1 peer
-    await connect_safe(p2pcs[0], p2pcs[1])
-    assert len(await p2pcs[0].list_peers()) == 1
-    assert len(await p2pcs[1].list_peers()) == 1
-    # test case: one more peer
-    await connect_safe(p2pcs[0], p2pcs[2])
-    assert len(await p2pcs[0].list_peers()) == 2
-    assert len(await p2pcs[1].list_peers()) == 1
-    assert len(await p2pcs[2].list_peers()) == 1
-
-
-@pytest.mark.asyncio
-async def test_client_disconnect(p2pcs):
-    # test case: disconnect a peer without connections
-    await p2pcs[1].disconnect(PEER_ID_RANDOM)
-    # test case: disconnect
-    peer_id_0, _ = await p2pcs[0].identify()
-    await connect_safe(p2pcs[0], p2pcs[1])
-    assert len(await p2pcs[0].list_peers()) == 1
-    assert len(await p2pcs[1].list_peers()) == 1
-    await p2pcs[1].disconnect(peer_id_0)
-    assert len(await p2pcs[0].list_peers()) == 0
-    assert len(await p2pcs[1].list_peers()) == 0
-    # test case: disconnect twice
-    await p2pcs[1].disconnect(peer_id_0)
-    assert len(await p2pcs[0].list_peers()) == 0
-    assert len(await p2pcs[1].list_peers()) == 0
-
-
-@pytest.mark.asyncio
-async def test_client_stream_open_success(p2pcs):
-    peer_id_1, maddrs_1 = await p2pcs[1].identify()
-    await connect_safe(p2pcs[0], p2pcs[1])
-
-    proto = "123"
-
-    async def handle_proto(stream_info, reader, writer):
-        await reader.readexactly(1)
-
-    await p2pcs[1].stream_handler(proto, handle_proto)
-
-    # test case: normal
-    stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
-    assert stream_info.peer_id == peer_id_1
-    assert stream_info.addr in maddrs_1
-    assert stream_info.proto == "123"
-    writer.close()
-
-    # test case: open with multiple protocols
-    stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto, "another_protocol"))
-    assert stream_info.peer_id == peer_id_1
-    assert stream_info.addr in maddrs_1
-    assert stream_info.proto == "123"
-    writer.close()
-
-
-@pytest.mark.asyncio
-async def test_client_stream_open_failure(p2pcs):
-    peer_id_1, _ = await p2pcs[1].identify()
-    await connect_safe(p2pcs[0], p2pcs[1])
-
-    proto = "123"
-
-    # test case: `stream_open` to a peer who didn't register the protocol
-    with pytest.raises(ControlFailure):
-        await p2pcs[0].stream_open(peer_id_1, (proto,))
-
-    # test case: `stream_open` to a peer for a non-registered protocol
-    async def handle_proto(stream_info, reader, writer):
-        pass
-
-    await p2pcs[1].stream_handler(proto, handle_proto)
-    with pytest.raises(ControlFailure):
-        await p2pcs[0].stream_open(peer_id_1, ("another_protocol",))
-
-
-@pytest.mark.asyncio
-async def test_client_stream_handler_success(p2pcs):
-    peer_id_1, _ = await p2pcs[1].identify()
-    await connect_safe(p2pcs[0], p2pcs[1])
-
-    proto = "protocol123"
-    bytes_to_send = b"yoyoyoyoyog"
-    # event for this test function to wait until the handler function receiving the incoming data
-    event_handler_finished = asyncio.Event()
-
-    async def handle_proto(stream_info, reader, writer):
-        nonlocal event_handler_finished
-        bytes_received = await reader.readexactly(len(bytes_to_send))
-        assert bytes_received == bytes_to_send
-        event_handler_finished.set()
-
-    await p2pcs[1].stream_handler(proto, handle_proto)
-    assert proto in p2pcs[1].control.handlers
-    assert handle_proto == p2pcs[1].control.handlers[proto]
-
-    # test case: test the stream handler `handle_proto`
-
-    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
-
-    # wait until the handler function starts blocking waiting for the data
-    # because we haven't sent the data, we know the handler function must still blocking waiting.
-    # get the task of the protocol handler
-    writer.write(bytes_to_send)
-
-    # wait for the handler to finish
-    writer.close()
-
-    await event_handler_finished.wait()
-
-    # test case: two streams to different handlers respectively
-    another_proto = "another_protocol123"
-    another_bytes_to_send = b"456"
-    event_another_proto = asyncio.Event()
-
-    async def handle_another_proto(stream_info, reader, writer):
-        event_another_proto.set()
-        bytes_received = await reader.readexactly(len(another_bytes_to_send))
-        assert bytes_received == another_bytes_to_send
-
-    await p2pcs[1].stream_handler(another_proto, handle_another_proto)
-    assert another_proto in p2pcs[1].control.handlers
-    assert handle_another_proto == p2pcs[1].control.handlers[another_proto]
-
-    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (another_proto,))
-    await event_another_proto.wait()
-
-    # we know at this moment the handler must still blocking wait
-
-    writer.write(another_bytes_to_send)
-
-    writer.close()
-
-    # test case: registering twice can override the previous registration
-    event_third = asyncio.Event()
-
-    async def handler_third(stream_info, reader, writer):
-        event_third.set()
-
-    await p2pcs[1].stream_handler(another_proto, handler_third)
-    assert another_proto in p2pcs[1].control.handlers
-    # ensure the handler is override
-    assert handler_third == p2pcs[1].control.handlers[another_proto]
-
-    await p2pcs[0].stream_open(peer_id_1, (another_proto,))
-    # ensure the overriding handler is called when the protocol is opened a stream
-    await event_third.wait()
-
-
-@pytest.mark.asyncio
-async def test_client_stream_handler_failure(p2pcs):
-    peer_id_1, _ = await p2pcs[1].identify()
-    await connect_safe(p2pcs[0], p2pcs[1])
-
-    proto = "123"
-
-    # test case: registered a wrong protocol name
-    async def handle_proto_correct_params(stream_info, stream):
-        pass
-
-    await p2pcs[1].stream_handler("another_protocol", handle_proto_correct_params)
-    with pytest.raises(ControlFailure):
-        await p2pcs[0].stream_open(peer_id_1, (proto,))

+ 0 - 148
tests/test_p2p_servicer.py

@@ -1,148 +0,0 @@
-import asyncio
-from typing import AsyncIterator
-
-import pytest
-
-from hivemind.p2p import P2P, P2PContext, ServicerBase
-from hivemind.proto import test_pb2
-
-
-@pytest.fixture
-async def server_client():
-    server = await P2P.create()
-    client = await P2P.create(initial_peers=await server.get_visible_maddrs())
-    yield server, client
-
-    await asyncio.gather(server.shutdown(), client.shutdown())
-
-
-@pytest.mark.asyncio
-async def test_unary_unary(server_client):
-    class ExampleServicer(ServicerBase):
-        async def rpc_square(self, request: test_pb2.TestRequest, _: P2PContext) -> test_pb2.TestResponse:
-            return test_pb2.TestResponse(number=request.number ** 2)
-
-    server, client = server_client
-    servicer = ExampleServicer()
-    await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
-
-    assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
-
-
-@pytest.mark.asyncio
-async def test_stream_unary(server_client):
-    class ExampleServicer(ServicerBase):
-        async def rpc_sum(self, numbers: AsyncIterator[test_pb2.TestRequest], _: P2PContext) -> test_pb2.TestResponse:
-            result = 0
-            async for item in numbers:
-                result += item.number
-            return test_pb2.TestResponse(number=result)
-
-    server, client = server_client
-    servicer = ExampleServicer()
-    await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
-
-    async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
-        for i in range(10):
-            yield test_pb2.TestRequest(number=i)
-
-    assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
-
-
-@pytest.mark.asyncio
-async def test_unary_stream(server_client):
-    class ExampleServicer(ServicerBase):
-        async def rpc_count(
-            self, request: test_pb2.TestRequest, _: P2PContext
-        ) -> AsyncIterator[test_pb2.TestResponse]:
-            for i in range(request.number):
-                yield test_pb2.TestResponse(number=i)
-
-    server, client = server_client
-    servicer = ExampleServicer()
-    await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
-
-    i = 0
-    async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
-        assert item == test_pb2.TestResponse(number=i)
-        i += 1
-    assert i == 10
-
-
-@pytest.mark.asyncio
-async def test_stream_stream(server_client):
-    class ExampleServicer(ServicerBase):
-        async def rpc_powers(
-            self, stream: AsyncIterator[test_pb2.TestRequest], _: P2PContext
-        ) -> AsyncIterator[test_pb2.TestResponse]:
-            async for item in stream:
-                yield test_pb2.TestResponse(number=item.number ** 2)
-                yield test_pb2.TestResponse(number=item.number ** 3)
-
-    server, client = server_client
-    servicer = ExampleServicer()
-    await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
-
-    async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
-        for i in range(10):
-            yield test_pb2.TestRequest(number=i)
-
-    i = 0
-    async for item in stub.rpc_powers(generate_requests()):
-        if i % 2 == 0:
-            assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
-        else:
-            assert item == test_pb2.TestResponse(number=(i // 2) ** 3)
-        i += 1
-
-
-@pytest.mark.parametrize(
-    "cancel_reason",
-    ["close_connection", "close_generator"],
-)
-@pytest.mark.asyncio
-async def test_unary_stream_cancel(server_client, cancel_reason):
-    handler_cancelled = False
-
-    class ExampleServicer(ServicerBase):
-        async def rpc_wait(self, request: test_pb2.TestRequest, _: P2PContext) -> AsyncIterator[test_pb2.TestResponse]:
-            try:
-                yield test_pb2.TestResponse(number=request.number + 1)
-                await asyncio.sleep(2)
-                yield test_pb2.TestResponse(number=request.number + 2)
-            except asyncio.CancelledError:
-                nonlocal handler_cancelled
-                handler_cancelled = True
-                raise
-
-    server, client = server_client
-    servicer = ExampleServicer()
-    await servicer.add_p2p_handlers(server)
-
-    if cancel_reason == "close_connection":
-        _, reader, writer = await client.call_binary_stream_handler(server.id, "ExampleServicer.rpc_wait")
-        await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
-        await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
-
-        response, _ = await P2P.receive_protobuf(test_pb2.TestResponse, reader)
-        assert response == test_pb2.TestResponse(number=11)
-        await asyncio.sleep(0.25)
-
-        writer.close()
-    elif cancel_reason == "close_generator":
-        stub = ExampleServicer.get_stub(client, server.id)
-        iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
-
-        assert await iter.__anext__() == test_pb2.TestResponse(number=11)
-        await asyncio.sleep(0.25)
-
-        await iter.aclose()
-    else:
-        assert False, f"Unknown cancel_reason = `{cancel_reason}`"
-
-    await asyncio.sleep(0.25)
-    assert handler_cancelled

+ 0 - 132
tests/test_routing.py

@@ -1,132 +0,0 @@
-import random
-import heapq
-import operator
-from itertools import chain, zip_longest
-
-from hivemind import LOCALHOST
-from hivemind.dht.routing import RoutingTable, DHTID
-
-
-def test_ids_basic():
-    # basic functionality tests
-    for i in range(100):
-        id1, id2 = DHTID.generate(), DHTID.generate()
-        assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX
-        assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2, id2) == 0
-        assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2)
-        assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes(id2.to_bytes()) == id2
-
-
-def test_ids_depth():
-    for i in range(100):
-        ids = [random.randint(0, 4096) for i in range(random.randint(1, 256))]
-        ours = DHTID.longest_common_prefix_length(*map(DHTID, ids))
-
-        ids_bitstr = ["".join(bin(bite)[2:].rjust(8, "0") for bite in uid.to_bytes(20, "big")) for uid in ids]
-        reference = len(shared_prefix(*ids_bitstr))
-        assert reference == ours, f"ours {ours} != reference {reference}, ids: {ids}"
-
-
-def test_routing_table_basic():
-    node_id = DHTID.generate()
-    routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
-    added_nodes = []
-
-    for phony_neighbor_port in random.sample(range(10000), 100):
-        phony_id = DHTID.generate()
-        routing_table.add_or_update_node(phony_id, f"{LOCALHOST}:{phony_neighbor_port}")
-        assert phony_id in routing_table
-        assert f"{LOCALHOST}:{phony_neighbor_port}" in routing_table
-        assert routing_table[phony_id] == f"{LOCALHOST}:{phony_neighbor_port}"
-        assert routing_table[f"{LOCALHOST}:{phony_neighbor_port}"] == phony_id
-        added_nodes.append(phony_id)
-
-    assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
-    for bucket in routing_table.buckets:
-        assert len(bucket.replacement_nodes) == 0, "There should be no replacement nodes in a table with 100 entries"
-    assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets)
-
-    random_node = random.choice(added_nodes)
-    assert routing_table.get(node_id=random_node) == routing_table[random_node]
-    dummy_node = DHTID.generate()
-    assert (dummy_node not in routing_table) == (routing_table.get(node_id=dummy_node) is None)
-
-    for node in added_nodes:
-        found_bucket_index = routing_table.get_bucket_index(node)
-        for bucket_index, bucket in enumerate(routing_table.buckets):
-            if bucket.lower <= node < bucket.upper:
-                break
-        else:
-            raise ValueError("Naive search could not find bucket. Universe has gone crazy.")
-        assert bucket_index == found_bucket_index
-
-
-def test_routing_table_parameters():
-    for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
-        (20, 5, 45, 65),
-        (50, 5, 35, 45),
-        (20, 10, 650, 800),
-        (20, 1, 7, 15),
-    ]:
-        node_id = DHTID.generate()
-        routing_table = RoutingTable(node_id, bucket_size=bucket_size, depth_modulo=modulo)
-        for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
-            routing_table.add_or_update_node(DHTID.generate(), f"{LOCALHOST}:{phony_neighbor_port}")
-        for bucket in routing_table.buckets:
-            assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_peer_id) <= bucket.size
-        assert (
-            min_nbuckets <= len(routing_table.buckets) <= max_nbuckets
-        ), f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}"
-
-
-def test_routing_table_search():
-    for table_size, lower_active, upper_active in [(10, 10, 10), (10_000, 800, 1100)]:
-        node_id = DHTID.generate()
-        routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
-        num_added = 0
-        total_nodes = 0
-
-        for phony_neighbor_port in random.sample(range(1_000_000), table_size):
-            routing_table.add_or_update_node(DHTID.generate(), f"{LOCALHOST}:{phony_neighbor_port}")
-            new_total = sum(len(bucket.nodes_to_peer_id) for bucket in routing_table.buckets)
-            num_added += new_total > total_nodes
-            total_nodes = new_total
-        num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
-
-        all_active_neighbors = list(chain(*(bucket.nodes_to_peer_id.keys() for bucket in routing_table.buckets)))
-        assert lower_active <= len(all_active_neighbors) <= upper_active
-        assert len(all_active_neighbors) == num_added
-        assert num_added + num_replacements == table_size
-
-        # random queries
-        for i in range(1000):
-            k = random.randint(1, 100)
-            query_id = DHTID.generate()
-            exclude = query_id if random.random() < 0.5 else None
-            our_knn, our_peer_ids = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
-            reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
-            assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
-            assert all(our_peer_id == routing_table[our_node] for our_node, our_peer_id in zip(our_knn, our_peer_ids))
-
-        # queries from table
-        for i in range(1000):
-            k = random.randint(1, 100)
-            query_id = random.choice(all_active_neighbors)
-            our_knn, our_peer_ids = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
-
-            reference_knn = heapq.nsmallest(k + 1, all_active_neighbors, key=query_id.xor_distance)
-            if query_id in reference_knn:
-                reference_knn.remove(query_id)
-            assert len(our_knn) == len(reference_knn)
-            assert all(
-                query_id.xor_distance(our) == query_id.xor_distance(ref)
-                for our, ref in zip_longest(our_knn, reference_knn)
-            )
-            assert routing_table.get_nearest_neighbors(query_id, k=k, exclude=None)[0][0] == query_id
-
-
-def shared_prefix(*strings: str):
-    for i in range(min(map(len, strings))):
-        if len(set(map(operator.itemgetter(i), strings))) != 1:
-            return strings[0][:i]
-    return min(strings, key=len)

+ 0 - 211
tests/test_training.py

@@ -1,211 +0,0 @@
-import time
-from functools import partial
-
-import pytest
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from sklearn.datasets import load_digits
-
-from hivemind import DHT
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import background_server
-from hivemind.optim import DecentralizedSGD, DecentralizedAdam
-
-
-@pytest.mark.forked
-def test_training(max_steps: int = 100, threshold: float = 0.9):
-    dataset = load_digits(n_class=2)
-    X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
-    SGD = partial(torch.optim.SGD, lr=0.05)
-
-    with background_server(num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1, no_dht=True) as (
-        server_endpoint,
-        _,
-    ):
-        expert1 = RemoteExpert("expert.0", server_endpoint)
-        expert2 = RemoteExpert("expert.1", server_endpoint)
-        model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))
-
-        opt = SGD(model.parameters(), lr=0.05)
-
-        for step in range(max_steps):
-            outputs = model(X_train)
-            loss = F.cross_entropy(outputs, y_train)
-            loss.backward()
-            opt.step()
-            opt.zero_grad()
-
-            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
-            if accuracy >= threshold:
-                break
-
-        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
-
-
-@pytest.mark.forked
-def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
-    dataset = load_digits(n_class=2)
-    X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
-    subsample_ix = torch.randint(0, len(X_train), (32,))
-    X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
-    SGD = partial(torch.optim.SGD, lr=0.05)
-
-    all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
-    with background_server(
-        expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
-
-        moe = RemoteMixtureOfExperts(in_features=64, grid_size=(num_experts,), dht=dht, uid_prefix="expert.", k_best=2)
-        model = nn.Sequential(moe, nn.Linear(64, 2))
-
-        opt = SGD(model.parameters(), lr=0.05)
-
-        for step in range(max_steps):
-            outputs = model(X_train)
-            loss = F.cross_entropy(outputs, y_train)
-            loss.backward()
-            opt.step()
-            opt.zero_grad()
-
-            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
-            if accuracy >= threshold:
-                break
-
-        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
-
-
-class SwitchNetwork(nn.Module):
-    def __init__(self, dht, in_features, num_classes, num_experts):
-        super().__init__()
-        self.moe = RemoteSwitchMixtureOfExperts(
-            in_features=in_features,
-            grid_size=(num_experts,),
-            dht=dht,
-            jitter_eps=0,
-            uid_prefix="expert.",
-            k_best=1,
-            k_min=1,
-        )
-        self.linear = nn.Linear(in_features, num_classes)
-
-    def forward(self, x):
-        moe_output, balancing_loss = self.moe(x)
-        return self.linear(moe_output), balancing_loss
-
-
-@pytest.mark.forked
-def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
-    dataset = load_digits(n_class=2)
-    X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
-    subsample_ix = torch.randint(0, len(X_train), (32,))
-    X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
-
-    SGD = partial(torch.optim.SGD, lr=0.05)
-
-    all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
-    with background_server(
-        expert_uids=all_expert_uids, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
-    ) as (server_endpoint, dht_maddrs):
-        dht = DHT(start=True, initial_peers=dht_maddrs)
-
-        model = SwitchNetwork(dht, 64, 2, num_experts)
-        opt = SGD(model.parameters(), lr=0.05)
-
-        for step in range(max_steps):
-            outputs, balancing_loss = model(X_train)
-            loss = F.cross_entropy(outputs, y_train) + 0.01 * balancing_loss
-            loss.backward()
-            opt.step()
-            opt.zero_grad()
-
-            accuracy = (outputs.argmax(dim=1) == y_train).float().mean().item()
-            if accuracy >= threshold:
-                break
-
-        assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
-        assert accuracy >= threshold, f"too small accuracy: {accuracy}"
-
-
-@pytest.mark.forked
-def test_decentralized_optimizer_step():
-    dht_root = DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-
-    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedSGD(
-        [param1],
-        lr=0.1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedSGD(
-        [param2],
-        lr=0.05,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    assert not torch.allclose(param1, param2)
-
-    (param1.sum() + 300 * param2.sum()).backward()
-
-    for i in range(5):
-        time.sleep(0.1)
-        opt1.step()
-        opt2.step()
-        opt1.zero_grad()
-        opt2.zero_grad()
-
-    assert torch.allclose(param1, param2)
-    reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
-    assert torch.allclose(param1, torch.full_like(param1, reference))
-
-
-@pytest.mark.skip(reason="Skipped until finishing a more stable averager implementation (TODO @justheuristic)")
-@pytest.mark.forked
-def test_decentralized_optimizer_averaging():
-    dht_root = DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-
-    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedAdam(
-        [param1],
-        lr=0.1,
-        averaging_steps_period=1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedAdam(
-        [param2],
-        lr=0.05,
-        averaging_steps_period=1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    assert not torch.allclose(param1, param2, atol=1e-3, rtol=0)
-    (param1.sum() + param2.sum()).backward()
-
-    for _ in range(100):
-        time.sleep(0.1)
-        opt1.step()
-        opt2.step()
-        opt1.zero_grad()
-        opt2.zero_grad()
-
-    assert torch.allclose(param1, param2, atol=1e-3, rtol=0)
-    assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"], atol=1e-3, rtol=0)