schema.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import re
  2. from typing import Any, Dict, Optional, Type
  3. import pydantic
  4. from hivemind.dht.crypto import RSASignatureValidator
  5. from hivemind.dht.protocol import DHTProtocol
  6. from hivemind.dht.routing import DHTID
  7. from hivemind.dht.validation import DHTRecord, RecordValidatorBase
  8. from hivemind.utils import get_logger
  9. logger = get_logger(__name__)
  10. class SchemaValidator(RecordValidatorBase):
  11. """
  12. Restricts specified DHT keys to match a Pydantic schema.
  13. This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
  14. """
  15. def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None):
  16. """
  17. :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
  18. You must always use strict types for the number fields
  19. (e.g. ``StrictInt`` instead of ``int``,
  20. ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
  21. See the validate() docstring for details.
  22. The model will be patched to adjust it for the schema validation.
  23. :param allow_extra_keys: Whether to allow keys that are not defined in the schema.
  24. If a SchemaValidator is merged with another SchemaValidator, this option applies to
  25. keys that are not defined in each of the schemas.
  26. :param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields.
  27. """
  28. self._patch_schema(schema)
  29. self._schemas = [schema]
  30. self._key_id_to_field_name = {}
  31. for field in schema.__fields__.values():
  32. raw_key = f"{prefix}_{field.name}" if prefix is not None else field.name
  33. self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name
  34. self._allow_extra_keys = allow_extra_keys
  35. @staticmethod
  36. def _patch_schema(schema: pydantic.BaseModel):
  37. # We set required=False because the validate() interface provides only one key at a time
  38. for field in schema.__fields__.values():
  39. field.required = False
  40. schema.Config.extra = pydantic.Extra.forbid
  41. def validate(self, record: DHTRecord) -> bool:
  42. """
  43. Validates ``record`` in two steps:
  44. 1. Create a Pydantic model and ensure that no exceptions are thrown.
  45. 2. Ensure that Pydantic has not made any type conversions [1]_ while creating the model.
  46. To do this, we check that the value of the model field is equal
  47. (in terms of == operator) to the source value.
  48. This works for the iterable default types like str, list, and dict
  49. (they are equal only if the types match) but does not work for numbers
  50. (they have a special case allowing ``3.0 == 3`` to be true). [2]_
  51. Because of that, you must always use strict types [3]_ for the number fields
  52. (e.g. to avoid ``3.0`` to be validated successfully for the ``field: int``).
  53. .. [1] https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
  54. .. [2] https://stackoverflow.com/a/52557261
  55. .. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
  56. """
  57. if record.key not in self._key_id_to_field_name:
  58. if not self._allow_extra_keys:
  59. logger.debug(
  60. f"Record {record} has a key ID that is not defined in any of the "
  61. f"schemas (therefore, the raw key is unknown)"
  62. )
  63. return self._allow_extra_keys
  64. try:
  65. record = self._deserialize_record(record)
  66. except ValueError as e:
  67. logger.debug(e)
  68. return False
  69. [field_name] = list(record.keys())
  70. n_outside_schema = 0
  71. validation_errors = []
  72. for schema in self._schemas:
  73. try:
  74. parsed_record = schema.parse_obj(record)
  75. except pydantic.ValidationError as e:
  76. if not self._is_failed_due_to_extra_field(e):
  77. validation_errors.append(e)
  78. continue
  79. parsed_value = parsed_record.dict(by_alias=True)[field_name]
  80. if parsed_value != record[field_name]:
  81. validation_errors.append(
  82. ValueError(
  83. f"The record {record} needed type conversions to match "
  84. f"the schema: {parsed_value}. Type conversions are not allowed"
  85. )
  86. )
  87. else:
  88. return True
  89. logger.debug(f"Record {record} doesn't match any of the schemas: {validation_errors}")
  90. return False
  91. def _deserialize_record(self, record: DHTRecord) -> Dict[str, Any]:
  92. field_name = self._key_id_to_field_name[record.key]
  93. deserialized_value = DHTProtocol.serializer.loads(record.value)
  94. if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
  95. deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
  96. return {field_name: {deserialized_subkey: deserialized_value}}
  97. else:
  98. if isinstance(deserialized_value, dict):
  99. raise ValueError(
  100. f"Record {record} contains an improperly serialized dictionary (you must use "
  101. f"a DictionaryDHTValue of serialized values instead of a `dict` subclass)"
  102. )
  103. return {field_name: deserialized_value}
  104. @staticmethod
  105. def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
  106. inner_errors = exc.errors()
  107. return (
  108. len(inner_errors) == 1
  109. and inner_errors[0]["type"] == "value_error.extra"
  110. and len(inner_errors[0]["loc"]) == 1 # Require the extra field to be on the top level
  111. )
  112. def merge_with(self, other: RecordValidatorBase) -> bool:
  113. if not isinstance(other, SchemaValidator):
  114. return False
  115. self._schemas.extend(other._schemas)
  116. self._key_id_to_field_name.update(other._key_id_to_field_name)
  117. self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
  118. return True
  119. def __setstate__(self, state):
  120. self.__dict__.update(state)
  121. # If unpickling happens in another process, the previous model modifications may be lost
  122. for schema in self._schemas:
  123. self._patch_schema(schema)
  124. def conbytes(*, regex: bytes = None, **kwargs) -> Type[pydantic.BaseModel]:
  125. """
  126. Extend pydantic.conbytes() to support ``regex`` constraints (like pydantic.constr() does).
  127. """
  128. compiled_regex = re.compile(regex) if regex is not None else None
  129. class ConstrainedBytesWithRegex(pydantic.conbytes(**kwargs)):
  130. @classmethod
  131. def __get_validators__(cls):
  132. yield from super().__get_validators__()
  133. yield cls.match_regex
  134. @classmethod
  135. def match_regex(cls, value: bytes) -> bytes:
  136. if compiled_regex is not None and compiled_regex.match(value) is None:
  137. raise ValueError(f"Value `{value}` doesn't match regex `{regex}`")
  138. return value
  139. return ConstrainedBytesWithRegex
  140. BytesWithPublicKey = conbytes(regex=b".*" + RSASignatureValidator.PUBLIC_KEY_REGEX + b".*")