test_dht_validation.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import dataclasses
  2. from typing import Dict
  3. import pytest
  4. from pydantic import BaseModel, StrictInt
  5. import hivemind
  6. from hivemind.dht.crypto import RSASignatureValidator
  7. from hivemind.dht.protocol import DHTProtocol
  8. from hivemind.dht.routing import DHTID
  9. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  10. from hivemind.dht.validation import DHTRecord, CompositeValidator
  11. class SchemaA(BaseModel):
  12. field_a: bytes
  13. class SchemaB(BaseModel):
  14. field_b: Dict[BytesWithPublicKey, StrictInt]
  15. @pytest.fixture
  16. def validators_for_app():
  17. # Each application may add its own validator set
  18. return {
  19. "A": [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)],
  20. "B": [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()],
  21. }
  22. def test_composite_validator(validators_for_app):
  23. validator = CompositeValidator(validators_for_app["A"])
  24. assert [type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]
  25. validator.extend(validators_for_app["B"])
  26. assert [type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]
  27. assert len(validator._validators[0]._schemas) == 2
  28. local_public_key = validators_for_app["A"][0].local_public_key
  29. record = DHTRecord(
  30. key=DHTID.generate(source="field_b").to_bytes(),
  31. subkey=DHTProtocol.serializer.dumps(local_public_key),
  32. value=DHTProtocol.serializer.dumps(777),
  33. expiration_time=hivemind.get_dht_time() + 10,
  34. )
  35. signed_record = dataclasses.replace(record, value=validator.sign_value(record))
  36. # Expect only one signature since two RSASignatureValidatos have been merged
  37. assert signed_record.value.count(b"[signature:") == 1
  38. # Expect successful validation since the second SchemaValidator has been merged to the first
  39. assert validator.validate(signed_record)
  40. assert validator.strip_value(signed_record) == record.value
  41. record = DHTRecord(
  42. key=DHTID.generate(source="unknown_key").to_bytes(),
  43. subkey=DHTProtocol.IS_REGULAR_VALUE,
  44. value=DHTProtocol.serializer.dumps(777),
  45. expiration_time=hivemind.get_dht_time() + 10,
  46. )
  47. signed_record = dataclasses.replace(record, value=validator.sign_value(record))
  48. assert signed_record.value.count(b"[signature:") == 0
  49. # Expect failed validation since `unknown_key` is not a part of any schema
  50. assert not validator.validate(signed_record)
  51. @pytest.mark.forked
  52. def test_dht_add_validators(validators_for_app):
  53. # One app may create a DHT with its validators
  54. dht = hivemind.DHT(start=False, record_validators=validators_for_app["A"])
  55. # While the DHT process is not started, you can't send a command to append new validators
  56. with pytest.raises(RuntimeError):
  57. dht.add_validators(validators_for_app["B"])
  58. dht.run_in_background(await_ready=True)
  59. # After starting the process, other apps may add new validators to the existing DHT
  60. dht.add_validators(validators_for_app["B"])
  61. assert dht.store("field_a", b"bytes_value", hivemind.get_dht_time() + 10)
  62. assert dht.get("field_a", latest=True).value == b"bytes_value"
  63. assert not dht.store("field_a", 666, hivemind.get_dht_time() + 10)
  64. assert dht.get("field_a", latest=True).value == b"bytes_value"
  65. local_public_key = validators_for_app["A"][0].local_public_key
  66. assert dht.store("field_b", 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
  67. dictionary = dht.get("field_b", latest=True).value
  68. assert len(dictionary) == 1 and dictionary[local_public_key].value == 777
  69. assert not dht.store("unknown_key", 666, hivemind.get_dht_time() + 10)
  70. assert dht.get("unknown_key", latest=True) is None