123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- import asyncio
- from typing import Dict
- import pytest
- from pydantic import BaseModel, StrictInt, conint
- import hivemind
- from hivemind.dht.node import DHTNode
- from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
- from hivemind.dht.validation import DHTRecord, RecordValidatorBase
- from hivemind.utils.timed_storage import get_dht_time
- class SampleSchema(BaseModel):
- experiment_name: bytes
- n_batches: Dict[bytes, conint(ge=0, strict=True)]
- signed_data: Dict[BytesWithPublicKey, bytes]
- @pytest.fixture
- async def dht_nodes_with_schema():
- validator = SchemaValidator(SampleSchema)
- alice = await DHTNode.create(record_validator=validator)
- bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
- yield alice, bob
- await asyncio.gather(alice.shutdown(), bob.shutdown())
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_expecting_regular_value(dht_nodes_with_schema):
- alice, bob = dht_nodes_with_schema
- # Regular value (bytes) expected
- assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
- assert not await bob.store("experiment_name", 666, get_dht_time() + 10)
- assert not await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10, subkey=b"subkey")
- # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
- assert not await bob.store("experiment_name", [], get_dht_time() + 10)
- assert not await bob.store("experiment_name", [1, 2, 3], get_dht_time() + 10)
- for peer in [alice, bob]:
- assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_expecting_dictionary(dht_nodes_with_schema):
- alice, bob = dht_nodes_with_schema
- # Dictionary (bytes -> non-negative int) expected
- assert await bob.store("n_batches", 777, get_dht_time() + 10, subkey=b"uid1")
- assert await bob.store("n_batches", 778, get_dht_time() + 10, subkey=b"uid2")
- assert not await bob.store("n_batches", -666, get_dht_time() + 10, subkey=b"uid3")
- assert not await bob.store("n_batches", 666, get_dht_time() + 10)
- assert not await bob.store("n_batches", b"not_integer", get_dht_time() + 10, subkey=b"uid1")
- assert not await bob.store("n_batches", 666, get_dht_time() + 10, subkey=666)
- # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
- assert not await bob.store("n_batches", {b"uid3": 779}, get_dht_time() + 10)
- # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
- assert not await bob.store("n_batches", 779.5, get_dht_time() + 10, subkey=b"uid3")
- assert not await bob.store("n_batches", 779.0, get_dht_time() + 10, subkey=b"uid3")
- assert not await bob.store("n_batches", [], get_dht_time() + 10)
- assert not await bob.store("n_batches", [(b"uid3", 779)], get_dht_time() + 10)
- # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
- assert not await bob.store("n_batches", "", get_dht_time() + 10)
- for peer in [alice, bob]:
- dictionary = (await peer.get("n_batches", latest=True)).value
- assert len(dictionary) == 2 and dictionary[b"uid1"].value == 777 and dictionary[b"uid2"].value == 778
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_expecting_public_keys(dht_nodes_with_schema):
- alice, bob = dht_nodes_with_schema
- # Subkeys expected to contain a public key
- # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
- assert await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid[owner:public-key]")
- assert not await bob.store("signed_data", b"foo_bar", get_dht_time() + 10, subkey=b"uid-without-public-key")
- for peer in [alice, bob]:
- dictionary = (await peer.get("signed_data", latest=True)).value
- assert len(dictionary) == 1 and dictionary[b"uid[owner:public-key]"].value == b"foo_bar"
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_keys_outside_schema(dht_nodes_with_schema):
- class Schema(BaseModel):
- some_field: StrictInt
- class MergedSchema(BaseModel):
- another_field: StrictInt
- for allow_extra_keys in [False, True]:
- validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
- assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))
- alice = await DHTNode.create(record_validator=validator)
- bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
- store_ok = await bob.store("unknown_key", b"foo_bar", get_dht_time() + 10)
- assert store_ok == allow_extra_keys
- for peer in [alice, bob]:
- result = await peer.get("unknown_key", latest=True)
- if allow_extra_keys:
- assert result.value == b"foo_bar"
- else:
- assert result is None
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_prefix():
- class Schema(BaseModel):
- field: StrictInt
- validator = SchemaValidator(Schema, allow_extra_keys=False, prefix="prefix")
- alice = await DHTNode.create(record_validator=validator)
- bob = await DHTNode.create(record_validator=validator, initial_peers=await alice.get_visible_maddrs())
- assert await bob.store("prefix_field", 777, get_dht_time() + 10)
- assert not await bob.store("prefix_field", "string_value", get_dht_time() + 10)
- assert not await bob.store("field", 777, get_dht_time() + 10)
- for peer in [alice, bob]:
- assert (await peer.get("prefix_field", latest=True)).value == 777
- assert (await peer.get("field", latest=True)) is None
- await asyncio.gather(alice.shutdown(), bob.shutdown())
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_merging_schema_validators(dht_nodes_with_schema):
- alice, bob = dht_nodes_with_schema
- class TrivialValidator(RecordValidatorBase):
- def validate(self, record: DHTRecord) -> bool:
- return True
- second_validator = TrivialValidator()
- # Can't merge with the validator of the different type
- assert not alice.protocol.record_validator.merge_with(second_validator)
- class SecondSchema(BaseModel):
- some_field: StrictInt
- another_field: str
- class ThirdSchema(BaseModel):
- another_field: StrictInt # Allow it to be a StrictInt as well
- for schema in [SecondSchema, ThirdSchema]:
- new_validator = SchemaValidator(schema, allow_extra_keys=False)
- for peer in [alice, bob]:
- assert peer.protocol.record_validator.merge_with(new_validator)
- assert await bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
- assert await bob.store("some_field", 777, get_dht_time() + 10)
- assert not await bob.store("some_field", "string_value", get_dht_time() + 10)
- assert await bob.store("another_field", 42, get_dht_time() + 10)
- assert await bob.store("another_field", "string_value", get_dht_time() + 10)
- # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
- assert await bob.store("unknown_key", 999, get_dht_time() + 10)
- for peer in [alice, bob]:
- assert (await peer.get("experiment_name", latest=True)).value == b"foo_bar"
- assert (await peer.get("some_field", latest=True)).value == 777
- assert (await peer.get("another_field", latest=True)).value == "string_value"
- assert (await peer.get("unknown_key", latest=True)).value == 999
- @pytest.mark.forked
- def test_sending_validator_instance_between_processes():
- alice = hivemind.DHT(start=True)
- bob = hivemind.DHT(start=True, initial_peers=alice.get_visible_maddrs())
- alice.add_validators([SchemaValidator(SampleSchema)])
- bob.add_validators([SchemaValidator(SampleSchema)])
- assert bob.store("experiment_name", b"foo_bar", get_dht_time() + 10)
- assert not bob.store("experiment_name", 777, get_dht_time() + 10)
- assert alice.get("experiment_name", latest=True).value == b"foo_bar"
- alice.shutdown()
- bob.shutdown()
|