|
@@ -11,6 +11,7 @@ from dalle_pytorch.vae import VQGanVAE
|
|
|
from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
|
|
|
from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
|
|
|
from torch import nn
|
|
|
+from transformers import training_args
|
|
|
|
|
|
import utils
|
|
|
from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
|
|
@@ -95,6 +96,12 @@ class TrainingTask:
|
|
|
@property
|
|
|
def dht(self):
|
|
|
if self._dht is None:
|
|
|
+ if self.peer_args.authorize:
|
|
|
+ authorizer = authorize_with_huggingface()
|
|
|
+ self.trainer_args.run_name = authorizer.username # For wandb
|
|
|
+ else:
|
|
|
+ authorizer = None
|
|
|
+
|
|
|
self._dht = hivemind.DHT(
|
|
|
start=True,
|
|
|
initial_peers=self.peer_args.initial_peers,
|
|
@@ -104,7 +111,7 @@ class TrainingTask:
|
|
|
use_ipfs=self.peer_args.use_ipfs,
|
|
|
record_validators=self.validators,
|
|
|
identity_path=self.peer_args.identity_path,
|
|
|
- authorizer=authorize_with_huggingface() if self.peer_args.authorize else None,
|
|
|
+ authorizer=authorizer,
|
|
|
)
|
|
|
if self.peer_args.client_mode:
|
|
|
logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
|