|
@@ -85,6 +85,8 @@ class TrainingTask:
|
|
rotary_emb=True,
|
|
rotary_emb=True,
|
|
reversible=True,
|
|
reversible=True,
|
|
)
|
|
)
|
|
|
|
+ logger.info(f"Trainable parameters: "
|
|
|
|
+ f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
|
|
self.model = ModelWrapper(dalle)
|
|
self.model = ModelWrapper(dalle)
|
|
else:
|
|
else:
|
|
logger.info(f"Loading model from {latest_checkpoint_dir}")
|
|
logger.info(f"Loading model from {latest_checkpoint_dir}")
|