Przeglądaj źródła

Add RSA signature protection for DHT records (#187)

This PR introduces a notion of protected DHT records whose key/subkey contains substring [owner:ssh-rsa ...] (the format can be changed) with an RSA public key of the owner. Changes to such records always must be signed with the corresponding private key (so only the owner can change them). This protects from malicious nodes trying to spoil the DHT contents.
Aleksandr Borzunov 4 lat temu
rodzic
commit
1deab01c71

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.9.6'
+__version__ = '0.9.7'

+ 88 - 0
hivemind/dht/crypto.py

@@ -0,0 +1,88 @@
+import base64
+import dataclasses
+import re
+
+from cryptography.exceptions import InvalidSignature
+from cryptography.hazmat.primitives import hashes, serialization
+from cryptography.hazmat.primitives.asymmetric import padding, rsa
+
+from hivemind.dht.validation import DHTRecord, RecordValidatorBase
+from hivemind.utils import MSGPackSerializer, get_logger
+
+
+logger = get_logger(__name__)
+
+
+class RSASignatureValidator(RecordValidatorBase):
+    """
+    Introduces a notion of *protected records* whose key/subkey contains substring
+    "[owner:ssh-rsa ...]" (the format can be changed) with an RSA public key of the owner.
+
+    If this validator is used, changes to such records always must be signed with
+    the corresponding private key (so only the owner can change them).
+    """
+
+    def __init__(self,
+                 marker_format: bytes=b'[owner:_key_]',
+                 signature_format: bytes=b'[signature:_value_]'):
+        self._marker_re = re.compile(re.escape(marker_format).replace(b'_key_', rb'(.+?)'))
+
+        self._signature_format = signature_format
+        self._signature_re = re.compile(re.escape(signature_format).replace(b'_value_', rb'(.+?)'))
+
+        self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+
+        serialized_public_key = self._private_key.public_key().public_bytes(
+            encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH)
+        self._ownership_marker = marker_format.replace(b'_key_', serialized_public_key)
+
+        self._padding = padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
+                                    salt_length=padding.PSS.MAX_LENGTH)
+        self._hash_algorithm = hashes.SHA256()
+
+    @property
+    def ownership_marker(self) -> bytes:
+        return self._ownership_marker
+
+    def validate(self, record: DHTRecord) -> bool:
+        public_keys = self._marker_re.findall(record.key)
+        if record.subkey is not None:
+            public_keys += self._marker_re.findall(record.subkey)
+        if not public_keys:
+            return True  # The record is not protected with a public key
+
+        if len(set(public_keys)) > 1:
+            logger.debug(f"Key and subkey can't contain different public keys in {record}")
+            return False
+        public_key = serialization.load_ssh_public_key(public_keys[0])
+
+        signatures = self._signature_re.findall(record.value)
+        if len(signatures) != 1:
+            logger.debug(f"Record should have exactly one signature in {record}")
+            return False
+        signature = base64.b64decode(signatures[0])
+
+        stripped_record = dataclasses.replace(record, value=self.strip_value(record))
+        try:
+            # verify() returns None iff the signature is correct
+            public_key.verify(signature, self._serialize_record(stripped_record),
+                              self._padding, self._hash_algorithm)
+            return True
+        except InvalidSignature:
+            logger.debug(f'Signature is invalid in {record}')
+            return False
+
+    def sign_value(self, record: DHTRecord) -> bytes:
+        if self._ownership_marker not in record.key and self._ownership_marker not in record.subkey:
+            return record.value
+
+        signature = self._private_key.sign(self._serialize_record(record),
+                                           self._padding, self._hash_algorithm)
+        signature = base64.b64encode(signature)
+        return record.value + self._signature_format.replace(b'_value_', signature)
+
+    def strip_value(self, record: DHTRecord) -> bytes:
+        return self._signature_re.sub(b'', record.value)
+
+    def _serialize_record(self, record: DHTRecord) -> bytes:
+        return MSGPackSerializer.dumps(dataclasses.astuple(record))

+ 37 - 6
hivemind/dht/node.py

@@ -9,6 +9,7 @@ from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union,
 
 from sortedcontainers import SortedSet
 
+from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
 from hivemind.dht.storage import DictionaryDHTValue
@@ -80,6 +81,7 @@ class DHTNode:
             cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
             blacklist_time: float = 5.0, backoff_rate: float = 2.0,
             listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", endpoint: Optional[Endpoint] = None,
