auth.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import asyncio
  2. import functools
  3. import secrets
  4. import threading
  5. import time
  6. from abc import ABC, abstractmethod
  7. from enum import Enum
  8. from datetime import timedelta
  9. from typing import Optional, Tuple
  10. import grpc
  11. from hivemind.proto.auth_pb2 import AccessToken, RequestAuthInfo, ResponseAuthInfo
  12. from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
  13. from hivemind.utils.logging import get_logger
  14. from hivemind.utils.timed_storage import TimedStorage, get_dht_time
  15. logger = get_logger(__name__)
  16. class AuthorizedRequestBase:
  17. """
  18. Interface for protobufs with the ``RequestAuthInfo auth`` field. Used for type annotations only.
  19. """
  20. auth: RequestAuthInfo
  21. class AuthorizedResponseBase:
  22. """
  23. Interface for protobufs with the ``ResponseAuthInfo auth`` field. Used for type annotations only.
  24. """
  25. auth: ResponseAuthInfo
  26. class AuthorizerBase(ABC):
  27. @abstractmethod
  28. async def sign_request(self, request: AuthorizedRequestBase, service_public_key: Optional[RSAPublicKey]) -> None:
  29. ...
  30. @abstractmethod
  31. async def validate_request(self, request: AuthorizedRequestBase) -> bool:
  32. ...
  33. @abstractmethod
  34. async def sign_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> None:
  35. ...
  36. @abstractmethod
  37. async def validate_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> bool:
  38. ...
  39. class TokenAuthorizerBase(AuthorizerBase):
  40. """
  41. Implements the authorization protocol for a moderated Hivemind network.
  42. See https://github.com/learning-at-home/hivemind/issues/253
  43. """
  44. def __init__(self, local_private_key: Optional[RSAPrivateKey]=None):
  45. if local_private_key is None:
  46. local_private_key = RSAPrivateKey.process_wide()
  47. self._local_private_key = local_private_key
  48. self._local_public_key = local_private_key.get_public_key()
  49. self._local_access_token = None
  50. self._refresh_lock = asyncio.Lock()
  51. self._recent_nonces = TimedStorage()
  52. @abstractmethod
  53. async def get_token(self) -> AccessToken:
  54. ...
  55. @abstractmethod
  56. def is_token_valid(self, access_token: AccessToken) -> bool:
  57. ...
  58. @abstractmethod
  59. def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
  60. ...
  61. async def refresh_token_if_needed(self) -> None:
  62. if self._local_access_token is None or self.does_token_need_refreshing(self._local_access_token):
  63. async with self._refresh_lock:
  64. if self._local_access_token is None or self.does_token_need_refreshing(self._local_access_token):
  65. self._local_access_token = await self.get_token()
  66. assert self.is_token_valid(self._local_access_token)
  67. @property
  68. def local_public_key(self) -> RSAPublicKey:
  69. return self._local_public_key
  70. async def sign_request(self, request: AuthorizedRequestBase, service_public_key: Optional[RSAPublicKey]) -> None:
  71. await self.refresh_token_if_needed()
  72. auth = request.auth
  73. auth.client_access_token.CopyFrom(self._local_access_token)
  74. if service_public_key is not None:
  75. auth.service_public_key = service_public_key.to_bytes()
  76. auth.time = get_dht_time()
  77. auth.nonce = secrets.token_bytes(8)
  78. assert auth.signature == b''
  79. auth.signature = self._local_private_key.sign(request.SerializeToString())
  80. _MAX_CLIENT_SERVICER_TIME_DIFF = timedelta(minutes=1)
  81. async def validate_request(self, request: AuthorizedRequestBase) -> bool:
  82. await self.refresh_token_if_needed()
  83. auth = request.auth
  84. if not self.is_token_valid(auth.client_access_token):
  85. logger.debug('Client failed to prove that it (still) has access to the network')
  86. return False
  87. client_public_key = RSAPublicKey.from_bytes(auth.client_access_token.public_key)
  88. signature = auth.signature
  89. auth.signature = b''
  90. if not client_public_key.verify(request.SerializeToString(), signature):
  91. logger.debug('Request has invalid signature')
  92. return False
  93. if auth.service_public_key and auth.service_public_key != self._local_public_key.to_bytes():
  94. logger.debug('Request is generated for a peer with another public key')
  95. return False
  96. with self._recent_nonces.freeze():
  97. current_time = get_dht_time()
  98. if abs(auth.time - current_time) > self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds():
  99. logger.debug('Clocks are not synchronized or a previous request is replayed again')
  100. return False
  101. if auth.nonce in self._recent_nonces:
  102. logger.debug('Previous request is replayed again')
  103. return False
  104. self._recent_nonces.store(auth.nonce, None,
  105. current_time + self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds() * 3)
  106. return True
  107. async def sign_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> None:
  108. await self.refresh_token_if_needed()
  109. auth = response.auth
  110. auth.service_access_token.CopyFrom(self._local_access_token)
  111. auth.nonce = request.auth.nonce
  112. assert auth.signature == b''
  113. auth.signature = self._local_private_key.sign(response.SerializeToString())
  114. async def validate_response(self, response: AuthorizedResponseBase, request: AuthorizedRequestBase) -> bool:
  115. await self.refresh_token_if_needed()
  116. auth = response.auth
  117. if not self.is_token_valid(auth.service_access_token):
  118. logger.debug('Service failed to prove that it (still) has access to the network')
  119. return False
  120. service_public_key = RSAPublicKey.from_bytes(auth.service_access_token.public_key)
  121. signature = auth.signature
  122. auth.signature = b''
  123. if not service_public_key.verify(response.SerializeToString(), signature):
  124. logger.debug('Response has invalid signature')
  125. return False
  126. if auth.nonce != request.auth.nonce:
  127. logger.debug('Response is generated for another request')
  128. return False
  129. return True
  130. class AuthRole(Enum):
  131. CLIENT = 0
  132. SERVICER = 1
  133. class AuthRPCWrapper:
  134. def __init__(self, stub, role: AuthRole,
  135. authorizer: Optional[AuthorizerBase], service_public_key: Optional[RSAPublicKey]=None):
  136. self._stub = stub
  137. self._role = role
  138. self._authorizer = authorizer
  139. self._service_public_key = service_public_key
  140. def __getattribute__(self, name: str):
  141. if not name.startswith('rpc_'):
  142. return object.__getattribute__(self, name)
  143. method = getattr(self._stub, name)
  144. @functools.wraps(method)
  145. async def wrapped_rpc(request: AuthorizedRequestBase, *args, **kwargs):
  146. if self._authorizer is not None:
  147. if self._role == AuthRole.CLIENT:
  148. await self._authorizer.sign_request(request, self._service_public_key)
  149. elif self._role == AuthRole.SERVICER:
  150. if not await self._authorizer.validate_request(request):
  151. return None
  152. response = await method(request, *args, **kwargs)
  153. if self._authorizer is not None:
  154. if self._role == AuthRole.SERVICER:
  155. await self._authorizer.sign_response(response, request)
  156. elif self._role == AuthRole.CLIENT:
  157. if not await self._authorizer.validate_response(response, request):
  158. return None
  159. return response
  160. return wrapped_rpc