|
@@ -9,7 +9,7 @@ import requests
|
|
|
import torch
|
|
|
import wandb
|
|
|
from torch_optimizer import Lamb
|
|
|
-from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
|
|
|
+from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser, get_linear_schedule_with_warmup
|
|
|
|
|
|
import hivemind
|
|
|
from hivemind.optim.state_averager import TrainingStateAverager
|
|
@@ -40,6 +40,7 @@ class TrainingMonitorArguments(BaseTrainingArguments):
|
|
|
wandb_project: Optional[str] = field(
|
|
|
default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
|
|
|
)
|
|
|
+ store_checkpoints: bool = field(default=True, metadata={"help": "If False, disables periodic checkpoint saving"})
|
|
|
save_checkpoint_step_interval: int = field(
|
|
|
default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
|
|
|
)
|
|
@@ -56,7 +57,6 @@ class TrainingMonitorArguments(BaseTrainingArguments):
|
|
|
upload_interval: Optional[float] = field(
|
|
|
default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
|
|
|
)
|
|
|
- store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
|
|
|
|
|
|
|
|
|
class CheckpointHandler:
|
|
@@ -99,7 +99,8 @@ class CheckpointHandler:
|
|
|
self.state_averager = TrainingStateAverager(
|
|
|
dht=dht,
|
|
|
optimizer=opt,
|
|
|
- prefix=experiment_prefix,
|
|
|
+ scheduler=get_linear_schedule_with_warmup(opt, num_warmup_steps=5000, num_training_steps=125_000),
|
|
|
+ prefix=f"{run_id}_state_averager",
|
|
|
state_compression=hivemind.Float16Compression(),
|
|
|
bandwidth=optimizer_args.bandwidth,
|
|
|
client_mode=optimizer_args.client_mode,
|
|
@@ -155,8 +156,8 @@ if __name__ == "__main__":
|
|
|
version = ip_address(address).version
|
|
|
monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]
|
|
|
|
|
|
- experiment_prefix = monitor_args.experiment_prefix
|
|
|
- validators, local_public_key = utils.make_validators(experiment_prefix)
|
|
|
+ run_id = monitor_args.run_id
|
|
|
+ validators, local_public_key = utils.make_validators(run_id)
|
|
|
|
|
|
dht = hivemind.DHT(
|
|
|
start=True,
|
|
@@ -177,7 +178,7 @@ if __name__ == "__main__":
|
|
|
checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)
|
|
|
|
|
|
while True:
|
|
|
- metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
|
|
|
+ metrics_dict = dht.get(run_id + "_metrics", latest=True)
|
|
|
if metrics_dict is not None:
|
|
|
metrics_dict = metrics_dict.value
|
|
|
metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
|