|
@@ -38,6 +38,9 @@ class ModelWrapper(nn.Module):
|
|
|
super().__init__()
|
|
|
self.model = model
|
|
|
|
|
|
+ def tie_weights(self):
|
|
|
+ pass
|
|
|
+
|
|
|
def forward(self, input_ids, attention_mask, image):
|
|
|
loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
|
|
|
return {'loss': loss}
|
|
@@ -64,7 +67,7 @@ class TrainingTask:
|
|
|
if latest_checkpoint_dir is None:
|
|
|
logger.info(f"Creating model")
|
|
|
|
|
|
- depth = 64
|
|
|
+ depth = 16#TODO
|
|
|
attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
|
|
|
attn_types.append('conv_like')
|
|
|
shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
|