1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import dataclasses
- from typing import Dict
- import pytest
- from pydantic import BaseModel, StrictInt
- import hivemind
- from hivemind.dht.crypto import RSASignatureValidator
- from hivemind.dht.protocol import DHTProtocol
- from hivemind.dht.routing import DHTID
- from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
- from hivemind.dht.validation import DHTRecord, CompositeValidator
- class SchemaA(BaseModel):
- field_a: bytes
- class SchemaB(BaseModel):
- field_b: Dict[BytesWithPublicKey, StrictInt]
- @pytest.fixture
- def validators_for_app():
- # Each application may add its own validator set
- return {
- 'A': [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)],
- 'B': [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()],
- }
- def test_composite_validator(validators_for_app):
- validator = CompositeValidator(validators_for_app['A'])
- assert ([type(item) for item in validator._validators] ==
- [SchemaValidator, RSASignatureValidator])
- validator.extend(validators_for_app['B'])
- assert ([type(item) for item in validator._validators] ==
- [SchemaValidator, RSASignatureValidator])
- assert len(validator._validators[0]._schemas) == 2
- local_public_key = validators_for_app['A'][0].local_public_key
- record = DHTRecord(key=DHTID.generate(source='field_b').to_bytes(),
- subkey=DHTProtocol.serializer.dumps(local_public_key),
- value=DHTProtocol.serializer.dumps(777),
- expiration_time=hivemind.get_dht_time() + 10)
- signed_record = dataclasses.replace(record, value=validator.sign_value(record))
- # Expect only one signature since two RSASignatureValidatos have been merged
- assert signed_record.value.count(b'[signature:') == 1
- # Expect successful validation since the second SchemaValidator has been merged to the first
- assert validator.validate(signed_record)
- assert validator.strip_value(signed_record) == record.value
- record = DHTRecord(key=DHTID.generate(source='unknown_key').to_bytes(),
- subkey=DHTProtocol.IS_REGULAR_VALUE,
- value=DHTProtocol.serializer.dumps(777),
- expiration_time=hivemind.get_dht_time() + 10)
- signed_record = dataclasses.replace(record, value=validator.sign_value(record))
- assert signed_record.value.count(b'[signature:') == 0
- # Expect failed validation since `unknown_key` is not a part of any schema
- assert not validator.validate(signed_record)
- @pytest.mark.forked
- def test_dht_add_validators(validators_for_app):
- # One app may create a DHT with its validators
- dht = hivemind.DHT(start=False, record_validators=validators_for_app['A'])
- # While the DHT process is not started, you can't send a command to append new validators
- with pytest.raises(RuntimeError):
- dht.add_validators(validators_for_app['B'])
- dht.run_in_background(await_ready=True)
- # After starting the process, other apps may add new validators to the existing DHT
- dht.add_validators(validators_for_app['B'])
- assert dht.store('field_a', b'bytes_value', hivemind.get_dht_time() + 10)
- assert dht.get('field_a', latest=True).value == b'bytes_value'
- assert not dht.store('field_a', 666, hivemind.get_dht_time() + 10)
- assert dht.get('field_a', latest=True).value == b'bytes_value'
- local_public_key = validators_for_app['A'][0].local_public_key
- assert dht.store('field_b', 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
- dictionary = dht.get('field_b', latest=True).value
- assert (len(dictionary) == 1 and
- dictionary[local_public_key].value == 777)
- assert not dht.store('unknown_key', 666, hivemind.get_dht_time() + 10)
- assert dht.get('unknown_key', latest=True) is None
|