Aleksandr Borzunov 2 gadi atpakaļ
vecāks
revīzija
7efa6eb99f
1 mainītis faili ar 2 papildinājumiem un 2 dzēšanām
  1. 2 2
      src/petals/cli/benchmark_training.py

+ 2 - 2
src/petals/cli/benchmark_training.py

@@ -70,8 +70,8 @@ def benchmark_training(process_idx, args):
         opt.zero_grad()
 
         if step >= 1:
-            fwd_speed = step / np.mean(fwd_times[1:]) * input_ids.numel()
-            bwd_speed = step / np.mean(bwd_times[1:]) * input_ids.numel()
+            fwd_speed = input_ids.numel() / np.mean(fwd_times[1:])
+            bwd_speed = input_ids.numel() / np.mean(bwd_times[1:])
             logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
 
     logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")