crypto.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from __future__ import annotations
  2. import base64
  3. import threading
  4. from abc import ABC, abstractmethod
  5. from cryptography import exceptions
  6. from cryptography.hazmat.primitives import hashes, serialization
  7. from cryptography.hazmat.primitives.asymmetric import padding, rsa
  8. class PrivateKey(ABC):
  9. @abstractmethod
  10. def sign(self, data: bytes) -> bytes:
  11. ...
  12. @abstractmethod
  13. def get_public_key(self) -> PublicKey:
  14. ...
  15. class PublicKey(ABC):
  16. @abstractmethod
  17. def verify(self, data: bytes, signature: bytes) -> bool:
  18. ...
  19. @abstractmethod
  20. def to_bytes(self) -> bytes:
  21. ...
  22. @classmethod
  23. @abstractmethod
  24. def from_bytes(cls, key: bytes) -> bytes:
  25. ...
  26. _RSA_PADDING = padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH)
  27. _RSA_HASH_ALGORITHM = hashes.SHA256()
  28. class RSAPrivateKey(PrivateKey):
  29. def __init__(self):
  30. self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
  31. _process_wide_key = None
  32. _process_wide_key_lock = threading.RLock()
  33. @classmethod
  34. def process_wide(cls) -> RSAPrivateKey:
  35. if cls._process_wide_key is None:
  36. with cls._process_wide_key_lock:
  37. if cls._process_wide_key is None:
  38. cls._process_wide_key = cls()
  39. return cls._process_wide_key
  40. def sign(self, data: bytes) -> bytes:
  41. signature = self._private_key.sign(data, _RSA_PADDING, _RSA_HASH_ALGORITHM)
  42. return base64.b64encode(signature)
  43. def get_public_key(self) -> RSAPublicKey:
  44. return RSAPublicKey(self._private_key.public_key())
  45. def __getstate__(self):
  46. state = self.__dict__.copy()
  47. # Serializes the private key to make the class instances picklable
  48. state["_private_key"] = self._private_key.private_bytes(
  49. encoding=serialization.Encoding.PEM,
  50. format=serialization.PrivateFormat.OpenSSH,
  51. encryption_algorithm=serialization.NoEncryption(),
  52. )
  53. return state
  54. def __setstate__(self, state):
  55. self.__dict__.update(state)
  56. self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
  57. class RSAPublicKey(PublicKey):
  58. def __init__(self, public_key: rsa.RSAPublicKey):
  59. self._public_key = public_key
  60. def verify(self, data: bytes, signature: bytes) -> bool:
  61. try:
  62. signature = base64.b64decode(signature)
  63. # Returns None if the signature is correct, raises an exception otherwise
  64. self._public_key.verify(signature, data, _RSA_PADDING, _RSA_HASH_ALGORITHM)
  65. return True
  66. except (ValueError, exceptions.InvalidSignature):
  67. return False
  68. def to_bytes(self) -> bytes:
  69. return self._public_key.public_bytes(
  70. encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH
  71. )
  72. @classmethod
  73. def from_bytes(cls, key: bytes) -> RSAPublicKey:
  74. key = serialization.load_ssh_public_key(key)
  75. if not isinstance(key, rsa.RSAPublicKey):
  76. raise ValueError(f"Expected an RSA public key, got {key}")
  77. return cls(key)