+            record_validator: Optional[RecordValidatorBase] = None,
             validate: bool = True, strict: bool = True, **kwargs) -> DHTNode:
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
@@ -136,7 +138,8 @@ class DHTNode:
         self.cache_refresh_task = None
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
-                                                 parallel_rpc, cache_size, listen, listen_on, endpoint, **kwargs)
+                                                 parallel_rpc, cache_size, listen, listen_on, endpoint, record_validator,
+                                                 **kwargs)
         self.port = self.protocol.port
 
         if initial_peers:
@@ -297,7 +300,17 @@ class DHTNode:
             store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
                                       key=key_id.xor_distance, reverse=True)  # ordered so that .pop() returns nearest
             [original_key, *_], current_subkeys, current_values, current_expirations = zip(*key_id_to_data[key_id])
-            binary_values: List[bytes] = list(map(self.protocol.serializer.dumps, current_values))
+
+            key_bytes = key_id.to_bytes()
+            binary_values = []
+            for subkey, value, expiration_time in zip(
+                    current_subkeys, current_values, current_expirations):
+                value_bytes = self.protocol.serializer.dumps(value)
+                if self.protocol.record_validator is not None:
+                    subkey_bytes = self.protocol.serializer.dumps(subkey)
+                    record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
+                    value_bytes = self.protocol.record_validator.sign_value(record)
+                binary_values.append(value_bytes)
 
             while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
                 while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
@@ -420,7 +433,9 @@ class DHTNode:
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
         num_workers = num_workers if num_workers is not None else self.num_workers
         search_results: Dict[DHTID, _SearchState] = {key_id: _SearchState(
-            key_id, sufficient_expiration_time, serializer=self.protocol.serializer) for key_id in key_ids}
+            key_id, sufficient_expiration_time,
+            serializer=self.protocol.serializer,
+            record_validator=self.protocol.record_validator) for key_id in key_ids}
 
         if not _is_refresh:  # if we're already refreshing cache, there's no need to trigger subsequent refreshes
             for key_id in key_ids:
@@ -609,6 +624,7 @@ class _SearchState:
     source_node_id: Optional[DHTID] = None  # node that gave us the value
     future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
     serializer: type(SerializerBase) = MSGPackSerializer
+    record_validator: Optional[RecordValidatorBase] = None
 
     def add_candidate(self, candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
                       source_node_id: Optional[DHTID]):
@@ -637,10 +653,25 @@ class _SearchState:
         elif not self.found_something:
             self.future.set_result(None)
         elif isinstance(self.binary_value, BinaryDHTValue):
-            self.future.set_result(ValueWithExpiration(self.serializer.loads(self.binary_value), self.expiration_time))
+            value_bytes = self.binary_value
+            if self.record_validator is not None:
+                record = DHTRecord(self.key_id.to_bytes(), DHTProtocol.IS_REGULAR_VALUE,
+                                   value_bytes, self.expiration_time)
+                value_bytes = self.record_validator.strip_value(record)
+
+            self.future.set_result(
+                ValueWithExpiration(self.serializer.loads(value_bytes), self.expiration_time))
         elif isinstance(self.binary_value, DictionaryDHTValue):
-            dict_with_subkeys = {key: ValueWithExpiration(self.serializer.loads(value), item_expiration_time)
-                                 for key, (value, item_expiration_time) in self.binary_value.items()}
+            dict_with_subkeys = {}
+            for subkey, (value_bytes, item_expiration_time) in self.binary_value.items():
+                if self.record_validator is not None:
+                    subkey_bytes = self.serializer.dumps(subkey)
+                    record = DHTRecord(self.key_id.to_bytes(), subkey_bytes,
+                                       value_bytes, item_expiration_time)
+                    value_bytes = self.record_validator.strip_value(record)
+
+                dict_with_subkeys[subkey] = ValueWithExpiration(
+                    self.serializer.loads(value_bytes), item_expiration_time)
             self.future.set_result(ValueWithExpiration(dict_with_subkeys, self.expiration_time))
         else:
             logger.error(f"Invalid value type: {type(self.binary_value)}")

+ 54 - 8
hivemind/dht/protocol.py

@@ -2,10 +2,12 @@
 from __future__ import annotations
 
 import asyncio
+from itertools import zip_longest
 from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 
 import grpc
 
+from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
@@ -20,6 +22,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     channel_options: Tuple[Tuple[str, Any]]; server: grpc.aio.Server
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
+    record_validator: Optional[RecordValidatorBase]
     # fmt:on
 
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
@@ -30,6 +33,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
             parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
             listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
+            record_validator: Optional[RecordValidatorBase] = None,
             channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
