Selaa lähdekoodia

Implement combining validators (#249)

This PR implements a method to combine validators. Each application may add its own list of validators. Then, validators are executed with respect to their priorities and custom combination policies.
Aleksandr Borzunov 4 vuotta sitten
vanhempi
commit
18add2c04b

+ 1 - 1
hivemind/__init__.py

@@ -4,4 +4,4 @@ from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 from hivemind.optim import *
 from hivemind.optim import *
 
 
-__version__ = '0.9.7'
+__version__ = '0.9.8'

+ 20 - 3
hivemind/dht/__init__.py

@@ -17,12 +17,14 @@ import asyncio
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from typing import List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
+from functools import partial
+from typing import Iterable, List, Optional, Sequence, Union, Callable, Awaitable, TypeVar
 
 
 import hivemind
 import hivemind
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import DHTValue, DHTKey, Subkey
 from hivemind.dht.routing import DHTValue, DHTKey, Subkey
+from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.utils.networking import Hostname, Endpoint, strip_port
 from hivemind.utils.networking import Hostname, Endpoint, strip_port
 from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, get_dht_time
 from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, get_dht_time
 
 
@@ -49,12 +51,14 @@ class DHT(mp.Process):
 
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
                  daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
-                 expiration: float = 300, **kwargs):
+                 expiration: float = 300, record_validators: Iterable[RecordValidatorBase] = (),
+                 **kwargs):
         super().__init__()
         super().__init__()
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         assert not isinstance(initial_peers, str), "please specify a list/tuple of initial peers (even if there's one)"
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
         self.default_expiration = expiration
         self.default_expiration = expiration
+        self._record_validator = CompositeValidator(record_validators)
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
         self.ready = mp.Event()
@@ -70,7 +74,8 @@ class DHT(mp.Process):
         async def _run():
         async def _run():
             node = await DHTNode.create(
             node = await DHTNode.create(
                 initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
                 initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
-                num_workers=self.max_workers or 1, **self.kwargs)
+                num_workers=self.max_workers or 1, record_validator=self._record_validator,
+                **self.kwargs)
             if node.port is not None:
             if node.port is not None:
                 self._port.value = node.port
                 self._port.value = node.port
             self.ready.set()
             self.ready.set()
@@ -190,6 +195,18 @@ class DHT(mp.Process):
             if not future.done():
             if not future.done():
                 future.set_exception(e)
                 future.set_exception(e)
 
 
+    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
+        if not self.ready.is_set():
+            raise RuntimeError(
+                "Can't append new validators before the DHT process has started. "
+                "Consider adding them to the initial list via DHT.__init__(record_validators=...)")
+
+        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
+
+    async def _add_validators(
+            self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
+        node.protocol.record_validator.extend(record_validators)
+
     def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
     def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
         """
         """
         Get this machine's visible address by requesting other peers or using pre-specified network addresses.
         Get this machine's visible address by requesting other peers or using pre-specified network addresses.

+ 27 - 0
hivemind/dht/crypto.py

@@ -86,3 +86,30 @@ class RSASignatureValidator(RecordValidatorBase):
 
 
     def _serialize_record(self, record: DHTRecord) -> bytes:
     def _serialize_record(self, record: DHTRecord) -> bytes:
         return MSGPackSerializer.dumps(dataclasses.astuple(record))
         return MSGPackSerializer.dumps(dataclasses.astuple(record))
+
+    @property
+    def priority(self) -> int:
+        # On validation, this validator must be executed before validators
+        # that deserialize the record
+        return 10
+
+    def merge_with(self, other: RecordValidatorBase) -> bool:
+        if not isinstance(other, RSASignatureValidator):
+            return False
+
+        # Ignore another RSASignatureValidator instance (it doesn't make sense to have several
+        # instances of this class) and report successful merge
+        return True
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        # Serializes the private key to make the class instances picklable
+        state['_private_key'] = self._private_key.private_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PrivateFormat.OpenSSH,
+            encryption_algorithm=serialization.NoEncryption())
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)

+ 74 - 25
hivemind/dht/schema.py

@@ -1,6 +1,6 @@
 import binascii
 import binascii
 import re
 import re
-from typing import Type
+from typing import Any, Dict, Type
 
 
 import pydantic
 import pydantic
 
 
