arguments.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from dataclasses import dataclass, field
  2. from typing import List, Optional
  3. import torch
  4. from transformers import TrainingArguments
  5. @dataclass
  6. class HFTrainerArguments(TrainingArguments):
  7. """Arguments for huggingface/transformers.Trainer"""
  8. dataloader_num_workers: int = 1
  9. per_device_train_batch_size: int = 1
  10. per_device_eval_batch_size: int = 1
  11. gradient_accumulation_steps: int = 1
  12. seq_length: int = 512
  13. pad_to_multiple_of: int = 8
  14. learning_rate: float = 0.0025
  15. total_steps: int = 31250 # total number of collaborative SGD updates, used for learning rate schedule
  16. warmup_steps: int = 3125
  17. adam_epsilon: float = 1e-6
  18. weight_decay: float = 0.01
  19. max_grad_norm: float = 1.0
  20. clamp_value: float = 10000.0
  21. fp16: bool = False
  22. fp16_opt_level: str = "O2"
  23. do_train: bool = True
  24. logging_steps: int = 100
  25. max_steps: int = 10 ** 20
  26. save_steps: int = 10 ** 20
  27. save_total_limit: int = 2
  28. output_dir: str = "outputs"
  29. @property
  30. def batch_size_per_step(self):
  31. """Compute the number of training sequences contributed by each .step() from this peer"""
  32. total_batch_size_per_step = self.per_device_train_batch_size * self.gradient_accumulation_steps
  33. if torch.cuda.device_count() > 0:
  34. total_batch_size_per_step *= torch.cuda.device_count()
  35. return total_batch_size_per_step
  36. @dataclass
  37. class TPUTrainerArguments(HFTrainerArguments):
  38. num_tpus: int = 8 # the total number of TPU cores in use
  39. wandb_project: str = "huggingface"
  40. @property
  41. def batch_size_per_step(self):
  42. """Compute the number of training sequences contributed by each .step() from this peer"""
  43. return self.per_device_train_batch_size * self.gradient_accumulation_steps * self.num_tpus
  44. @dataclass
  45. class CollaborativeArguments:
  46. """Configuration for CollaborativeOptimzier and its internals"""
  47. target_batch_size: int = field(
  48. default=16384,
  49. metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
  50. )
  51. target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
  52. bandwidth: float = field(
  53. default=100.0,
  54. metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
  55. )
  56. averaging_expiration: float = field(
  57. default=15.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
  58. )
  59. averaging_timeout: float = field(
  60. default=120.0, metadata={"help": "Give up on averaging step after this many seconds"}
  61. )
  62. min_refresh_period: float = field(
  63. default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
  64. )
  65. max_refresh_period: float = field(
  66. default=30, metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
  67. )
  68. default_refresh_period: float = field(
  69. default=3, metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
  70. )
  71. expected_drift_peers: float = field(
  72. default=3, metadata={"help": "Trainer assumes that this many new peers can join per step"}
  73. )
  74. expected_drift_rate: float = field(
  75. default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
  76. )
  77. performance_ema_alpha: float = field(
  78. default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
  79. )
  80. metadata_expiration: float = field(
  81. default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
  82. )
  83. reuse_grad_buffers: bool = field(default=True, metadata={
  84. "help": "Whether or not to use model's .grad buffers for accumulating gradients across local steps. This "
  85. "optimization reduces GPU memory consumption but may result in incorrect gradients when using some "
  86. "advanced techniques (e.g. applying custom loss scaler)"})
  87. @dataclass
  88. class BasePeerArguments:
  89. """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
  90. experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
  91. model_config_path: Optional[str] = field(default="./model.json", metadata={"help": "Path to the model config"})
  92. tokenizer_path: Optional[str] = field(default="./tokenizer", metadata={"help": "Path to the tokenizer"})
  93. cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
  94. authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})
  95. client_mode: bool = field(
  96. default=False,
  97. metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"},
  98. )
  99. initial_peers: List[str] = field(
  100. default_factory=list,
  101. metadata={
  102. "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
  103. "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
  104. },
  105. )
  106. use_ipfs: bool = field(
  107. default=False,
  108. metadata={
  109. "help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of multiaddrs "
  110. "for the initial_peers (no need to specify a particular IPv4/IPv6 address and port)"
  111. },
  112. )
  113. host_maddrs: List[str] = field(
  114. default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"],
  115. metadata={
  116. "help": "Multiaddrs to listen for external connections from other p2p instances. "
  117. "Defaults to all IPv4 interfaces with TCP protocol: /ip4/0.0.0.0/tcp/0"
  118. },
  119. )
  120. announce_maddrs: List[str] = field(
  121. default_factory=list,
  122. metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
  123. )
  124. identity_path: Optional[str] = field(
  125. default=None,
  126. metadata={
  127. "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
  128. "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
  129. },
  130. )
  131. @dataclass
  132. class TrainingPeerArguments(BasePeerArguments):
  133. statistics_expiration: float = field(
  134. default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
  135. )
  136. backup_every_steps: Optional[int] = field(
  137. default=None, metadata={"help": "Update training state backup on disk once in this many global steps "
  138. "(default = do not update local state)"}
  139. )
  140. state_path: str = field(
  141. default="state.zip", metadata={"help": "Load this state upon init and when recovering from NaN parameters"})
  142. @dataclass
  143. class AuxiliaryPeerArguments(BasePeerArguments):
  144. """
  145. Arguments for run_aux_peer.py that is responsible for connecting peers to one another, tracking
  146. learning curves, assisting in all-reduce and uploading checkpoints to the hub
  147. """
  148. refresh_period: float = field(default=10, metadata={"help": "Period (in seconds) for fetching the keys from DHT"})
  149. wandb_project: Optional[str] = field(
  150. default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
  151. )
  152. save_checkpoint_step_interval: int = field(
  153. default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
  154. )
  155. repo_url: Optional[str] = field(
  156. default=None, metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"}
  157. )
  158. local_path: Optional[str] = field(
  159. default="Repo", metadata={"help": "Path to local repository to store the model and optimizer states"}
  160. )
  161. upload_interval: Optional[float] = field(
  162. default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
  163. )
  164. store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
  165. assist_in_averaging: bool = field(
  166. default=False, metadata={"help": "If True, this peer will facilitate averaging for other (training) peers"})
  167. assist_refresh: float = field(default=1.0, metadata={"help": "Period (in seconds) for tryin to assist averaging"})