Aleksandr Borzunov 2 年之前
父節點
當前提交
a4d0c9e82f
共有 1 個文件被更改,包括 1 次插入4 次删除
  1. 1 4
      src/petals/cli/benchmark_forward.py

+ 1 - 4
src/petals/cli/benchmark_forward.py

@@ -47,10 +47,7 @@ def benchmark_forward(process_idx, args):
         input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
 
         logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
-        embs = model.transformer.word_embeddings(input_ids)
-        embs = model.transformer.word_embeddings_layernorm(embs)
-        h = sess.step(embs)
-        h_last = model.transformer.ln_f(h[:, -1])
+        h = model.transformer(input_ids)
         # We don't use model.lm_head
         logger.info(f"{process_idx=} Fwd end")