|
@@ -57,41 +57,39 @@ class TrainingTask:
|
|
|
self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
+ logger.info(f"Creating model")
|
|
|
+ depth = 64
|
|
|
+ attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
|
|
|
+ attn_types.append('conv_like')
|
|
|
+ shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
|
|
|
+ shared_layer_ids.append('w_conv')
|
|
|
+ dalle = DALLE(
|
|
|
+ vae=VQGanParams(),
|
|
|
+ num_text_tokens=self.tokenizer.vocab_size,
|
|
|
+ text_seq_len=trainer_args.text_seq_length,
|
|
|
+ dim=1024,
|
|
|
+ depth=depth,
|
|
|
+ heads=16,
|
|
|
+ dim_head=64,
|
|
|
+ attn_types=attn_types,
|
|
|
+ ff_dropout=0,
|
|
|
+ attn_dropout=0,
|
|
|
+ shared_attn_ids=shared_layer_ids,
|
|
|
+ shared_ff_ids=shared_layer_ids,
|
|
|
+ rotary_emb=True,
|
|
|
+ reversible=True,
|
|
|
+ share_input_output_emb=True,
|
|
|
+ )
|
|
|
+ logger.info(f"Trainable parameters: "
|
|
|
+ f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
|
|
|
+ self.model = ModelWrapper(dalle)
|
|
|
+
|
|
|
output_dir = Path(trainer_args.output_dir)
|
|
|
logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
|
|
|
latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
|
|
|
-
|
|
|
- if latest_checkpoint_dir is None:
|
|
|
- logger.info(f"Creating model")
|
|
|
-
|
|
|
- depth = 64
|
|
|
- attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
|
|
|
- attn_types.append('conv_like')
|
|
|
- shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
|
|
|
- shared_layer_ids.append('w_conv')
|
|
|
- dalle = DALLE(
|
|
|
- vae=VQGanParams(),
|
|
|
- num_text_tokens=self.tokenizer.vocab_size,
|
|
|
- text_seq_len=trainer_args.text_seq_length,
|
|
|
- dim=1024,
|
|
|
- depth=depth,
|
|
|
- heads=16,
|
|
|
- dim_head=64,
|
|
|
- attn_types=attn_types,
|
|
|
- ff_dropout=0,
|
|
|
- attn_dropout=0,
|
|
|
- shared_attn_ids=shared_layer_ids,
|
|
|
- shared_ff_ids=shared_layer_ids,
|
|
|
- rotary_emb=True,
|
|
|
- reversible=True,
|
|
|
- share_input_output_emb=True,
|
|
|
- )
|
|
|
- logger.info(f"Trainable parameters: "
|
|
|
- f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
|
|
|
- self.model = ModelWrapper(dalle)
|
|
|
- else:
|
|
|
+ if latest_checkpoint_dir is not None:
|
|
|
logger.info(f"Loading model from {latest_checkpoint_dir}")
|
|
|
- self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
|
|
|
+ self.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
|
|
|
|
|
|
@property
|
|
|
def dht(self):
|