|
@@ -15,7 +15,7 @@ class HFTrainerArguments(TrainingArguments):
|
|
|
text_seq_length: int = 256
|
|
|
|
|
|
# DALLE-specific params
|
|
|
- learning_rate: float = 0.003535
|
|
|
+ learning_rate: float = 0.0025
|
|
|
adam_beta1: float = 0.9
|
|
|
adam_beta2: float = 0.96
|
|
|
max_grad_norm: float = 4.0
|
|
@@ -60,13 +60,13 @@ class TPUTrainerArguments(HFTrainerArguments):
|
|
|
class CollaborativeArguments:
|
|
|
"""Configuration for CollaborativeOptimzier and its internals"""
|
|
|
target_batch_size: int = field(
|
|
|
- default=16384,
|
|
|
+ default=4096,
|
|
|
metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
|
|
|
)
|
|
|
matchmaking_time: float = field(
|
|
|
default=30.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
|
|
|
)
|
|
|
- allreduce_timeout_timeout: float = field(
|
|
|
+ allreduce_timeout: float = field(
|
|
|
default=60, metadata={"help": "Give up on a given all-reduce round after this many seconds"}
|
|
|
)
|
|
|
averaging_timeout: float = field(
|