|
@@ -1,6 +1,6 @@
|
|
|
import binascii
|
|
|
import re
|
|
|
-from typing import Type
|
|
|
+from typing import Any, Dict, Type
|
|
|
|
|
|
import pydantic
|
|
|
|
|
@@ -19,7 +19,7 @@ class SchemaValidator(RecordValidatorBase):
|
|
|
This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, schema: pydantic.BaseModel):
|
|
|
+ def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool=True):
|
|
|
"""
|
|
|
:param schema: The Pydantic model (a subclass of pydantic.BaseModel).
|
|
|
|
|
@@ -27,18 +27,25 @@ class SchemaValidator(RecordValidatorBase):
|
|
|
(e.g. ``StrictInt`` instead of ``int``,
|
|
|
``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
|
|
|
See the validate() docstring for details.
|
|
|
+
|
|
|
+ :param allow_extra_keys: Whether to allow keys that are not defined in the schema.
|
|
|
+
|
|
|
+ If a SchemaValidator is merged with another SchemaValidator, this option applies to
|
|
|
+ keys that are not defined in each of the schemas.
|
|
|
"""
|
|
|
|
|
|
self._alias_to_name = {}
|
|
|
+
|
|
|
for field in schema.__fields__.values():
|
|
|
field.alias = self._key_id_to_str(DHTID.generate(source=field.name.encode()).to_bytes())
|
|
|
self._alias_to_name[field.alias] = field.name
|
|
|
|
|
|
# Because validate() interface provides one key at a time
|
|
|
field.required = False
|
|
|
+ schema.Config.extra = pydantic.Extra.forbid
|
|
|
|
|
|
- schema.Config.extra = pydantic.Extra.allow
|
|
|
- self._schema = schema
|
|
|
+ self._schemas = [schema]
|
|
|
+ self._allow_extra_keys = allow_extra_keys
|
|
|
|
|
|
def validate(self, record: DHTRecord) -> bool:
|
|
|
"""
|
|
@@ -62,34 +69,58 @@ class SchemaValidator(RecordValidatorBase):
|
|
|
.. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
|
|
|
"""
|
|
|
|
|
|
- key_alias = self._key_id_to_str(record.key)
|
|
|
+ try:
|
|
|
+ record = self._deserialize_record(record)
|
|
|
+ except ValueError as e:
|
|
|
+ logger.warning(e)
|
|
|
+ return False
|
|
|
+ [key_alias] = list(record.keys())
|
|
|
+
|
|
|
+ n_outside_schema = 0
|
|
|
+ validation_errors = []
|
|
|
+ for schema in self._schemas:
|
|
|
+ try:
|
|
|
+ parsed_record = schema.parse_obj(record)
|
|
|
+ except pydantic.ValidationError as e:
|
|
|
+ if self._is_failed_due_to_extra_field(e):
|
|
|
+ n_outside_schema += 1
|
|
|
+ else:
|
|
|
+ validation_errors.append(e)
|
|
|
+ continue
|
|
|
+
|
|
|
+ parsed_value = parsed_record.dict(by_alias=True)[key_alias]
|
|
|
+ if parsed_value != record[key_alias]:
|
|
|
+ validation_errors.append(ValueError(
|
|
|
+ f"Value {record[key_alias]} needed type conversions to match "
|
|
|
+ f"the schema: {parsed_value}. Type conversions are not allowed"))
|
|
|
+ else:
|
|
|
+ return True
|
|
|
+
|
|
|
+ readable_record = {self._alias_to_name.get(key_alias, key_alias): record[key_alias]}
|
|
|
+
|
|
|
+ if n_outside_schema == len(self._schemas):
|
|
|
+ if not self._allow_extra_keys:
|
|
|
+ logger.warning(f"Record {readable_record} contains a field that "
|
|
|
+ f"is not defined in each of the schemas")
|
|
|
+ return self._allow_extra_keys
|
|
|
+
|
|
|
+ logger.warning(
|
|
|
+ f"Record {readable_record} doesn't match any of the schemas: {validation_errors}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _deserialize_record(record: DHTRecord) -> Dict[str, Any]:
|
|
|
+ key_alias = SchemaValidator._key_id_to_str(record.key)
|
|
|
deserialized_value = DHTProtocol.serializer.loads(record.value)
|
|
|
if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
|
|
|
deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
|
|
|
- deserialized_record = {key_alias: {deserialized_subkey: deserialized_value}}
|
|
|
+ return {key_alias: {deserialized_subkey: deserialized_value}}
|
|
|
else:
|
|
|
if isinstance(deserialized_value, dict):
|
|
|
- logger.warning(
|
|
|
+ raise ValueError(
|
|
|
f'Record {record} contains an improperly serialized dictionary (you must use '
|
|
|
f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
|
|
|
- return False
|
|
|
- deserialized_record = {key_alias: deserialized_value}
|
|
|
-
|
|
|
- try:
|
|
|
- parsed_record = self._schema.parse_obj(deserialized_record)
|
|
|
- except pydantic.ValidationError as e:
|
|
|
- readable_record = {self._alias_to_name.get(key_alias, key_alias):
|
|
|
- deserialized_record[key_alias]}
|
|
|
- logger.warning(f"Record {readable_record} doesn't match the schema: {e}")
|
|
|
- return False
|
|
|
-
|
|
|
- parsed_value = parsed_record.dict(by_alias=True)[key_alias]
|
|
|
- if parsed_value != deserialized_record[key_alias]:
|
|
|
- logger.warning(
|
|
|
- f"Value {deserialized_record[key_alias]} needed type conversions to match "
|
|
|
- f" the schema: {parsed_value}. Type conversions are not allowed")
|
|
|
- return False
|
|
|
- return True
|
|
|
+ return {key_alias: deserialized_value}
|
|
|
|
|
|
@staticmethod
|
|
|
def _key_id_to_str(key_id: bytes) -> str:
|
|
@@ -100,6 +131,24 @@ class SchemaValidator(RecordValidatorBase):
|
|
|
|
|
|
return binascii.hexlify(key_id).decode()
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
|
|
|
+ inner_errors = exc.errors()
|
|
|
+ return (
|
|
|
+ len(inner_errors) == 1 and
|
|
|
+ inner_errors[0]['type'] == 'value_error.extra' and
|
|
|
+ len(inner_errors[0]['loc']) == 1 # Require the extra field to be on the top level
|
|
|
+ )
|
|
|
+
|
|
|
+ def merge_with(self, other: RecordValidatorBase) -> bool:
|
|
|
+ if not isinstance(other, SchemaValidator):
|
|
|
+ return False
|
|
|
+
|
|
|
+ self._alias_to_name.update(other._alias_to_name)
|
|
|
+ self._schemas.extend(other._schemas)
|
|
|
+ self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
|
|
|
+ return True
|
|
|
+
|
|
|
|
|
|
def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
|
|
|
"""
|