123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- import os
- import time
- from datetime import datetime, timedelta
- from getpass import getpass
- import requests
- from huggingface_hub import HfApi
- from termcolor import colored
- from hivemind.proto.auth_pb2 import AccessToken
- from hivemind.utils.auth import TokenAuthorizerBase
- from hivemind.utils.crypto import RSAPublicKey
- from hivemind.utils.logging import get_logger
- logger = get_logger("root." + __name__)
- class NonRetriableError(Exception):
- pass
- def call_with_retries(func, n_retries=10, initial_delay=1.0):
- for i in range(n_retries):
- try:
- return func()
- except NonRetriableError:
- raise
- except Exception as e:
- if i == n_retries - 1:
- raise
- delay = initial_delay * (2 ** i)
- logger.warning(f'Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec')
- time.sleep(delay)
- class InvalidCredentialsError(NonRetriableError):
- pass
- class NotInAllowlistError(NonRetriableError):
- pass
- class HuggingFaceAuthorizer(TokenAuthorizerBase):
- _AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
- def __init__(self, organization_name: str, model_name: str, hf_user_access_token: str):
- super().__init__()
- self.organization_name = organization_name
- self.model_name = model_name
- self.hf_user_access_token = hf_user_access_token
- self._authority_public_key = None
- self.coordinator_ip = None
- self.coordinator_port = None
- self._hf_api = HfApi()
- async def get_token(self) -> AccessToken:
- """
- Hivemind calls this method to refresh the token when necessary.
- """
- self.join_experiment()
- return self._local_access_token
- def join_experiment(self) -> None:
- call_with_retries(self._join_experiment)
- def _join_experiment(self) -> None:
- try:
- url = f'{self._AUTH_SERVER_URL}/api/experiments/join'
- headers = {'Authorization': f'Bearer {self.hf_user_access_token}'}
- response = requests.put(
- url,
- headers=headers,
- params={
- 'organization_name': self.organization_name,
- 'model_name': self.model_name,
- },
- json={
- 'experiment_join_input': {
- 'peer_public_key': self.local_public_key.to_bytes().decode(),
- },
- },
- )
- response.raise_for_status()
- response = response.json()
- self._authority_public_key = RSAPublicKey.from_bytes(response['auth_server_public_key'].encode())
- self.coordinator_ip = response['coordinator_ip']
- self.coordinator_port = response['coordinator_port']
- token_dict = response['hivemind_access']
- access_token = AccessToken()
- access_token.username = token_dict['username']
- access_token.public_key = token_dict['peer_public_key'].encode()
- access_token.expiration_time = str(datetime.fromisoformat(token_dict['expiration_time']))
- access_token.signature = token_dict['signature'].encode()
- self._local_access_token = access_token
- logger.info(f'Access for user {access_token.username} '
- f'has been granted until {access_token.expiration_time} UTC')
- except requests.exceptions.HTTPError as e:
- if e.response.status_code == 401: # Unauthorized
- raise NotInAllowlistError()
- raise
- def is_token_valid(self, access_token: AccessToken) -> bool:
- data = self._token_to_bytes(access_token)
- if not self._authority_public_key.verify(data, access_token.signature):
- logger.exception('Access token has invalid signature')
- return False
- try:
- expiration_time = datetime.fromisoformat(access_token.expiration_time)
- except ValueError:
- logger.exception(
- f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}')
- return False
- if expiration_time.tzinfo is not None:
- logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}')
- return False
- if expiration_time < datetime.utcnow():
- logger.exception('Access token has expired')
- return False
- return True
- _MAX_LATENCY = timedelta(minutes=1)
- def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
- expiration_time = datetime.fromisoformat(access_token.expiration_time)
- return expiration_time < datetime.utcnow() + self._MAX_LATENCY
- @staticmethod
- def _token_to_bytes(access_token: AccessToken) -> bytes:
- return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode()
- def authorize_with_huggingface() -> HuggingFaceAuthorizer:
- while True:
- organization_name = os.getenv('HF_ORGANIZATION_NAME')
- if organization_name is None:
- organization_name = input('HuggingFace organization name: ')
- model_name = os.getenv('HF_MODEL_NAME')
- if model_name is None:
- model_name = input('HuggingFace model name: ')
- hf_user_access_token = os.getenv('HF_USER_ACCESS_TOKEN')
- if hf_user_access_token is None:
- print(
- "\nCopy a token from 🤗 Hugging Face settings page at "
- f"{colored('https://huggingface.co/settings/token', attrs=['bold'])} "
- "and paste it here.\n\n"
- f"💡 {colored('Tip:', attrs=['bold'])} "
- "If you don't already have one, you can create a dedicated user access token.\n"
- f"Go to {colored('https://huggingface.co/settings/token', attrs=['bold'])}, "
- f"click the {colored('New token', attrs=['bold'])} button, "
- f"and choose the {colored('read', attrs=['bold'])} role.\n"
- )
- hf_user_access_token = getpass('🤗 Hugging Face user access token (characters will be hidden): ')
- authorizer = HuggingFaceAuthorizer(organization_name, model_name, hf_user_access_token)
- try:
- authorizer.join_experiment()
- username = authorizer._local_access_token.username
- print(f"🚀 You will contribute to the collaborative training under the username {username}")
- return authorizer
- except InvalidCredentialsError:
- print('Invalid user access token, please try again')
- except NotInAllowlistError:
- print(
- '\n😥 Authentication has failed.\n\n'
- 'This error may be due to the fact:\n'
- " 1. Your user access token is not valid. You can try to delete the previous token and "
- "recreate one. Be careful, organization tokens can't be used to join a collaborative "
- "training.\n"
- f" 2. You have not yet joined the {organization_name} organization. You can request to "
- "join this organization by clicking on the 'request to join this org' button at "
- f"https://huggingface.co/{organization_name}.\n"
- f" 3. The {organization_name} organization doesn't exist at https://huggingface.co/{organization_name}.\n"
- f" 4. No {organization_name}'s admin has created a collaborative training for the {organization_name} "
- f"organization and the {model_name} model."
- )
|