Эх сурвалжийг харах

Use HF username as wandb run name

Aleksandr Borzunov 3 жил өмнө
parent
commit
2fed6ba68e
2 өөрчлөгдсөн 13 нэмэгдсэн , 4 устгасан
  1. 5 3
      huggingface_auth.py
  2. 8 1
      task.py

+ 5 - 3
huggingface_auth.py

@@ -67,6 +67,10 @@ class HuggingFaceAuthorizer(TokenAuthorizerBase):
         self.join_experiment()
         return self._local_access_token
 
+    @property
+    def username(self):
+        return self._local_access_token.username
+
     def join_experiment(self) -> None:
         call_with_retries(self._join_experiment)
 
@@ -169,9 +173,7 @@ def authorize_with_huggingface() -> HuggingFaceAuthorizer:
 
         try:
             authorizer.join_experiment()
-
-            username = authorizer._local_access_token.username
-            print(f"🚀 You will contribute to the collaborative training under the username {username}")
+            print(f"🚀 You will contribute to the collaborative training under the username {authorizer.username}")
             return authorizer
         except InvalidCredentialsError:
             print('Invalid user access token, please try again')

+ 8 - 1
task.py

@@ -11,6 +11,7 @@ from dalle_pytorch.vae import VQGanVAE
 from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
 from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
 from torch import nn
+from transformers import training_args
 
 import utils
 from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
@@ -95,6 +96,12 @@ class TrainingTask:
     @property
     def dht(self):
         if self._dht is None:
+            if self.peer_args.authorize:
+                authorizer = authorize_with_huggingface()
+                self.trainer_args.run_name = authorizer.username  # For wandb
+            else:
+                authorizer = None
+
             self._dht = hivemind.DHT(
                 start=True,
                 initial_peers=self.peer_args.initial_peers,
@@ -104,7 +111,7 @@ class TrainingTask:
                 use_ipfs=self.peer_args.use_ipfs,
                 record_validators=self.validators,
                 identity_path=self.peer_args.identity_path,
-                authorizer=authorize_with_huggingface() if self.peer_args.authorize else None,
+                authorizer=authorizer,
             )
             if self.peer_args.client_mode:
                 logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")