|
@@ -45,9 +45,14 @@ def benchmark_forward(process_idx, args):
|
|
|
torch.manual_seed(42)
|
|
|
for step in range(args.n_steps):
|
|
|
input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
|
|
|
- logger.info(f"Fwd begin {input_ids.shape=}")
|
|
|
- outputs = model.forward(input_ids)
|
|
|
- logger.info("Fwd end")
|
|
|
+
|
|
|
+ logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
|
|
|
+ embs = model.transformer.word_embeddings(token_ids)
|
|
|
+ embs = model.transformer.word_embeddings_layernorm(embs)
|
|
|
+ h = sess.step(embs)
|
|
|
+ h_last = model.transformer.ln_f(h[:, -1])
|
|
|
+ # We don't use model.lm_head
|
|
|
+ logger.info(f"{process_idx=} Fwd end")
|
|
|
|
|
|
if step == 0:
|
|
|
start_time = perf_counter()
|