@@ -19,7 +19,7 @@ class SchemaValidator(RecordValidatorBase):
     This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
     This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
     """
     """
 
 
-    def __init__(self, schema: pydantic.BaseModel):
+    def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool=True):
         """
         """
         :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
         :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
 
 
@@ -27,18 +27,25 @@ class SchemaValidator(RecordValidatorBase):
             (e.g. ``StrictInt`` instead of ``int``,
             (e.g. ``StrictInt`` instead of ``int``,
             ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
             ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
             See the validate() docstring for details.
             See the validate() docstring for details.
+
+        :param allow_extra_keys: Whether to allow keys that are not defined in the schema.
+
+            If a SchemaValidator is merged with another SchemaValidator, this option applies to
+            keys that are not defined in each of the schemas.
         """
         """
 
 
         self._alias_to_name = {}
         self._alias_to_name = {}
+
         for field in schema.__fields__.values():
         for field in schema.__fields__.values():
             field.alias = self._key_id_to_str(DHTID.generate(source=field.name.encode()).to_bytes())
             field.alias = self._key_id_to_str(DHTID.generate(source=field.name.encode()).to_bytes())
             self._alias_to_name[field.alias] = field.name
             self._alias_to_name[field.alias] = field.name
 
 
             # Because validate() interface provides one key at a time
             # Because validate() interface provides one key at a time
             field.required = False
             field.required = False
+        schema.Config.extra = pydantic.Extra.forbid
 
 
-        schema.Config.extra = pydantic.Extra.allow
-        self._schema = schema
+        self._schemas = [schema]
+        self._allow_extra_keys = allow_extra_keys
 
 
     def validate(self, record: DHTRecord) -> bool:
     def validate(self, record: DHTRecord) -> bool:
         """
         """
@@ -62,34 +69,58 @@ class SchemaValidator(RecordValidatorBase):
            .. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
            .. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
         """
         """
 
 
-        key_alias = self._key_id_to_str(record.key)
+        try:
+            record = self._deserialize_record(record)
+        except ValueError as e:
+            logger.warning(e)
+            return False
+        [key_alias] = list(record.keys())
+
+        n_outside_schema = 0
+        validation_errors = []
+        for schema in self._schemas:
+            try:
+                parsed_record = schema.parse_obj(record)
+            except pydantic.ValidationError as e:
+                if self._is_failed_due_to_extra_field(e):
+                    n_outside_schema += 1
+                else:
+                    validation_errors.append(e)
+                continue
+
+            parsed_value = parsed_record.dict(by_alias=True)[key_alias]
+            if parsed_value != record[key_alias]:
+                validation_errors.append(ValueError(
+                    f"Value {record[key_alias]} needed type conversions to match "
+                    f"the schema: {parsed_value}. Type conversions are not allowed"))
+            else:
+                return True
+
+        readable_record = {self._alias_to_name.get(key_alias, key_alias): record[key_alias]}
+
+        if n_outside_schema == len(self._schemas):
+            if not self._allow_extra_keys:
+                logger.warning(f"Record {readable_record} contains a field that "
+                               f"is not defined in each of the schemas")
+            return self._allow_extra_keys
+
+        logger.warning(
+            f"Record {readable_record} doesn't match any of the schemas: {validation_errors}")
+        return False
+
+    @staticmethod
+    def _deserialize_record(record: DHTRecord) -> Dict[str, Any]:
+        key_alias = SchemaValidator._key_id_to_str(record.key)
         deserialized_value = DHTProtocol.serializer.loads(record.value)
         deserialized_value = DHTProtocol.serializer.loads(record.value)
         if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
         if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
             deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
             deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
-            deserialized_record = {key_alias: {deserialized_subkey: deserialized_value}}
+            return {key_alias: {deserialized_subkey: deserialized_value}}
         else:
         else:
             if isinstance(deserialized_value, dict):
             if isinstance(deserialized_value, dict):
-                logger.warning(
+                raise ValueError(
                     f'Record {record} contains an improperly serialized dictionary (you must use '
                     f'Record {record} contains an improperly serialized dictionary (you must use '
                     f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
                     f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
-                return False
-            deserialized_record = {key_alias: deserialized_value}
-
-        try:
-            parsed_record = self._schema.parse_obj(deserialized_record)
-        except pydantic.ValidationError as e:
-            readable_record = {self._alias_to_name.get(key_alias, key_alias):
-                               deserialized_record[key_alias]}
-            logger.warning(f"Record {readable_record} doesn't match the schema: {e}")
-            return False
-
-        parsed_value = parsed_record.dict(by_alias=True)[key_alias]
-        if parsed_value != deserialized_record[key_alias]:
-            logger.warning(
-                f"Value {deserialized_record[key_alias]} needed type conversions to match "
-                f" the schema: {parsed_value}. Type conversions are not allowed")
-            return False
-        return True
+            return {key_alias: deserialized_value}
 
 
     @staticmethod
     @staticmethod
     def _key_id_to_str(key_id: bytes) -> str:
     def _key_id_to_str(key_id: bytes) -> str:
@@ -100,6 +131,24 @@ class SchemaValidator(RecordValidatorBase):
 
 
         return binascii.hexlify(key_id).decode()
         return binascii.hexlify(key_id).decode()
 
 
+    @staticmethod
+    def _is_failed_due_to_extra_field(exc: pydantic.ValidationError):
+        inner_errors = exc.errors()
+        return (
+            len(inner_errors) == 1 and
+            inner_errors[0]['type'] == 'value_error.extra' and
+            len(inner_errors[0]['loc']) == 1  # Require the extra field to be on the top level
+        )
+
+    def merge_with(self, other: RecordValidatorBase) -> bool:
+        if not isinstance(other, SchemaValidator):
+            return False
+
+        self._alias_to_name.update(other._alias_to_name)
+        self._schemas.extend(other._schemas)
+        self._allow_extra_keys = self._allow_extra_keys or other._allow_extra_keys
+        return True
+
 
 
 def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
 def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
     """
     """

+ 68 - 0
hivemind/dht/validation.py

@@ -1,5 +1,6 @@
 import dataclasses
 import dataclasses
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
+from typing import Iterable
 
 
 
 
 @dataclasses.dataclass(init=True, repr=True, frozen=True)
 @dataclasses.dataclass(init=True, repr=True, frozen=True)
@@ -52,3 +53,70 @@ class RecordValidatorBase(ABC):
         """
         """
 
 
         return record.value
         return record.value
+
+    @property
+    def priority(self) -> int:
+        """
+        Defines the order of applying this validator with respect to other validators.
+
+        The validators are applied:
+          - In order of increasing priority for signing a record
+          - In order of decreasing priority for validating and stripping a record
+        """
+
+        return 0
+
+    def merge_with(self, other: 'RecordValidatorBase') -> bool:
+        """
+        By default, all validators are applied sequentially (i.e. we require all validate() calls
+        to return True for a record to be validated successfully).
+
+        However, you may want to define another policy for combining your validator classes
+        (e.g. for schema validators, we want to require only one validate() call to return True
+        because each validator bears a part of the schema).
+
+        This can be achieved with overriding merge_with(). It should:
+
+          - Return True if it has successfully merged the `other` validator to `self`,
+            so that `self` became a validator that combines the old `self` and `other` using
+            the necessary policy. In this case, `other` should remain unchanged.
+
+          - Return False if the merging has not happened. In this case, both `self` and `other`
+            should remain unchanged. The DHT will try merging `other` to another validator or
+            add it as a separate validator (to be applied sequentially).
+        """
+
+        return False
+
+
+class CompositeValidator(RecordValidatorBase):
+    def __init__(self, validators: Iterable[RecordValidatorBase]=()):
+        self._validators = []
+        self.extend(validators)
+
+    def extend(self, validators: Iterable[RecordValidatorBase]) -> None:
+        for new_validator in validators:
+            for existing_validator in self._validators:
+                if existing_validator.merge_with(new_validator):
+                    break
+            else:
+                self._validators.append(new_validator)
+        self._validators.sort(key=lambda item: item.priority)
+
+    def validate(self, record: DHTRecord) -> bool:
+        for i, validator in enumerate(reversed(self._validators)):
+            if not validator.validate(record):
+                return False
+            if i < len(self._validators) - 1:
+                record = dataclasses.replace(record, value=validator.strip_value(record))
+        return True
+
+    def sign_value(self, record: DHTRecord) -> bytes:
+        for validator in self._validators:
+            record = dataclasses.replace(record, value=validator.sign_value(record))
+        return record.value
+
+    def strip_value(self, record: DHTRecord) -> bytes:
+        for validator in reversed(self._validators):
+            record = dataclasses.replace(record, value=validator.strip_value(record))
+        return record.value

+ 74 - 15
tests/test_dht_schema.py

@@ -1,18 +1,18 @@
 import re
 import re
 
 
-import pydantic
 import pytest
 import pytest
-from pydantic import conint
-from typing import Dict
+from pydantic import BaseModel, StrictFloat, StrictInt, conint
+from typing import Dict, List
 
 
 from hivemind.dht import get_dht_time
 from hivemind.dht import get_dht_time
 from hivemind.dht.node import DHTNode, LOCALHOST
 from hivemind.dht.node import DHTNode, LOCALHOST
 from hivemind.dht.schema import SchemaValidator, conbytes
 from hivemind.dht.schema import SchemaValidator, conbytes
+from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 
 
 
 
 @pytest.fixture
 @pytest.fixture
 async def dht_nodes_with_schema():
 async def dht_nodes_with_schema():
-    class Schema(pydantic.BaseModel):
+    class Schema(BaseModel):
         experiment_name: bytes
         experiment_name: bytes
         n_batches: Dict[bytes, conint(ge=0, strict=True)]
         n_batches: Dict[bytes, conint(ge=0, strict=True)]
         signed_data: Dict[conbytes(regex=rb'.*\[owner:.+\]'), bytes]
         signed_data: Dict[conbytes(regex=rb'.*\[owner:.+\]'), bytes]
@@ -25,17 +25,6 @@ async def dht_nodes_with_schema():
     return alice, bob
     return alice, bob
 
 
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_keys_outside_schema(dht_nodes_with_schema):
-    alice, bob = dht_nodes_with_schema
-
-    assert await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
-
-    for peer in [alice, bob]:
-        assert (await peer.get(b'unknown_key', latest=True)).value == b'foo_bar'
-
-
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_expecting_regular_value(dht_nodes_with_schema):
 async def test_expecting_regular_value(dht_nodes_with_schema):
@@ -103,3 +92,73 @@ async def test_expecting_public_keys(dht_nodes_with_schema):
         dictionary = (await peer.get(b'signed_data', latest=True)).value
         dictionary = (await peer.get(b'signed_data', latest=True)).value
         assert (len(dictionary) == 1 and
         assert (len(dictionary) == 1 and
                 dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
                 dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_keys_outside_schema(dht_nodes_with_schema):
+    class Schema(BaseModel):
+        some_field: StrictInt
+
+    class MergedSchema(BaseModel):
+        another_field: StrictInt
+
+    for allow_extra_keys in [False, True]:
+        validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
+        assert validator.merge_with(SchemaValidator(MergedSchema, allow_extra_keys=False))
+
+        alice = await DHTNode.create(record_validator=validator)
+        bob = await DHTNode.create(
+            record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+
+        store_ok = await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
+        assert store_ok == allow_extra_keys
+
+        for peer in [alice, bob]:
+            result = await peer.get(b'unknown_key', latest=True)
+            if allow_extra_keys:
+                assert result.value == b'foo_bar'
+            else:
+                assert result is None
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_merging_schema_validators(dht_nodes_with_schema):
+    alice, bob = dht_nodes_with_schema
+
+    class TrivialValidator(RecordValidatorBase):
+        def validate(self, record: DHTRecord) -> bool:
+            return True
+
+    second_validator = TrivialValidator()
+    # Can't merge with the validator of the different type
+    assert not alice.protocol.record_validator.merge_with(second_validator)
+
+    class SecondSchema(BaseModel):
+        some_field: StrictInt
+        another_field: str
+
+    class ThirdSchema(BaseModel):
+        another_field: StrictInt  # Allow it to be a StrictInt as well
+
+    for schema in [SecondSchema, ThirdSchema]:
+        new_validator = SchemaValidator(schema, allow_extra_keys=False)
+        for peer in [alice, bob]:
+            assert peer.protocol.record_validator.merge_with(new_validator)
+
+    assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
+    assert await bob.store(b'some_field', 777, get_dht_time() + 10)
+    assert not await bob.store(b'some_field', 'string_value', get_dht_time() + 10)
+    assert await bob.store(b'another_field', 42, get_dht_time() + 10)
+    assert await bob.store(b'another_field', 'string_value', get_dht_time() + 10)
+
+    # Unkown keys are allowed since the first schema is created with allow_extra_keys=True
+    assert await bob.store(b'unknown_key', 999, get_dht_time() + 10)
+
+    for peer in [alice, bob]:
+        assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
+        assert (await peer.get(b'some_field', latest=True)).value == 777
+        assert (await peer.get(b'another_field', latest=True)).value == 'string_value'
+
+        assert (await peer.get(b'unknown_key', latest=True)).value == 999

+ 93 - 0
tests/test_dht_validation.py

@@ -0,0 +1,93 @@
+import dataclasses
+from functools import partial
+from typing import Dict
+
+import pytest
+from pydantic import BaseModel, StrictInt
+
+import hivemind
+from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.routing import DHTID
+from hivemind.dht.schema import SchemaValidator
+from hivemind.dht.validation import DHTRecord, CompositeValidator, RecordValidatorBase
+
+
+class SchemaA(BaseModel):
+    field_a: bytes
+
+
+class SchemaB(BaseModel):
+    field_b: Dict[bytes, StrictInt]
+
+
+@pytest.fixture
+def validators_for_app():
+    # Each application may add its own validator set
+    return {
+        'A': [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)],
+        'B': [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()],
+    }
+
+
+def test_composite_validator(validators_for_app):
+    validator = CompositeValidator(validators_for_app['A'])
+    assert ([type(item) for item in validator._validators] ==
+        [SchemaValidator, RSASignatureValidator])
+
+    validator.extend(validators_for_app['B'])
+    assert ([type(item) for item in validator._validators] ==
+        [SchemaValidator, RSASignatureValidator])
+    assert len(validator._validators[0]._schemas) == 2
+
+    public_key = validators_for_app['A'][0].ownership_marker
+    record = DHTRecord(key=DHTID.generate(source=b'field_b').to_bytes(),
+                       subkey=DHTProtocol.serializer.dumps(public_key),
+                       value=DHTProtocol.serializer.dumps(777),
+                       expiration_time=hivemind.get_dht_time() + 10)
+
+    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
+    # Expect only one signature since two RSASignatureValidatos have been merged
+    assert signed_record.value.count(b'[signature:') == 1
+    # Expect successful validation since the second SchemaValidator has been merged to the first
+    assert validator.validate(signed_record)
+    assert validator.strip_value(signed_record) == record.value
+
+    record = DHTRecord(key=DHTID.generate(source=b'unknown_key').to_bytes(),
+                       subkey=DHTProtocol.IS_REGULAR_VALUE,
+                       value=DHTProtocol.serializer.dumps(777),
+                       expiration_time=hivemind.get_dht_time() + 10)
+
+    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
+    assert signed_record.value.count(b'[signature:') == 0
+    # Expect failed validation since `unknown_key` is not a part of any schema
+    assert not validator.validate(signed_record)
+
+
+@pytest.mark.forked
+def test_dht_add_validators(validators_for_app):
+    # One app may create a DHT with its validators
+    dht = hivemind.DHT(start=False, record_validators=validators_for_app['A'])
+
+    # While the DHT process is not started, you can't send a command to append new validators
+    with pytest.raises(RuntimeError):
+        dht.add_validators(validators_for_app['B'])
+    dht.run_in_background(await_ready=True)
+
+    # After starting the process, other apps may add new validators to the existing DHT
+    dht.add_validators(validators_for_app['B'])
+
+    assert dht.store(b'field_a', b'bytes_value', hivemind.get_dht_time() + 10)
+    assert dht.get(b'field_a', latest=True).value == b'bytes_value'
+
+    assert not dht.store(b'field_a', 666, hivemind.get_dht_time() + 10)
+    assert dht.get(b'field_a', latest=True).value == b'bytes_value'
+
+    public_key = validators_for_app['A'][0].ownership_marker
+    assert dht.store(b'field_b', 777, hivemind.get_dht_time() + 10, subkey=public_key)
+    dictionary = dht.get(b'field_b', latest=True).value
+    assert (len(dictionary) == 1 and
+            dictionary[public_key].value == 777)
+
+    assert not dht.store(b'unknown_key', 666, hivemind.get_dht_time() + 10)
+    assert dht.get(b'unknown_key', latest=True) is None