validation.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import dataclasses
  2. from abc import ABC, abstractmethod
  3. from typing import Iterable
  4. @dataclasses.dataclass(init=True, repr=True, frozen=True)
  5. class DHTRecord:
  6. key: bytes
  7. subkey: bytes
  8. value: bytes
  9. expiration_time: float
  10. class RecordValidatorBase(ABC):
  11. """
  12. Record validators are a generic mechanism for checking the DHT records including:
  13. - Enforcing a data schema (e.g. checking content types)
  14. - Enforcing security requirements (e.g. allowing only the owner to update the record)
  15. """
  16. @abstractmethod
  17. def validate(self, record: DHTRecord) -> bool:
  18. """
  19. Should return whether the `record` is valid.
  20. The valid records should have been extended with sign_value().
  21. validate() is called when another DHT peer:
  22. - Asks us to store the record
  23. - Returns the record by our request
  24. """
  25. pass
  26. def sign_value(self, record: DHTRecord) -> bytes:
  27. """
  28. Should return `record.value` extended with the record's signature.
  29. Note: there's no need to overwrite this method if a validator doesn't use a signature.
  30. sign_value() is called after the application asks the DHT to store the record.
  31. """
  32. return record.value
  33. def strip_value(self, record: DHTRecord) -> bytes:
  34. """
  35. Should return `record.value` stripped of the record's signature.
  36. strip_value() is only called if validate() was successful.
  37. Note: there's no need to overwrite this method if a validator doesn't use a signature.
  38. strip_value() is called before the DHT returns the record by the application's request.
  39. """
  40. return record.value
  41. @property
  42. def priority(self) -> int:
  43. """
  44. Defines the order of applying this validator with respect to other validators.
  45. The validators are applied:
  46. - In order of increasing priority for signing a record
  47. - In order of decreasing priority for validating and stripping a record
  48. """
  49. return 0
  50. def merge_with(self, other: "RecordValidatorBase") -> bool:
  51. """
  52. By default, all validators are applied sequentially (i.e. we require all validate() calls
  53. to return True for a record to be validated successfully).
  54. However, you may want to define another policy for combining your validator classes
  55. (e.g. for schema validators, we want to require only one validate() call to return True
  56. because each validator bears a part of the schema).
  57. This can be achieved with overriding merge_with(). It should:
  58. - Return True if it has successfully merged the `other` validator to `self`,
  59. so that `self` became a validator that combines the old `self` and `other` using
  60. the necessary policy. In this case, `other` should remain unchanged.
  61. - Return False if the merging has not happened. In this case, both `self` and `other`
  62. should remain unchanged. The DHT will try merging `other` to another validator or
  63. add it as a separate validator (to be applied sequentially).
  64. """
  65. return False
  66. class CompositeValidator(RecordValidatorBase):
  67. def __init__(self, validators: Iterable[RecordValidatorBase] = ()):
  68. self._validators = []
  69. self.extend(validators)
  70. def extend(self, validators: Iterable[RecordValidatorBase]) -> None:
  71. for new_validator in validators:
  72. for existing_validator in self._validators:
  73. if existing_validator.merge_with(new_validator):
  74. break
  75. else:
  76. self._validators.append(new_validator)
  77. self._validators.sort(key=lambda item: item.priority)
  78. def validate(self, record: DHTRecord) -> bool:
  79. for i, validator in enumerate(reversed(self._validators)):
  80. if not validator.validate(record):
  81. return False
  82. if i < len(self._validators) - 1:
  83. record = dataclasses.replace(record, value=validator.strip_value(record))
  84. return True
  85. def sign_value(self, record: DHTRecord) -> bytes:
  86. for validator in self._validators:
  87. record = dataclasses.replace(record, value=validator.sign_value(record))
  88. return record.value
  89. def strip_value(self, record: DHTRecord) -> bytes:
  90. for validator in reversed(self._validators):
  91. record = dataclasses.replace(record, value=validator.strip_value(record))
  92. return record.value