test_auth.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from datetime import datetime, timedelta
  2. from typing import Optional
  3. import pytest
  4. from hivemind.proto import dht_pb2
  5. from hivemind.proto.auth_pb2 import AccessToken
  6. from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
  7. from hivemind.utils.crypto import RSAPrivateKey
  8. from hivemind.utils.logging import get_logger
  9. logger = get_logger(__name__)
  10. class MockAuthorizer(TokenAuthorizerBase):
  11. _authority_private_key = None
  12. _authority_public_key = None
  13. def __init__(self, local_private_key: Optional[RSAPrivateKey], username: str='mock'):
  14. super().__init__(local_private_key)
  15. self._username = username
  16. self._authority_public_key = None
  17. async def get_token(self) -> AccessToken:
  18. if MockAuthorizer._authority_private_key is None:
  19. MockAuthorizer._authority_private_key = RSAPrivateKey()
  20. self._authority_public_key = MockAuthorizer._authority_private_key.get_public_key()
  21. token = AccessToken(username=self._username,
  22. public_key=self.local_public_key.to_bytes(),
  23. expiration_time=str(datetime.utcnow() + timedelta(minutes=1)))
  24. token.signature = MockAuthorizer._authority_private_key.sign(self._token_to_bytes(token))
  25. return token
  26. def is_token_valid(self, access_token: AccessToken) -> bool:
  27. data = self._token_to_bytes(access_token)
  28. if not self._authority_public_key.verify(data, access_token.signature):
  29. logger.exception('Access token has invalid signature')
  30. return False
  31. try:
  32. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  33. except ValueError:
  34. logger.exception(
  35. f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
  36. return False
  37. if expiration_time.tzinfo is not None:
  38. logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
  39. return False
  40. if expiration_time < datetime.utcnow():
  41. logger.exception('Access token has expired')
  42. return False
  43. return True
  44. _MAX_LATENCY = timedelta(minutes=1)
  45. def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
  46. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  47. return expiration_time < datetime.utcnow() + self._MAX_LATENCY
  48. @staticmethod
  49. def _token_to_bytes(access_token: AccessToken) -> bytes:
  50. return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
  51. @pytest.mark.asyncio
  52. async def test_valid_request_and_response():
  53. client_authorizer = MockAuthorizer(RSAPrivateKey())
  54. service_authorizer = MockAuthorizer(RSAPrivateKey())
  55. request = dht_pb2.PingRequest()
  56. request.peer.node_id = b'ping'
  57. await client_authorizer.sign_request(request, service_authorizer.local_public_key)
  58. assert await service_authorizer.validate_request(request)
  59. response = dht_pb2.PingResponse()
  60. response.peer.node_id = b'pong'
  61. await service_authorizer.sign_response(response, request)
  62. assert await client_authorizer.validate_response(response, request)
  63. @pytest.mark.asyncio
  64. async def test_invalid_access_token():
  65. client_authorizer = MockAuthorizer(RSAPrivateKey())
  66. service_authorizer = MockAuthorizer(RSAPrivateKey())
  67. request = dht_pb2.PingRequest()
  68. request.peer.node_id = b'ping'
  69. await client_authorizer.sign_request(request, service_authorizer.local_public_key)
  70. # Break the access token signature
  71. request.auth.client_access_token.signature = b'broken'
  72. assert not await service_authorizer.validate_request(request)
  73. response = dht_pb2.PingResponse()
  74. response.peer.node_id = b'pong'
  75. await service_authorizer.sign_response(response, request)
  76. # Break the access token signature
  77. response.auth.service_access_token.signature = b'broken'
  78. assert not await client_authorizer.validate_response(response, request)
  79. @pytest.mark.asyncio
  80. async def test_invalid_signatures():
  81. client_authorizer = MockAuthorizer(RSAPrivateKey())
  82. service_authorizer = MockAuthorizer(RSAPrivateKey())
  83. request = dht_pb2.PingRequest()
  84. request.peer.node_id = b'true-ping'
  85. await client_authorizer.sign_request(request, service_authorizer.local_public_key)
  86. # A man-in-the-middle attacker changes the request content
  87. request.peer.node_id = b'fake-ping'
  88. assert not await service_authorizer.validate_request(request)
  89. response = dht_pb2.PingResponse()
  90. response.peer.node_id = b'true-pong'
  91. await service_authorizer.sign_response(response, request)
  92. # A man-in-the-middle attacker changes the response content
  93. response.peer.node_id = b'fake-pong'
  94. assert not await client_authorizer.validate_response(response, request)
  95. @pytest.mark.asyncio
  96. async def test_auth_rpc_wrapper():
  97. class Servicer:
  98. async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
  99. assert request.peer.node_id == b'ping'
  100. assert request.auth.client_access_token.username == 'alice'
  101. response = dht_pb2.PingResponse()
  102. response.peer.node_id = b'pong'
  103. return response
  104. class Client:
  105. def __init__(self, servicer: Servicer):
  106. self._servicer = servicer
  107. async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
  108. return await self._servicer.rpc_increment(request)
  109. servicer = AuthRPCWrapper(Servicer(), AuthRole.SERVICER, MockAuthorizer(RSAPrivateKey(), 'bob'))
  110. client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), 'alice'))
  111. request = dht_pb2.PingRequest()
  112. request.peer.node_id = b'ping'
  113. response = await client.rpc_increment(request)
  114. assert response.peer.node_id == b'pong'
  115. assert response.auth.service_access_token.username == 'bob'