Browse Source

Log number of params

Aleksandr Borzunov 3 năm trước cách đây
mục cha
commit
df54ab6da5
1 tập tin đã thay đổi với 2 bổ sung0 xóa
  1. 2 0
      task.py

+ 2 - 0
task.py

@@ -85,6 +85,8 @@ class TrainingTask:
                 rotary_emb=True,
                 rotary_emb=True,
                 reversible=True,
                 reversible=True,
             )
             )
+            logger.info(f"Trainable parameters: "
+                        f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
             self.model = ModelWrapper(dalle)
             self.model = ModelWrapper(dalle)
         else:
         else:
             logger.info(f"Loading model from {latest_checkpoint_dir}")
             logger.info(f"Loading model from {latest_checkpoint_dir}")