Browse Source

Fix loading state from checkpoint

Aleksandr Borzunov 3 years ago
parent
commit
156dc236a1
1 changed files with 29 additions and 31 deletions
  1. 29 31
      task.py

+ 29 - 31
task.py

@@ -57,41 +57,39 @@ class TrainingTask:
         self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
         self.tokenizer.pad_token = self.tokenizer.eos_token
 
+        logger.info(f"Creating model")
+        depth = 64
+        attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
+        attn_types.append('conv_like')
+        shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
+        shared_layer_ids.append('w_conv')
+        dalle = DALLE(
+            vae=VQGanParams(),
+            num_text_tokens=self.tokenizer.vocab_size,
+            text_seq_len=trainer_args.text_seq_length,
+            dim=1024,
+            depth=depth,
+            heads=16,
+            dim_head=64,
+            attn_types=attn_types,
+            ff_dropout=0,
+            attn_dropout=0,
+            shared_attn_ids=shared_layer_ids,
+            shared_ff_ids=shared_layer_ids,
+            rotary_emb=True,
+            reversible=True,
+            share_input_output_emb=True,
+        )
+        logger.info(f"Trainable parameters: "
+                    f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
+        self.model = ModelWrapper(dalle)
+
         output_dir = Path(trainer_args.output_dir)
         logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
         latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
-
-        if latest_checkpoint_dir is None:
-            logger.info(f"Creating model")
-
-            depth = 64
-            attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
-            attn_types.append('conv_like')
-            shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
-            shared_layer_ids.append('w_conv')
-            dalle = DALLE(
-                vae=VQGanParams(),
-                num_text_tokens=self.tokenizer.vocab_size,
-                text_seq_len=trainer_args.text_seq_length,
-                dim=1024,
-                depth=depth,
-                heads=16,
-                dim_head=64,
-                attn_types=attn_types,
-                ff_dropout=0,
-                attn_dropout=0,
-                shared_attn_ids=shared_layer_ids,
-                shared_ff_ids=shared_layer_ids,
-                rotary_emb=True,
-                reversible=True,
-                share_input_output_emb=True,
-            )
-            logger.info(f"Trainable parameters: "
-                        f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
-            self.model = ModelWrapper(dalle)
-        else:
+        if latest_checkpoint_dir is not None:
             logger.info(f"Loading model from {latest_checkpoint_dir}")
-            self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
+            self.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
 
     @property
     def dht(self):