Bladeren bron

Add DHT schema validator (#227)

This PR introduces the DHT schema validator that allows an application to restrict the DHT content format using the pydantic module (e.g. enforce types, min/max values, require a subkey to contain a public key, etc.).
Aleksandr Borzunov 4 jaren geleden
bovenliggende
commit
a3feafa907
4 gewijzigde bestanden met toevoegingen van 251 en 10 verwijderingen
  1. 22 10
      hivemind/dht/node.py
  2. 123 0
      hivemind/dht/schema.py
  3. 1 0
      requirements.txt
  4. 105 0
      tests/test_dht_schema.py

+ 22 - 10
hivemind/dht/node.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import asyncio
+import dataclasses
 import random
 from collections import defaultdict, Counter
 from dataclasses import dataclass, field
@@ -303,14 +304,17 @@ class DHTNode:
 
             key_bytes = key_id.to_bytes()
             binary_values = []
+            stored_records = []
             for subkey, value, expiration_time in zip(
                     current_subkeys, current_values, current_expirations):
+                subkey_bytes = self.protocol.serializer.dumps(subkey)
                 value_bytes = self.protocol.serializer.dumps(value)
+                record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
                 if self.protocol.record_validator is not None:
-                    subkey_bytes = self.protocol.serializer.dumps(subkey)
-                    record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
                     value_bytes = self.protocol.record_validator.sign_value(record)
+                    record = dataclasses.replace(record, value=value_bytes)
                 binary_values.append(value_bytes)
+                stored_records.append(record)
 
             while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
                 while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
@@ -318,9 +322,13 @@ class DHTNode:
 
                     if node_id == self.node_id:
                         num_successful_stores += 1
-                        for subkey, value, expiration_time in zip(current_subkeys, binary_values, current_expirations):
-                            store_ok[original_key, subkey] = self.protocol.storage.store(
-                                key_id, value, expiration_time, subkey=subkey)
+                        for subkey, record in zip(current_subkeys, stored_records):
+                            if (self.protocol.record_validator is None or
+                                    self.protocol.record_validator.validate(record)):
+                                store_ok[original_key, subkey] = self.protocol.storage.store(
+                                    key_id, record.value, record.expiration_time, subkey=subkey)
+                            else:
+                                store_ok[original_key, subkey] = False
                             if not await_all_replicas:
                                 store_finished_events[original_key, subkey].set()
                     else:
@@ -364,15 +372,19 @@ class DHTNode:
         store_succeeded = any(store_ok)
         is_dictionary = any(subkey is not None for subkey in subkeys)
         if store_succeeded and not is_dictionary:  # stored a new regular value, cache it!
-            stored_value_bytes, stored_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
+            stored_expiration, stored_value_bytes = max(zip(expirations, binary_values))
             self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration)
         elif not store_succeeded and not is_dictionary:  # store rejected, check if local cache is also obsolete
-            rejected_value, rejected_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
-            if (self.protocol.cache.get(key_id)[1] or float("inf")) <= rejected_expiration:  # cache would be rejected
+            rejected_expiration, rejected_value = max(zip(expirations, binary_values))
+            cached_value = self.protocol.cache.get(key_id)
+            if (cached_value is not None and
+                    cached_value.expiration_time <= rejected_expiration):  # cache would be rejected
                 self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
         elif is_dictionary and key_id in self.protocol.cache:  # there can be other keys and we should update
-            for subkey, stored_value_bytes, expiration_time in zip(subkeys, binary_values, expirations):
-                self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
+            for subkey, stored_value_bytes, expiration_time, accepted in zip(
+                    subkeys, binary_values, expirations, store_ok):
+                if accepted:
+                    self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
             self._schedule_for_refresh(key_id, refresh_time=get_dht_time())  # fetch new key in background (asap)
 
     async def get(self, key: DHTKey, latest=False, **kwargs) -> Optional[ValueWithExpiration[DHTValue]]:

+ 123 - 0
hivemind/dht/schema.py

