run_trainer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. #!/usr/bin/env python
  2. import logging
  3. import os
  4. from dataclasses import dataclass, field, asdict
  5. from pathlib import Path
  6. from typing import Optional, Dict, Any, List
  7. import uuid
  8. from datasets import load_from_disk
  9. import transformers
  10. from torch.utils.data import DataLoader
  11. from transformers import (set_seed, HfArgumentParser, TrainingArguments,
  12. DataCollatorForLanguageModeling, AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining)
  13. from transformers.optimization import get_linear_schedule_with_warmup
  14. from transformers.trainer_utils import is_main_process
  15. from transformers.trainer import Trainer
  16. from torch_optimizer import Lamb
  17. import torch
  18. import hivemind
  19. logger = logging.getLogger(__name__)
  20. LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
  21. @dataclass
  22. class CollaborationArguments:
  23. """ define how peers interact with each other while training"""
  24. # primary parameters
  25. initial_peers: List[str] # one or more peers (comma-separated) that will welcome you into the collaboration
  26. experiment_prefix: str # a unique "name" of this experiment, used to store metadata on the DHT
  27. averaging_expiration: float = 5.0 # averaging group will wait for stragglers for at most this many seconds
  28. averaging_timeout: float = 30.0 # give up on averaging step after this many seconds
  29. target_batch_size: int = 4096 # perform optimizer step after all peers collectively accumulate this many samples
  30. client_mode: bool = False # if True, runs training without incoming connections, in a firewall-compatible mode
  31. trainer_uuid: str = uuid.uuid4().hex # this peer's name - used when publishing metadata to DHT, default = random
  32. # optional tweaks
  33. target_group_size: int = 64 # maximum group size for all-reduce
  34. metadata_expiration: float = 30 # peer's metadata will be removed if not updated in this many seconds
  35. statistics_expiration: float = 600 # statistics will be removed if not updated in this many seconds
  36. dht_listen_on: str = '[::]:*' # network interface used for incoming DHT communication. Default: all ipv6
  37. listen_on: str = '[::]:*' # network interface used for incoming averager communication. Default: all ipv6
  38. endpoint: Optional[str] = None # this node's IP for inbound connections, used when running from behind a proxy
  39. batch_size_lead: int = 0 # optional: begin looking for group in advance, this many samples before target_batch_size
  40. compression: str = 'FLOAT16' # use this compression when averaging parameters/gradients
  41. min_refresh_period: float = 0.5 # wait for at least this many seconds before fetching new collaboration state
  42. max_refresh_period: float = 30 # wait for at most this many seconds before fetching new collaboration state
  43. default_refresh_period: float = 3 # attempt to fetch collaboration state every this often until successful
  44. expected_drift_peers: float = 3 # trainer assumes that this many new peers can join per step
  45. expected_drift_rate: float = 0.2 # trainer assumes that this fraction of current size can join per step
  46. bandwidth: float = 100.0 # available network bandwidth, in mbps (used for load balancing in all-reduce)
  47. performance_ema_alpha: float = 0.1 # uses this alpha for moving average estimate of samples per second
  48. @dataclass
  49. class DatasetArguments:
  50. dataset_path: Optional[str] = field(default='./data/albert_tokenized_wikitext',
  51. metadata={"help": "Path to the tokenized dataset"})
  52. tokenizer_path: Optional[str] = field(default='./data/tokenizer',
  53. metadata={"help": "Path to the tokenizer"})
  54. config_path: Optional[str] = field(
  55. default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
  56. metadata={"help": "Path to the model config"})
  57. cache_dir: Optional[str] = field(default='./data', metadata={"help": "Path to the cache"})
  58. @dataclass
  59. class AlbertTrainingArguments(TrainingArguments):
  60. dataloader_num_workers: int = 4
  61. per_device_train_batch_size: int = 4
  62. per_device_eval_batch_size: int = 4
  63. gradient_accumulation_steps: int = 2
  64. seq_length: int = 512
  65. max_steps: int = 1_000_000 # Albert is actually ready after 125000 steps
  66. learning_rate: float = 0.00176
  67. warmup_steps: int = 5000
  68. adam_epsilon: float = 1e-6
  69. weight_decay: float = 0.01
  70. max_grad_norm: float = 1.0
  71. clamp_value: float = 10000.0
  72. fp16: bool = True
  73. fp16_opt_level: str = 'O2'
  74. do_train: bool = True
  75. logging_steps: int = 100
  76. save_total_limit: int = 2
  77. save_steps: int = 500
  78. def setup_logging(training_args):
  79. logging.basicConfig(
  80. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  81. datefmt="%m/%d/%Y %H:%M:%S",
  82. level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
  83. )
  84. # Log on each process the small summary:
  85. logger.warning(
  86. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
  87. + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  88. )
  89. # Set the verbosity to info of the Transformers logger (on main process only):
  90. if is_main_process(training_args.local_rank):
  91. transformers.utils.logging.set_verbosity_info()
  92. transformers.utils.logging.enable_default_handler()
  93. transformers.utils.logging.enable_explicit_format()
  94. logger.info("Training/evaluation parameters %s", training_args)
  95. def get_model(training_args, config, tokenizer):
  96. # Find latest checkpoint in output_dir
  97. output_dir = Path(training_args.output_dir)
  98. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  99. latest_checkpoint_dir = max(output_dir.glob('checkpoint*'), default=None, key=os.path.getctime)
  100. if latest_checkpoint_dir is not None:
  101. logger.info(f'Loading model from {latest_checkpoint_dir}')
  102. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  103. else:
  104. logger.info(f'Training from scratch')
  105. model = AlbertForPreTraining(config)
  106. model.resize_token_embeddings(len(tokenizer))
  107. return model
  108. def get_optimizer_and_scheduler(training_args, model):
  109. no_decay = ["bias", "LayerNorm.weight"]
  110. optimizer_grouped_parameters = [
  111. {
  112. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  113. "weight_decay": training_args.weight_decay,
  114. },
  115. {
  116. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  117. "weight_decay": 0.0,
  118. },
  119. ]
  120. opt = Lamb(
  121. optimizer_grouped_parameters,
  122. lr=training_args.learning_rate,
  123. betas=(training_args.adam_beta1, training_args.adam_beta2),
  124. eps=training_args.adam_epsilon,
  125. weight_decay=training_args.weight_decay,
  126. clamp_value=training_args.clamp_value,
  127. debias=True,
  128. )
  129. scheduler = get_linear_schedule_with_warmup(
  130. opt,
  131. num_warmup_steps=training_args.warmup_steps,
  132. num_training_steps=training_args.max_steps
  133. )
  134. return opt, scheduler
  135. class CollaborativeCallback(transformers.TrainerCallback):
  136. def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer,
  137. model: torch.nn.Module, trainer_uuid: str, statistics_expiration: float):
  138. super().__init__()
  139. self.model = model
  140. self.dht, self.collaborative_optimizer = dht, optimizer
  141. self.trainer_uuid, self.statistics_expiration = trainer_uuid, statistics_expiration
  142. self.last_reported_collaboration_step = -1
  143. self.previous_state = self.get_current_state()
  144. self.samples = 0
  145. self.steps = 0
  146. self.loss = 0
  147. def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
  148. control: transformers.TrainerControl, **kwargs):
  149. control.should_log = True
  150. if not self.params_are_finite():
  151. self.load_from_state(self.previous_state)
  152. return control
  153. self.previous_state = self.get_current_state()
  154. if state.log_history:
  155. self.loss += state.log_history[-1]['loss']
  156. if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
  157. self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
  158. statistics = [self.collaborative_optimizer.local_step,
  159. self.collaborative_optimizer.performance_ema.samples_per_second,
  160. self.samples,
  161. self.loss / self.steps if self.steps else 0]
  162. self.loss = 0
  163. self.dht.store(self.collaborative_optimizer.prefix + "_metrics", subkey=self.trainer_uuid,
  164. value=statistics, expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
  165. return_future=True)
  166. self.samples = self.collaborative_optimizer.local_samples_accumulated
  167. self.steps = self.collaborative_optimizer.local_steps_accumulated
  168. return control
  169. @torch.no_grad()
  170. def get_current_state(self) -> Dict[str, Any]:
  171. return {
  172. 'model': self.model.state_dict(),
  173. 'opt': self.collaborative_optimizer.opt.state_dict()
  174. }
  175. @torch.no_grad()
  176. def load_from_state(self, state):
  177. self.model.load_state_dict(state['model'])
  178. self.collaborative_optimizer.opt.load_state_dict(state['opt'])
  179. @torch.no_grad()
  180. def params_are_finite(self):
  181. for param in self.model.parameters():
  182. if not torch.all(torch.isfinite(param)):
  183. return False
  184. return True
  185. class NoOpScheduler(LRSchedulerBase):
  186. """ Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler """
  187. def get_lr(self):
  188. return [group['lr'] for group in self.optimizer.param_groups]
  189. def print_lr(self, *args, **kwargs):
  190. if self.optimizer.scheduler:
  191. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  192. def step(self):
  193. logger.debug("Called NoOpScheduler.step")
  194. self._last_lr = self.get_lr()
  195. def state_dict(self):
  196. return {}
  197. def load_state_dict(self, *args, **kwargs):
  198. logger.debug("Called NoOpScheduler.load_state_dict")
  199. def main():
  200. parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments))
  201. training_args, dataset_args, collaboration_args = parser.parse_args_into_dataclasses()
  202. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  203. if len(collaboration_args.initial_peers) == 0:
  204. raise ValueError("Please specify at least one network endpoint in initial peers.")
  205. collaboration_args_dict = asdict(collaboration_args)
  206. setup_logging(training_args)
  207. # Set seed before initializing model.
  208. set_seed(training_args.seed)
  209. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  210. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  211. model = get_model(training_args, config, tokenizer)
  212. model.to(training_args.device)
  213. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  214. # This data collator will take care of randomly masking the tokens.
  215. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  216. opt, scheduler = get_optimizer_and_scheduler(training_args, model)
  217. dht = hivemind.DHT(
  218. initial_peers=collaboration_args_dict.pop('initial_peers'),
  219. listen=not collaboration_args_dict['client_mode'], listen_on=collaboration_args_dict.pop('dht_listen_on'),
  220. endpoint=collaboration_args_dict.pop('endpoint'), start=True)
  221. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  222. trainer_uuid = collaboration_args_dict.pop('trainer_uuid')
  223. statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
  224. adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
  225. - collaboration_args_dict.pop('batch_size_lead')
  226. collaborative_optimizer = hivemind.CollaborativeOptimizer(
  227. opt=opt, dht=dht, scheduler=scheduler, prefix=collaboration_args_dict.pop('experiment_prefix'),
  228. compression_type=hivemind.utils.CompressionType.Value(collaboration_args_dict.pop('compression')),
  229. batch_size_per_step=total_batch_size_per_step, throughput=collaboration_args_dict.pop('bandwidth'),
  230. target_batch_size=adjusted_target_batch_size, client_mode=collaboration_args_dict.pop('client_mode'),
  231. verbose=True, start=True, **collaboration_args_dict
  232. )
  233. class TrainerWithIndependentShuffling(Trainer):
  234. def get_train_dataloader(self) -> DataLoader:
  235. """ Shuffle data independently for each peer to avoid duplicating batches [important for quality] """
  236. torch.manual_seed(hash(trainer_uuid))
  237. return super().get_train_dataloader()
  238. trainer = TrainerWithIndependentShuffling(
  239. model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator,
  240. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  241. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  242. optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
  243. callbacks=[CollaborativeCallback(dht, collaborative_optimizer, model, trainer_uuid, statistics_expiration)]
  244. )
  245. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  246. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  247. # Training
  248. if training_args.do_train:
  249. latest_checkpoint_dir = max(
  250. Path(training_args.output_dir).glob('checkpoint*'),
  251. default=None,
  252. key=os.path.getctime
  253. )
  254. trainer.train(model_path=latest_checkpoint_dir)
  255. if __name__ == "__main__":
  256. main()