run_trainer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. #!/usr/bin/env python
  2. import logging
  3. import os
  4. from dataclasses import dataclass, field, asdict
  5. from pathlib import Path
  6. from typing import Optional, Dict, Any, List
  7. import hivemind
  8. import torch
  9. import transformers
  10. from datasets import load_from_disk
  11. from torch.utils.data import DataLoader
  12. from transformers import (set_seed, HfArgumentParser, TrainingArguments,
  13. DataCollatorForLanguageModeling, AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining)
  14. from transformers.optimization import get_linear_schedule_with_warmup
  15. from transformers.trainer_utils import is_main_process
  16. from transformers.trainer import Trainer
  17. from torch_optimizer import Lamb
  18. import metrics_utils
  19. logger = logging.getLogger(__name__)
  20. LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
  21. @dataclass
  22. class CollaborationArguments:
  23. """ define how peers interact with each other while training"""
  24. # primary parameters
  25. initial_peers: List[str] # one or more peers (comma-separated) that will welcome you into the collaboration
  26. experiment_prefix: str # a unique "name" of this experiment, used to store metadata on the DHT
  27. averaging_expiration: float = 5.0 # averaging group will wait for stragglers for at most this many seconds
  28. averaging_timeout: float = 30.0 # give up on averaging step after this many seconds
  29. target_batch_size: int = 4096 # perform optimizer step after all peers collectively accumulate this many samples
  30. client_mode: bool = False # if True, runs training without incoming connections, in a firewall-compatible mode
  31. # optional tweaks
  32. target_group_size: int = 256 # maximum group size for all-reduce
  33. metadata_expiration: float = 30 # peer's metadata will be removed if not updated in this many seconds
  34. statistics_expiration: float = 600 # statistics will be removed if not updated in this many seconds
  35. dht_listen_on: str = '[::]:*' # network interface used for incoming DHT communication. Default: all ipv6
  36. listen_on: str = '[::]:*' # network interface used for incoming averager communication. Default: all ipv6
  37. endpoint: Optional[str] = None # this node's IP for inbound connections, used when running from behind a proxy
  38. batch_size_lead: int = 0 # optional: begin looking for group in advance, this many samples before target_batch_size
  39. compression: str = 'FLOAT16' # use this compression when averaging parameters/gradients
  40. min_refresh_period: float = 0.5 # wait for at least this many seconds before fetching new collaboration state
  41. max_refresh_period: float = 30 # wait for at most this many seconds before fetching new collaboration state
  42. default_refresh_period: float = 3 # attempt to fetch collaboration state every this often until successful
  43. expected_drift_peers: float = 3 # trainer assumes that this many new peers can join per step
  44. expected_drift_rate: float = 0.2 # trainer assumes that this fraction of current size can join per step
  45. bandwidth: float = 100.0 # available network bandwidth, in mbps (used for load balancing in all-reduce)
  46. performance_ema_alpha: float = 0.1 # uses this alpha for moving average estimate of samples per second
  47. @dataclass
  48. class DatasetArguments:
  49. dataset_path: Optional[str] = field(default='./data/albert_tokenized_wikitext',
  50. metadata={"help": "Path to the tokenized dataset"})
  51. tokenizer_path: Optional[str] = field(default='./data/tokenizer',
  52. metadata={"help": "Path to the tokenizer"})
  53. config_path: Optional[str] = field(
  54. default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
  55. metadata={"help": "Path to the model config"})
  56. cache_dir: Optional[str] = field(default='./data', metadata={"help": "Path to the cache"})
  57. @dataclass
  58. class AlbertTrainingArguments(TrainingArguments):
  59. dataloader_num_workers: int = 4
  60. per_device_train_batch_size: int = 4
  61. per_device_eval_batch_size: int = 4
  62. gradient_accumulation_steps: int = 2
  63. seq_length: int = 512
  64. max_steps: int = 1_000_000 # Albert is actually ready after 125000 steps
  65. learning_rate: float = 0.00176
  66. warmup_steps: int = 5000
  67. adam_epsilon: float = 1e-6
  68. weight_decay: float = 0.01
  69. max_grad_norm: float = 1.0
  70. clamp_value: float = 10000.0
  71. fp16: bool = True
  72. fp16_opt_level: str = 'O2'
  73. do_train: bool = True
  74. logging_steps: int = 100
  75. save_total_limit: int = 2
  76. save_steps: int = 500
  77. def setup_logging(training_args):
  78. logging.basicConfig(
  79. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  80. datefmt="%m/%d/%Y %H:%M:%S",
  81. level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
  82. )
  83. # Log on each process the small summary:
  84. logger.warning(
  85. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
  86. + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  87. )
  88. # Set the verbosity to info of the Transformers logger (on main process only):
  89. if is_main_process(training_args.local_rank):
  90. transformers.utils.logging.set_verbosity_info()
  91. transformers.utils.logging.enable_default_handler()
  92. transformers.utils.logging.enable_explicit_format()
  93. logger.info("Training/evaluation parameters %s", training_args)
  94. def get_model(training_args, config, tokenizer):
  95. # Find latest checkpoint in output_dir
  96. output_dir = Path(training_args.output_dir)
  97. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  98. latest_checkpoint_dir = max(output_dir.glob('checkpoint*'), default=None, key=os.path.getctime)
  99. if latest_checkpoint_dir is not None:
  100. logger.info(f'Loading model from {latest_checkpoint_dir}')
  101. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  102. else:
  103. logger.info(f'Training from scratch')
  104. model = AlbertForPreTraining(config)
  105. model.resize_token_embeddings(len(tokenizer))
  106. return model
  107. def get_optimizer_and_scheduler(training_args, model):
  108. no_decay = ["bias", "LayerNorm.weight"]
  109. optimizer_grouped_parameters = [
  110. {
  111. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  112. "weight_decay": training_args.weight_decay,
  113. },
  114. {
  115. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  116. "weight_decay": 0.0,
  117. },
  118. ]
  119. opt = Lamb(
  120. optimizer_grouped_parameters,
  121. lr=training_args.learning_rate,
  122. betas=(training_args.adam_beta1, training_args.adam_beta2),
  123. eps=training_args.adam_epsilon,
  124. weight_decay=training_args.weight_decay,
  125. clamp_value=training_args.clamp_value,
  126. debias=True,
  127. )
  128. scheduler = get_linear_schedule_with_warmup(
  129. opt,
  130. num_warmup_steps=training_args.warmup_steps,
  131. num_training_steps=training_args.max_steps
  132. )
  133. return opt, scheduler
  134. class CollaborativeCallback(transformers.TrainerCallback):
  135. def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer,
  136. model: torch.nn.Module, local_public_key: bytes, statistics_expiration: float):
  137. super().__init__()
  138. self.model = model
  139. self.dht, self.collaborative_optimizer = dht, optimizer
  140. self.local_public_key = local_public_key
  141. self.statistics_expiration = statistics_expiration
  142. self.last_reported_collaboration_step = -1
  143. self.previous_state = self.get_current_state()
  144. self.samples = 0
  145. self.steps = 0
  146. self.loss = 0
  147. def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
  148. control: transformers.TrainerControl, **kwargs):
  149. control.should_log = True
  150. if not self.params_are_finite():
  151. self.load_from_state(self.previous_state)
  152. return control
  153. self.previous_state = self.get_current_state()
  154. if state.log_history:
  155. self.loss += state.log_history[-1]['loss']
  156. self.steps += 1
  157. if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
  158. self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
  159. samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
  160. statistics = metrics_utils.LocalMetrics(
  161. step=self.collaborative_optimizer.local_step,
  162. samples_per_second=samples_per_second,
  163. samples_accumulated=self.samples,
  164. loss=self.loss,
  165. mini_steps=self.steps)
  166. self.loss = 0
  167. self.steps = 0
  168. self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
  169. subkey=self.local_public_key, value=statistics.dict(),
  170. expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
  171. return_future=True)
  172. self.samples = self.collaborative_optimizer.local_samples_accumulated
  173. return control
  174. @torch.no_grad()
  175. def get_current_state(self) -> Dict[str, Any]:
  176. return {
  177. 'model': self.model.state_dict(),
  178. 'opt': self.collaborative_optimizer.opt.state_dict()
  179. }
  180. @torch.no_grad()
  181. def load_from_state(self, state):
  182. self.model.load_state_dict(state['model'])
  183. self.collaborative_optimizer.opt.load_state_dict(state['opt'])
  184. @torch.no_grad()
  185. def params_are_finite(self):
  186. for param in self.model.parameters():
  187. if not torch.all(torch.isfinite(param)):
  188. return False
  189. return True
  190. class NoOpScheduler(LRSchedulerBase):
  191. """ Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler """
  192. def get_lr(self):
  193. return [group['lr'] for group in self.optimizer.param_groups]
  194. def print_lr(self, *args, **kwargs):
  195. if self.optimizer.scheduler:
  196. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  197. def step(self):
  198. logger.debug("Called NoOpScheduler.step")
  199. self._last_lr = self.get_lr()
  200. def state_dict(self):
  201. return {}
  202. def load_state_dict(self, *args, **kwargs):
  203. logger.debug("Called NoOpScheduler.load_state_dict")
  204. def main():
  205. parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments))
  206. training_args, dataset_args, collaboration_args = parser.parse_args_into_dataclasses()
  207. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  208. if len(collaboration_args.initial_peers) == 0:
  209. raise ValueError("Please specify at least one network endpoint in initial peers.")
  210. collaboration_args_dict = asdict(collaboration_args)
  211. setup_logging(training_args)
  212. # Set seed before initializing model.
  213. set_seed(training_args.seed)
  214. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  215. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  216. model = get_model(training_args, config, tokenizer)
  217. model.to(training_args.device)
  218. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  219. # This data collator will take care of randomly masking the tokens.
  220. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  221. opt, scheduler = get_optimizer_and_scheduler(training_args, model)
  222. validators, local_public_key = metrics_utils.make_validators(
  223. collaboration_args_dict['experiment_prefix'])
  224. dht = hivemind.DHT(
  225. start=True, initial_peers=collaboration_args_dict.pop('initial_peers'),
  226. listen=not collaboration_args_dict['client_mode'],
  227. listen_on=collaboration_args_dict.pop('dht_listen_on'),
  228. endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
  229. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  230. statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
  231. adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
  232. - collaboration_args_dict.pop('batch_size_lead')
  233. collaborative_optimizer = hivemind.CollaborativeOptimizer(
  234. opt=opt, dht=dht, scheduler=scheduler, prefix=collaboration_args_dict.pop('experiment_prefix'),
  235. compression_type=hivemind.utils.CompressionType.Value(collaboration_args_dict.pop('compression')),
  236. batch_size_per_step=total_batch_size_per_step, throughput=collaboration_args_dict.pop('bandwidth'),
  237. target_batch_size=adjusted_target_batch_size, client_mode=collaboration_args_dict.pop('client_mode'),
  238. verbose=True, start=True, **collaboration_args_dict
  239. )
  240. class TrainerWithIndependentShuffling(Trainer):
  241. def get_train_dataloader(self) -> DataLoader:
  242. """ Shuffle data independently for each peer to avoid duplicating batches [important for quality] """
  243. torch.manual_seed(hash(local_public_key))
  244. return super().get_train_dataloader()
  245. trainer = TrainerWithIndependentShuffling(
  246. model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator,
  247. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  248. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  249. optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
  250. callbacks=[CollaborativeCallback(
  251. dht, collaborative_optimizer, model, local_public_key, statistics_expiration)]
  252. )
  253. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  254. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  255. # Training
  256. if training_args.do_train:
  257. latest_checkpoint_dir = max(
  258. Path(training_args.output_dir).glob('checkpoint*'),
  259. default=None,
  260. key=os.path.getctime
  261. )
  262. trainer.train(model_path=latest_checkpoint_dir)
  263. if __name__ == "__main__":
  264. main()