Bläddra i källkod

Fix warmup steps and minor issues in benchmarks (#334)

The previous code was incorrect for the case of `warmup_steps != 1` (this mode was never used, but can be used in future).
Alexander Borzunov 2 år sedan
förälder
incheckning
10c72acdf4
3 ändrade filer med 26 tillägg och 18 borttagningar
  1. 8 6
      benchmarks/benchmark_forward.py
  2. 11 7
      benchmarks/benchmark_inference.py
  3. 7 5
      benchmarks/benchmark_training.py

+ 8 - 6
benchmarks/benchmark_forward.py

@@ -4,6 +4,7 @@ import argparse
 import multiprocessing as mp
 from time import perf_counter
 
+import numpy as np
 import torch
 from hivemind.utils.logging import get_logger
 
@@ -47,9 +48,9 @@ def benchmark_forward(process_idx, args):
     logger.info(f"Created model: {process_idx=} {model.device=}")
 
     torch.manual_seed(42)
-    for step in range(args.n_steps):
-        if step == args.warmup_steps:
-            start_time = perf_counter()
+    step_times = []
+    for step in range(args.warmup_steps + args.n_steps):
+        start_time = perf_counter()
 
         input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
 
@@ -59,10 +60,11 @@ def benchmark_forward(process_idx, args):
         logger.info(f"{process_idx=} Fwd end")
 
         if step >= args.warmup_steps:
-            speed = step / (perf_counter() - start_time) * input_ids.numel()
-            logger.info(f"{process_idx=} {step=} {speed=:.3f}")
+            step_times.append(perf_counter() - start_time)
+            speed = input_ids.numel() / np.mean(step_times)
+            logger.info(f"{process_idx=} {step=} {speed=:.2f}")
 
-    logger.info(f"Final result: {process_idx=} {speed=:.3f}")
+    logger.info(f"Final result: {process_idx=} {speed=:.2f}")
 
 
 if __name__ == "__main__":

+ 11 - 7
benchmarks/benchmark_inference.py

@@ -4,6 +4,7 @@ import argparse
 import multiprocessing as mp
 from time import perf_counter
 
+import numpy as np
 import torch
 from hivemind.utils.logging import get_logger
 from transformers import AutoTokenizer
@@ -38,26 +39,29 @@ def main():
 
 @torch.inference_mode()
 def benchmark_inference(process_idx, args):
-    tokenizer = AutoTokenizer.from_pretrained(args.model)
+    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
+    # Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
+
     model = AutoDistributedModelForCausalLM.from_pretrained(
         args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
     )
-    logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}")
+    logger.info(f"Created model: {process_idx=} {model.device=}")
 
     result = ""
+    step_times = []
     with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
         for step in range(args.seq_len):
-            if step == args.warmup_steps:
-                start_time = perf_counter()
+            start_time = perf_counter()
 
             outputs = model.generate(max_new_tokens=1, session=sess)
             result += tokenizer.decode(outputs[0])
 
             if step >= args.warmup_steps:
-                speed = step / (perf_counter() - start_time)
-                logger.info(f"{process_idx=} {step=} {speed=:.3f}")
+                step_times.append(perf_counter() - start_time)
+                speed = 1 / np.mean(step_times)
+                logger.info(f"{process_idx=} {step=} {speed=:.2f}")
 
-    logger.info(f"Final result: {process_idx=} {speed=:.3f}")
+    logger.info(f"Final result: {process_idx=} {speed=:.2f}")
 
 
 if __name__ == "__main__":

+ 7 - 5
benchmarks/benchmark_training.py

@@ -68,7 +68,7 @@ def benchmark_training(process_idx, args):
     torch.manual_seed(42)
     fwd_times = []
     bwd_times = []
-    for step in range(args.n_steps):
+    for step in range(args.warmup_steps + args.n_steps):
         input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
         if args.task == "cls":
             labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
@@ -78,20 +78,22 @@ def benchmark_training(process_idx, args):
         logger.info(f"{process_idx=} {step=} Forward")
         start_time = perf_counter()
         outputs = model(input_ids, labels=labels)
-        fwd_times.append(perf_counter() - start_time)
+        if step >= args.warmup_steps:
+            fwd_times.append(perf_counter() - start_time)
 
         logger.info(f"{process_idx=} {step=} Backward")
         start_time = perf_counter()
         outputs.loss.backward()
-        bwd_times.append(perf_counter() - start_time)
+        if step >= args.warmup_steps:
+            bwd_times.append(perf_counter() - start_time)
 
         logger.info(f"{process_idx=} {step=} Optimizer step")
         opt.step()
         opt.zero_grad()
 
         if step >= args.warmup_steps:
-            fwd_speed = input_ids.numel() / np.mean(fwd_times[1:])
-            bwd_speed = input_ids.numel() / np.mean(bwd_times[1:])
+            fwd_speed = input_ids.numel() / np.mean(fwd_times)
+            bwd_speed = input_ids.numel() / np.mean(bwd_times)
             logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
 
     logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")