@@ -0,0 +1,123 @@
+import binascii
+import re
+from typing import Type
+
+import pydantic
+
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.routing import DHTID, DHTKey
+from hivemind.dht.validation import DHTRecord, RecordValidatorBase
+from hivemind.utils import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class SchemaValidator(RecordValidatorBase):
+    """
+    Restricts specified DHT keys to match a Pydantic schema.
+    This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
+    """
+
+    def __init__(self, schema: pydantic.BaseModel):
+        """
+        :param schema: The Pydantic model (a subclass of pydantic.BaseModel).
+
+            You must always use strict types for the number fields
+            (e.g. ``StrictInt`` instead of ``int``,
+            ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
+            See the validate() docstring for details.
+        """
+
+        self._alias_to_name = {}
+        for field in schema.__fields__.values():
+            field.alias = self._key_id_to_str(DHTID.generate(source=field.name.encode()).to_bytes())
+            self._alias_to_name[field.alias] = field.name
+
+            # Because validate() interface provides one key at a time
+            field.required = False
+
+        schema.Config.extra = pydantic.Extra.allow
+        self._schema = schema
+
+    def validate(self, record: DHTRecord) -> bool:
+        """
+        Validates ``record`` in two steps:
+
+        1. Create a Pydantic model and ensure that no exceptions are thrown.
+
+        2. Ensure that Pydantic has not made any type conversions [1]_ while creating the model.
+           To do this, we check that the value of the model field is equal
+           (in terms of == operator) to the source value.
+
+           This works for the iterable default types like str, list, and dict
+           (they are equal only if the types match) but does not work for numbers
+           (they have a special case allowing ``3.0 == 3`` to be true). [2]_
+
+           Because of that, you must always use strict types [3]_ for the number fields
+           (e.g. to avoid ``3.0`` to be validated successfully for the ``field: int``).
+
+           .. [1] https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
+           .. [2] https://stackoverflow.com/a/52557261
+           .. [3] https://pydantic-docs.helpmanual.io/usage/types/#strict-types
+        """
+
+        key_alias = self._key_id_to_str(record.key)
+        deserialized_value = DHTProtocol.serializer.loads(record.value)
+        if record.subkey not in DHTProtocol.RESERVED_SUBKEYS:
+            deserialized_subkey = DHTProtocol.serializer.loads(record.subkey)
+            deserialized_record = {key_alias: {deserialized_subkey: deserialized_value}}
+        else:
+            if isinstance(deserialized_value, dict):
+                logger.warning(
+                    f'Record {record} contains an improperly serialized dictionary (you must use '
+                    f'a DictionaryDHTValue of serialized values instead of a `dict` subclass)')
+                return False
+            deserialized_record = {key_alias: deserialized_value}
+
+        try:
+            parsed_record = self._schema.parse_obj(deserialized_record)
+        except pydantic.ValidationError as e:
+            readable_record = {self._alias_to_name.get(key_alias, key_alias):
+                               deserialized_record[key_alias]}
+            logger.warning(f"Record {readable_record} doesn't match the schema: {e}")
+            return False
+
+        parsed_value = parsed_record.dict(by_alias=True)[key_alias]
+        if parsed_value != deserialized_record[key_alias]:
+            logger.warning(
+                f"Value {deserialized_record[key_alias]} needed type conversions to match "
+                f" the schema: {parsed_value}. Type conversions are not allowed")
+            return False
+        return True
+
+    @staticmethod
+    def _key_id_to_str(key_id: bytes) -> str:
+        """
+        Represent ``key_id`` as a ``str`` since Pydantic does not support field aliases
+        of type ``bytes``.
+        """
+
+        return binascii.hexlify(key_id).decode()
+
+
+def conbytes(*, regex: bytes=None, **kwargs) -> Type[pydantic.BaseModel]:
+    """
+    Extend pydantic.conbytes() to support ``regex`` constraints (like pydantic.constr() does).
+    """
+
+    compiled_regex = re.compile(regex) if regex is not None else None
+
+    class ConstrainedBytesWithRegex(pydantic.conbytes(**kwargs)):
+        @classmethod
+        def __get_validators__(cls):
+            yield from super().__get_validators__()
+            yield cls.match_regex
+
+        @classmethod
+        def match_regex(cls, value: bytes) -> bytes:
+            if compiled_regex is not None and compiled_regex.match(value) is None:
+                raise ValueError(f"Value `{value}` doesn't match regex `{regex}`")
+            return value
+
+    return ConstrainedBytesWithRegex

