arguments.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from typing import Optional, List
  2. from dataclasses import dataclass, field
  3. from transformers import TrainingArguments
  4. @dataclass
  5. class BaseTrainingArguments:
  6. experiment_prefix: str = field(
  7. metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
  8. )
  9. initial_peers: List[str] = field(
  10. default_factory=list,
  11. metadata={"help": "One or more peers (comma-separated) that will welcome you into the collaboration"}
  12. )
  13. dht_listen_on: str = field(
  14. default="[::]:*",
  15. metadata={"help": "Network interface used for incoming DHT communication. Default: all ipv6"}
  16. )
  17. @dataclass
  18. class AveragerArguments:
  19. averaging_expiration: float = field(
  20. default=5.0,
  21. metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
  22. )
  23. averaging_timeout: float = field(
  24. default=30.0,
  25. metadata={"help": "Give up on averaging step after this many seconds"}
  26. )
  27. listen_on: str = field(
  28. default="[::]:*",
  29. metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"}
  30. )
  31. min_refresh_period: float = field(
  32. default=0.5,
  33. metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
  34. )
  35. max_refresh_period: float = field(
  36. default=30,
  37. metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
  38. )
  39. default_refresh_period: float = field(
  40. default=3,
  41. metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
  42. )
  43. expected_drift_peers: float = field(
  44. default=3,
  45. metadata={"help": "Trainer assumes that this many new peers can join per step"}
  46. )
  47. expected_drift_rate: float = field(
  48. default=0.2,
  49. metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
  50. )
  51. performance_ema_alpha: float = field(
  52. default=0.1,
  53. metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
  54. )
  55. target_group_size: int = field(
  56. default=256,
  57. metadata={"help": "Maximum group size for all-reduce"}
  58. )
  59. metadata_expiration: float = field(
  60. default=30,
  61. metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
  62. )
  63. @dataclass
  64. class CollaborativeOptimizerArguments:
  65. target_batch_size: int = field(
  66. default=4096,
  67. metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"}
  68. )
  69. client_mode: bool = field(
  70. default=False,
  71. metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"}
  72. )
  73. batch_size_lead: int = field(
  74. default=0,
  75. metadata={"help": "Optional: begin looking for group in advance, this many samples before target_batch_size"}
  76. )
  77. bandwidth: float = field(
  78. default=100.0,
  79. metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"}
  80. )
  81. compression: str = field(
  82. default="NONE",
  83. metadata={"help": "Use this compression when averaging parameters/gradients"}
  84. )
  85. @dataclass
  86. class CollaborationArguments(AveragerArguments, CollaborativeOptimizerArguments, BaseTrainingArguments):
  87. statistics_expiration: float = field(
  88. default=600,
  89. metadata={"help": "Statistics will be removed if not updated in this many seconds"}
  90. )
  91. endpoint: Optional[str] = field(
  92. default=None,
  93. metadata={"help": "This node's IP for inbound connections, used when running from behind a proxy"}
  94. )
  95. @dataclass
  96. class DatasetArguments:
  97. dataset_path: Optional[str] = field(
  98. default='data/albert_tokenized_wikitext',
  99. metadata={"help": "Path to the tokenized dataset"}
  100. )
  101. tokenizer_path: Optional[str] = field(
  102. default='data/tokenizer',
  103. metadata={"help": "Path to the tokenizer"}
  104. )
  105. config_path: Optional[str] = field(
  106. default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
  107. metadata={"help": "Path to the model config"}
  108. )
  109. cache_dir: Optional[str] = field(
  110. default='data',
  111. metadata={"help": "Path to the cache"}
  112. )
  113. @dataclass
  114. class AlbertTrainingArguments(TrainingArguments):
  115. dataloader_num_workers: int = 4
  116. per_device_train_batch_size: int = 4
  117. per_device_eval_batch_size: int = 4
  118. gradient_accumulation_steps: int = 2
  119. seq_length: int = 512
  120. max_steps: int = 1_000_000 # Albert is actually ready after 125000 steps
  121. learning_rate: float = 0.00176
  122. warmup_steps: int = 5000
  123. adam_epsilon: float = 1e-6
  124. weight_decay: float = 0.01
  125. max_grad_norm: float = 1.0
  126. clamp_value: float = 10000.0
  127. fp16: bool = False
  128. fp16_opt_level: str = 'O2'
  129. do_train: bool = True
  130. logging_steps: int = 100
  131. save_total_limit: int = 2
  132. save_steps: int = 500
  133. output_dir: str = 'outputs'