arguments.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from dataclasses import dataclass, field
  2. from typing import List, Optional
  3. import torch
  4. from transformers import TrainingArguments
  5. @dataclass
  6. class HFTrainerArguments(TrainingArguments):
  7. """Arguments for huggingface/transformers.Trainer"""
  8. dataloader_num_workers: int = 1
  9. per_device_train_batch_size: int = 2
  10. per_device_eval_batch_size: int = 2
  11. gradient_accumulation_steps: int = 1
  12. text_seq_length: int = 256
  13. # DALLE-specific params
  14. learning_rate: float = 0.003535
  15. adam_beta1: float = 0.9
  16. adam_beta2: float = 0.96
  17. max_grad_norm: float = 4.0
  18. weight_decay: float = 0.045
  19. total_steps: int = 31250 # total number of collaborative SGD updates, used for learning rate schedule
  20. warmup_steps: int = 3125
  21. adam_epsilon: float = 1e-6
  22. clamp_value: float = 10000.0
  23. fp16: bool = False
  24. fp16_opt_level: str = "O2"
  25. do_train: bool = True
  26. logging_steps: int = 100
  27. max_steps: int = 10 ** 20
  28. save_steps: int = 10 ** 20
  29. save_total_limit: int = 2
  30. output_dir: str = "outputs"
  31. @property
  32. def batch_size_per_step(self):
  33. """Compute the number of training sequences contributed by each .step() from this peer"""
  34. total_batch_size_per_step = self.per_device_train_batch_size * self.gradient_accumulation_steps
  35. if torch.cuda.device_count() > 0:
  36. total_batch_size_per_step *= torch.cuda.device_count()
  37. return total_batch_size_per_step
  38. @dataclass
  39. class TPUTrainerArguments(HFTrainerArguments):
  40. num_tpus: int = 8 # the total number of TPU cores in use
  41. wandb_project: str = "huggingface"
  42. @property
  43. def batch_size_per_step(self):
  44. """Compute the number of training sequences contributed by each .step() from this peer"""
  45. return self.per_device_train_batch_size * self.gradient_accumulation_steps * self.num_tpus
  46. @dataclass
  47. class CollaborativeArguments:
  48. """Configuration for CollaborativeOptimzier and its internals"""
  49. target_batch_size: int = field(
  50. default=16384,
  51. metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
  52. )
  53. matchmaking_time: float = field(
  54. default=15.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
  55. )
  56. averaging_timeout: float = field(
  57. default=120, metadata={"help": "Give up on averaging step after this many seconds"}
  58. )
  59. reuse_grad_buffers: bool = field(default=True, metadata={
  60. "help": "Whether or not to use model's .grad buffers for accumulating gradients across local steps. This "
  61. "optimization reduces GPU memory consumption but may result in incorrect gradients when using some "
  62. "advanced techniques (e.g. applying custom loss scaler)"})
  63. @dataclass
  64. class BasePeerArguments:
  65. """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
  66. experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
  67. tokenizer_path: Optional[str] = field(default="t5-small", metadata={"help": "Path to the tokenizer"})
  68. cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
  69. authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})
  70. client_mode: bool = field(
  71. default=False,
  72. metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"},
  73. )
  74. initial_peers: List[str] = field(
  75. default_factory=list,
  76. metadata={
  77. "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
  78. "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
  79. },
  80. )
  81. use_ipfs: bool = field(
  82. default=False,
  83. metadata={
  84. "help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of multiaddrs "
  85. "for the initial_peers (no need to specify a particular IPv4/IPv6 address and port)"
  86. },
  87. )
  88. host_maddrs: List[str] = field(
  89. default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"],
  90. metadata={
  91. "help": "Multiaddrs to listen for external connections from other p2p instances. "
  92. "Defaults to all IPv4 interfaces with TCP protocol: /ip4/0.0.0.0/tcp/0"
  93. },
  94. )
  95. announce_maddrs: List[str] = field(
  96. default_factory=list,
  97. metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
  98. )
  99. identity_path: Optional[str] = field(
  100. default=None,
  101. metadata={
  102. "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
  103. "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
  104. },
  105. )
  106. @dataclass
  107. class TrainingPeerArguments(BasePeerArguments):
  108. statistics_expiration: float = field(
  109. default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
  110. )
  111. backup_every_steps: Optional[int] = field(
  112. default=None, metadata={"help": "Update training state backup on disk once in this many global steps "
  113. "(default = do not update local state)"}
  114. )
  115. state_path: str = field(
  116. default="state.zip", metadata={"help": "Load this state upon init and when recovering from NaN parameters"})
  117. @dataclass
  118. class AuxiliaryPeerArguments(BasePeerArguments):
  119. """
  120. Arguments for run_aux_peer.py that is responsible for connecting peers to one another, tracking
  121. learning curves, assisting in all-reduce and uploading checkpoints to the hub
  122. """
  123. refresh_period: float = field(default=10, metadata={"help": "Period (in seconds) for fetching the keys from DHT"})
  124. wandb_project: Optional[str] = field(
  125. default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
  126. )
  127. save_checkpoint_step_interval: int = field(
  128. default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
  129. )
  130. repo_url: Optional[str] = field(
  131. default=None, metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"}
  132. )
  133. local_path: Optional[str] = field(
  134. default="Repo", metadata={"help": "Path to local repository to store the model and optimizer states"}
  135. )
  136. upload_interval: Optional[float] = field(
  137. default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
  138. )
  139. store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
  140. assist_in_averaging: bool = field(
  141. default=False, metadata={"help": "If True, this peer will facilitate averaging for other (training) peers"})
  142. assist_refresh: float = field(default=1.0, metadata={"help": "Period (in seconds) for tryin to assist averaging"})