Aleksandr Borzunov 2 年 前
コミット
a4d0c9e82f
1 ファイル変更1 行追加4 行削除
  1. 1 4
      src/petals/cli/benchmark_forward.py

+ 1 - 4
src/petals/cli/benchmark_forward.py

@@ -47,10 +47,7 @@ def benchmark_forward(process_idx, args):
         input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
 
         logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
-        embs = model.transformer.word_embeddings(input_ids)
-        embs = model.transformer.word_embeddings_layernorm(embs)
-        h = sess.step(embs)
-        h_last = model.transformer.ln_f(h[:, -1])
+        h = model.transformer(input_ids)
         # We don't use model.lm_head
         logger.info(f"{process_idx=} Fwd end")