|
@@ -56,6 +56,10 @@ class CoordinatorArguments(BaseTrainingArguments):
|
|
|
default=None,
|
|
|
metadata={"help": "Coordinator will upload model once in this many seconds"}
|
|
|
)
|
|
|
+ store_checkpoins: bool = field(
|
|
|
+ default=False,
|
|
|
+ metadata={"help": "If True, enables CheckpointHandler"}
|
|
|
+ )
|
|
|
|
|
|
|
|
|
class CheckpointHandler:
|
|
@@ -151,8 +155,8 @@ if __name__ == '__main__':
|
|
|
wandb.init(project=coordinator_args.wandb_project)
|
|
|
|
|
|
current_step = 0
|
|
|
-
|
|
|
- checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
|
|
|
+ if coordinator_args.store_checkpoins:
|
|
|
+ checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
|
|
|
|
|
|
while True:
|
|
|
metrics_dict = dht.get(experiment_prefix + '_metrics', latest=True)
|
|
@@ -189,10 +193,11 @@ if __name__ == '__main__':
|
|
|
"performance": sum_perf,
|
|
|
"step": latest_step
|
|
|
})
|
|
|
- if checkpoint_handler.is_time_to_save_state(current_step):
|
|
|
- checkpoint_handler.save_state(current_step)
|
|
|
- if checkpoint_handler.is_time_to_upload():
|
|
|
- checkpoint_handler.upload_checkpoint(current_loss)
|
|
|
- logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
|
|
|
+ if coordinator_args.store_checkpoins:
|
|
|
+ if checkpoint_handler.is_time_to_save_state(current_step):
|
|
|
+ checkpoint_handler.save_state(current_step)
|
|
|
+ if checkpoint_handler.is_time_to_upload():
|
|
|
+ checkpoint_handler.upload_checkpoint(current_loss)
|
|
|
+ logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
|
|
|
logger.debug("Peer is still alive...")
|
|
|
time.sleep(coordinator_args.refresh_period)
|