test_auth.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from datetime import datetime, timedelta
  2. from typing import Optional
  3. import pytest
  4. from hivemind.proto import dht_pb2
  5. from hivemind.proto.auth_pb2 import AccessToken
  6. from hivemind.utils.auth import AuthRole, AuthRPCWrapper, TokenAuthorizerBase
  7. from hivemind.utils.crypto import RSAPrivateKey
  8. from hivemind.utils.logging import get_logger
  9. logger = get_logger(__name__)
  10. class MockAuthorizer(TokenAuthorizerBase):
  11. _authority_private_key = None
  12. _authority_public_key = None
  13. def __init__(self, local_private_key: Optional[RSAPrivateKey], username: str = "mock"):
  14. super().__init__(local_private_key)
  15. self._username = username
  16. self._authority_public_key = None
  17. async def get_token(self) -> AccessToken:
  18. if MockAuthorizer._authority_private_key is None:
  19. MockAuthorizer._authority_private_key = RSAPrivateKey()
  20. self._authority_public_key = MockAuthorizer._authority_private_key.get_public_key()
  21. token = AccessToken(
  22. username=self._username,
  23. public_key=self.local_public_key.to_bytes(),
  24. expiration_time=str(datetime.utcnow() + timedelta(minutes=1)),
  25. )
  26. token.signature = MockAuthorizer._authority_private_key.sign(self._token_to_bytes(token))
  27. return token
  28. def is_token_valid(self, access_token: AccessToken) -> bool:
  29. data = self._token_to_bytes(access_token)
  30. if not self._authority_public_key.verify(data, access_token.signature):
  31. logger.exception("Access token has invalid signature")
  32. return False
  33. try:
  34. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  35. except ValueError:
  36. logger.exception(
  37. f"datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}"
  38. )
  39. return False
  40. if expiration_time.tzinfo is not None:
  41. logger.exception(f"Expected to have no timezone for expiration time: {access_token.expiration_time}")
  42. return False
  43. if expiration_time < datetime.utcnow():
  44. logger.exception("Access token has expired")
  45. return False
  46. return True
  47. _MAX_LATENCY = timedelta(minutes=1)
  48. def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
  49. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  50. return expiration_time < datetime.utcnow() + self._MAX_LATENCY
  51. @staticmethod
  52. def _token_to_bytes(access_token: AccessToken) -> bytes:
  53. return f"{access_token.username} {access_token.public_key} {access_token.expiration_time}".encode()
  54. @pytest.mark.asyncio
  55. async def test_valid_request_and_response():
  56. client_authorizer = MockAuthorizer(RSAPrivateKey())
  57. service_authorizer = MockAuthorizer(RSAPrivateKey())
  58. request = dht_pb2.PingRequest()
  59. request.peer.node_id = b"ping"
  60. await client_authorizer.sign_request(request, service_authorizer.local_public_key)
  61. assert await service_authorizer.validate_request(request)
  62. response = dht_pb2.PingResponse()
  63. response.peer.node_id = b"pong"
  64. await service_authorizer.sign_response(response, request)
  65. assert await client_authorizer.validate_response(response, request)
  66. @pytest.mark.asyncio
  67. async def test_invalid_access_token():
  68. client_authorizer = MockAuthorizer(RSAPrivateKey())
  69. service_authorizer = MockAuthorizer(RSAPrivateKey())
  70. request = dht_pb2.PingRequest()
  71. request.peer.node_id = b"ping"
  72. await client_authorizer.sign_request(request, service_authorizer.local_public_key)
  73. # Break the access token signature
  74. request.auth.client_access_token.signature = b"broken"
  75. assert not await service_authorizer.validate_request(request)
  76. response = dht_pb2.PingResponse()
  77. response.peer.node_id = b"pong"
  78. await service_authorizer.sign_response(response, request)
  79. # Break the access token signature
  80. response.auth.service_access_token.signature = b"broken"
  81. assert not await client_authorizer.validate_response(response, request)
  82. @pytest.mark.asyncio
  83. async def test_invalid_signatures():
  84. client_authorizer = MockAuthorizer(RSAPrivateKey())
  85. service_authorizer = MockAuthorizer(RSAPrivateKey())
  86. request = dht_pb2.PingRequest()
  87. request.peer.node_id = b"true-ping"
  88. await client_authorizer.sign_request(request, service_authorizer.local_public_key)
  89. # A man-in-the-middle attacker changes the request content
  90. request.peer.node_id = b"fake-ping"
  91. assert not await service_authorizer.validate_request(request)
  92. response = dht_pb2.PingResponse()
  93. response.peer.node_id = b"true-pong"
  94. await service_authorizer.sign_response(response, request)
  95. # A man-in-the-middle attacker changes the response content
  96. response.peer.node_id = b"fake-pong"
  97. assert not await client_authorizer.validate_response(response, request)
  98. @pytest.mark.asyncio
  99. async def test_auth_rpc_wrapper():
  100. class Servicer:
  101. async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
  102. assert request.peer.node_id == b"ping"
  103. assert request.auth.client_access_token.username == "alice"
  104. response = dht_pb2.PingResponse()
  105. response.peer.node_id = b"pong"
  106. return response
  107. class Client:
  108. def __init__(self, servicer: Servicer):
  109. self._servicer = servicer
  110. async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
  111. return await self._servicer.rpc_increment(request)
  112. servicer = AuthRPCWrapper(Servicer(), AuthRole.SERVICER, MockAuthorizer(RSAPrivateKey(), "bob"))
  113. client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), "alice"))
  114. request = dht_pb2.PingRequest()
  115. request.peer.node_id = b"ping"
  116. response = await client.rpc_increment(request)
  117. assert response.peer.node_id == b"pong"
  118. assert response.auth.service_access_token.username == "bob"