huggingface_auth.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import os
  2. import time
  3. from datetime import datetime, timedelta
  4. from getpass import getpass
  5. import requests
  6. from huggingface_hub import HfApi
  7. from termcolor import colored
  8. from hivemind.proto.auth_pb2 import AccessToken
  9. from hivemind.utils.auth import TokenAuthorizerBase
  10. from hivemind.utils.crypto import RSAPublicKey
  11. from hivemind.utils.logging import get_logger
  12. logger = get_logger("root." + __name__)
  13. class NonRetriableError(Exception):
  14. pass
  15. def call_with_retries(func, n_retries=10, initial_delay=1.0):
  16. for i in range(n_retries):
  17. try:
  18. return func()
  19. except NonRetriableError:
  20. raise
  21. except Exception as e:
  22. if i == n_retries - 1:
  23. raise
  24. delay = initial_delay * (2 ** i)
  25. logger.warning(f'Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec')
  26. time.sleep(delay)
  27. class InvalidCredentialsError(NonRetriableError):
  28. pass
  29. class NotInAllowlistError(NonRetriableError):
  30. pass
  31. class HuggingFaceAuthorizer(TokenAuthorizerBase):
  32. _AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
  33. def __init__(self, organization_name: str, model_name: str, hf_user_access_token: str):
  34. super().__init__()
  35. self.organization_name = organization_name
  36. self.model_name = model_name
  37. self.hf_user_access_token = hf_user_access_token
  38. self._authority_public_key = None
  39. self.coordinator_ip = None
  40. self.coordinator_port = None
  41. self._hf_api = HfApi()
  42. async def get_token(self) -> AccessToken:
  43. """
  44. Hivemind calls this method to refresh the token when necessary.
  45. """
  46. self.join_experiment()
  47. return self._local_access_token
  48. @property
  49. def username(self):
  50. return self._local_access_token.username
  51. def join_experiment(self) -> None:
  52. call_with_retries(self._join_experiment)
  53. def _join_experiment(self) -> None:
  54. try:
  55. url = f'{self._AUTH_SERVER_URL}/api/experiments/join'
  56. headers = {'Authorization': f'Bearer {self.hf_user_access_token}'}
  57. response = requests.put(
  58. url,
  59. headers=headers,
  60. params={
  61. 'organization_name': self.organization_name,
  62. 'model_name': self.model_name,
  63. },
  64. json={
  65. 'experiment_join_input': {
  66. 'peer_public_key': self.local_public_key.to_bytes().decode(),
  67. },
  68. },
  69. )
  70. response.raise_for_status()
  71. response = response.json()
  72. self._authority_public_key = RSAPublicKey.from_bytes(response['auth_server_public_key'].encode())
  73. self.coordinator_ip = response['coordinator_ip']
  74. self.coordinator_port = response['coordinator_port']
  75. token_dict = response['hivemind_access']
  76. access_token = AccessToken()
  77. access_token.username = token_dict['username']
  78. access_token.public_key = token_dict['peer_public_key'].encode()
  79. access_token.expiration_time = str(datetime.fromisoformat(token_dict['expiration_time']))
  80. access_token.signature = token_dict['signature'].encode()
  81. self._local_access_token = access_token
  82. logger.info(f'Access for user {access_token.username} '
  83. f'has been granted until {access_token.expiration_time} UTC')
  84. except requests.exceptions.HTTPError as e:
  85. if e.response.status_code == 401: # Unauthorized
  86. raise NotInAllowlistError()
  87. raise
  88. def is_token_valid(self, access_token: AccessToken) -> bool:
  89. data = self._token_to_bytes(access_token)
  90. if not self._authority_public_key.verify(data, access_token.signature):
  91. logger.exception('Access token has invalid signature')
  92. return False
  93. try:
  94. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  95. except ValueError:
  96. logger.exception(
  97. f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
  98. return False
  99. if expiration_time.tzinfo is not None:
  100. logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
  101. return False
  102. if expiration_time < datetime.utcnow():
  103. logger.exception('Access token has expired')
  104. return False
  105. return True
  106. _MAX_LATENCY = timedelta(minutes=1)
  107. def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
  108. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  109. return expiration_time < datetime.utcnow() + self._MAX_LATENCY
  110. @staticmethod
  111. def _token_to_bytes(access_token: AccessToken) -> bytes:
  112. return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
  113. def authorize_with_huggingface() -> HuggingFaceAuthorizer:
  114. while True:
  115. organization_name = os.getenv('HF_ORGANIZATION_NAME')
  116. if organization_name is None:
  117. organization_name = input('HuggingFace organization name: ')
  118. model_name = os.getenv('HF_MODEL_NAME')
  119. if model_name is None:
  120. model_name = input('HuggingFace model name: ')
  121. hf_user_access_token = os.getenv('HF_USER_ACCESS_TOKEN')
  122. if hf_user_access_token is None:
  123. print(
  124. "\nCopy a token from 🤗 Hugging Face settings page at "
  125. f"{colored('https://huggingface.co/settings/token', attrs=['bold'])} "
  126. "and paste it here.\n\n"
  127. f"💡 {colored('Tip:', attrs=['bold'])} "
  128. "If you don't already have one, you can create a dedicated user access token.\n"
  129. f"Go to {colored('https://huggingface.co/settings/token', attrs=['bold'])}, "
  130. f"click the {colored('New token', attrs=['bold'])} button, "
  131. f"and choose the {colored('read', attrs=['bold'])} role.\n"
  132. )
  133. hf_user_access_token = getpass('🤗 Hugging Face user access token (characters will be hidden): ')
  134. authorizer = HuggingFaceAuthorizer(organization_name, model_name, hf_user_access_token)
  135. try:
  136. authorizer.join_experiment()
  137. print(f"🚀 You will contribute to the collaborative training under the username {authorizer.username}")
  138. return authorizer
  139. except InvalidCredentialsError:
  140. print('Invalid user access token, please try again')
  141. except NotInAllowlistError:
  142. print(
  143. '\n😥 Authentication has failed.\n\n'
  144. 'This error may be due to the fact:\n'
  145. " 1. Your user access token is not valid. You can try to delete the previous token and "
  146. "recreate one. Be careful, organization tokens can't be used to join a collaborative "
  147. "training.\n"
  148. f" 2. You have not yet joined the {organization_name} organization. You can request to "
  149. "join this organization by clicking on the 'request to join this org' button at "
  150. f"https://huggingface.co/{organization_name}.\n"
  151. f" 3. The {organization_name} organization doesn't exist at https://huggingface.co/{organization_name}.\n"
  152. f" 4. No {organization_name}'s admin has created a collaborative training for the {organization_name} "
  153. f"organization and the {model_name} model."
  154. )