Prechádzať zdrojové kódy

Don't use .lm_head() in benchmark_forward.py

Aleksandr Borzunov 2 rokov pred
rodič
commit
dcf5183b69
1 zmenil súbory, kde vykonal 8 pridanie a 3 odobranie
  1. 8 3
      src/petals/cli/benchmark_forward.py

+ 8 - 3
src/petals/cli/benchmark_forward.py

@@ -45,9 +45,14 @@ def benchmark_forward(process_idx, args):
     torch.manual_seed(42)
     for step in range(args.n_steps):
         input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
-        logger.info(f"Fwd begin {input_ids.shape=}")
-        outputs = model.forward(input_ids)
-        logger.info("Fwd end")
+
+        logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
+        embs = model.transformer.word_embeddings(token_ids)
+        embs = model.transformer.word_embeddings_layernorm(embs)
+        h = sess.step(embs)
+        h_last = model.transformer.ln_f(h[:, -1])
+        # We don't use model.lm_head
+        logger.info(f"{process_idx=} Fwd end")
 
         if step == 0:
             start_time = perf_counter()