@@ -82,7 +82,7 @@ class TrainingTask:
attn_dropout=0,
shared_attn_ids=shared_layer_ids,
shared_ff_ids=shared_layer_ids,
- rotary_emb=False, # FIXME: Fix RuntimeError when True
+ rotary_emb=True,
reversible=True,
)
self.model = ModelWrapper(dalle)