|
@@ -0,0 +1,319 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+
|
|
|
+import logging
|
|
|
+import os
|
|
|
+from dataclasses import dataclass, field, asdict
|
|
|
+from pathlib import Path
|
|
|
+from typing import Optional, Dict, Any, List
|
|
|
+import uuid
|
|
|
+
|
|
|
+from datasets import load_from_disk
|
|
|
+import transformers
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+from transformers import (set_seed, HfArgumentParser, TrainingArguments,
|
|
|
+ DataCollatorForLanguageModeling, AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining)
|
|
|
+from transformers.optimization import get_linear_schedule_with_warmup
|
|
|
+from transformers.trainer_utils import is_main_process
|
|
|
+from transformers.trainer import Trainer
|
|
|
+from torch_optimizer import Lamb
|
|
|
+import torch
|
|
|
+
|
|
|
+import hivemind
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class CollaborationArguments:
|
|
|
+ """ define how peers interact with each other while training"""
|
|
|
+
|
|
|
+ # primary parameters
|
|
|
+ initial_peers: List[str] # one or more peers (comma-separated) that will welcome you into the collaboration
|
|
|
+ experiment_prefix: str # a unique "name" of this experiment, used to store metadata on the DHT
|
|
|
+ averaging_expiration: float = 5.0 # averaging group will wait for stragglers for at most this many seconds
|
|
|
+ averaging_timeout: float = 30.0 # give up on averaging step after this many seconds
|
|
|
+ target_batch_size: int = 4096 # perform optimizer step after all peers collectively accumulate this many samples
|
|
|
+ client_mode: bool = False # if True, runs training without incoming connections, in a firewall-compatible mode
|
|
|
+ trainer_uuid: str = uuid.uuid4().hex # this peer's name - used when publishing metadata to DHT, default = random
|
|
|
+
|
|
|
+ # optional tweaks
|
|
|
+ target_group_size: int = 64 # maximum group size for all-reduce
|
|
|
+ metadata_expiration: float = 30 # peer's metadata will be removed if not updated in this many seconds
|
|
|
+ statistics_expiration: float = 600 # statistics will be removed if not updated in this many seconds
|
|
|
+ dht_listen_on: str = '[::]:*' # network interface used for incoming DHT communication. Default: all ipv6
|
|
|
+ listen_on: str = '[::]:*' # network interface used for incoming averager communication. Default: all ipv6
|
|
|
+ endpoint: Optional[str] = None # this node's IP for inbound connections, used when running from behind a proxy
|
|
|
+ batch_size_lead: int = 0 # optional: begin looking for group in advance, this many samples before target_batch_size
|
|
|
+ compression: str = 'FLOAT16' # use this compression when averaging parameters/gradients
|
|
|
+
|
|
|
+ min_refresh_period: float = 0.5 # wait for at least this many seconds before fetching new collaboration state
|
|
|
+ max_refresh_period: float = 30 # wait for at most this many seconds before fetching new collaboration state
|
|
|
+ default_refresh_period: float = 3 # attempt to fetch collaboration state every this often until successful
|
|
|
+ expected_drift_peers: float = 3 # trainer assumes that this many new peers can join per step
|
|
|
+ expected_drift_rate: float = 0.2 # trainer assumes that this fraction of current size can join per step
|
|
|
+
|
|
|
+ bandwidth: float = 100.0 # available network bandwidth, in mbps (used for load balancing in all-reduce)
|
|
|
+ performance_ema_alpha: float = 0.1 # uses this alpha for moving average estimate of samples per second
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class DatasetArguments:
|
|
|
+ dataset_path: Optional[str] = field(default='./data/albert_tokenized_wikitext',
|
|
|
+ metadata={"help": "Path to the tokenized dataset"})
|
|
|
+ tokenizer_path: Optional[str] = field(default='./data/tokenizer',
|
|
|
+ metadata={"help": "Path to the tokenizer"})
|
|
|
+ config_path: Optional[str] = field(
|
|
|
+ default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
|
|
|
+ metadata={"help": "Path to the model config"})
|
|
|
+ cache_dir: Optional[str] = field(default='./data', metadata={"help": "Path to the cache"})
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class AlbertTrainingArguments(TrainingArguments):
|
|
|
+ dataloader_num_workers: int = 4
|
|
|
+ per_device_train_batch_size: int = 4
|
|
|
+ per_device_eval_batch_size: int = 4
|
|
|
+ gradient_accumulation_steps: int = 2
|
|
|
+ seq_length: int = 512
|
|
|
+
|
|
|
+ max_steps: int = 1_000_000 # Albert is actually ready after 125000 steps
|
|
|
+ learning_rate: float = 0.00176
|
|
|
+ warmup_steps: int = 5000
|
|
|
+ adam_epsilon: float = 1e-6
|
|
|
+ weight_decay: float = 0.01
|
|
|
+ max_grad_norm: float = 1.0
|
|
|
+ clamp_value: float = 10000.0
|
|
|
+
|
|
|
+ fp16: bool = True
|
|
|
+ fp16_opt_level: str = 'O2'
|
|
|
+ do_train: bool = True
|
|
|
+
|
|
|
+ logging_steps: int = 100
|
|
|
+ save_total_limit: int = 2
|
|
|
+ save_steps: int = 500
|
|
|
+
|
|
|
+
|
|
|
+def setup_logging(training_args):
|
|
|
+ logging.basicConfig(
|
|
|
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
+ datefmt="%m/%d/%Y %H:%M:%S",
|
|
|
+ level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Log on each process the small summary:
|
|
|
+ logger.warning(
|
|
|
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
|
|
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
|
|
+ )
|
|
|
+ # Set the verbosity to info of the Transformers logger (on main process only):
|
|
|
+ if is_main_process(training_args.local_rank):
|
|
|
+ transformers.utils.logging.set_verbosity_info()
|
|
|
+ transformers.utils.logging.enable_default_handler()
|
|
|
+ transformers.utils.logging.enable_explicit_format()
|
|
|
+ logger.info("Training/evaluation parameters %s", training_args)
|
|
|
+
|
|
|
+
|
|
|
+def get_model(training_args, config, tokenizer):
|
|
|
+ # Find latest checkpoint in output_dir
|
|
|
+ output_dir = Path(training_args.output_dir)
|
|
|
+ logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
|
|
|
+ latest_checkpoint_dir = max(output_dir.glob('checkpoint*'), default=None, key=os.path.getctime)
|
|
|
+
|
|
|
+ if latest_checkpoint_dir is not None:
|
|
|
+ logger.info(f'Loading model from {latest_checkpoint_dir}')
|
|
|
+ model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
|
|
|
+ else:
|
|
|
+ logger.info(f'Training from scratch')
|
|
|
+ model = AlbertForPreTraining(config)
|
|
|
+ model.resize_token_embeddings(len(tokenizer))
|
|
|
+
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+def get_optimizer_and_scheduler(training_args, model):
|
|
|
+ no_decay = ["bias", "LayerNorm.weight"]
|
|
|
+ optimizer_grouped_parameters = [
|
|
|
+ {
|
|
|
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
|
+ "weight_decay": training_args.weight_decay,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
|
+ "weight_decay": 0.0,
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ opt = Lamb(
|
|
|
+ optimizer_grouped_parameters,
|
|
|
+ lr=training_args.learning_rate,
|
|
|
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
|
|
|
+ eps=training_args.adam_epsilon,
|
|
|
+ weight_decay=training_args.weight_decay,
|
|
|
+ clamp_value=training_args.clamp_value,
|
|
|
+ debias=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ scheduler = get_linear_schedule_with_warmup(
|
|
|
+ opt,
|
|
|
+ num_warmup_steps=training_args.warmup_steps,
|
|
|
+ num_training_steps=training_args.max_steps
|
|
|
+ )
|
|
|
+
|
|
|
+ return opt, scheduler
|
|
|
+
|
|
|
+
|
|
|
+class CollaborativeCallback(transformers.TrainerCallback):
|
|
|
+ def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer,
|
|
|
+ model: torch.nn.Module, trainer_uuid: str, statistics_expiration: float):
|
|
|
+ super().__init__()
|
|
|
+ self.model = model
|
|
|
+ self.dht, self.collaborative_optimizer = dht, optimizer
|
|
|
+ self.trainer_uuid, self.statistics_expiration = trainer_uuid, statistics_expiration
|
|
|
+ self.last_reported_collaboration_step = -1
|
|
|
+ self.previous_state = self.get_current_state()
|
|
|
+ self.samples = 0
|
|
|
+ self.steps = 0
|
|
|
+ self.loss = 0
|
|
|
+
|
|
|
+ def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
|
|
|
+ control: transformers.TrainerControl, **kwargs):
|
|
|
+ control.should_log = True
|
|
|
+ if not self.params_are_finite():
|
|
|
+ self.load_from_state(self.previous_state)
|
|
|
+ return control
|
|
|
+ self.previous_state = self.get_current_state()
|
|
|
+
|
|
|
+ if state.log_history:
|
|
|
+ self.loss += state.log_history[-1]['loss']
|
|
|
+ if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
|
|
|
+ self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
|
|
|
+
|
|
|
+ statistics = [self.collaborative_optimizer.local_step,
|
|
|
+ self.collaborative_optimizer.performance_ema.samples_per_second,
|
|
|
+ self.samples,
|
|
|
+ self.loss / self.steps if self.steps else 0]
|
|
|
+ self.loss = 0
|
|
|
+
|
|
|
+ self.dht.store(self.collaborative_optimizer.prefix + "_metrics", subkey=self.trainer_uuid,
|
|
|
+ value=statistics, expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
|
|
|
+ return_future=True)
|
|
|
+ self.samples = self.collaborative_optimizer.local_samples_accumulated
|
|
|
+ self.steps = self.collaborative_optimizer.local_steps_accumulated
|
|
|
+
|
|
|
+ return control
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ def get_current_state(self) -> Dict[str, Any]:
|
|
|
+ return {
|
|
|
+ 'model': self.model.state_dict(),
|
|
|
+ 'opt': self.collaborative_optimizer.opt.state_dict()
|
|
|
+ }
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ def load_from_state(self, state):
|
|
|
+ self.model.load_state_dict(state['model'])
|
|
|
+ self.collaborative_optimizer.opt.load_state_dict(state['opt'])
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ def params_are_finite(self):
|
|
|
+ for param in self.model.parameters():
|
|
|
+ if not torch.all(torch.isfinite(param)):
|
|
|
+ return False
|
|
|
+ return True
|
|
|
+
|
|
|
+
|
|
|
+class NoOpScheduler(LRSchedulerBase):
|
|
|
+ """ Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler """
|
|
|
+
|
|
|
+ def get_lr(self):
|
|
|
+ return [group['lr'] for group in self.optimizer.param_groups]
|
|
|
+
|
|
|
+ def print_lr(self, *args, **kwargs):
|
|
|
+ if self.optimizer.scheduler:
|
|
|
+ return self.optimizer.scheduler.print_lr(*args, **kwargs)
|
|
|
+
|
|
|
+ def step(self):
|
|
|
+ logger.debug("Called NoOpScheduler.step")
|
|
|
+ self._last_lr = self.get_lr()
|
|
|
+
|
|
|
+ def state_dict(self):
|
|
|
+ return {}
|
|
|
+
|
|
|
+ def load_state_dict(self, *args, **kwargs):
|
|
|
+ logger.debug("Called NoOpScheduler.load_state_dict")
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments))
|
|
|
+ training_args, dataset_args, collaboration_args = parser.parse_args_into_dataclasses()
|
|
|
+
|
|
|
+ logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
|
|
|
+ if len(collaboration_args.initial_peers) == 0:
|
|
|
+ raise ValueError("Please specify at least one network endpoint in initial peers.")
|
|
|
+
|
|
|
+ collaboration_args_dict = asdict(collaboration_args)
|
|
|
+ setup_logging(training_args)
|
|
|
+
|
|
|
+ # Set seed before initializing model.
|
|
|
+ set_seed(training_args.seed)
|
|
|
+
|
|
|
+ config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
|
|
|
+ tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
|
|
|
+ model = get_model(training_args, config, tokenizer)
|
|
|
+ model.to(training_args.device)
|
|
|
+
|
|
|
+ tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
|
|
|
+ # This data collator will take care of randomly masking the tokens.
|
|
|
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
|
|
|
+
|
|
|
+ opt, scheduler = get_optimizer_and_scheduler(training_args, model)
|
|
|
+
|
|
|
+ dht = hivemind.DHT(
|
|
|
+ initial_peers=collaboration_args_dict.pop('initial_peers'),
|
|
|
+ listen=not collaboration_args_dict['client_mode'], listen_on=collaboration_args_dict.pop('dht_listen_on'),
|
|
|
+ endpoint=collaboration_args_dict.pop('endpoint'), start=True)
|
|
|
+
|
|
|
+ total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
|
|
+ trainer_uuid = collaboration_args_dict.pop('trainer_uuid')
|
|
|
+ statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
|
|
|
+ adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
|
|
|
+ - collaboration_args_dict.pop('batch_size_lead')
|
|
|
+
|
|
|
+ collaborative_optimizer = hivemind.CollaborativeOptimizer(
|
|
|
+ opt=opt, dht=dht, scheduler=scheduler, prefix=collaboration_args_dict.pop('experiment_prefix'),
|
|
|
+ compression_type=hivemind.utils.CompressionType.Value(collaboration_args_dict.pop('compression')),
|
|
|
+ batch_size_per_step=total_batch_size_per_step, throughput=collaboration_args_dict.pop('bandwidth'),
|
|
|
+ target_batch_size=adjusted_target_batch_size, client_mode=collaboration_args_dict.pop('client_mode'),
|
|
|
+ verbose=True, start=True, **collaboration_args_dict
|
|
|
+ )
|
|
|
+
|
|
|
+ class TrainerWithIndependentShuffling(Trainer):
|
|
|
+ def get_train_dataloader(self) -> DataLoader:
|
|
|
+ """ Shuffle data independently for each peer to avoid duplicating batches [important for quality] """
|
|
|
+ torch.manual_seed(hash(trainer_uuid))
|
|
|
+ return super().get_train_dataloader()
|
|
|
+
|
|
|
+ trainer = TrainerWithIndependentShuffling(
|
|
|
+ model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator,
|
|
|
+ train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
|
|
|
+ eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
|
|
|
+ optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
|
|
|
+ callbacks=[CollaborativeCallback(dht, collaborative_optimizer, model, trainer_uuid, statistics_expiration)]
|
|
|
+ )
|
|
|
+ trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
|
|
|
+ trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
|
|
|
+
|
|
|
+ # Training
|
|
|
+ if training_args.do_train:
|
|
|
+ latest_checkpoint_dir = max(
|
|
|
+ Path(training_args.output_dir).glob('checkpoint*'),
|
|
|
+ default=None,
|
|
|
+ key=os.path.getctime
|
|
|
+ )
|
|
|
+
|
|
|
+ trainer.train(model_path=latest_checkpoint_dir)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|