schema.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import re
  2. from typing import Any, Dict, Optional, Type
  3. import pydantic
  4. from hivemind.dht.crypto import RSASignatureValidator
  5. from hivemind.dht.protocol import DHTProtocol
  6. from hivemind.dht.routing import DHTID
  7. from hivemind.dht.validation import DHTRecord, RecordValidatorBase
  8. from hivemind.utils import get_logger
  9. logger = get_logger(__name__)
  10. class SchemaValidator(RecordValidatorBase):
  11. """
  12. Restricts specified DHT keys to match a Pydantic schema.
  13. This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
  14. """
  15. def __init__(self, schema: pydantic.BaseModel, *,
  16. allow_extra_keys: bool=True, prefix: Optional[str]=None):
  17. """
  18. :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
  19. You must always use strict types for the number fields
  20. (e.g. ``StrictInt`` instead of ``int``,
  21. ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
  22. See the validate() docstring for details.
  23. The model will be patched to adjust it for the schema validation.
  24. :param allow_extra_keys: Whether to allow keys that are not defined in the schema.
  25. If a SchemaValidator is merged with another SchemaValidator, this option applies to
  26. keys that are not defined in each of the schemas.
  27. :param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields.
  28. """
  29. self._patch_schema(schema)
  30. self._schemas = [schema]
  31. self._key_id_to_field_name = {}
  32. for field in schema.__fields__.values():
  33. raw_key = f'{prefix}_{field.name}' if prefix is not None else field.name
  34. self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name
  35. self._allow_extra_keys = allow_extra_keys
  36. @staticmethod
  37. def _patch_schema(schema: pydantic.BaseModel):
  38. # We set required=False because the validate() interface provides only one key at a time
  39. for field in schema.__fields__.values():
  40. field.required = False
  41. schema.Config.extra = pydantic.Extra.forbid
  42. def validate(self, record: DHTRecord) -> bool:
  43. """
  44. Validates ``record`` in two steps:
  45. 1. Create a Pydantic model and ensure that no exceptions are thrown.
  46. 2. Ensure that Pydantic has not made any type conversions [1]_ while creating the model.
  47. To do this, we check that the value of the model field is equal
  48. (in terms of == operator) to the source value.
  49. This works for the iterable default types like str, list, and dict
  50. (they are equal only if the types match) but does not work for numbers
  51. (they have a special case allowing ``3.0 == 3`` to be true). [2]_
  52. Because of that, you must always use strict types [3]_ for the number fields
  53. (e.g. to avoid ``3.0`` to be validated successfully for the ``field: int``).
  54. .. [1] https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  55. .. [2] https://stackoverflow.com/a/52557261
  56. .. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
  57. """
  58. if record.key not in self._key_id_to_field_name:
  59. if not self._allow_extra_keys:
  60. logger.debug(f"Record {record} has a key ID that is not defined in any of the "
  61. f"schemas (therefore, the raw key is unknown)")
  62. return self._allow_extra_keys
  63. try:
  64. record = self._deserialize_record(record)
  65. except ValueError as e:
  66. logger.debug(e)
  67. return False
  68. [field_name] = list(record.keys())
  69. n_outside_schema = 0
  70. validation_errors = []
  71. for schema in self._schemas:
  72. try:
  73. parsed_record = schema.parse_obj(record)
  74. except pydantic.ValidationError as e:
  75. if not self._is_failed_due_to_extra_field(e):
  76. validation_errors.append(e)
  77. continue
  78. parsed_value = parsed_record.dict(by_alias=True)[field_name]
  79. if parsed_value != record[field_name]:
  80. validation_errors.append(ValueError(
  81. f"The record {record} needed type conversions to match "
  82. f"the schema: {parsed_value}. Type conversions are not allowed"))
  83. else:
  84. return True
  85. logger.debug(f"Record {record} doesn't match any of the schemas: {validation_errors}")
  86. return False
  87. def _deserialize_record(self, record: DHTRecord) -> Dict[str, Any]:
  88. field_name = self._key_id_to_field_name[record.key]
  89. deserialized_value = DHTProtocol.serializer.loads(record.value)
  90. if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
  91. deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
  92. return {field_name: {deserialized_subkey: deserialized_value}}
  93. else:
  94. if isinstance(deserialized_value, dict):
  95. raise ValueError(
  96. f'Record {record} contains an improperly serialized dictionary (you must use '
  97. f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
  98. return {field_name: deserialized_value}
  99. @staticmethod
  100. def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
  101. inner_errors = exc.errors()
  102. return (
  103. len(inner_errors) == 1 and
  104. inner_errors[0]['type'] == 'value_error.extra' and
  105. len(inner_errors[0]['loc']) == 1 # Require the extra field to be on the top level
  106. )
  107. def merge_with(self, other: RecordValidatorBase) -> bool:
  108. if not isinstance(other, SchemaValidator):
  109. return False
  110. self._schemas.extend(other._schemas)
  111. self._key_id_to_field_name.update(other._key_id_to_field_name)
  112. self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
  113. return True
  114. def __setstate__(self, state):
  115. self.__dict__.update(state)
  116. # If unpickling happens in another process, the previous model modifications may be lost
  117. for schema in self._schemas:
  118. self._patch_schema(schema)
  119. def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
  120. """
  121. Extend pydantic.conbytes() to support ``regex`` constraints (like pydantic.constr() does).
  122. """
  123. compiled_regex = re.compile(regex) if regex is not None else None
  124. class ConstrainedBytesWithRegex(pydantic.conbytes(**kwargs)):
  125. @classmethod
  126. def __get_validators__(cls):
  127. yield from super().__get_validators__()
  128. yield cls.match_regex
  129. @classmethod
  130. def match_regex(cls, value: bytes) -> bytes:
  131. if compiled_regex is not None and compiled_regex.match(value) is None:
  132. raise ValueError(f"Value `{value}` doesn't match regex `{regex}`")
  133. return value
  134. return ConstrainedBytesWithRegex
  135. BytesWithPublicKey = conbytes(regex=b'.*' + RSASignatureValidator.PUBLIC_KEY_REGEX + b'.*')