Sfoglia il codice sorgente

Upgrade to the new auth version (#3)

Co-authored-by: SaulLu <lucilesaul.com@gmail.com>
Alexander Borzunov 3 anni fa
parent
commit
27139b8a28
2 ha cambiato i file con 26 aggiunte e 13 eliminazioni
  1. 1 1
      arguments.py
  2. 25 12
      huggingface_auth.py

+ 1 - 1
arguments.py

@@ -85,7 +85,7 @@ class BasePeerArguments:
     tokenizer_path: Optional[str] = field(default="t5-small", metadata={"help": "Path to the tokenizer"})
     cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
 
-    authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})
+    authorize: bool = field(default=True, metadata={"help": "Whether or not to use HF authorizer"})
     client_mode: bool = field(
         default=False,
         metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"},

+ 25 - 12
huggingface_auth.py

@@ -1,4 +1,3 @@
-import base64
 import os
 import time
 from datetime import datetime, timedelta
@@ -46,10 +45,11 @@ class NotInAllowlistError(NonRetriableError):
 class HuggingFaceAuthorizer(TokenAuthorizerBase):
     _AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co'
 
-    def __init__(self, experiment_id: int, username: str, password: str):
+    def __init__(self, organization_name: str, model_name: str, username: str, password: str):
         super().__init__()
 
-        self.experiment_id = experiment_id
+        self.organization_name = organization_name
+        self.model_name = model_name
         self.username = username
         self.password = password
 
@@ -79,13 +79,22 @@ class HuggingFaceAuthorizer(TokenAuthorizerBase):
             raise
 
         try:
-            url = f'{self._AUTH_SERVER_URL}/api/experiments/join/{self.experiment_id}/'
+            url = f'{self._AUTH_SERVER_URL}/api/experiments/join'
             headers = {'Authorization': f'Bearer {token}'}
-            response = requests.put(url, headers=headers, json={
-                'experiment_join_input': {
-                    'peer_public_key': self.local_public_key.to_bytes().decode(),
+            response = requests.put(
+                url,
+                headers=headers,
+                params={
+                    'organization_name': self.organization_name,
+                    'model_name': self.model_name,
                 },
-            })
+                json={
+                    'experiment_join_input': {
+                        'peer_public_key': self.local_public_key.to_bytes().decode(),
+                    },
+                },
+                verify=False,  # FIXME: Update the expired API certificate
+            )
 
             response.raise_for_status()
             response = response.json()
@@ -144,9 +153,13 @@ class HuggingFaceAuthorizer(TokenAuthorizerBase):
 
 def authorize_with_huggingface() -> HuggingFaceAuthorizer:
     while True:
-        experiment_id = os.getenv('HF_EXPERIMENT_ID')
-        if experiment_id is None:
-            experiment_id = input('HuggingFace experiment ID: ')
+        organization_name = os.getenv('HF_ORGANIZATION_NAME')
+        if organization_name is None:
+            organization_name = input('HuggingFace organization name: ')
+
+        model_name = os.getenv('HF_MODEL_NAME')
+        if model_name is None:
+            model_name = input('HuggingFace model name: ')
 
         username = os.getenv('HF_USERNAME')
         if username is None:
@@ -160,7 +173,7 @@ def authorize_with_huggingface() -> HuggingFaceAuthorizer:
         if password is None:
             password = getpass('HuggingFace password: ')
 
-        authorizer = HuggingFaceAuthorizer(experiment_id, username, password)
+        authorizer = HuggingFaceAuthorizer(organization_name, model_name, username, password)
         try:
             authorizer.join_experiment()
             return authorizer