arguments.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from dataclasses import dataclass, field
  2. from typing import Optional, List
  3. from transformers import TrainingArguments
  4. @dataclass
  5. class BaseTrainingArguments:
  6. experiment_prefix: str = field(
  7. metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
  8. )
  9. initial_peers: List[str] = field(
  10. default_factory=list,
  11. metadata={
  12. "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
  13. "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
  14. },
  15. )
  16. use_ipfs: bool = field(
  17. default=False,
  18. metadata={
  19. "help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of the multiaddrs "
  20. "for the initial_peers (no need to specify a particular IPv4/IPv6 host and port)"
  21. },
  22. )
  23. host_maddrs: List[str] = field(
  24. default_factory=lambda: ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
  25. metadata={
  26. "help": "Multiaddrs to listen for external connections from other p2p instances. "
  27. "Defaults to all IPv4 interfaces with TCP and QUIC (over UDP) protocols: "
  28. "/ip4/0.0.0.0/tcp/0 /ip4/0.0.0.0/udp/0/quic"
  29. },
  30. )
  31. announce_maddrs: List[str] = field(
  32. default_factory=list,
  33. metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
  34. )
  35. @dataclass
  36. class AveragerArguments:
  37. averaging_expiration: float = field(
  38. default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
  39. )
  40. averaging_timeout: float = field(
  41. default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
  42. )
  43. listen_on: str = field(
  44. default="[::]:*",
  45. metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"},
  46. )
  47. min_refresh_period: float = field(
  48. default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
  49. )
  50. max_refresh_period: float = field(
  51. default=30, metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
  52. )
  53. default_refresh_period: float = field(
  54. default=3, metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
  55. )
  56. expected_drift_peers: float = field(
  57. default=3, metadata={"help": "Trainer assumes that this many new peers can join per step"}
  58. )
  59. expected_drift_rate: float = field(
  60. default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
  61. )
  62. performance_ema_alpha: float = field(
  63. default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
  64. )
  65. target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
  66. metadata_expiration: float = field(
  67. default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
  68. )
  69. @dataclass
  70. class CollaborativeOptimizerArguments:
  71. target_batch_size: int = field(
  72. default=4096,
  73. metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
  74. )
  75. client_mode: bool = field(
  76. default=False,
  77. metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"},
  78. )
  79. batch_size_lead: int = field(
  80. default=0,
  81. metadata={"help": "Optional: begin looking for group in advance, this many samples before target_batch_size"},
  82. )
  83. bandwidth: float = field(
  84. default=100.0,
  85. metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
  86. )
  87. compression: str = field(
  88. default="FLOAT16", metadata={"help": "Use this compression when averaging parameters/gradients"}
  89. )
  90. @dataclass
  91. class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArguments):
  92. statistics_expiration: float = field(
  93. default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
  94. )
  95. backup_every_steps: int = field(
  96. default=10, metadata={"help": "Frequency of backups to restore from in case of encountering NaN values"}
  97. )
  98. @dataclass
  99. class DatasetArguments:
  100. dataset_path: Optional[str] = field(
  101. default="data/albert_tokenized_wikitext", metadata={"help": "Path to the tokenized dataset"}
  102. )
  103. tokenizer_path: Optional[str] = field(default="data/tokenizer", metadata={"help": "Path to the tokenizer"})
  104. config_path: Optional[str] = field(
  105. default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
  106. metadata={"help": "Path to the model config"},
  107. )
  108. cache_dir: Optional[str] = field(default="data", metadata={"help": "Path to the cache"})
  109. @dataclass
  110. class AlbertTrainingArguments(TrainingArguments):
  111. dataloader_num_workers: int = 4
  112. per_device_train_batch_size: int = 4
  113. per_device_eval_batch_size: int = 4
  114. gradient_accumulation_steps: int = 2
  115. seq_length: int = 512
  116. max_steps: int = 125_000 # please note: this affects both number of steps and learning rate schedule
  117. learning_rate: float = 0.00176
  118. warmup_steps: int = 5000
  119. adam_epsilon: float = 1e-6
  120. weight_decay: float = 0.01
  121. max_grad_norm: float = 1.0
  122. clamp_value: float = 10000.0
  123. fp16: bool = True
  124. fp16_opt_level: str = "O2"
  125. do_train: bool = True
  126. logging_steps: int = 100
  127. save_total_limit: int = 2
  128. save_steps: int = 500
  129. output_dir: str = "outputs"