|
@@ -84,6 +84,7 @@ class TrainingTask:
|
|
|
shared_ff_ids=shared_layer_ids,
|
|
|
rotary_emb=True,
|
|
|
reversible=True,
|
|
|
+ share_input_output_emb=True,
|
|
|
)
|
|
|
logger.info(f"Trainable parameters: "
|
|
|
f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
|