소스 검색

Implement combining validators (#249)

This PR implements a method to combine validators. Each application may add its own list of validators. Then, validators are executed with respect to their priorities and custom combination policies.
Aleksandr Borzunov 4 년 전
부모
커밋
18add2c04b
7개의 변경된 파일357개의 추가작업 그리고 44개의 파일을 삭제
  1. 1 1
      hivemind/__init__.py
  2. 20 3
      hivemind/dht/__init__.py
  3. 27 0
      hivemind/dht/crypto.py
  4. 74 25
      hivemind/dht/schema.py
  5. 68 0
      hivemind/dht/validation.py
  6. 74 15
      tests/test_dht_schema.py
  7. 93 0
      tests/test_dht_validation.py

+ 1 - 1
hivemind/__init__.py

@@ -4,4 +4,4 @@ from hivemind.server import *
 from hivemind.utils import *
 from hivemind.optim import *
 
-__version__ = '0.9.7'
+__version__ = '0.9.8'

+ 20 - 3
hivemind/dht/__init__.py

@@ -17,12 +17,14 @@ import asyncio
 import ctypes
 import multiprocessing as mp
 from concurrent.futures import ThreadPoolExecutor
-from typing import List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
+from functools import partial
+from typing import Iterable, List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
 
 import hivemind
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import DHTValue, DHTKey, Subkey
+from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.utils.networking import Hostname, Endpoint, strip_port
 from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, get_dht_time
 
@@ -49,12 +51,14 @@ class DHT(mp.Process):
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 expiration: float = 300, **kwargs):
+                 expiration: float = 300, record_validators: Iterable[RecordValidatorBase] = (),
+                 **kwargs):
         super().__init__()
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
         self.default_expiration = expiration
+        self._record_validator = CompositeValidator(record_validators)
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
@@ -70,7 +74,8 @@ class DHT(mp.Process):
         async def _run():
             node = await DHTNode.create(
                 initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
-                num_workers=self.max_workers or 1, **self.kwargs)
+                num_workers=self.max_workers or 1, record_validator=self._record_validator,
+                **self.kwargs)
             if node.port is not None:
                 self._port.value = node.port
             self.ready.set()
@@ -190,6 +195,18 @@ class DHT(mp.Process):
             if not future.done():
                 future.set_exception(e)
 
+    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
+        if not self.ready.is_set():
+            raise RuntimeError(
+                "Can't append new validators before the DHT process has started. "
+                "Consider adding them to the initial list via DHT.__init__(record_validators=...)")
+
+        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
+
+    async def _add_validators(
+            self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
+        node.protocol.record_validator.extend(record_validators)
+
     def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
         """
         Get this machine's visible address by requesting other peers or using pre-specified network addresses.

+ 27 - 0
hivemind/dht/crypto.py

@@ -86,3 +86,30 @@ class RSASignatureValidator(RecordValidatorBase):
 
     def _serialize_record(self, record: DHTRecord) -> bytes:
         return MSGPackSerializer.dumps(dataclasses.astuple(record))
+
+    @property
+    def priority(self) -> int:
+        # On validation, this validator must be executed before validators
+        # that deserialize the record
+        return 10
+
+    def merge_with(self, other: RecordValidatorBase) -> bool:
+        if not isinstance(other, RSASignatureValidator):
+            return False
+
+        # Ignore another RSASignatureValidator instance (it doesn't make sense to have several
+        # instances of this class) and report successful merge
+        return True
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        # Serializes the private key to make the class instances picklable
+        state['_private_key'] = self._private_key.private_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PrivateFormat.OpenSSH,
+            encryption_algorithm=serialization.NoEncryption())
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)

+ 74 - 25
hivemind/dht/schema.py

@@ -1,6 +1,6 @@
 import binascii
 import re
-from typing import Type
+from typing import Any, Dict, Type
 
 import pydantic
 
@@ -19,7 +19,7 @@ class SchemaValidator(RecordValidatorBase):
     This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
     """
 
-    def __init__(self, schema: pydantic.BaseModel):
+    def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool=True):
         """
         :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
 
@@ -27,18 +27,25 @@ class SchemaValidator(RecordValidatorBase):
             (e.g. ``StrictInt`` instead of ``int``,
             ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
             See the validate() docstring for details.
+
+        :param allow_extra_keys: Whether to allow keys that are not defined in the schema.
+
+            If a SchemaValidator is merged with another SchemaValidator, this option applies to
+            keys that are not defined in each of the schemas.
         """
 
         self._alias_to_name = {}
+
         for field in schema.__fields__.values():
             field.alias = self._key_id_to_str(DHTID.generate(source=field.name.encode()).to_bytes())
             self._alias_to_name[field.alias] = field.name
 
             # Because validate() interface provides one key at a time
             field.required = False