+ 1 - 0
requirements.txt

@@ -11,3 +11,4 @@ grpcio-tools>=1.33.2
 protobuf>=3.12.2
 configargparse>=1.2.3
 cryptography>=3.4.6
+pydantic>=1.8.1

+ 105 - 0
tests/test_dht_schema.py

@@ -0,0 +1,105 @@
+import re
+
+import pydantic
+import pytest
+from pydantic import conint
+from typing import Dict
+
+from hivemind.dht import get_dht_time
+from hivemind.dht.node import DHTNode, LOCALHOST
+from hivemind.dht.schema import SchemaValidator, conbytes
+
+
+@pytest.fixture
+async def dht_nodes_with_schema():
+    class Schema(pydantic.BaseModel):
+        experiment_name: bytes
+        n_batches: Dict[bytes, conint(ge=0, strict=True)]
+        signed_data: Dict[conbytes(regex=rb'.*\[owner:.+\]'), bytes]
+
+    validator = SchemaValidator(Schema)
+
+    alice = await DHTNode.create(record_validator=validator)
+    bob = await DHTNode.create(
+        record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])
+    return alice, bob
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_keys_outside_schema(dht_nodes_with_schema):
+    alice, bob = dht_nodes_with_schema
+
+    assert await bob.store(b'unknown_key', b'foo_bar', get_dht_time() + 10)
+
+    for peer in [alice, bob]:
+        assert (await peer.get(b'unknown_key', latest=True)).value == b'foo_bar'
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_expecting_regular_value(dht_nodes_with_schema):
+    alice, bob = dht_nodes_with_schema
+
+    # Regular value (bytes) expected
+    assert await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10)
+    assert not await bob.store(b'experiment_name', 666, get_dht_time() + 10)
+    assert not await bob.store(b'experiment_name', b'foo_bar', get_dht_time() + 10,
+                               subkey=b'subkey')
+
+    # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
+    assert not await bob.store(b'experiment_name', [], get_dht_time() + 10)
+    assert not await bob.store(b'experiment_name', [1, 2, 3], get_dht_time() + 10)
+
+    for peer in [alice, bob]:
+        assert (await peer.get(b'experiment_name', latest=True)).value == b'foo_bar'
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_expecting_dictionary(dht_nodes_with_schema):
+    alice, bob = dht_nodes_with_schema
+
+    # Dictionary (bytes -> non-negative int) expected
+    assert await bob.store(b'n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
+    assert await bob.store(b'n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
+    assert not await bob.store(b'n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
+    assert not await bob.store(b'n_batches', 666, get_dht_time() + 10)
+    assert not await bob.store(b'n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
+    assert not await bob.store(b'n_batches', 666, get_dht_time() + 10, subkey=666)
+
+    # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
+    assert not await bob.store(b'n_batches', {b'uid3': 779}, get_dht_time() + 10)
+
+    # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
+    assert not await bob.store(b'n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
+    assert not await bob.store(b'n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
+    assert not await bob.store(b'n_batches', [], get_dht_time() + 10)
+    assert not await bob.store(b'n_batches', [(b'uid3', 779)], get_dht_time() + 10)
+
+    # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
+    assert not await bob.store(b'n_batches', '', get_dht_time() + 10)
+
+    for peer in [alice, bob]:
+        dictionary = (await peer.get(b'n_batches', latest=True)).value
+        assert (len(dictionary) == 2 and
+                dictionary[b'uid1'].value == 777 and
+                dictionary[b'uid2'].value == 778)
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_expecting_public_keys(dht_nodes_with_schema):
+    alice, bob = dht_nodes_with_schema
+
+    # Subkeys expected to contain a public key
+    # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
+    assert await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
+                           subkey=b'uid[owner:public-key]')
+    assert not await bob.store(b'signed_data', b'foo_bar', get_dht_time() + 10,
+                               subkey=b'uid-without-public-key')
+
+    for peer in [alice, bob]:
+        dictionary = (await peer.get(b'signed_data', latest=True)).value
+        assert (len(dictionary) == 1 and
+                dictionary[b'uid[owner:public-key]'].value == b'foo_bar')