@@ -49,6 +53,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
+        self.record_validator = record_validator
 
         if listen:  # set up server to process incoming rpc requests
             grpc.aio.init_grpc_aio()
@@ -221,17 +226,28 @@ class DHTProtocol(dht_grpc.DHTServicer):
             asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
-        keys = map(DHTID.from_bytes, request.keys)
-        for key_id, tag, value_bytes, expiration_time, in_cache in zip(
-                keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
+        for key, tag, value_bytes, expiration_time, in_cache in zip(
+                request.keys, request.subkeys, request.values, request.expiration_time, request.in_cache):
+            key_id = DHTID.from_bytes(key)
             storage = self.cache if in_cache else self.storage
-            if tag == self.IS_REGULAR_VALUE:  # store normal value without subkeys
-                response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
-            elif tag == self.IS_DICTIONARY:  # store an entire dictionary with several subkeys
+
+            if tag == self.IS_DICTIONARY:  # store an entire dictionary with several subkeys
                 value_dictionary = self.serializer.loads(value_bytes)
                 assert isinstance(value_dictionary, DictionaryDHTValue)
+                if not self._validate_dictionary(key, value_dictionary):
+                    response.store_ok.append(False)
+                    continue
+
                 response.store_ok.append(all(storage.store_subkey(key_id, subkey, item.value, item.expiration_time)
                                              for subkey, item in value_dictionary.items()))
+                continue
+
+            if not self._validate_record(key, tag, value_bytes, expiration_time):
+                response.store_ok.append(False)
+                continue
+
+            if tag == self.IS_REGULAR_VALUE:  # store normal value without subkeys
+                response.store_ok.append(storage.store(key_id, value_bytes, expiration_time))
             else:  # add a new entry into an existing dictionary value or create a new dictionary with one sub-key
                 subkey = self.serializer.loads(tag)
                 response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
@@ -261,15 +277,25 @@ class DHTProtocol(dht_grpc.DHTServicer):
 
             output = {}  # unpack data depending on its type
             for key, result in zip(keys, response.results):
+                key_bytes = DHTID.to_bytes(key)
                 nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))
 
                 if result.type == dht_pb2.NOT_FOUND:
                     output[key] = None, nearest
                 elif result.type == dht_pb2.FOUND_REGULAR:
+                    if not self._validate_record(
+                            key_bytes, self.IS_REGULAR_VALUE, result.value, result.expiration_time):
+                        output[key] = None, nearest
+                        continue
+
                     output[key] = ValueWithExpiration(result.value, result.expiration_time), nearest
                 elif result.type == dht_pb2.FOUND_DICTIONARY:
-                    deserialized_dictionary = self.serializer.loads(result.value)
-                    output[key] = ValueWithExpiration(deserialized_dictionary, result.expiration_time), nearest
+                    value_dictionary = self.serializer.loads(result.value)
+                    if not self._validate_dictionary(key_bytes, value_dictionary):
+                        output[key] = None, nearest
+                        continue
+
+                    output[key] = ValueWithExpiration(value_dictionary, result.expiration_time), nearest
                 else:
                     logger.error(f"Unknown result type: {result.type}")
 
@@ -346,6 +372,26 @@ class DHTProtocol(dht_grpc.DHTServicer):
             if node_id is not None and node_id in self.routing_table:
                 del self.routing_table[node_id]
 
+    def _validate_record(self, key_bytes: bytes, subkey_bytes: bytes, value_bytes: bytes,
+                         expiration_time: float) -> bool:
+        if self.record_validator is None:
+            return True
+
+        record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
+        return self.record_validator.validate(record)
+
+    def _validate_dictionary(self, key_bytes: bytes, dictionary: DictionaryDHTValue) -> bool:
+        if self.record_validator is None:
+            return True
+
+        with dictionary.freeze():
+            for subkey, (value_bytes, expiration_time) in dictionary.items():
+                subkey_bytes = self.serializer.dumps(subkey)
+                record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
+                if not self.record_validator.validate(record):
+                    return False
+        return True
+
 
 class ValidationError(Exception):
     """ This exception is thrown if DHT node didn't pass validation by other nodes. """

+ 54 - 0
hivemind/dht/validation.py

