huggingface_auth.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import os
  2. import time
  3. from datetime import datetime, timedelta
  4. from getpass import getpass
  5. import requests
  6. from huggingface_hub import HfApi
  7. from termcolor import colored
  8. from hivemind.proto.auth_pb2 import AccessToken
  9. from hivemind.utils.auth import TokenAuthorizerBase
  10. from hivemind.utils.crypto import RSAPublicKey
  11. from hivemind.utils.logging import get_logger
  12. logger = get_logger("root." + __name__)
  13. class NonRetriableError(Exception):
  14. pass
  15. def call_with_retries(func, n_retries=10, initial_delay=1.0):
  16. for i in range(n_retries):
  17. try:
  18. return func()
  19. except NonRetriableError:
  20. raise
  21. except Exception as e:
  22. if i == n_retries - 1:
  23. raise
  24. delay = initial_delay * (2 ** i)
  25. logger.warning(f'Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec')
  26. time.sleep(delay)
  27. class InvalidCredentialsError(NonRetriableError):
  28. pass
  29. class NotInAllowlistError(NonRetriableError):
  30. pass
  31. class HuggingFaceAuthorizer(TokenAuthorizerBase):
  32. _AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
  33. def __init__(self, organization_name: str, model_name: str, hf_user_access_token: str):
  34. super().__init__()
  35. self.organization_name = organization_name
  36. self.model_name = model_name
  37. self.hf_user_access_token = hf_user_access_token
  38. self._authority_public_key = None
  39. self.coordinator_ip = None
  40. self.coordinator_port = None
  41. self._hf_api = HfApi()
  42. async def get_token(self) -> AccessToken:
  43. """
  44. Hivemind calls this method to refresh the token when necessary.
  45. """
  46. self.join_experiment()
  47. return self._local_access_token
  48. def join_experiment(self) -> None:
  49. call_with_retries(self._join_experiment)
  50. def _join_experiment(self) -> None:
  51. try:
  52. url = f'{self._AUTH_SERVER_URL}/api/experiments/join'
  53. headers = {'Authorization': f'Bearer {self.hf_user_access_token}'}
  54. response = requests.put(
  55. url,
  56. headers=headers,
  57. params={
  58. 'organization_name': self.organization_name,
  59. 'model_name': self.model_name,
  60. },
  61. json={
  62. 'experiment_join_input': {
  63. 'peer_public_key': self.local_public_key.to_bytes().decode(),
  64. },
  65. },
  66. )
  67. response.raise_for_status()
  68. response = response.json()
  69. self._authority_public_key = RSAPublicKey.from_bytes(response['auth_server_public_key'].encode())
  70. self.coordinator_ip = response['coordinator_ip']
  71. self.coordinator_port = response['coordinator_port']
  72. token_dict = response['hivemind_access']
  73. access_token = AccessToken()
  74. access_token.username = token_dict['username']
  75. access_token.public_key = token_dict['peer_public_key'].encode()
  76. access_token.expiration_time = str(datetime.fromisoformat(token_dict['expiration_time']))
  77. access_token.signature = token_dict['signature'].encode()
  78. self._local_access_token = access_token
  79. logger.info(f'Access for user {access_token.username} '
  80. f'has been granted until {access_token.expiration_time} UTC')
  81. except requests.exceptions.HTTPError as e:
  82. if e.response.status_code == 401: # Unauthorized
  83. raise NotInAllowlistError()
  84. raise
  85. def is_token_valid(self, access_token: AccessToken) -> bool:
  86. data = self._token_to_bytes(access_token)
  87. if not self._authority_public_key.verify(data, access_token.signature):
  88. logger.exception('Access token has invalid signature')
  89. return False
  90. try:
  91. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  92. except ValueError:
  93. logger.exception(
  94. f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
  95. return False
  96. if expiration_time.tzinfo is not None:
  97. logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
  98. return False
  99. if expiration_time < datetime.utcnow():
  100. logger.exception('Access token has expired')
  101. return False
  102. return True
  103. _MAX_LATENCY = timedelta(minutes=1)
  104. def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
  105. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  106. return expiration_time < datetime.utcnow() + self._MAX_LATENCY
  107. @staticmethod
  108. def _token_to_bytes(access_token: AccessToken) -> bytes:
  109. return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
  110. def authorize_with_huggingface() -> HuggingFaceAuthorizer:
  111. while True:
  112. organization_name = os.getenv('HF_ORGANIZATION_NAME')
  113. if organization_name is None:
  114. organization_name = input('HuggingFace organization name: ')
  115. model_name = os.getenv('HF_MODEL_NAME')
  116. if model_name is None:
  117. model_name = input('HuggingFace model name: ')
  118. hf_user_access_token = os.getenv('HF_USER_ACCESS_TOKEN')
  119. if hf_user_access_token is None:
  120. print(
  121. "\nCopy a token from 🤗 Hugging Face settings page at "
  122. f"{colored('https://huggingface.co/settings/token', attrs=['bold'])} "
  123. "and paste it here.\n\n"
  124. f"💡 {colored('Tip:', attrs=['bold'])} "
  125. "If you don't already have one, you can create a dedicated user access token.\n"
  126. f"Go to {colored('https://huggingface.co/settings/token', attrs=['bold'])}, "
  127. f"click the {colored('New token', attrs=['bold'])} button, "
  128. f"and choose the {colored('read', attrs=['bold'])} role.\n"
  129. )
  130. hf_user_access_token = getpass('🤗 Hugging Face user access token (characters will be hidden): ')
  131. authorizer = HuggingFaceAuthorizer(organization_name, model_name, hf_user_access_token)
  132. try:
  133. authorizer.join_experiment()
  134. username = authorizer._local_access_token.username
  135. print(f"🚀 You will contribute to the collaborative training under the username {username}")
  136. return authorizer
  137. except InvalidCredentialsError:
  138. print('Invalid user access token, please try again')
  139. except NotInAllowlistError:
  140. print(
  141. '\n😥 Authentication has failed.\n\n'
  142. 'This error may be due to the fact:\n'
  143. " 1. Your user access token is not valid. You can try to delete the previous token and "
  144. "recreate one. Be careful, organization tokens can't be used to join a collaborative "
  145. "training.\n"
  146. f" 2. You have not yet joined the {organization_name} organization. You can request to "
  147. "join this organization by clicking on the 'request to join this org' button at "
  148. f"https://huggingface.co/{organization_name}.\n"
  149. f" 3. The {organization_name} organization doesn't exist at https://huggingface.co/{organization_name}.\n"
  150. f" 4. No {organization_name}'s admin has created a collaborative training for the {organization_name} "
  151. f"organization and the {model_name} model."
  152. )