huggingface_auth.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import base64
  2. import os
  3. import time
  4. from datetime import datetime, timedelta
  5. from getpass import getpass
  6. import requests
  7. from huggingface_hub import HfApi
  8. from hivemind.proto.auth_pb2 import AccessToken
  9. from hivemind.utils.auth import TokenAuthorizerBase
  10. from hivemind.utils.crypto import RSAPublicKey
  11. from hivemind.utils.logging import get_logger
  12. logger = get_logger("root." + __name__)
  13. class NonRetriableError(Exception):
  14. pass
  15. def call_with_retries(func, n_retries=10, initial_delay=1.0):
  16. for i in range(n_retries):
  17. try:
  18. return func()
  19. except NonRetriableError:
  20. raise
  21. except Exception as e:
  22. if i == n_retries - 1:
  23. raise
  24. delay = initial_delay * (2 ** i)
  25. logger.warning(f'Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec')
  26. time.sleep(delay)
  27. class InvalidCredentialsError(NonRetriableError):
  28. pass
  29. class NotInAllowlistError(NonRetriableError):
  30. pass
  31. class HuggingFaceAuthorizer(TokenAuthorizerBase):
  32. _AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
  33. def __init__(self, experiment_id: int, username: str, password: str):
  34. super().__init__()
  35. self.experiment_id = experiment_id
  36. self.username = username
  37. self.password = password
  38. self._authority_public_key = None
  39. self.coordinator_ip = None
  40. self.coordinator_port = None
  41. self._hf_api = HfApi()
  42. async def get_token(self) -> AccessToken:
  43. """
  44. Hivemind calls this method to refresh the token when necessary.
  45. """
  46. self.join_experiment()
  47. return self._local_access_token
  48. def join_experiment(self) -> None:
  49. call_with_retries(self._join_experiment)
  50. def _join_experiment(self) -> None:
  51. try:
  52. token = self._hf_api.login(self.username, self.password)
  53. except requests.exceptions.HTTPError as e:
  54. if e.response.status_code == 401: # Unauthorized
  55. raise InvalidCredentialsError()
  56. raise
  57. try:
  58. url = f'{self._AUTH_SERVER_URL}/api/experiments/join/{self.experiment_id}/'
  59. headers = {'Authorization': f'Bearer {token}'}
  60. response = requests.put(url, headers=headers, json={
  61. 'experiment_join_input': {
  62. 'peer_public_key': self.local_public_key.to_bytes().decode(),
  63. },
  64. })
  65. response.raise_for_status()
  66. response = response.json()
  67. self._authority_public_key = RSAPublicKey.from_bytes(response['auth_server_public_key'].encode())
  68. self.coordinator_ip = response['coordinator_ip']
  69. self.coordinator_port = response['coordinator_port']
  70. token_dict = response['hivemind_access']
  71. access_token = AccessToken()
  72. access_token.username = token_dict['username']
  73. access_token.public_key = token_dict['peer_public_key'].encode()
  74. access_token.expiration_time = str(datetime.fromisoformat(token_dict['expiration_time']))
  75. access_token.signature = token_dict['signature'].encode()
  76. self._local_access_token = access_token
  77. logger.info(f'Access for user {access_token.username} '
  78. f'has been granted until {access_token.expiration_time} UTC')
  79. except requests.exceptions.HTTPError as e:
  80. if e.response.status_code == 401: # Unauthorized
  81. raise NotInAllowlistError()
  82. raise
  83. finally:
  84. self._hf_api.logout(token)
  85. def is_token_valid(self, access_token: AccessToken) -> bool:
  86. data = self._token_to_bytes(access_token)
  87. if not self._authority_public_key.verify(data, access_token.signature):
  88. logger.exception('Access token has invalid signature')
  89. return False
  90. try:
  91. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  92. except ValueError:
  93. logger.exception(
  94. f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
  95. return False
  96. if expiration_time.tzinfo is not None:
  97. logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
  98. return False
  99. if expiration_time < datetime.utcnow():
  100. logger.exception('Access token has expired')
  101. return False
  102. return True
  103. _MAX_LATENCY = timedelta(minutes=1)
  104. def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
  105. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  106. return expiration_time < datetime.utcnow() + self._MAX_LATENCY
  107. @staticmethod
  108. def _token_to_bytes(access_token: AccessToken) -> bytes:
  109. return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
  110. def authorize_with_huggingface() -> HuggingFaceAuthorizer:
  111. while True:
  112. experiment_id = os.getenv('HF_EXPERIMENT_ID')
  113. if experiment_id is None:
  114. experiment_id = input('HuggingFace experiment ID: ')
  115. username = os.getenv('HF_USERNAME')
  116. if username is None:
  117. while True:
  118. username = input('HuggingFace username: ')
  119. if '@' not in username:
  120. break
  121. print('Please enter your Huggingface _username_ instead of the email address!')
  122. password = os.getenv('HF_PASSWORD')
  123. if password is None:
  124. password = getpass('HuggingFace password: ')
  125. authorizer = HuggingFaceAuthorizer(experiment_id, username, password)
  126. try:
  127. authorizer.join_experiment()
  128. return authorizer
  129. except InvalidCredentialsError:
  130. print('Invalid username or password, please try again')
  131. except NotInAllowlistError:
  132. print('This account is not specified in the allowlist. '
  133. 'Please ask a moderator to add you to the allowlist and try again')