+        schema.Config.extra = pydantic.Extra.forbid
 
-        schema.Config.extra = pydantic.Extra.allow
-        self._schema = schema
+        self._schemas = [schema]
+        self._allow_extra_keys = allow_extra_keys
 
     def validate(self, record: DHTRecord) -> bool:
         """
@@ -62,34 +69,58 @@ class SchemaValidator(RecordValidatorBase):
            .. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
         """
 
-        key_alias = self._key_id_to_str(record.key)
+        try:
+            record = self._deserialize_record(record)
+        except ValueError as e:
+            logger.warning(e)
+            return False
+        [key_alias] = list(record.keys())
+
+        n_outside_schema = 0
+        validation_errors = []
+        for schema in self._schemas:
+            try:
+                parsed_record = schema.parse_obj(record)
+            except pydantic.ValidationError as e:
+                if self._is_failed_due_to_extra_field(e):
+                    n_outside_schema += 1
+                else:
+                    validation_errors.append(e)
+                continue
+
+            parsed_value = parsed_record.dict(by_alias=True)[key_alias]
+            if parsed_value != record[key_alias]:
+                validation_errors.append(ValueError(
+                    f"Value {record[key_alias]} needed type conversions to match "
+                    f"the schema: {parsed_value}. Type conversions are not allowed"))
+            else:
+                return True
+
+        readable_record = {self._alias_to_name.get(key_alias, key_alias): record[key_alias]}
+
+        if n_outside_schema == len(self._schemas):
+            if not self._allow_extra_keys:
+                logger.warning(f"Record {readable_record} contains a field that "
+                               f"is not defined in each of the schemas")
+            return self._allow_extra_keys
+
+        logger.warning(
+            f"Record {readable_record} doesn't match any of the schemas: {validation_errors}")
+        return False
+
+    @staticmethod
+    def _deserialize_record(record: DHTRecord) -> Dict[str, Any]:
+        key_alias = SchemaValidator._key_id_to_str(record.key)
         deserialized_value = DHTProtocol.serializer.loads(record.value)
         if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
             deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
-            deserialized_record = {key_alias: {deserialized_subkey: deserialized_value}}
+            return {key_alias: {deserialized_subkey: deserialized_value}}
         else:
             if isinstance(deserialized_value, dict):
