|
@@ -61,14 +61,14 @@ def benchmark_training(process_idx, args):
|
|
|
bwd_times = []
|
|
|
for step in range(args.n_steps):
|
|
|
input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
|
|
|
- labels = torch.randint(0, 2, size=[args.batch_size])
|
|
|
+ if args.task == "cls":
|
|
|
+ labels = torch.randint(0, 2, size=[args.batch_size])
|
|
|
+ else:
|
|
|
+ labels = input_ids
|
|
|
|
|
|
logger.info(f"{process_idx=} {step=} Forward")
|
|
|
start_time = perf_counter()
|
|
|
- if args.task == "cls":
|
|
|
- outputs = model(input_ids, labels=labels)
|
|
|
- else:
|
|
|
- outputs = model(input_ids)
|
|
|
+ outputs = model(input_ids, labels=labels)
|
|
|
fwd_times.append(perf_counter() - start_time)
|
|
|
|
|
|
logger.info(f"{process_idx=} {step=} Backward")
|