|
@@ -16,6 +16,8 @@ from transformers.models.albert import AlbertConfig, AlbertForPreTraining, Alber
|
|
|
from transformers.optimization import get_linear_schedule_with_warmup
|
|
|
from transformers.trainer import Trainer
|
|
|
from transformers.trainer_utils import is_main_process
|
|
|
+from transformers.optimization import Adafactor, AdafactorSchedule
|
|
|
+
|
|
|
|
|
|
import hivemind
|
|
|
from hivemind.utils.compression import CompressionType
|
|
@@ -77,18 +79,13 @@ def get_optimizer_and_scheduler(training_args, model):
|
|
|
},
|
|
|
]
|
|
|
|
|
|
- opt = Lamb(
|
|
|
+ opt = Adafactor(
|
|
|
optimizer_grouped_parameters,
|
|
|
- lr=training_args.learning_rate,
|
|
|
- betas=(training_args.adam_beta1, training_args.adam_beta2),
|
|
|
- eps=training_args.adam_epsilon,
|
|
|
- weight_decay=training_args.weight_decay,
|
|
|
- clamp_value=training_args.clamp_value,
|
|
|
- debias=True,
|
|
|
+ scale_parameter=True, relative_step=True, warmup_init=True, lr=None
|
|
|
)
|
|
|
|
|
|
- scheduler = get_linear_schedule_with_warmup(
|
|
|
- opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
|
|
|
+ scheduler = AdafactorSchedule(
|
|
|
+ opt
|
|
|
)
|
|
|
|
|
|
return opt, scheduler
|
|
@@ -219,8 +216,6 @@ def main():
|
|
|
training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
|
|
|
|
|
|
logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
|
|
|
- if len(collaboration_args.initial_peers) == 0:
|
|
|
- raise ValueError("Please specify at least one network endpoint in initial peers.")
|
|
|
|
|
|
setup_logging(training_args)
|
|
|
|
|
@@ -231,6 +226,7 @@ def main():
|
|
|
tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
|
|
|
model = get_model(training_args, config, tokenizer)
|
|
|
model.to(training_args.device)
|
|
|
+ model.tie_weights()
|
|
|
|
|
|
tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
|
|
|
# This data collator will take care of randomly masking the tokens.
|
|
@@ -310,5 +306,10 @@ def main():
|
|
|
trainer.train(model_path=latest_checkpoint_dir)
|
|
|
|
|
|
|
|
|
+def _mp_fn(index):
|
|
|
+ # For xla_spawn (TPUs)
|
|
|
+ main()
|
|
|
+
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
main()
|