Browse Source

Measure fwd and bwd speeds separately

Aleksandr Borzunov 2 years ago
parent
commit
c7142dba8d
1 changed files with 12 additions and 6 deletions
  1. 12 6
      src/petals/cli/benchmark_training.py

+ 12 - 6
src/petals/cli/benchmark_training.py

@@ -4,6 +4,7 @@ import argparse
 import multiprocessing as mp
 from time import perf_counter
 
+import numpy as np
 import torch
 import petals.client.sequential_autograd
 from hivemind.utils.logging import get_logger
@@ -48,27 +49,32 @@ def benchmark_training(process_idx, args):
     logger.info(f"Created model: {process_idx=} {model.device=}")
 
     torch.manual_seed(42)
+    fwd_times = []
+    bwd_times = []
     for step in range(args.n_steps):
         input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
         labels = torch.randint(0, 2, size=[args.batch_size])
 
         logger.info(f"{process_idx=} {step=} Forward")
+        start_time = perf_counter()
         outputs = model(input_ids, labels=labels)
+        fwd_times.append(perf_counter() - start_time)
 
         logger.info(f"{process_idx=} {step=} Backward")
+        start_time = perf_counter()
         outputs.loss.backward()
+        bwd_times.append(perf_counter() - start_time)
 
         logger.info(f"{process_idx=} {step=} Optimizer step")
         opt.step()
         opt.zero_grad()
 
-        if step == 0:
-            start_time = perf_counter()
-        else:
-            speed = step / (perf_counter() - start_time) * input_ids.numel()
-            logger.info(f"{process_idx=} {step=} {speed=:.3f}")
+        if step >= 1:
+            fwd_speed = step / np.mean(fwd_times[1:]) * input_ids.numel()
+            bwd_speed = step / np.mean(bwd_times[1:]) * input_ids.numel()
+            logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
 
-    logger.info(f"Final result: {process_idx=} {speed=:.3f}")
+    logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
 
 
 if __name__ == "__main__":