huggingface_auth.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import os
  2. import time
  3. from datetime import datetime, timedelta
  4. from getpass import getpass
  5. import requests
  6. from huggingface_hub import HfApi
  7. from hivemind.proto.auth_pb2 import AccessToken
  8. from hivemind.utils.auth import TokenAuthorizerBase
  9. from hivemind.utils.crypto import RSAPublicKey
  10. from hivemind.utils.logging import get_logger
  11. logger = get_logger("root." + __name__)
  12. class NonRetriableError(Exception):
  13. pass
  14. def call_with_retries(func, n_retries=10, initial_delay=1.0):
  15. for i in range(n_retries):
  16. try:
  17. return func()
  18. except NonRetriableError:
  19. raise
  20. except Exception as e:
  21. if i == n_retries - 1:
  22. raise
  23. delay = initial_delay * (2 ** i)
  24. logger.warning(f'Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec')
  25. time.sleep(delay)
  26. class InvalidCredentialsError(NonRetriableError):
  27. pass
  28. class NotInAllowlistError(NonRetriableError):
  29. pass
  30. class HuggingFaceAuthorizer(TokenAuthorizerBase):
  31. _AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
  32. def __init__(self, organization_name: str, model_name: str, username: str, password: str):
  33. super().__init__()
  34. self.organization_name = organization_name
  35. self.model_name = model_name
  36. self.username = username
  37. self.password = password
  38. self._authority_public_key = None
  39. self.coordinator_ip = None
  40. self.coordinator_port = None
  41. self._hf_api = HfApi()
  42. async def get_token(self) -> AccessToken:
  43. """
  44. Hivemind calls this method to refresh the token when necessary.
  45. """
  46. self.join_experiment()
  47. return self._local_access_token
  48. def join_experiment(self) -> None:
  49. call_with_retries(self._join_experiment)
  50. def _join_experiment(self) -> None:
  51. try:
  52. token = self._hf_api.login(self.username, self.password)
  53. except requests.exceptions.HTTPError as e:
  54. if e.response.status_code == 401: # Unauthorized
  55. raise InvalidCredentialsError()
  56. raise
  57. try:
  58. url = f'{self._AUTH_SERVER_URL}/api/experiments/join'
  59. headers = {'Authorization': f'Bearer {token}'}
  60. response = requests.put(
  61. url,
  62. headers=headers,
  63. params={
  64. 'organization_name': self.organization_name,
  65. 'model_name': self.model_name,
  66. },
  67. json={
  68. 'experiment_join_input': {
  69. 'peer_public_key': self.local_public_key.to_bytes().decode(),
  70. },
  71. },
  72. verify=False, # FIXME: Update the expired API certificate
  73. )
  74. response.raise_for_status()
  75. response = response.json()
  76. self._authority_public_key = RSAPublicKey.from_bytes(response['auth_server_public_key'].encode())
  77. self.coordinator_ip = response['coordinator_ip']
  78. self.coordinator_port = response['coordinator_port']
  79. token_dict = response['hivemind_access']
  80. access_token = AccessToken()
  81. access_token.username = token_dict['username']
  82. access_token.public_key = token_dict['peer_public_key'].encode()
  83. access_token.expiration_time = str(datetime.fromisoformat(token_dict['expiration_time']))
  84. access_token.signature = token_dict['signature'].encode()
  85. self._local_access_token = access_token
  86. logger.info(f'Access for user {access_token.username} '
  87. f'has been granted until {access_token.expiration_time} UTC')
  88. except requests.exceptions.HTTPError as e:
  89. if e.response.status_code == 401: # Unauthorized
  90. raise NotInAllowlistError()
  91. raise
  92. finally:
  93. self._hf_api.logout(token)
  94. def is_token_valid(self, access_token: AccessToken) -> bool:
  95. data = self._token_to_bytes(access_token)
  96. if not self._authority_public_key.verify(data, access_token.signature):
  97. logger.exception('Access token has invalid signature')
  98. return False
  99. try:
  100. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  101. except ValueError:
  102. logger.exception(
  103. f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
  104. return False
  105. if expiration_time.tzinfo is not None:
  106. logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
  107. return False
  108. if expiration_time < datetime.utcnow():
  109. logger.exception('Access token has expired')
  110. return False
  111. return True
  112. _MAX_LATENCY = timedelta(minutes=1)
  113. def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
  114. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  115. return expiration_time < datetime.utcnow() + self._MAX_LATENCY
  116. @staticmethod
  117. def _token_to_bytes(access_token: AccessToken) -> bytes:
  118. return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
  119. def authorize_with_huggingface() -> HuggingFaceAuthorizer:
  120. while True:
  121. organization_name = os.getenv('HF_ORGANIZATION_NAME')
  122. if organization_name is None:
  123. organization_name = input('HuggingFace organization name: ')
  124. model_name = os.getenv('HF_MODEL_NAME')
  125. if model_name is None:
  126. model_name = input('HuggingFace model name: ')
  127. username = os.getenv('HF_USERNAME')
  128. if username is None:
  129. while True:
  130. username = input('HuggingFace username: ')
  131. if '@' not in username:
  132. break
  133. print('Please enter your Huggingface _username_ instead of the email address!')
  134. password = os.getenv('HF_PASSWORD')
  135. if password is None:
  136. password = getpass('HuggingFace password: ')
  137. authorizer = HuggingFaceAuthorizer(organization_name, model_name, username, password)
  138. try:
  139. authorizer.join_experiment()
  140. return authorizer
  141. except InvalidCredentialsError:
  142. print('Invalid username or password, please try again')
  143. except NotInAllowlistError:
  144. print('This account is not specified in the allowlist. '
  145. 'Please ask a moderator to add you to the allowlist and try again')