12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- import dataclasses
- import re
- from typing import Optional
- from hivemind.dht.validation import DHTRecord, RecordValidatorBase
- from hivemind.utils import MSGPackSerializer, get_logger
- from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
- logger = get_logger(__name__)
- class RSASignatureValidator(RecordValidatorBase):
- """
- Introduces a notion of *protected records* whose key/subkey contains substring
- "[owner:ssh-rsa ...]" with an RSA public key of the owner.
- If this validator is used, changes to such records always must be signed with
- the corresponding private key (so only the owner can change them).
- """
- PUBLIC_KEY_FORMAT = b"[owner:_key_]"
- SIGNATURE_FORMAT = b"[signature:_value_]"
- PUBLIC_KEY_REGEX = re.escape(PUBLIC_KEY_FORMAT).replace(b"_key_", rb"(.+?)")
- _PUBLIC_KEY_RE = re.compile(PUBLIC_KEY_REGEX)
- _SIGNATURE_RE = re.compile(re.escape(SIGNATURE_FORMAT).replace(b"_value_", rb"(.+?)"))
- _cached_private_key = None
- def __init__(self, private_key: Optional[RSAPrivateKey] = None):
- if private_key is None:
- private_key = RSAPrivateKey.process_wide()
- self._private_key = private_key
- serialized_public_key = private_key.get_public_key().to_bytes()
- self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b"_key_", serialized_public_key)
- @property
- def local_public_key(self) -> bytes:
- return self._local_public_key
- def validate(self, record: DHTRecord) -> bool:
- public_keys = self._PUBLIC_KEY_RE.findall(record.key)
- if record.subkey is not None:
- public_keys += self._PUBLIC_KEY_RE.findall(record.subkey)
- if not public_keys:
- return True # The record is not protected with a public key
- if len(set(public_keys)) > 1:
- logger.debug(f"Key and subkey can't contain different public keys in {record}")
- return False
- public_key = RSAPublicKey.from_bytes(public_keys[0])
- signatures = self._SIGNATURE_RE.findall(record.value)
- if len(signatures) != 1:
- logger.debug(f"Record should have exactly one signature in {record}")
- return False
- signature = signatures[0]
- stripped_record = dataclasses.replace(record, value=self.strip_value(record))
- if not public_key.verify(self._serialize_record(stripped_record), signature):
- logger.debug(f"Signature is invalid in {record}")
- return False
- return True
- def sign_value(self, record: DHTRecord) -> bytes:
- if self._local_public_key not in record.key and self._local_public_key not in record.subkey:
- return record.value
- signature = self._private_key.sign(self._serialize_record(record))
- return record.value + self.SIGNATURE_FORMAT.replace(b"_value_", signature)
- def strip_value(self, record: DHTRecord) -> bytes:
- return self._SIGNATURE_RE.sub(b"", record.value)
- def _serialize_record(self, record: DHTRecord) -> bytes:
- return MSGPackSerializer.dumps(dataclasses.astuple(record))
- @property
- def priority(self) -> int:
- # On validation, this validator must be executed before validators
- # that deserialize the record
- return 10
- def merge_with(self, other: RecordValidatorBase) -> bool:
- if not isinstance(other, RSASignatureValidator):
- return False
- # Ignore another RSASignatureValidator instance (it doesn't make sense to have several
- # instances of this class) and report successful merge
- return True
|