run_trainer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. #!/usr/bin/env python
  2. import logging
  3. import os
  4. import pickle
  5. import threading
  6. import time
  7. from dataclasses import asdict
  8. from pathlib import Path
  9. import psutil
  10. import torch
  11. import transformers
  12. from datasets import load_from_disk
  13. from torch.utils.data import DataLoader
  14. from torch_optimizer import Lamb
  15. from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
  16. from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
  17. from transformers.optimization import get_linear_schedule_with_warmup
  18. from transformers.trainer import Trainer
  19. from transformers.trainer_utils import is_main_process
  20. import hivemind
  21. from hivemind.utils.compression import CompressionType
  22. import utils
  23. from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
  24. logger = logging.getLogger(__name__)
  25. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  26. def analyze_openfiles_periodically():
  27. while True:
  28. logger.info(f"Scanning open files for process {psutil.Process().pid}")
  29. children = [psutil.Process()] + psutil.Process().children(recursive=True)
  30. for child in children:
  31. open_files = child.open_files()
  32. logger.info(f"proc: '{child.name()}' pid={child.pid} parent={child.parent().pid} files: {len(open_files)}")
  33. for child in children:
  34. open_files = child.open_files()
  35. if len(open_files) > 100:
  36. logger.info(f"proc: {child.name()} has {len(open_files)} open files: {repr(open_files)}")
  37. logger.info("DONE scanning")
  38. time.sleep(30)
  39. analyzer = threading.Thread(target=analyze_openfiles_periodically)
  40. analyzer.start()
  41. def setup_logging(training_args):
  42. logging.basicConfig(
  43. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  44. datefmt="%m/%d/%Y %H:%M:%S",
  45. level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
  46. )
  47. # Log on each process the small summary:
  48. logger.warning(
  49. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
  50. + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  51. )
  52. # Set the verbosity to info of the Transformers logger (on main process only):
  53. if is_main_process(training_args.local_rank):
  54. transformers.utils.logging.set_verbosity_info()
  55. transformers.utils.logging.enable_default_handler()
  56. transformers.utils.logging.enable_explicit_format()
  57. logger.info("Training/evaluation parameters %s", training_args)
  58. def get_model(training_args, config, tokenizer):
  59. # Find latest checkpoint in output_dir
  60. output_dir = Path(training_args.output_dir)
  61. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  62. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  63. if latest_checkpoint_dir is not None:
  64. logger.info(f"Loading model from {latest_checkpoint_dir}")
  65. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  66. else:
  67. logger.info(f"Training from scratch")
  68. model = AlbertForPreTraining(config)
  69. model.resize_token_embeddings(len(tokenizer))
  70. return model
  71. def get_optimizer_and_scheduler(training_args, model):
  72. no_decay = ["bias", "LayerNorm.weight"]
  73. optimizer_grouped_parameters = [
  74. {
  75. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  76. "weight_decay": training_args.weight_decay,
  77. },
  78. {
  79. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  80. "weight_decay": 0.0,
  81. },
  82. ]
  83. opt = Lamb(
  84. optimizer_grouped_parameters,
  85. lr=training_args.learning_rate,
  86. betas=(training_args.adam_beta1, training_args.adam_beta2),
  87. eps=training_args.adam_epsilon,
  88. weight_decay=training_args.weight_decay,
  89. clamp_value=training_args.clamp_value,
  90. debias=True,
  91. )
  92. scheduler = get_linear_schedule_with_warmup(
  93. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
  94. )
  95. return opt, scheduler
  96. class CollaborativeCallback(transformers.TrainerCallback):
  97. """
  98. This callback monitors and reports collaborative training progress.
  99. In case of a catastrophic failure, it can also revert training to a backup.
  100. """
  101. def __init__(
  102. self,
  103. dht: hivemind.DHT,
  104. optimizer: hivemind.CollaborativeOptimizer,
  105. model: torch.nn.Module,
  106. local_public_key: bytes,
  107. statistics_expiration: float,
  108. backup_every_steps: int,
  109. ):
  110. super().__init__()
  111. self.model = model
  112. self.dht, self.collaborative_optimizer = dht, optimizer
  113. self.local_public_key = local_public_key
  114. self.statistics_expiration = statistics_expiration
  115. self.last_reported_collaboration_step = -1
  116. self.samples = 0
  117. self.steps = 0
  118. self.loss = 0
  119. self.total_samples_processed = 0
  120. self.backup_every_steps = backup_every_steps
  121. self.latest_backup = self.backup_state()
  122. def on_train_begin(
  123. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  124. ):
  125. logger.info("Loading state from peers")
  126. self.collaborative_optimizer.load_state_from_peers()
  127. def on_step_end(
  128. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  129. ):
  130. control.should_log = True
  131. if not self.params_are_finite():
  132. self.restore_from_backup(self.latest_backup)
  133. return control
  134. if state.log_history:
  135. self.loss += state.log_history[-1]["loss"]
  136. self.steps += 1
  137. if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
  138. self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
  139. self.total_samples_processed += self.samples
  140. samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
  141. statistics = utils.LocalMetrics(
  142. step=self.collaborative_optimizer.local_step,
  143. samples_per_second=samples_per_second,
  144. samples_accumulated=self.samples,
  145. loss=self.loss,
  146. mini_steps=self.steps,
  147. )
  148. logger.info(f"Step {self.collaborative_optimizer.local_step}")
  149. logger.info(f"Your current contribution: {self.total_samples_processed} samples")
  150. logger.info(f"Performance: {samples_per_second} samples per second.")
  151. if self.steps:
  152. logger.info(f"Local loss: {self.loss / self.steps}")
  153. if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
  154. self.latest_backup = self.backup_state()
  155. self.loss = 0
  156. self.steps = 0
  157. if self.collaborative_optimizer.is_synchronized:
  158. self.dht.store(
  159. key=self.collaborative_optimizer.prefix + "_metrics",
  160. subkey=self.local_public_key,
  161. value=statistics.dict(),
  162. expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
  163. return_future=True,
  164. )
  165. self.samples = self.collaborative_optimizer.local_samples_accumulated
  166. return control
  167. @torch.no_grad()
  168. def params_are_finite(self):
  169. for param in self.model.parameters():
  170. if not torch.all(torch.isfinite(param)):
  171. return False
  172. return True
  173. @torch.no_grad()
  174. def backup_state(self) -> bytes:
  175. return pickle.dumps(
  176. {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
  177. )
  178. @torch.no_grad()
  179. def restore_from_backup(self, backup: bytes):
  180. state = pickle.loads(backup)
  181. self.model.load_state_dict(state["model"])
  182. self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
  183. class NoOpScheduler(LRSchedulerBase):
  184. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
  185. def get_lr(self):
  186. return [group["lr"] for group in self.optimizer.param_groups]
  187. def print_lr(self, *args, **kwargs):
  188. if self.optimizer.scheduler:
  189. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  190. def step(self):
  191. logger.debug("Called NoOpScheduler.step")
  192. self._last_lr = self.get_lr()
  193. def state_dict(self):
  194. return {}
  195. def load_state_dict(self, *args, **kwargs):
  196. logger.debug("Called NoOpScheduler.load_state_dict")
  197. def main():
  198. parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
  199. training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
  200. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  201. if len(collaboration_args.initial_peers) == 0:
  202. raise ValueError("Please specify at least one network endpoint in initial peers.")
  203. setup_logging(training_args)
  204. # Set seed before initializing model.
  205. set_seed(training_args.seed)
  206. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  207. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  208. model = get_model(training_args, config, tokenizer)
  209. model.to(training_args.device)
  210. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  211. # This data collator will take care of randomly masking the tokens.
  212. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  213. opt, scheduler = get_optimizer_and_scheduler(training_args, model)
  214. validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
  215. dht = hivemind.DHT(
  216. start=True,
  217. initial_peers=collaboration_args.initial_peers,
  218. client_mode=collaboration_args.client_mode,
  219. record_validators=validators,
  220. use_ipfs=collaboration_args.use_ipfs,
  221. host_maddrs=collaboration_args.host_maddrs,
  222. announce_maddrs=collaboration_args.announce_maddrs,
  223. identity_path=collaboration_args.identity_path,
  224. )
  225. utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
  226. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  227. if torch.cuda.device_count() != 0:
  228. total_batch_size_per_step *= torch.cuda.device_count()
  229. adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
  230. collaborative_optimizer = hivemind.CollaborativeOptimizer(
  231. opt=opt,
  232. dht=dht,
  233. scheduler=scheduler,
  234. prefix=collaboration_args.experiment_prefix,
  235. compression_type=CompressionType.Value(collaboration_args.compression),
  236. batch_size_per_step=total_batch_size_per_step,
  237. bandwidth=collaboration_args.bandwidth,
  238. target_batch_size=adjusted_target_batch_size,
  239. client_mode=collaboration_args.client_mode,
  240. verbose=True,
  241. start=True,
  242. **asdict(averager_args),
  243. )
  244. class TrainerWithIndependentShuffling(Trainer):
  245. def get_train_dataloader(self) -> DataLoader:
  246. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  247. torch.manual_seed(hash(local_public_key))
  248. return super().get_train_dataloader()
  249. trainer = TrainerWithIndependentShuffling(
  250. model=model,
  251. args=training_args,
  252. tokenizer=tokenizer,
  253. data_collator=data_collator,
  254. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  255. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  256. optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
  257. callbacks=[
  258. CollaborativeCallback(
  259. dht,
  260. collaborative_optimizer,
  261. model,
  262. local_public_key,
  263. collaboration_args.statistics_expiration,
  264. collaboration_args.backup_every_steps,
  265. )
  266. ],
  267. )
  268. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  269. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  270. # Training
  271. if training_args.do_train:
  272. latest_checkpoint_dir = max(
  273. Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
  274. )
  275. trainer.train(model_path=latest_checkpoint_dir)
  276. if __name__ == "__main__":
  277. main()