|
@@ -1,4 +1,3 @@
|
|
|
-import base64
|
|
|
import os
|
|
|
import time
|
|
|
from datetime import datetime, timedelta
|
|
@@ -46,10 +45,11 @@ class NotInAllowlistError(NonRetriableError):
|
|
|
class HuggingFaceAuthorizer(TokenAuthorizerBase):
|
|
|
_AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
|
|
|
|
|
|
- def __init__(self, experiment_id: int, username: str, password: str):
|
|
|
+ def __init__(self, organization_name: str, model_name: str, username: str, password: str):
|
|
|
super().__init__()
|
|
|
|
|
|
- self.experiment_id = experiment_id
|
|
|
+ self.organization_name = organization_name
|
|
|
+ self.model_name = model_name
|
|
|
self.username = username
|
|
|
self.password = password
|
|
|
|
|
@@ -79,13 +79,22 @@ class HuggingFaceAuthorizer(TokenAuthorizerBase):
|
|
|
raise
|
|
|
|
|
|
try:
|
|
|
- url = f'{self._AUTH_SERVER_URL}/api/experiments/join/{self.experiment_id}/'
|
|
|
+ url = f'{self._AUTH_SERVER_URL}/api/experiments/join'
|
|
|
headers = {'Authorization': f'Bearer {token}'}
|
|
|
- response = requests.put(url, headers=headers, json={
|
|
|
- 'experiment_join_input': {
|
|
|
- 'peer_public_key': self.local_public_key.to_bytes().decode(),
|
|
|
+ response = requests.put(
|
|
|
+ url,
|
|
|
+ headers=headers,
|
|
|
+ params={
|
|
|
+ 'organization_name': self.organization_name,
|
|
|
+ 'model_name': self.model_name,
|
|
|
},
|
|
|
- })
|
|
|
+ json={
|
|
|
+ 'experiment_join_input': {
|
|
|
+ 'peer_public_key': self.local_public_key.to_bytes().decode(),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ verify=False, # FIXME: Update the expired API certificate
|
|
|
+ )
|
|
|
|
|
|
response.raise_for_status()
|
|
|
response = response.json()
|
|
@@ -144,9 +153,13 @@ class HuggingFaceAuthorizer(TokenAuthorizerBase):
|
|
|
|
|
|
def authorize_with_huggingface() -> HuggingFaceAuthorizer:
|
|
|
while True:
|
|
|
- experiment_id = os.getenv('HF_EXPERIMENT_ID')
|
|
|
- if experiment_id is None:
|
|
|
- experiment_id = input('HuggingFace experiment ID: ')
|
|
|
+ organization_name = os.getenv('HF_ORGANIZATION_NAME')
|
|
|
+ if organization_name is None:
|
|
|
+ organization_name = input('HuggingFace organization name: ')
|
|
|
+
|
|
|
+ model_name = os.getenv('HF_MODEL_NAME')
|
|
|
+ if model_name is None:
|
|
|
+ model_name = input('HuggingFace model name: ')
|
|
|
|
|
|
username = os.getenv('HF_USERNAME')
|
|
|
if username is None:
|
|
@@ -160,7 +173,7 @@ def authorize_with_huggingface() -> HuggingFaceAuthorizer:
|
|
|
if password is None:
|
|
|
password = getpass('HuggingFace password: ')
|
|
|
|
|
|
- authorizer = HuggingFaceAuthorizer(experiment_id, username, password)
|
|
|
+ authorizer = HuggingFaceAuthorizer(organization_name, model_name, username, password)
|
|
|
try:
|
|
|
authorizer.join_experiment()
|
|
|
return authorizer
|