huggingface_auth.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import os
  2. import time
  3. from datetime import datetime, timedelta
  4. from getpass import getpass
  5. import requests
  6. from huggingface_hub import HfApi
  7. from termcolor import colored
  8. from hivemind.proto.auth_pb2 import AccessToken
  9. from hivemind.utils.auth import TokenAuthorizerBase
  10. from hivemind.utils.crypto import RSAPublicKey
  11. from hivemind.utils.logging import get_logger
  12. logger = get_logger("root." + __name__)
  13. class NonRetriableError(Exception):
  14. pass
  15. def call_with_retries(func, n_retries=10, initial_delay=1.0):
  16. for i in range(n_retries):
  17. try:
  18. return func()
  19. except NonRetriableError:
  20. raise
  21. except Exception as e:
  22. if i == n_retries - 1:
  23. raise
  24. delay = initial_delay * (2 ** i)
  25. logger.warning(f'Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec')
  26. time.sleep(delay)
  27. class InvalidCredentialsError(NonRetriableError):
  28. pass
  29. class NotInAllowlistError(NonRetriableError):
  30. pass
  31. class HuggingFaceAuthorizer(TokenAuthorizerBase):
  32. _AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
  33. def __init__(self, organization_name: str, model_name: str, hf_user_access_token: str):
  34. super().__init__()
  35. self.organization_name = organization_name
  36. self.model_name = model_name
  37. self.hf_user_access_token = hf_user_access_token
  38. self._authority_public_key = None
  39. self.coordinator_ip = None
  40. self.coordinator_port = None
  41. self._hf_api = HfApi()
  42. async def get_token(self) -> AccessToken:
  43. """
  44. Hivemind calls this method to refresh the token when necessary.
  45. """
  46. self.join_experiment()
  47. return self._local_access_token
  48. def join_experiment(self) -> None:
  49. call_with_retries(self._join_experiment)
  50. def _join_experiment(self) -> None:
  51. try:
  52. url = f'{self._AUTH_SERVER_URL}/api/experiments/join'
  53. headers = {'Authorization': f'Bearer {self.hf_user_access_token}'}
  54. response = requests.put(
  55. url,
  56. headers=headers,
  57. params={
  58. 'organization_name': self.organization_name,
  59. 'model_name': self.model_name,
  60. },
  61. json={
  62. 'experiment_join_input': {
  63. 'peer_public_key': self.local_public_key.to_bytes().decode(),
  64. },
  65. },
  66. )
  67. response.raise_for_status()
  68. response = response.json()
  69. self._authority_public_key = RSAPublicKey.from_bytes(response['auth_server_public_key'].encode())
  70. self.coordinator_ip = response['coordinator_ip']
  71. self.coordinator_port = response['coordinator_port']
  72. token_dict = response['hivemind_access']
  73. access_token = AccessToken()
  74. access_token.username = token_dict['username']
  75. access_token.public_key = token_dict['peer_public_key'].encode()
  76. access_token.expiration_time = str(datetime.fromisoformat(token_dict['expiration_time']))
  77. access_token.signature = token_dict['signature'].encode()
  78. self._local_access_token = access_token
  79. logger.info(f'Access for user {access_token.username} '
  80. f'has been granted until {access_token.expiration_time} UTC')
  81. except requests.exceptions.HTTPError as e:
  82. if e.response.status_code == 401: # Unauthorized
  83. raise NotInAllowlistError()
  84. raise
  85. def is_token_valid(self, access_token: AccessToken) -> bool:
  86. data = self._token_to_bytes(access_token)
  87. if not self._authority_public_key.verify(data, access_token.signature):
  88. logger.exception('Access token has invalid signature')
  89. return False
  90. try:
  91. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  92. except ValueError:
  93. logger.exception(
  94. f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
  95. return False
  96. if expiration_time.tzinfo is not None:
  97. logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
  98. return False
  99. if expiration_time < datetime.utcnow():
  100. logger.exception('Access token has expired')
  101. return False
  102. return True
  103. _MAX_LATENCY = timedelta(minutes=1)
  104. def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
  105. expiration_time = datetime.fromisoformat(access_token.expiration_time)
  106. return expiration_time < datetime.utcnow() + self._MAX_LATENCY
  107. @staticmethod
  108. def _token_to_bytes(access_token: AccessToken) -> bytes:
  109. return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
  110. def authorize_with_huggingface() -> HuggingFaceAuthorizer:
  111. while True:
  112. organization_name = os.getenv('HF_ORGANIZATION_NAME')
  113. if organization_name is None:
  114. organization_name = input('HuggingFace organization name: ')
  115. model_name = os.getenv('HF_MODEL_NAME')
  116. if model_name is None:
  117. model_name = input('HuggingFace model name: ')
  118. hf_user_access_token = os.getenv('HF_USER_ACCESS_TOKEN')
  119. if hf_user_access_token is None:
  120. msg = [
  121. "Copy a token from your Hugging Face tokens page at ",
  122. colored("https://huggingface.co/settings/token", attrs=['bold']),
  123. "and paste it.\n💡",
  124. colored("Pro Tip:", attrs=['bold']),
  125. "If you don't already have one, you can create a dedicated user access token. Go to "
  126. "https://huggingface.co/settings/token and click on the `new token` button. ",
  127. "You just need to give a ",
  128. colored("'read'", attrs=['bold']),
  129. "role to this access token." ,
  130. "Don't forget to give an explicit name to this access token like",
  131. colored(f"'{organization_name}-{model_name}-collaborative-training'", attrs=['bold'])
  132. ]
  133. print(*msg)
  134. hf_user_access_token = getpass('HF user access token : ')
  135. authorizer = HuggingFaceAuthorizer(organization_name, model_name, hf_user_access_token)
  136. try:
  137. authorizer.join_experiment()
  138. username = authorizer._local_access_token.username
  139. print(f"🚀 You will contribute to the collaborative training under the username {username}.")
  140. return authorizer
  141. except InvalidCredentialsError:
  142. print('Invalid user access token, please try again')
  143. except NotInAllowlistError:
  144. print(
  145. '😥 Authentication has failed. '
  146. 'This error may be due to the fact:\n',
  147. " 1. your user access token is not valid. You can try to delete the previous token and"
  148. " recreate one. Be careful, organization tokens can't be used to join a collaborative "
  149. "training.\n"
  150. f" 2. you have not yet joined the {organization_name} organization. You can request to"
  151. " join this organization by clicking on the 'request to join this org' button at "
  152. f"https://huggingface.co/{organization_name}.\n",
  153. f" 3. the {organization_name} organization doesn't exist at https://huggingface.co/{organization_name}.\n",
  154. f" 4. no {organization_name}'s admin has created a collaborative training for the {organization_name}"
  155. f"organization and the {model_name} model.",
  156. )