arguments.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from dataclasses import dataclass, field
  2. from typing import Optional, List
  3. from transformers import TrainingArguments
  4. @dataclass
  5. class BaseTrainingArguments:
  6. experiment_prefix: str = field(
  7. metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
  8. )
  9. initial_peers: List[str] = field(
  10. default_factory=list,
  11. metadata={"help":
  12. "Multiaddrs of the peers that will welcome you into the existing collaboration. "
  13. "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"}
  14. )
  15. use_ipfs: bool = field(
  16. default=False,
  17. metadata={"help":
  18. "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of the multiaddrs "
  19. "for the initial_peers (no need to specify a particular IPv4/IPv6 host and port)"}
  20. )
  21. host_maddrs: List[str] = field(
  22. default_factory=lambda: ['/ip4/0.0.0.0/tcp/0', '/ip4/0.0.0.0/udp/0/quic'],
  23. metadata={"help":
  24. "Multiaddrs to listen for external connections from other p2p instances. "
  25. "Defaults to all IPv4 interfaces with TCP and QUIC (over UDP) protocols: "
  26. "/ip4/0.0.0.0/tcp/0 /ip4/0.0.0.0/udp/0/quic"}
  27. )
  28. announce_maddrs: List[str] = field(
  29. default_factory=list,
  30. metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"}
  31. )
  32. @dataclass
  33. class AveragerArguments:
  34. averaging_expiration: float = field(
  35. default=5.0,
  36. metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
  37. )
  38. averaging_timeout: float = field(
  39. default=30.0,
  40. metadata={"help": "Give up on averaging step after this many seconds"}
  41. )
  42. listen_on: str = field(
  43. default="[::]:*",
  44. metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"}
  45. )
  46. min_refresh_period: float = field(
  47. default=0.5,
  48. metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
  49. )
  50. max_refresh_period: float = field(
  51. default=30,
  52. metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
  53. )
  54. default_refresh_period: float = field(
  55. default=3,
  56. metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
  57. )
  58. expected_drift_peers: float = field(
  59. default=3,
  60. metadata={"help": "Trainer assumes that this many new peers can join per step"}
  61. )
  62. expected_drift_rate: float = field(
  63. default=0.2,
  64. metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
  65. )
  66. performance_ema_alpha: float = field(
  67. default=0.1,
  68. metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
  69. )
  70. target_group_size: int = field(
  71. default=256,
  72. metadata={"help": "Maximum group size for all-reduce"}
  73. )
  74. metadata_expiration: float = field(
  75. default=30,
  76. metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
  77. )
  78. @dataclass
  79. class CollaborativeOptimizerArguments:
  80. target_batch_size: int = field(
  81. default=4096,
  82. metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"}
  83. )
  84. client_mode: bool = field(
  85. default=False,
  86. metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"}
  87. )
  88. batch_size_lead: int = field(
  89. default=0,
  90. metadata={"help": "Optional: begin looking for group in advance, this many samples before target_batch_size"}
  91. )
  92. bandwidth: float = field(
  93. default=100.0,
  94. metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"}
  95. )
  96. compression: str = field(
  97. default="FLOAT16",
  98. metadata={"help": "Use this compression when averaging parameters/gradients"}
  99. )
  100. @dataclass
  101. class CollaborationArguments(AveragerArguments, CollaborativeOptimizerArguments, BaseTrainingArguments):
  102. statistics_expiration: float = field(
  103. default=600,
  104. metadata={"help": "Statistics will be removed if not updated in this many seconds"}
  105. )
  106. @dataclass
  107. class DatasetArguments:
  108. dataset_path: Optional[str] = field(
  109. default='data/albert_tokenized_wikitext',
  110. metadata={"help": "Path to the tokenized dataset"}
  111. )
  112. tokenizer_path: Optional[str] = field(
  113. default='data/tokenizer',
  114. metadata={"help": "Path to the tokenizer"}
  115. )
  116. config_path: Optional[str] = field(
  117. default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
  118. metadata={"help": "Path to the model config"}
  119. )
  120. cache_dir: Optional[str] = field(
  121. default='data',
  122. metadata={"help": "Path to the cache"}
  123. )
  124. @dataclass
  125. class AlbertTrainingArguments(TrainingArguments):
  126. dataloader_num_workers: int = 4
  127. per_device_train_batch_size: int = 4
  128. per_device_eval_batch_size: int = 4
  129. gradient_accumulation_steps: int = 2
  130. seq_length: int = 512
  131. max_steps: int = 125_000 # please note: this affects both number of steps and learning rate schedule
  132. learning_rate: float = 0.00176
  133. warmup_steps: int = 5000
  134. adam_epsilon: float = 1e-6
  135. weight_decay: float = 0.01
  136. max_grad_norm: float = 1.0
  137. clamp_value: float = 10000.0
  138. fp16: bool = True
  139. fp16_opt_level: str = 'O2'
  140. do_train: bool = True
  141. logging_steps: int = 100
  142. save_total_limit: int = 2
  143. save_steps: int = 500
  144. output_dir: str = 'outputs'