|
@@ -82,7 +82,7 @@ class TrainingTask:
|
|
attn_dropout=0,
|
|
attn_dropout=0,
|
|
shared_attn_ids=shared_layer_ids,
|
|
shared_attn_ids=shared_layer_ids,
|
|
shared_ff_ids=shared_layer_ids,
|
|
shared_ff_ids=shared_layer_ids,
|
|
- rotary_emb=False, # FIXME: Fix RuntimeError when True
|
|
|
|
|
|
+ rotary_emb=True,
|
|
reversible=True,
|
|
reversible=True,
|
|
)
|
|
)
|
|
self.model = ModelWrapper(dalle)
|
|
self.model = ModelWrapper(dalle)
|