-                logger.warning(
+                raise ValueError(
                     f'Record {record} contains an improperly serialized dictionary (you must use '
                     f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
-                return False
-            deserialized_record = {key_alias: deserialized_value}
-
-        try:
-            parsed_record = self._schema.parse_obj(deserialized_record)
-        except pydantic.ValidationError as e:
-            readable_record = {self._alias_to_name.get(key_alias, key_alias):
-                               deserialized_record[key_alias]}
-            logger.warning(f"Record {readable_record} doesn't match the schema: {e}")
-            return False
-
-        parsed_value = parsed_record.dict(by_alias=True)[key_alias]
-        if parsed_value != deserialized_record[key_alias]:
-            logger.warning(
-                f"Value {deserialized_record[key_alias]} needed type conversions to match "
-                f" the schema: {parsed_value}. Type conversions are not allowed")
-            return False
-        return True
+            return {key_alias: deserialized_value}
 
     @staticmethod
     def _key_id_to_str(key_id: bytes) -> str:
@@ -100,6 +131,24 @@ class SchemaValidator(RecordValidatorBase):
 
         return binascii.hexlify(key_id).decode()
 
+    @staticmethod
+    def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
+        inner_errors = exc.errors()
+        return (
+            len(inner_errors) == 1 and
+            inner_errors[0]['type'] == 'value_error.extra' and
+            len(inner_errors[0]['loc']) == 1  # Require the extra field to be on the top level
+        )
+
+    def merge_with(self, other: RecordValidatorBase) -> bool:
+        if not isinstance(other, SchemaValidator):
+            return False
+
+        self._alias_to_name.update(other._alias_to_name)
+        self._schemas.extend(other._schemas)
+        self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
+        return True
+
 
 def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
     """

+ 68 - 0
hivemind/dht/validation.py

@@ -1,5 +1,6 @@
 import dataclasses
 from abc import ABC, abstractmethod
+from typing import Iterable
 
 
 @dataclasses.dataclass(init=True, repr=True, frozen=True)
@@ -52,3 +53,70 @@ class RecordValidatorBase(ABC):
         """
 
         return record.value
+
+    @property
+    def priority(self) -> int:
+        """
+        Defines the order of applying this validator with respect to other validators.
+
+        The validators are applied:
+          - In order of increasing priority for signing a record
+          - In order of decreasing priority for validating and stripping a record
+        """
+
+        return 0
+
+    def merge_with(self, other: 'RecordValidatorBase') -> bool:
+        """
+        By default, all validators are applied sequentially (i.e. we require all validate() calls
+        to return True for a record to be validated successfully).
+
+        However, you may want to define another policy for combining your validator classes
+        (e.g. for schema validators, we want to require only one validate() call to return True
+        because each validator bears a part of the schema).
+
+        This can be achieved with overriding merge_with(). It should:
+
+          - Return True if it has successfully merged the `other` validator to `self`,
+            so that `self` became a validator that combines the old `self` and `other` using
+            the necessary policy. In this case, `other` should remain unchanged.
+
+          - Return False if the merging has not happened. In this case, both `self` and `other`
+            should remain unchanged. The DHT will try merging `other` to another validator or
+            add it as a separate validator (to be applied sequentially).
+        """
+
+        return False
+
+
+class CompositeValidator(RecordValidatorBase):
+    def __init__(self, validators: Iterable[RecordValidatorBase]=()):
+        self._validators = []
+        self.extend(validators)
+
+    def extend(self, validators: Iterable[RecordValidatorBase]) -> None:
+        for new_validator in validators:
+            for existing_validator in self._validators:
+                if existing_validator.merge_with(new_validator):
+                    break
+            else:
+                self._validators.append(new_validator)
+        self._validators.sort(key=lambda item: item.priority)
+
+    def validate(self, record: DHTRecord) -> bool:
+        for i, validator in enumerate(reversed(self._validators)):
+            if not validator.validate(record):
+                return False
+            if i < len(self._validators) - 1:
+                record = dataclasses.replace(record, value=validator.strip_value(record))
+        return True
+
+    def sign_value(self, record: DHTRecord) -> bytes:
+        for validator in self._validators:
+            record = dataclasses.replace(record, value=validator.sign_value(record))
+        return record.value
+
+    def strip_value(self, record: DHTRecord) -> bytes:
+        for validator in reversed(self._validators):
+            record = dataclasses.replace(record, value=validator.strip_value(record))
+        return record.value

+ 74 - 15
tests/test_dht_schema.py

@@ -1,18 +1,18 @@
 import re
 
-import pydantic
 import pytest
-from pydantic import conint
-from typing import Dict
+from pydantic import BaseModel, StrictFloat, StrictInt, conint
+from typing import Dict, List
 
 from hivemind.dht import get_dht_time
 from hivemind.dht.node import DHTNode, LOCALHOST
 from hivemind.dht.schema import SchemaValidator, conbytes
+from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 
 
 @pytest.fixture
 async def dht_nodes_with_schema():
-    class Schema(pydantic.BaseModel):
+    class Schema(BaseModel):
         experiment_name: bytes
         n_batches: Dict[bytes, conint(ge=0, strict=True)]
         signed_data: Dict[conbytes(regex=rb'.*\[owner:.+\]'), bytes]
@@ -25,17 +25,6 @@ async def dht_nodes_with_schema():
     return alice, bob
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_keys_outside_schema(dht_nodes_with_schema):
-    alice, bob = dht_nodes_with_schema
-
-    assert await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
-
-    for peer in [alice, bob]:
-        assert (await peer.get(b'unknown_key', latest=True)).value == b'foo_bar'
-
-
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_expecting_regular_value(dht_nodes_with_schema):
@@ -103,3 +92,73 @@ async def test_expecting_public_keys(dht_nodes_with_schema):
         dictionary = (await peer.get(b'signed_data', latest=True)).value
         assert (len(dictionary) == 1 and
                 dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_keys_outside_schema(dht_nodes_with_schema):
+    class Schema(BaseModel):
+        some_field: StrictInt
+
+    class MergedSchema(BaseModel):
+        another_field: StrictInt
+
+    for allow_extra_keys in [False, True]:
+        validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
+        assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))
+
+        alice = await DHTNode.create(record_validator=validator)
+        bob = await DHTNode.create(
+            record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+
+        store_ok = await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
+        assert store_ok == allow_extra_keys
+
+        for peer in [alice, bob]:
+            result = await peer.get(b'unknown_key', latest=True)
+            if allow_extra_keys:
+                assert result.value == b'foo_bar'
+            else:
+                assert result is None
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_merging_schema_validators(dht_nodes_with_schema):
+    alice, bob = dht_nodes_with_schema
+
+    class TrivialValidator(RecordValidatorBase):
+        def validate(self, record: DHTRecord) -> bool:
+            return True
+
+    second_validator = TrivialValidator()
+    # Can't merge with the validator of the different type
+    assert not alice.protocol.record_validator.merge_with(second_validator)
+
+    class SecondSchema(BaseModel):
+        some_field: StrictInt
+        another_field: str
+
+    class ThirdSchema(BaseModel):
+        another_field: StrictInt  # Allow it to be a StrictInt as well
+
+    for schema in [SecondSchema, ThirdSchema]:
+        new_validator = SchemaValidator(schema, allow_extra_keys=False)
+        for peer in [alice, bob]:
+            assert peer.protocol.record_validator.merge_with(new_validator)
+
+    assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
+    assert await bob.store(b'some_field', 777, get_dht_time() + 10)
+    assert not await bob.store(b'some_field', 'string_value', get_dht_time() + 10)
+    assert await bob.store(b'another_field', 42, get_dht_time() + 10)
+    assert await bob.store(b'another_field', 'string_value', get_dht_time() + 10)
+
+    # Unkown keys are allowed since the first schema is created with allow_extra_keys=True
+    assert await bob.store(b'unknown_key', 999, get_dht_time() + 10)
+
+    for peer in [alice, bob]:
+        assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
+        assert (await peer.get(b'some_field', latest=True)).value == 777
+        assert (await peer.get(b'another_field', latest=True)).value == 'string_value'
+
+        assert (await peer.get(b'unknown_key', latest=True)).value == 999

+ 93 - 0
tests/test_dht_validation.py

@@ -0,0 +1,93 @@
+import dataclasses
+from functools import partial
+from typing import Dict
+
+import pytest
+from pydantic import BaseModel, StrictInt
+
+import hivemind
+from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.routing import DHTID
+from hivemind.dht.schema import SchemaValidator
+from hivemind.dht.validation import DHTRecord, CompositeValidator, RecordValidatorBase
+
+
+class SchemaA(BaseModel):
+    field_a: bytes
+
+
+class SchemaB(BaseModel):
+    field_b: Dict[bytes, StrictInt]
+
+
+@pytest.fixture
+def validators_for_app():
+    # Each application may add its own validator set
+    return {
+        'A': [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)],
+        'B': [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()],
+    }
+
+
+def test_composite_validator(validators_for_app):
+    validator = CompositeValidator(validators_for_app['A'])
+    assert ([type(item) for item in validator._validators] ==
+        [SchemaValidator, RSASignatureValidator])
+
+    validator.extend(validators_for_app['B'])
+    assert ([type(item) for item in validator._validators] ==
+        [SchemaValidator, RSASignatureValidator])
+    assert len(validator._validators[0]._schemas) == 2
+
+    public_key = validators_for_app['A'][0].ownership_marker
+    record = DHTRecord(key=DHTID.generate(source=b'field_b').to_bytes(),
+                       subkey=DHTProtocol.serializer.dumps(public_key),
+                       value=DHTProtocol.serializer.dumps(777),
+                       expiration_time=hivemind.get_dht_time() + 10)
+
+    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
+    # Expect only one signature since two RSASignatureValidatos have been merged
+    assert signed_record.value.count(b'[signature:') == 1
+    # Expect successful validation since the second SchemaValidator has been merged to the first
+    assert validator.validate(signed_record)
+    assert validator.strip_value(signed_record) == record.value
+
+    record = DHTRecord(key=DHTID.generate(source=b'unknown_key').to_bytes(),
+                       subkey=DHTProtocol.IS_REGULAR_VALUE,
+                       value=DHTProtocol.serializer.dumps(777),
+                       expiration_time=hivemind.get_dht_time() + 10)
+
+    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
+    assert signed_record.value.count(b'[signature:') == 0
+    # Expect failed validation since `unknown_key` is not a part of any schema
+    assert not validator.validate(signed_record)
+
+
+@pytest.mark.forked
+def test_dht_add_validators(validators_for_app):
+    # One app may create a DHT with its validators
+    dht = hivemind.DHT(start=False, record_validators=validators_for_app['A'])
+
+    # While the DHT process is not started, you can't send a command to append new validators
+    with pytest.raises(RuntimeError):
+        dht.add_validators(validators_for_app['B'])
+    dht.run_in_background(await_ready=True)
+
+    # After starting the process, other apps may add new validators to the existing DHT
+    dht.add_validators(validators_for_app['B'])
+
+    assert dht.store(b'field_a', b'bytes_value', hivemind.get_dht_time() + 10)
+    assert dht.get(b'field_a', latest=True).value == b'bytes_value'
+
+    assert not dht.store(b'field_a', 666, hivemind.get_dht_time() + 10)
+    assert dht.get(b'field_a', latest=True).value == b'bytes_value'
+
+    public_key = validators_for_app['A'][0].ownership_marker
+    assert dht.store(b'field_b', 777, hivemind.get_dht_time() + 10, subkey=public_key)
+    dictionary = dht.get(b'field_b', latest=True).value
+    assert (len(dictionary) == 1 and
+            dictionary[public_key].value == 777)
+
+    assert not dht.store(b'unknown_key', 666, hivemind.get_dht_time() + 10)
+    assert dht.get(b'unknown_key', latest=True) is None