crypto.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import dataclasses
  2. import re
  3. from typing import Optional
  4. from hivemind.dht.validation import DHTRecord, RecordValidatorBase
  5. from hivemind.utils import MSGPackSerializer, get_logger
  6. from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
  7. logger = get_logger(__name__)
  8. class RSASignatureValidator(RecordValidatorBase):
  9. """
  10. Introduces a notion of *protected records* whose key/subkey contains substring
  11. "[owner:ssh-rsa ...]" with an RSA public key of the owner.
  12. If this validator is used, changes to such records always must be signed with
  13. the corresponding private key (so only the owner can change them).
  14. """
  15. PUBLIC_KEY_FORMAT = b"[owner:_key_]"
  16. SIGNATURE_FORMAT = b"[signature:_value_]"
  17. PUBLIC_KEY_REGEX = re.escape(PUBLIC_KEY_FORMAT).replace(b"_key_", rb"(.+?)")
  18. _PUBLIC_KEY_RE = re.compile(PUBLIC_KEY_REGEX)
  19. _SIGNATURE_RE = re.compile(re.escape(SIGNATURE_FORMAT).replace(b"_value_", rb"(.+?)"))
  20. _cached_private_key = None
  21. def __init__(self, private_key: Optional[RSAPrivateKey] = None):
  22. if private_key is None:
  23. private_key = RSAPrivateKey.process_wide()
  24. self._private_key = private_key
  25. serialized_public_key = private_key.get_public_key().to_bytes()
  26. self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b"_key_", serialized_public_key)
  27. @property
  28. def local_public_key(self) -> bytes:
  29. return self._local_public_key
  30. def validate(self, record: DHTRecord) -> bool:
  31. public_keys = self._PUBLIC_KEY_RE.findall(record.key)
  32. if record.subkey is not None:
  33. public_keys += self._PUBLIC_KEY_RE.findall(record.subkey)
  34. if not public_keys:
  35. return True # The record is not protected with a public key
  36. if len(set(public_keys)) > 1:
  37. logger.debug(f"Key and subkey can't contain different public keys in {record}")
  38. return False
  39. public_key = RSAPublicKey.from_bytes(public_keys[0])
  40. signatures = self._SIGNATURE_RE.findall(record.value)
  41. if len(signatures) != 1:
  42. logger.debug(f"Record should have exactly one signature in {record}")
  43. return False
  44. signature = signatures[0]
  45. stripped_record = dataclasses.replace(record, value=self.strip_value(record))
  46. if not public_key.verify(self._serialize_record(stripped_record), signature):
  47. logger.debug(f"Signature is invalid in {record}")
  48. return False
  49. return True
  50. def sign_value(self, record: DHTRecord) -> bytes:
  51. if self._local_public_key not in record.key and self._local_public_key not in record.subkey:
  52. return record.value
  53. signature = self._private_key.sign(self._serialize_record(record))
  54. return record.value + self.SIGNATURE_FORMAT.replace(b"_value_", signature)
  55. def strip_value(self, record: DHTRecord) -> bytes:
  56. return self._SIGNATURE_RE.sub(b"", record.value)
  57. def _serialize_record(self, record: DHTRecord) -> bytes:
  58. return MSGPackSerializer.dumps(dataclasses.astuple(record))
  59. @property
  60. def priority(self) -> int:
  61. # On validation, this validator must be executed before validators
  62. # that deserialize the record
  63. return 10
  64. def merge_with(self, other: RecordValidatorBase) -> bool:
  65. if not isinstance(other, RSASignatureValidator):
  66. return False
  67. # Ignore another RSASignatureValidator instance (it doesn't make sense to have several
  68. # instances of this class) and report successful merge
  69. return True