Aleksandr Borzunov 2 роки тому
батько
коміт
a41ea38b6a
1 змінених файлів з 5 додано та 5 видалено
  1. 5 5
      src/petals/cli/benchmark_training.py

+ 5 - 5
src/petals/cli/benchmark_training.py

@@ -61,14 +61,14 @@ def benchmark_training(process_idx, args):
     bwd_times = []
     for step in range(args.n_steps):
         input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
-        labels = torch.randint(0, 2, size=[args.batch_size])
+        if args.task == "cls":
+            labels = torch.randint(0, 2, size=[args.batch_size])
+        else:
+            labels = input_ids
 
         logger.info(f"{process_idx=} {step=} Forward")
         start_time = perf_counter()
-        if args.task == "cls":
-            outputs = model(input_ids, labels=labels)
-        else:
-            outputs = model(input_ids)
+        outputs = model(input_ids, labels=labels)
         fwd_times.append(perf_counter() - start_time)
 
         logger.info(f"{process_idx=} {step=} Backward")