|
@@ -5,13 +5,12 @@ from pathlib import Path
|
|
|
|
|
|
import hivemind
|
|
|
import torch
|
|
|
+import torch.nn as nn
|
|
|
import transformers
|
|
|
from dalle_pytorch import DALLE
|
|
|
from dalle_pytorch.vae import VQGanVAE
|
|
|
from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
|
|
|
from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
|
|
|
-from torch import nn
|
|
|
-from transformers import training_args
|
|
|
|
|
|
import utils
|
|
|
from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
|
|
@@ -45,12 +44,14 @@ class ModelWrapper(nn.Module):
|
|
|
|
|
|
class TrainingTask:
|
|
|
"""A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
|
|
|
- _dht = _collaborative_optimizer = _training_dataset = None
|
|
|
+ _authorizer = _dht = _collaborative_optimizer = _training_dataset = None
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
|
|
|
self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
|
|
|
+ self.trainer_args.run_name = self.authorizer.username # For wandb
|
|
|
+
|
|
|
self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
|
|
|
transformers.set_seed(trainer_args.seed) # seed used for initialization
|
|
|
|
|
@@ -91,15 +92,15 @@ class TrainingTask:
|
|
|
logger.info(f"Loading model from {latest_checkpoint_dir}")
|
|
|
self.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
|
|
|
|
|
|
+ @property
|
|
|
+ def authorizer(self):
|
|
|
+ if self._authorizer is None and self.peer_args.authorize:
|
|
|
+ self._authorizer = authorize_with_huggingface()
|
|
|
+ return self._authorizer
|
|
|
+
|
|
|
@property
|
|
|
def dht(self):
|
|
|
if self._dht is None:
|
|
|
- if self.peer_args.authorize:
|
|
|
- authorizer = authorize_with_huggingface()
|
|
|
- self.trainer_args.run_name = authorizer.username # For wandb
|
|
|
- else:
|
|
|
- authorizer = None
|
|
|
-
|
|
|
self._dht = hivemind.DHT(
|
|
|
start=True,
|
|
|
initial_peers=self.peer_args.initial_peers,
|
|
@@ -109,7 +110,7 @@ class TrainingTask:
|
|
|
use_ipfs=self.peer_args.use_ipfs,
|
|
|
record_validators=self.validators,
|
|
|
identity_path=self.peer_args.identity_path,
|
|
|
- authorizer=authorizer,
|
|
|
+ authorizer=self.authorizer,
|
|
|
)
|
|
|
if self.peer_args.client_mode:
|
|
|
logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
|