auth.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import asyncio
  2. import functools
  3. import secrets
  4. from abc import ABC, abstractmethod
  5. from datetime import timedelta
  6. from enum import Enum
  7. from typing import Optional
  8. from hivemind.proto.auth_pb2 import AccessToken, RequestAuthInfo, ResponseAuthInfo
  9. from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
  10. from hivemind.utils.logging import get_logger
  11. from hivemind.utils.timed_storage import TimedStorage, get_dht_time
  12. logger = get_logger(__name__)
  13. class AuthorizedRequestBase:
  14. """
  15. Interface for protobufs with the ``RequestAuthInfo auth`` field. Used for type annotations only.
  16. """
  17. auth: RequestAuthInfo
  18. class AuthorizedResponseBase:
  19. """
  20. Interface for protobufs with the ``ResponseAuthInfo auth`` field. Used for type annotations only.
  21. """
  22. auth: ResponseAuthInfo
  23. class AuthorizerBase(ABC):
  24. @abstractmethod
  25. async def sign_request(self, request: AuthorizedRequestBase, service_public_key: Optional[RSAPublicKey]) -> None:
  26. ...
  27. @abstractmethod
  28. async def validate_request(self, request: AuthorizedRequestBase) -> bool:
  29. ...
  30. @abstractmethod
  31. async def sign_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> None:
  32. ...
  33. @abstractmethod
  34. async def validate_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> bool:
  35. ...
  36. class TokenAuthorizerBase(AuthorizerBase):
  37. """
  38. Implements the authorization protocol for a moderated Hivemind network.
  39. See https://github.com/learning-at-home/hivemind/issues/253
  40. """
  41. def __init__(self, local_private_key: Optional[RSAPrivateKey] = None):
  42. if local_private_key is None:
  43. local_private_key = RSAPrivateKey.process_wide()
  44. self._local_private_key = local_private_key
  45. self._local_public_key = local_private_key.get_public_key()
  46. self._local_access_token = None
  47. self._refresh_lock = asyncio.Lock()
  48. self._recent_nonces = TimedStorage()
  49. @abstractmethod
  50. async def get_token(self) -> AccessToken:
  51. ...
  52. @abstractmethod
  53. def is_token_valid(self, access_token: AccessToken) -> bool:
  54. ...
  55. @abstractmethod
  56. def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
  57. ...
  58. async def refresh_token_if_needed(self) -> None:
  59. if self._local_access_token is None or self.does_token_need_refreshing(self._local_access_token):
  60. async with self._refresh_lock:
  61. if self._local_access_token is None or self.does_token_need_refreshing(self._local_access_token):
  62. self._local_access_token = await self.get_token()
  63. assert self.is_token_valid(self._local_access_token)
  64. @property
  65. def local_public_key(self) -> RSAPublicKey:
  66. return self._local_public_key
  67. async def sign_request(self, request: AuthorizedRequestBase, service_public_key: Optional[RSAPublicKey]) -> None:
  68. await self.refresh_token_if_needed()
  69. auth = request.auth
  70. auth.client_access_token.CopyFrom(self._local_access_token)
  71. if service_public_key is not None:
  72. auth.service_public_key = service_public_key.to_bytes()
  73. auth.time = get_dht_time()
  74. auth.nonce = secrets.token_bytes(8)
  75. assert auth.signature == b""
  76. auth.signature = self._local_private_key.sign(request.SerializeToString())
  77. _MAX_CLIENT_SERVICER_TIME_DIFF = timedelta(minutes=1)
  78. async def validate_request(self, request: AuthorizedRequestBase) -> bool:
  79. await self.refresh_token_if_needed()
  80. auth = request.auth
  81. if not self.is_token_valid(auth.client_access_token):
  82. logger.debug("Client failed to prove that it (still) has access to the network")
  83. return False
  84. client_public_key = RSAPublicKey.from_bytes(auth.client_access_token.public_key)
  85. signature = auth.signature
  86. auth.signature = b""
  87. if not client_public_key.verify(request.SerializeToString(), signature):
  88. logger.debug("Request has invalid signature")
  89. return False
  90. if auth.service_public_key and auth.service_public_key != self._local_public_key.to_bytes():
  91. logger.debug("Request is generated for a peer with another public key")
  92. return False
  93. with self._recent_nonces.freeze():
  94. current_time = get_dht_time()
  95. if abs(auth.time - current_time) > self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds():
  96. logger.debug("Clocks are not synchronized or a previous request is replayed again")
  97. return False
  98. if auth.nonce in self._recent_nonces:
  99. logger.debug("Previous request is replayed again")
  100. return False
  101. self._recent_nonces.store(
  102. auth.nonce, None, current_time + self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds() * 3
  103. )
  104. return True
  105. async def sign_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> None:
  106. await self.refresh_token_if_needed()
  107. auth = response.auth
  108. auth.service_access_token.CopyFrom(self._local_access_token)
  109. auth.nonce = request.auth.nonce
  110. assert auth.signature == b""
  111. auth.signature = self._local_private_key.sign(response.SerializeToString())
  112. async def validate_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> bool:
  113. await self.refresh_token_if_needed()
  114. auth = response.auth
  115. if not self.is_token_valid(auth.service_access_token):
  116. logger.debug("Service failed to prove that it (still) has access to the network")
  117. return False
  118. service_public_key = RSAPublicKey.from_bytes(auth.service_access_token.public_key)
  119. signature = auth.signature
  120. auth.signature = b""
  121. if not service_public_key.verify(response.SerializeToString(), signature):
  122. logger.debug("Response has invalid signature")
  123. return False
  124. if auth.nonce != request.auth.nonce:
  125. logger.debug("Response is generated for another request")
  126. return False
  127. return True
  128. class AuthRole(Enum):
  129. CLIENT = 0
  130. SERVICER = 1
  131. class AuthRPCWrapper:
  132. def __init__(
  133. self,
  134. stub,
  135. role: AuthRole,
  136. authorizer: Optional[AuthorizerBase],
  137. service_public_key: Optional[RSAPublicKey] = None,
  138. ):
  139. self._stub = stub
  140. self._role = role
  141. self._authorizer = authorizer
  142. self._service_public_key = service_public_key
  143. def __getattribute__(self, name: str):
  144. if not name.startswith("rpc_"):
  145. return object.__getattribute__(self, name)
  146. method = getattr(self._stub, name)
  147. @functools.wraps(method)
  148. async def wrapped_rpc(request: AuthorizedRequestBase, *args, **kwargs):
  149. if self._authorizer is not None:
  150. if self._role == AuthRole.CLIENT:
  151. await self._authorizer.sign_request(request, self._service_public_key)
  152. elif self._role == AuthRole.SERVICER:
  153. if not await self._authorizer.validate_request(request):
  154. return None
  155. response = await method(request, *args, **kwargs)
  156. if self._authorizer is not None:
  157. if self._role == AuthRole.SERVICER:
  158. await self._authorizer.sign_response(response, request)
  159. elif self._role == AuthRole.CLIENT:
  160. if not await self._authorizer.validate_response(response, request):
  161. return None
  162. return response
  163. return wrapped_rpc