Bladeren bron

Make checkpointing optional in example (#303)

Michael Diskin 4 jaren geleden
bovenliggende
commit
2e1bb9c1c2
1 gewijzigde bestanden met toevoegingen van 12 en 7 verwijderingen
  1. 12 7
      examples/albert/run_first_peer.py

+ 12 - 7
examples/albert/run_first_peer.py

@@ -56,6 +56,10 @@ class CoordinatorArguments(BaseTrainingArguments):
         default=None,
         metadata={"help": "Coordinator will upload model once in this many seconds"}
     )
+    store_checkpoins: bool = field(
+        default=False,
+        metadata={"help": "If True, enables CheckpointHandler"}
+    )
 
 
 class CheckpointHandler:
@@ -151,8 +155,8 @@ if __name__ == '__main__':
         wandb.init(project=coordinator_args.wandb_project)
 
     current_step = 0
-
-    checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
+    if coordinator_args.store_checkpoins:
+        checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
 
     while True:
         metrics_dict = dht.get(experiment_prefix + '_metrics', latest=True)
@@ -189,10 +193,11 @@ if __name__ == '__main__':
                         "performance": sum_perf,
                         "step": latest_step
                     })
-                if checkpoint_handler.is_time_to_save_state(current_step):
-                    checkpoint_handler.save_state(current_step)
-                    if checkpoint_handler.is_time_to_upload():
-                        checkpoint_handler.upload_checkpoint(current_loss)
-                logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
+                if coordinator_args.store_checkpoins:
+                    if checkpoint_handler.is_time_to_save_state(current_step):
+                        checkpoint_handler.save_state(current_step)
+                        if checkpoint_handler.is_time_to_upload():
+                            checkpoint_handler.upload_checkpoint(current_loss)
+                    logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
         logger.debug("Peer is still alive...")
         time.sleep(coordinator_args.refresh_period)