浏览代码

Add benchmark_training.py, set default initial peer

Aleksandr Borzunov 2 年之前
父节点
当前提交
3d0669662b
共有 3 个文件被更改,包括 78 次插入16 次删除
  1. 2 8
      src/petals/cli/benchmark_forward.py
  2. 2 8
      src/petals/cli/benchmark_inference.py
  3. 74 0
      src/petals/cli/benchmark_training.py

+ 2 - 8
src/petals/cli/benchmark_forward.py

@@ -18,20 +18,14 @@ petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
 def main():
 def main():
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
-    parser.add_argument("-i", "--initial_peers", type=str, nargs='+', required=True)
+    parser.add_argument("-i", "--initial_peers", type=str, nargs='+',
+        default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
     parser.add_argument("-p", "--n_processes", type=str, required=True)
     parser.add_argument("-p", "--n_processes", type=str, required=True)
     parser.add_argument("--seq_len", type=int, default=128)
     parser.add_argument("--seq_len", type=int, default=128)
     parser.add_argument("--n_steps", type=int, default=100)
     parser.add_argument("--n_steps", type=int, default=100)
     parser.add_argument("-b", "--batch_size", type=int, required=True)
     parser.add_argument("-b", "--batch_size", type=int, required=True)
     args = parser.parse_args()
     args = parser.parse_args()
 
 
-    if args.initial_peers == ["3090"]:
-        args.initial_peers = ["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"]
-    elif args.initial_peers == ["a100"]:
-        args.initial_peers = ["/ip4/127.0.0.1/tcp/38355/p2p/QmU3wFRRW1XUbByqXqk9sbA3wiYQBp1Lpa32doxt1RvKRv"]
-    else:
-        logger.warning(f"Non-standard initial peers: {args.initial_peers}")
-
     if args.n_processes == "n_gpus":
     if args.n_processes == "n_gpus":
         args.n_processes = torch.cuda.device_count()
         args.n_processes = torch.cuda.device_count()
     else:
     else:

+ 2 - 8
src/petals/cli/benchmark_inference.py

@@ -15,18 +15,12 @@ logger = get_logger()
 def main():
 def main():
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
-    parser.add_argument("-i", "--initial_peers", type=str, nargs='+', required=True)
+    parser.add_argument("-i", "--initial_peers", type=str, nargs='+',
+        default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
     parser.add_argument("-p", "--n_processes", type=str, required=True)
     parser.add_argument("-p", "--n_processes", type=str, required=True)
     parser.add_argument("-l", "--seq_len", type=int, required=True)
     parser.add_argument("-l", "--seq_len", type=int, required=True)
     args = parser.parse_args()
     args = parser.parse_args()
 
 
-    if args.initial_peers == ["3090"]:
-        args.initial_peers = ["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"]
-    elif args.initial_peers == ["a100"]:
-        args.initial_peers = ["/ip4/127.0.0.1/tcp/38355/p2p/QmU3wFRRW1XUbByqXqk9sbA3wiYQBp1Lpa32doxt1RvKRv"]
-    else:
-        logger.warning(f"Non-standard initial peers: {args.initial_peers}")
-
     if args.n_processes == "n_gpus":
     if args.n_processes == "n_gpus":
         args.n_processes = torch.cuda.device_count()
         args.n_processes = torch.cuda.device_count()
     else:
     else:

+ 74 - 0
src/petals/cli/benchmark_training.py

@@ -0,0 +1,74 @@
+#!/usr/bin/env python3
+
+import argparse
+import multiprocessing as mp
+from time import perf_counter
+
+import torch
+import petals.client.sequential_autograd
+from hivemind.utils.logging import get_logger
+from petals import DistributedBloomForSequenceClassification
+from transformers import BloomTokenizerFast
+
+logger = get_logger()
+
+petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
+    parser.add_argument("-i", "--initial_peers", type=str, nargs='+',
+        default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
+    parser.add_argument("-p", "--n_processes", type=str, default="1")
+    parser.add_argument("--seq_len", type=int, default=128)
+    parser.add_argument("--n_steps", type=int, default=100)
+    parser.add_argument("-b", "--batch_size", type=int, required=True)
+    args = parser.parse_args()
+
+    if args.n_processes == "n_gpus":
+        args.n_processes = torch.cuda.device_count()
+    else:
+        args.n_processes = int(args.n_processes)
+
+    processes = [mp.Process(target=benchmark_training, args=(i, args,)) for i in range(args.n_processes)]
+    for proc in processes:
+        proc.start()
+    for proc in processes:
+        proc.join()
+
+
+def benchmark_training(process_idx, args):
+    tokenizer = BloomTokenizerFast.from_pretrained(args.model)
+    model = DistributedBloomForSequenceClassification.from_pretrained(
+        args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune", pre_seq_len=16, num_labels=2)
+    optimizer = torch.optim.Adam(model.parameters())
+    logger.info(f"Created model: {process_idx=} {model.device=}")
+
+    torch.manual_seed(42)
+    for step in range(args.n_steps):
+        input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
+        labels = torch.randint(0, 2, size=[args.batch_size])
+
+        logger.info(f"{process_idx=} {step=} Forward")
+        outputs = model(input_ids, labels=labels)
+        logger.info(f"{process_idx=} {step=} Loss: {outputs.loss=:.2f}")
+
+        logger.info(f"{process_idx=} {step=} Backward")
+        outputs.loss.backward()
+
+        logger.info(f"{process_idx=} {step=} Optimizer step")
+        opt.step()
+        opt.zero_grad()
+
+        if step == 0:
+            start_time = perf_counter()
+        else:
+            speed = step / (perf_counter() - start_time) * input_ids.numel()
+            logger.info(f"{process_idx=} {step=} {speed=:.3f}")
+
+    logger.info(f"Final result: {process_idx=} {speed=:.3f}")
+
+
+if __name__ == "__main__":
+    main()