@@ -0,0 +1,54 @@
+import dataclasses
+from abc import ABC, abstractmethod
+
+
+@dataclasses.dataclass(init=True, repr=True, frozen=True)
+class DHTRecord:
+    key: bytes
+    subkey: bytes
+    value: bytes
+    expiration_time: float
+
+
+class RecordValidatorBase(ABC):
+    """
+    Record validators are a generic mechanism for checking the DHT records including:
+      - Enforcing a data schema (e.g. checking content types)
+      - Enforcing security requirements (e.g. allowing only the owner to update the record)
+    """
+
+    @abstractmethod
+    def validate(self, record: DHTRecord) -> bool:
+        """
+        Should return whether the `record` is valid.
+        The valid records should have been extended with sign_value().
+
+        validate() is called when another DHT peer:
+          - Asks us to store the record
+          - Returns the record by our request
+        """
+
+        pass
+
+    def sign_value(self, record: DHTRecord) -> bytes:
+        """
+        Should return `record.value` extended with the record's signature.
+
+        Note: there's no need to overwrite this method if a validator doesn't use a signature.
+
+        sign_value() is called after the application asks the DHT to store the record.
+        """
+
+        return record.value
+
+    def strip_value(self, record: DHTRecord) -> bytes:
+        """
+        Should return `record.value` stripped of the record's signature.
+        strip_value() is only called if validate() was successful.
+
+        Note: there's no need to overwrite this method if a validator doesn't use a signature.
+
+        strip_value() is called before the DHT returns the record by the application's request.
+        """
+
+        return record.value

+ 2 - 1
requirements.txt

@@ -9,4 +9,5 @@ uvloop>=0.14.0
 grpcio>=1.33.2
 grpcio-tools>=1.33.2
 protobuf>=3.12.2
-configargparse>=1.2.3
+configargparse>=1.2.3
+cryptography>=3.4.6

+ 43 - 0
tests/test_dht_crypto.py

@@ -0,0 +1,43 @@
+import dataclasses
+
+import pytest
+
+from hivemind.dht import get_dht_time
+from hivemind.dht.crypto import RSASignatureValidator
+from hivemind.dht.validation import DHTRecord
+
+
+def test_rsa_signature_validator():
+    receiver_validator = RSASignatureValidator()
+    sender_validator = RSASignatureValidator()
+    mallory_validator = RSASignatureValidator()
+
+    plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
+                             expiration_time=get_dht_time() + 10)
+    protected_records = [
+        dataclasses.replace(plain_record,
+                            key=plain_record.key + sender_validator.ownership_marker),
+        dataclasses.replace(plain_record,
+                            subkey=plain_record.subkey + sender_validator.ownership_marker),
+    ]
+
+    # test 1: Non-protected record (no signature added)
+    assert sender_validator.sign_value(plain_record) == plain_record.value
+    assert receiver_validator.validate(plain_record)
+
+    # test 2: Correct signatures
+    signed_records = [dataclasses.replace(record, value=sender_validator.sign_value(record))
+                      for record in protected_records]
+    for record in signed_records:
+        assert receiver_validator.validate(record)
+        assert receiver_validator.strip_value(record) == b'value'
+
+    # test 3: Invalid signatures
+    signed_records = protected_records  # Without signature
+    signed_records += [dataclasses.replace(record,
+                                           value=record.value + b'[signature:INVALID_BYTES]')
+                       for record in protected_records]  # With invalid signature
+    signed_records += [dataclasses.replace(record, value=mallory_validator.sign_value(record))
+                       for record in protected_records]  # With someone else's signature
+    for record in signed_records:
+        assert not receiver_validator.validate(record)

+ 31 - 0
tests/test_dht_node.py

@@ -10,6 +10,7 @@ import pytest
 
 import hivemind
 from hivemind import get_dht_time, replace_port
+from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
 from hivemind.dht.protocol import DHTProtocol, ValidationError
 from hivemind.dht.storage import DictionaryDHTValue
@@ -453,3 +454,33 @@ async def test_dhtnode_edge_cases():
         assert stored is not None
         assert subkey in stored.value
         assert stored.value[subkey].value == value
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_signatures():
+    alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
+    bob = await hivemind.DHTNode.create(
+        record_validator=RSASignatureValidator(), initial_peers=[f"{LOCALHOST}:{alice.port}"])
+    mallory = await hivemind.DHTNode.create(
+        record_validator=RSASignatureValidator(), initial_peers=[f"{LOCALHOST}:{alice.port}"])
+
+    key = b'key'
+    subkey = b'protected_subkey' + bob.protocol.record_validator.ownership_marker
+
+    assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
+
+    store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    assert not store_ok
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
+
+    assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey)
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
+
+    await bob.shutdown()  # Bob has shut down, now Mallory is the single peer of Alice
+
+    store_ok = await mallory.store(key, b'updated_fake_value',
+                                   hivemind.get_dht_time() + 10, subkey=subkey)
+    assert not store_ok
+    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'