|
@@ -2,22 +2,24 @@ import re
|
|
|
|
|
|
import pytest
|
|
import pytest
|
|
from pydantic import BaseModel, StrictFloat, StrictInt, conint
|
|
from pydantic import BaseModel, StrictFloat, StrictInt, conint
|
|
-from typing import Dict, List
|
|
|
|
|
|
+from typing import Dict
|
|
|
|
|
|
|
|
+import hivemind
|
|
from hivemind.dht import get_dht_time
|
|
from hivemind.dht import get_dht_time
|
|
from hivemind.dht.node import DHTNode, LOCALHOST
|
|
from hivemind.dht.node import DHTNode, LOCALHOST
|
|
-from hivemind.dht.schema import SchemaValidator, conbytes
|
|
|
|
|
|
+from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator, conbytes
|
|
from hivemind.dht.validation import DHTRecord, RecordValidatorBase
|
|
from hivemind.dht.validation import DHTRecord, RecordValidatorBase
|
|
|
|
|
|
|
|
|
|
|
|
+class SampleSchema(BaseModel):
|
|
|
|
+ experiment_name: bytes
|
|
|
|
+ n_batches: Dict[bytes, conint(ge=0, strict=True)]
|
|
|
|
+ signed_data: Dict[BytesWithPublicKey, bytes]
|
|
|
|
+
|
|
|
|
+
|
|
@pytest.fixture
|
|
@pytest.fixture
|
|
async def dht_nodes_with_schema():
|
|
async def dht_nodes_with_schema():
|
|
- class Schema(BaseModel):
|
|
|
|
- experiment_name: bytes
|
|
|
|
- n_batches: Dict[bytes, conint(ge=0, strict=True)]
|
|
|
|
- signed_data: Dict[conbytes(regex=rb'.*\[owner:.+\]'), bytes]
|
|
|
|
-
|
|
|
|
- validator = SchemaValidator(Schema)
|
|
|
|
|
|
+ validator = SchemaValidator(SampleSchema)
|
|
|
|
|
|
alice = await DHTNode.create(record_validator=validator)
|
|
alice = await DHTNode.create(record_validator=validator)
|
|
bob = await DHTNode.create(
|
|
bob = await DHTNode.create(
|
|
@@ -31,17 +33,17 @@ async def test_expecting_regular_value(dht_nodes_with_schema):
|
|
alice, bob = dht_nodes_with_schema
|
|
alice, bob = dht_nodes_with_schema
|
|
|
|
|
|
# Regular value (bytes) expected
|
|
# Regular value (bytes) expected
|
|
- assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
|
|
|
|
- assert not await bob.store(b'experiment_name', 666, get_dht_time() + 10)
|
|
|
|
- assert not await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10,
|
|
|
|
|
|
+ assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
|
|
|
|
+ assert not await bob.store('experiment_name', 666, get_dht_time() + 10)
|
|
|
|
+ assert not await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10,
|
|
subkey=b'subkey')
|
|
subkey=b'subkey')
|
|
|
|
|
|
# Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
|
|
# Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
|
|
- assert not await bob.store(b'experiment_name', [], get_dht_time() + 10)
|
|
|
|
- assert not await bob.store(b'experiment_name', [1, 2, 3], get_dht_time() + 10)
|
|
|
|
|
|
+ assert not await bob.store('experiment_name', [], get_dht_time() + 10)
|
|
|
|
+ assert not await bob.store('experiment_name', [1, 2, 3], get_dht_time() + 10)
|
|
|
|
|
|
for peer in [alice, bob]:
|
|
for peer in [alice, bob]:
|
|
- assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
|
|
|
|
|
|
+ assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
|
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
@pytest.mark.forked
|
|
@@ -50,27 +52,27 @@ async def test_expecting_dictionary(dht_nodes_with_schema):
|
|
alice, bob = dht_nodes_with_schema
|
|
alice, bob = dht_nodes_with_schema
|
|
|
|
|
|
# Dictionary (bytes -> non-negative int) expected
|
|
# Dictionary (bytes -> non-negative int) expected
|
|
- assert await bob.store(b'n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
|
|
|
|
- assert await bob.store(b'n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
|
|
|
|
- assert not await bob.store(b'n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
|
|
|
|
- assert not await bob.store(b'n_batches', 666, get_dht_time() + 10)
|
|
|
|
- assert not await bob.store(b'n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
|
|
|
|
- assert not await bob.store(b'n_batches', 666, get_dht_time() + 10, subkey=666)
|
|
|
|
|
|
+ assert await bob.store('n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
|
|
|
|
+ assert await bob.store('n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
|
|
|
|
+ assert not await bob.store('n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
|
|
|
|
+ assert not await bob.store('n_batches', 666, get_dht_time() + 10)
|
|
|
|
+ assert not await bob.store('n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
|
|
|
|
+ assert not await bob.store('n_batches', 666, get_dht_time() + 10, subkey=666)
|
|
|
|
|
|
# Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
|
|
# Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
|
|
- assert not await bob.store(b'n_batches', {b'uid3': 779}, get_dht_time() + 10)
|
|
|
|
|
|
+ assert not await bob.store('n_batches', {b'uid3': 779}, get_dht_time() + 10)
|
|
|
|
|
|
# Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
|
|
# Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
|
|
- assert not await bob.store(b'n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
|
|
|
|
- assert not await bob.store(b'n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
|
|
|
|
- assert not await bob.store(b'n_batches', [], get_dht_time() + 10)
|
|
|
|
- assert not await bob.store(b'n_batches', [(b'uid3', 779)], get_dht_time() + 10)
|
|
|
|
|
|
+ assert not await bob.store('n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
|
|
|
|
+ assert not await bob.store('n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
|
|
|
|
+ assert not await bob.store('n_batches', [], get_dht_time() + 10)
|
|
|
|
+ assert not await bob.store('n_batches', [(b'uid3', 779)], get_dht_time() + 10)
|
|
|
|
|
|
# Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
|
|
# Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
|
|
- assert not await bob.store(b'n_batches', '', get_dht_time() + 10)
|
|
|
|
|
|
+ assert not await bob.store('n_batches', '', get_dht_time() + 10)
|
|
|
|
|
|
for peer in [alice, bob]:
|
|
for peer in [alice, bob]:
|
|
- dictionary = (await peer.get(b'n_batches', latest=True)).value
|
|
|
|
|
|
+ dictionary = (await peer.get('n_batches', latest=True)).value
|
|
assert (len(dictionary) == 2 and
|
|
assert (len(dictionary) == 2 and
|
|
dictionary[b'uid1'].value == 777 and
|
|
dictionary[b'uid1'].value == 777 and
|
|
dictionary[b'uid2'].value == 778)
|
|
dictionary[b'uid2'].value == 778)
|
|
@@ -83,13 +85,13 @@ async def test_expecting_public_keys(dht_nodes_with_schema):
|
|
|
|
|
|
# Subkeys expected to contain a public key
|
|
# Subkeys expected to contain a public key
|
|
# (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
|
|
# (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
|
|
- assert await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
|
|
|
|
|
|
+ assert await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
|
|
subkey=b'uid[owner:public-key]')
|
|
subkey=b'uid[owner:public-key]')
|
|
- assert not await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
|
|
|
|
|
|
+ assert not await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
|
|
subkey=b'uid-without-public-key')
|
|
subkey=b'uid-without-public-key')
|
|
|
|
|
|
for peer in [alice, bob]:
|
|
for peer in [alice, bob]:
|
|
- dictionary = (await peer.get(b'signed_data', latest=True)).value
|
|
|
|
|
|
+ dictionary = (await peer.get('signed_data', latest=True)).value
|
|
assert (len(dictionary) == 1 and
|
|
assert (len(dictionary) == 1 and
|
|
dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
|
|
dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
|
|
|
|
|
|
@@ -111,17 +113,38 @@ async def test_keys_outside_schema(dht_nodes_with_schema):
|
|
bob = await DHTNode.create(
|
|
bob = await DHTNode.create(
|
|
record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
|
|
record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
|
|
|
|
|
|
- store_ok = await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
|
|
|
|
|
|
+ store_ok = await bob.store('unknown_key', b'foo_bar', get_dht_time() + 10)
|
|
assert store_ok == allow_extra_keys
|
|
assert store_ok == allow_extra_keys
|
|
|
|
|
|
for peer in [alice, bob]:
|
|
for peer in [alice, bob]:
|
|
- result = await peer.get(b'unknown_key', latest=True)
|
|
|
|
|
|
+ result = await peer.get('unknown_key', latest=True)
|
|
if allow_extra_keys:
|
|
if allow_extra_keys:
|
|
assert result.value == b'foo_bar'
|
|
assert result.value == b'foo_bar'
|
|
else:
|
|
else:
|
|
assert result is None
|
|
assert result is None
|
|
|
|
|
|
|
|
|
|
|
|
+@pytest.mark.forked
|
|
|
|
+@pytest.mark.asyncio
|
|
|
|
+async def test_prefix():
|
|
|
|
+ class Schema(BaseModel):
|
|
|
|
+ field: StrictInt
|
|
|
|
+
|
|
|
|
+ validator = SchemaValidator(Schema, allow_extra_keys=False, prefix='prefix')
|
|
|
|
+
|
|
|
|
+ alice = await DHTNode.create(record_validator=validator)
|
|
|
|
+ bob = await DHTNode.create(
|
|
|
|
+ record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
|
|
|
|
+
|
|
|
|
+ assert await bob.store('prefix_field', 777, get_dht_time() + 10)
|
|
|
|
+ assert not await bob.store('prefix_field', 'string_value', get_dht_time() + 10)
|
|
|
|
+ assert not await bob.store('field', 777, get_dht_time() + 10)
|
|
|
|
+
|
|
|
|
+ for peer in [alice, bob]:
|
|
|
|
+ assert (await peer.get('prefix_field', latest=True)).value == 777
|
|
|
|
+ assert (await peer.get('field', latest=True)) is None
|
|
|
|
+
|
|
|
|
+
|
|
@pytest.mark.forked
|
|
@pytest.mark.forked
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.asyncio
|
|
async def test_merging_schema_validators(dht_nodes_with_schema):
|
|
async def test_merging_schema_validators(dht_nodes_with_schema):
|
|
@@ -147,18 +170,31 @@ async def test_merging_schema_validators(dht_nodes_with_schema):
|
|
for peer in [alice, bob]:
|
|
for peer in [alice, bob]:
|
|
assert peer.protocol.record_validator.merge_with(new_validator)
|
|
assert peer.protocol.record_validator.merge_with(new_validator)
|
|
|
|
|
|
- assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
|
|
|
|
- assert await bob.store(b'some_field', 777, get_dht_time() + 10)
|
|
|
|
- assert not await bob.store(b'some_field', 'string_value', get_dht_time() + 10)
|
|
|
|
- assert await bob.store(b'another_field', 42, get_dht_time() + 10)
|
|
|
|
- assert await bob.store(b'another_field', 'string_value', get_dht_time() + 10)
|
|
|
|
|
|
+ assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
|
|
|
|
+ assert await bob.store('some_field', 777, get_dht_time() + 10)
|
|
|
|
+ assert not await bob.store('some_field', 'string_value', get_dht_time() + 10)
|
|
|
|
+ assert await bob.store('another_field', 42, get_dht_time() + 10)
|
|
|
|
+ assert await bob.store('another_field', 'string_value', get_dht_time() + 10)
|
|
|
|
|
|
- # Unkown keys are allowed since the first schema is created with allow_extra_keys=True
|
|
|
|
- assert await bob.store(b'unknown_key', 999, get_dht_time() + 10)
|
|
|
|
|
|
+ # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
|
|
|
|
+ assert await bob.store('unknown_key', 999, get_dht_time() + 10)
|
|
|
|
|
|
for peer in [alice, bob]:
|
|
for peer in [alice, bob]:
|
|
- assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
|
|
|
|
- assert (await peer.get(b'some_field', latest=True)).value == 777
|
|
|
|
- assert (await peer.get(b'another_field', latest=True)).value == 'string_value'
|
|
|
|
|
|
+ assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
|
|
|
|
+ assert (await peer.get('some_field', latest=True)).value == 777
|
|
|
|
+ assert (await peer.get('another_field', latest=True)).value == 'string_value'
|
|
|
|
+
|
|
|
|
+ assert (await peer.get('unknown_key', latest=True)).value == 999
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@pytest.mark.forked
|
|
|
|
+def test_sending_validator_instance_between_processes():
|
|
|
|
+ alice = hivemind.DHT(start=True)
|
|
|
|
+ bob = hivemind.DHT(start=True, initial_peers=[f"{LOCALHOST}:{alice.port}"])
|
|
|
|
+
|
|
|
|
+ alice.add_validators([SchemaValidator(SampleSchema)])
|
|
|
|
+ bob.add_validators([SchemaValidator(SampleSchema)])
|
|
|
|
|
|
- assert (await peer.get(b'unknown_key', latest=True)).value == 999
|
|
|
|
|
|
+ assert bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
|
|
|
|
+ assert not bob.store('experiment_name', 777, get_dht_time() + 10)
|
|
|
|
+ assert alice.get('experiment_name', latest=True).value == b'foo_bar'
|