Selaa lähdekoodia

Add benchmark scripts (#319)

This PR:

- Adds benchmark scripts for inference, forward pass, and full training step (e.g. used for experiments in our paper).
- Fixes bug with dtypes in `petals.DistributedBloomForSequenceClassification`.
- (minor refactor) Moves `DTYPE_MAP` to `petals.constants` as a useful constant.
Alexander Borzunov 2 vuotta sitten
vanhempi
commit
d126ee3053

+ 69 - 0
benchmarks/benchmark_forward.py

@@ -0,0 +1,69 @@
+#!/usr/bin/env python3
+
+import argparse
+import multiprocessing as mp
+from time import perf_counter
+
+import torch
+from hivemind.utils.logging import get_logger
+
+from petals import AutoDistributedModel
+from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
+
+logger = get_logger()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model", type=str, default="bigscience/bloom")
+    parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
+    parser.add_argument("--torch_dtype", type=str, default="bfloat16")
+    parser.add_argument("--n_processes", type=str, default=1)
+    parser.add_argument("--seq_len", type=int, default=128)
+    parser.add_argument("--n_steps", type=int, default=100)
+    parser.add_argument("--batch_size", type=int, required=True)
+    parser.add_argument("--warmup_steps", type=int, default=1)
+    args = parser.parse_args()
+
+    if args.n_processes == "n_gpus":
+        args.n_processes = torch.cuda.device_count()
+    else:
+        args.n_processes = int(args.n_processes)
+
+    processes = [mp.Process(target=benchmark_forward, args=(i, args)) for i in range(args.n_processes)]
+    for proc in processes:
+        proc.start()
+    for proc in processes:
+        proc.join()
+
+
+@torch.inference_mode()
+def benchmark_forward(process_idx, args):
+    model = AutoDistributedModel.from_pretrained(
+        args.model,
+        initial_peers=args.initial_peers,
+        torch_dtype=DTYPE_MAP[args.torch_dtype],
+    )
+    logger.info(f"Created model: {process_idx=} {model.device=}")
+
+    torch.manual_seed(42)
+    for step in range(args.n_steps):
+        if step == args.warmup_steps:
+            start_time = perf_counter()
+
+        input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
+
+        logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
+        h = model(input_ids)
+        # We don't use model.lm_head
+        logger.info(f"{process_idx=} Fwd end")
+
+        if step >= args.warmup_steps:
+            speed = step / (perf_counter() - start_time) * input_ids.numel()
+            logger.info(f"{process_idx=} {step=} {speed=:.3f}")
+
+    logger.info(f"Final result: {process_idx=} {speed=:.3f}")
+
+
+if __name__ == "__main__":
+    main()

+ 64 - 0
benchmarks/benchmark_inference.py

@@ -0,0 +1,64 @@
+#!/usr/bin/env python3
+
+import argparse
+import multiprocessing as mp
+from time import perf_counter
+
+import torch
+from hivemind.utils.logging import get_logger
+from transformers import AutoTokenizer
+
+from petals import AutoDistributedModelForCausalLM
+from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
+
+logger = get_logger()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model", type=str, default="bigscience/bloom")
+    parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
+    parser.add_argument("--torch_dtype", type=str, default="bfloat16")
+    parser.add_argument("--n_processes", type=str, default=1)
+    parser.add_argument("--seq_len", type=int, default=2048)
+    parser.add_argument("--warmup_steps", type=int, default=1)
+    args = parser.parse_args()
+
+    if args.n_processes == "n_gpus":
+        args.n_processes = torch.cuda.device_count()
+    else:
+        args.n_processes = int(args.n_processes)
+
+    processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)]
+    for proc in processes:
+        proc.start()
+    for proc in processes:
+        proc.join()
+
+
+@torch.inference_mode()
+def benchmark_inference(process_idx, args):
+    tokenizer = AutoTokenizer.from_pretrained(args.model)
+    model = AutoDistributedModelForCausalLM.from_pretrained(
+        args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
+    )
+    logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}")
+
+    result = ""
+    with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
+        for step in range(args.seq_len):
+            if step == args.warmup_steps:
+                start_time = perf_counter()
+
+            outputs = model.generate(max_new_tokens=1, session=sess)
+            result += tokenizer.decode(outputs[0])
+
+            if step >= args.warmup_steps:
+                speed = step / (perf_counter() - start_time)
+                logger.info(f"{process_idx=} {step=} {speed=:.3f}")
+
+    logger.info(f"Final result: {process_idx=} {speed=:.3f}")
+
+
+if __name__ == "__main__":
+    main()

+ 101 - 0
benchmarks/benchmark_training.py

