|
@@ -47,10 +47,7 @@ def benchmark_forward(process_idx, args):
|
|
|
input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
|
|
|
|
|
|
logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
|
|
|
- embs = model.transformer.word_embeddings(input_ids)
|
|
|
- embs = model.transformer.word_embeddings_layernorm(embs)
|
|
|
- h = sess.step(embs)
|
|
|
- h_last = model.transformer.ln_f(h[:, -1])
|
|
|
+ h = model.transformer(input_ids)
|
|
|
# We don't use model.lm_head
|
|
|
logger.info(f"{process_idx=} Fwd end")
|
|
|
|