arguments.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from dataclasses import dataclass, field
  2. from typing import List, Optional
  3. import torch
  4. from transformers import TrainingArguments
  5. @dataclass
  6. class HFTrainerArguments(TrainingArguments):
  7. """Arguments for huggingface/transformers.Trainer"""
  8. dataloader_num_workers: int = 1
  9. per_device_train_batch_size: int = 2
  10. per_device_eval_batch_size: int = 2
  11. gradient_accumulation_steps: int = 1
  12. text_seq_length: int = 256
  13. # DALLE-specific params
  14. learning_rate: float = 0.0025
  15. adam_beta1: float = 0.9
  16. adam_beta2: float = 0.96
  17. max_grad_norm: float = 4.0
  18. weight_decay: float = 0.045
  19. total_steps: int = 31250 # total number of collaborative SGD updates, used for learning rate schedule
  20. warmup_steps: int = 3125
  21. adam_epsilon: float = 1e-6
  22. clamp_value: float = 10000.0
  23. fp16: bool = False
  24. do_train: bool = True
  25. logging_steps: int = 100
  26. max_steps: int = 10 ** 20
  27. save_steps: int = 10 ** 20
  28. save_total_limit: int = 2
  29. output_dir: str = "outputs"
  30. @property
  31. def batch_size_per_step(self):
  32. """Compute the number of training sequences contributed by each .step() from this peer"""
  33. total_batch_size_per_step = self.per_device_train_batch_size * self.gradient_accumulation_steps
  34. if torch.cuda.device_count() > 0:
  35. total_batch_size_per_step *= torch.cuda.device_count()
  36. return total_batch_size_per_step
  37. @dataclass
  38. class TPUTrainerArguments(HFTrainerArguments):
  39. num_tpus: int = 8 # the total number of TPU cores in use
  40. wandb_project: str = "huggingface"
  41. @property
  42. def batch_size_per_step(self):
  43. """Compute the number of training sequences contributed by each .step() from this peer"""
  44. return self.per_device_train_batch_size * self.gradient_accumulation_steps * self.num_tpus
  45. @dataclass
  46. class CollaborativeArguments:
  47. """Configuration for CollaborativeOptimzier and its internals"""
  48. target_batch_size: int = field(
  49. default=4096,
  50. metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
  51. )
  52. matchmaking_time: float = field(
  53. default=15.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
  54. )
  55. allreduce_timeout: float = field(
  56. default=60, metadata={"help": "Give up on a given all-reduce round after this many seconds"}
  57. )
  58. averaging_timeout: float = field(
  59. default=180, metadata={"help": "Give up on averaging step after this many seconds"}
  60. )
  61. reuse_grad_buffers: bool = field(default=True, metadata={
  62. "help": "Whether or not to use model's .grad buffers for accumulating gradients across local steps. This "
  63. "optimization reduces GPU memory consumption but may result in incorrect gradients when using some "
  64. "advanced techniques (e.g. applying custom loss scaler)"})
  65. @dataclass
  66. class BasePeerArguments:
  67. """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
  68. experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
  69. tokenizer_path: Optional[str] = field(default="t5-small", metadata={"help": "Path to the tokenizer"})
  70. cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
  71. authorize: bool = field(default=True, metadata={"help": "Whether or not to use HF authorizer"})
  72. client_mode: bool = field(
  73. default=False,
  74. metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"},
  75. )
  76. initial_peers: List[str] = field(
  77. default_factory=list,
  78. metadata={
  79. "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
  80. "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
  81. },
  82. )
  83. use_ipfs: bool = field(
  84. default=False,
  85. metadata={
  86. "help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of multiaddrs "
  87. "for the initial_peers (no need to specify a particular IPv4/IPv6 address and port)"
  88. },
  89. )
  90. host_maddrs: List[str] = field(
  91. default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"],
  92. metadata={
  93. "help": "Multiaddrs to listen for external connections from other p2p instances. "
  94. "Defaults to all IPv4 interfaces with TCP protocol: /ip4/0.0.0.0/tcp/0"
  95. },
  96. )
  97. announce_maddrs: List[str] = field(
  98. default_factory=list,
  99. metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
  100. )
  101. identity_path: Optional[str] = field(
  102. default=None,
  103. metadata={
  104. "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
  105. "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
  106. },
  107. )
  108. @dataclass
  109. class TrainingPeerArguments(BasePeerArguments):
  110. statistics_expiration: float = field(
  111. default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
  112. )
  113. backup_every_steps: Optional[int] = field(
  114. default=None, metadata={"help": "Update training state backup on disk once in this many global steps "
  115. "(default = do not update local state)"}
  116. )
  117. state_path: str = field(
  118. default="state.zip", metadata={"help": "Load this state upon init and when recovering from NaN parameters"})
  119. @dataclass
  120. class AuxiliaryPeerArguments(BasePeerArguments):
  121. """
  122. Arguments for run_aux_peer.py that is responsible for connecting peers to one another, tracking
  123. learning curves, assisting in all-reduce and uploading checkpoints to the hub
  124. """
  125. refresh_period: float = field(default=10, metadata={"help": "Period (in seconds) for fetching the keys from DHT"})
  126. wandb_project: Optional[str] = field(
  127. default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
  128. )
  129. save_checkpoint_step_interval: int = field(
  130. default=2, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
  131. )
  132. repo_url: Optional[str] = field(
  133. default=None, metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"}
  134. )
  135. local_path: Optional[str] = field(
  136. default="Repo", metadata={"help": "Path to local repository to store the model and optimizer states"}
  137. )
  138. upload_interval: Optional[float] = field(
  139. default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
  140. )
  141. store_checkpoints: bool = field(default=True, metadata={"help": "If True, enables CheckpointHandler"})
  142. assist_in_averaging: bool = field(
  143. default=False, metadata={"help": "If True, this peer will facilitate averaging for other (training) peers"})
  144. assist_refresh: float = field(default=1.0, metadata={"help": "Period (in seconds) for tryin to assist averaging"})