@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+
+import argparse
+import multiprocessing as mp
+from time import perf_counter
+
+import numpy as np
+import torch
+from hivemind.utils.logging import get_logger
+
+from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification
+from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
+
+logger = get_logger()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model", type=str, default="bigscience/bloom")
+    parser.add_argument("--device", type=str, default="cpu")
+    parser.add_argument("--task", type=str, default="cls")
+    parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
+    parser.add_argument("--torch_dtype", type=str, default="bfloat16")
+    parser.add_argument("--n_processes", type=str, default=1)
+    parser.add_argument("--seq_len", type=int, default=128)
+    parser.add_argument("--pre_seq_len", type=int, default=16)
+    parser.add_argument("--n_steps", type=int, default=10)
+    parser.add_argument("--batch_size", type=int, required=True)
+    parser.add_argument("--warmup_steps", type=int, default=1)
+    args = parser.parse_args()
+
+    assert args.task in ["cls", "causal_lm"]
+
+    if args.n_processes == "n_gpus":
+        args.n_processes = torch.cuda.device_count()
+    else:
+        args.n_processes = int(args.n_processes)
+
+    processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)]
+    for proc in processes:
+        proc.start()
+    for proc in processes:
+        proc.join()
+
+
+def benchmark_training(process_idx, args):
+    if args.task == "cls":
+        model = AutoDistributedModelForSequenceClassification.from_pretrained(
+            args.model,
+            initial_peers=args.initial_peers,
+            torch_dtype=DTYPE_MAP[args.torch_dtype],
+            tuning_mode="deep_ptune",
+            pre_seq_len=args.pre_seq_len,
+            num_labels=2,
+        )
+    elif args.task == "causal_lm":
+        model = AutoDistributedModelForCausalLM.from_pretrained(
+            args.model,
+            initial_peers=args.initial_peers,
+            torch_dtype=DTYPE_MAP[args.torch_dtype],
+            tuning_mode="deep_ptune",
+            pre_seq_len=args.pre_seq_len,
+        )
+    model = model.to(args.device)
+    opt = torch.optim.Adam(model.parameters())
+    logger.info(f"Created model: {process_idx=} {model.device=}")
+
+    torch.manual_seed(42)
+    fwd_times = []
+    bwd_times = []
+    for step in range(args.n_steps):
+        input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
+        if args.task == "cls":
+            labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
+        else:
+            labels = input_ids
+
+        logger.info(f"{process_idx=} {step=} Forward")
+        start_time = perf_counter()
+        outputs = model(input_ids, labels=labels)
+        fwd_times.append(perf_counter() - start_time)
+
+        logger.info(f"{process_idx=} {step=} Backward")
+        start_time = perf_counter()
+        outputs.loss.backward()
+        bwd_times.append(perf_counter() - start_time)
+
+        logger.info(f"{process_idx=} {step=} Optimizer step")
+        opt.step()
+        opt.zero_grad()
+
+        if step >= args.warmup_steps:
+            fwd_speed = input_ids.numel() / np.mean(fwd_times[1:])
+            bwd_speed = input_ids.numel() / np.mean(bwd_times[1:])
+            logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
+
+    logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
+
+
+if __name__ == "__main__":
+    main()

+ 2 - 2
src/petals/cli/run_server.py

@@ -6,8 +6,8 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from humanfriendly import parse_size
 
-from petals.constants import PUBLIC_INITIAL_PEERS
-from petals.server.server import DTYPE_MAP, Server
+from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
+from petals.server.server import Server
 from petals.utils.version import validate_version
 
 logger = get_logger(__name__)

+ 4 - 0
src/petals/constants.py

@@ -1,3 +1,5 @@
+import torch
+
 PUBLIC_INITIAL_PEERS = [
     "/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
     "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
@@ -7,3 +9,5 @@ PUBLIC_INITIAL_PEERS = [
 
 # The reachability API is currently used only when connecting to the public swarm
 REACHABILITY_API_URL = "http://health.petals.ml"
+
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 1 - 1
src/petals/models/bloom/model.py

@@ -128,7 +128,7 @@ class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSeq
         self.num_labels = config.num_labels
 
         self.transformer = DistributedBloomModel(config)
-        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
 
         # Initialize weights and apply final processing
         self.post_init()

+ 1 - 3
src/petals/server/from_pretrained.py

@@ -19,6 +19,7 @@ from huggingface_hub import get_hf_file_metadata, hf_hub_url
 from transformers import PretrainedConfig
 from transformers.utils import get_file_from_repo
 
+from petals.constants import DTYPE_MAP
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
@@ -170,6 +171,3 @@ def _load_state_dict_from_file(
         except Exception as e:
             logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
             time.sleep(delay)
-
-
-DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 2 - 2
src/petals/server/server.py

@@ -16,13 +16,13 @@ from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
 
-from petals.constants import PUBLIC_INITIAL_PEERS
+from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
-from petals.server.from_pretrained import DTYPE_MAP, load_pretrained_block
+from petals.server.from_pretrained import load_pretrained_block
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability