@@ -85,6 +85,8 @@ class TrainingTask:
rotary_emb=True,
reversible=True,
)
+ logger.info(f"Trainable parameters: "
+ f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
self.model = ModelWrapper(dalle)
else:
logger.info(f"Loading model from {latest_checkpoint_dir}")