test_dht_validation.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import dataclasses
  2. from functools import partial
  3. from typing import Dict
  4. import pytest
  5. from pydantic import BaseModel, StrictInt
  6. import hivemind
  7. from hivemind.dht.crypto import RSASignatureValidator
  8. from hivemind.dht.protocol import DHTProtocol
  9. from hivemind.dht.routing import DHTID
  10. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  11. from hivemind.dht.validation import DHTRecord, CompositeValidator, RecordValidatorBase
  12. class SchemaA(BaseModel):
  13. field_a: bytes
  14. class SchemaB(BaseModel):
  15. field_b: Dict[BytesWithPublicKey, StrictInt]
  16. @pytest.fixture
  17. def validators_for_app():
  18. # Each application may add its own validator set
  19. return {
  20. 'A': [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)],
  21. 'B': [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()],
  22. }
  23. def test_composite_validator(validators_for_app):
  24. validator = CompositeValidator(validators_for_app['A'])
  25. assert ([type(item) for item in validator._validators] ==
  26. [SchemaValidator, RSASignatureValidator])
  27. validator.extend(validators_for_app['B'])
  28. assert ([type(item) for item in validator._validators] ==
  29. [SchemaValidator, RSASignatureValidator])
  30. assert len(validator._validators[0]._schemas) == 2
  31. local_public_key = validators_for_app['A'][0].local_public_key
  32. record = DHTRecord(key=DHTID.generate(source='field_b').to_bytes(),
  33. subkey=DHTProtocol.serializer.dumps(local_public_key),
  34. value=DHTProtocol.serializer.dumps(777),
  35. expiration_time=hivemind.get_dht_time() + 10)
  36. signed_record = dataclasses.replace(record, value=validator.sign_value(record))
  37. # Expect only one signature since two RSASignatureValidatos have been merged
  38. assert signed_record.value.count(b'[signature:') == 1
  39. # Expect successful validation since the second SchemaValidator has been merged to the first
  40. assert validator.validate(signed_record)
  41. assert validator.strip_value(signed_record) == record.value
  42. record = DHTRecord(key=DHTID.generate(source='unknown_key').to_bytes(),
  43. subkey=DHTProtocol.IS_REGULAR_VALUE,
  44. value=DHTProtocol.serializer.dumps(777),
  45. expiration_time=hivemind.get_dht_time() + 10)
  46. signed_record = dataclasses.replace(record, value=validator.sign_value(record))
  47. assert signed_record.value.count(b'[signature:') == 0
  48. # Expect failed validation since `unknown_key` is not a part of any schema
  49. assert not validator.validate(signed_record)
  50. @pytest.mark.forked
  51. def test_dht_add_validators(validators_for_app):
  52. # One app may create a DHT with its validators
  53. dht = hivemind.DHT(start=False, record_validators=validators_for_app['A'])
  54. # While the DHT process is not started, you can't send a command to append new validators
  55. with pytest.raises(RuntimeError):
  56. dht.add_validators(validators_for_app['B'])
  57. dht.run_in_background(await_ready=True)
  58. # After starting the process, other apps may add new validators to the existing DHT
  59. dht.add_validators(validators_for_app['B'])
  60. assert dht.store('field_a', b'bytes_value', hivemind.get_dht_time() + 10)
  61. assert dht.get('field_a', latest=True).value == b'bytes_value'
  62. assert not dht.store('field_a', 666, hivemind.get_dht_time() + 10)
  63. assert dht.get('field_a', latest=True).value == b'bytes_value'
  64. local_public_key = validators_for_app['A'][0].local_public_key
  65. assert dht.store('field_b', 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
  66. dictionary = dht.get('field_b', latest=True).value
  67. assert (len(dictionary) == 1 and
  68. dictionary[local_public_key].value == 777)
  69. assert not dht.store('unknown_key', 666, hivemind.get_dht_time() + 10)
  70. assert dht.get('unknown_key', latest=True) is None