瀏覽代碼

Support -p n_gpus arg

Aleksandr Borzunov 2 年之前
父節點
當前提交
4669a9cd91
共有 2 個文件被更改,包括 13 次插入3 次删除
  1. 7 2
      src/petals/cli/benchmark_forward.py
  2. 6 1
      src/petals/cli/benchmark_inference.py

+ 7 - 2
src/petals/cli/benchmark_forward.py

@@ -12,14 +12,14 @@ from transformers import BloomTokenizerFast
 
 logger = get_logger()
 
-# petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
+petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
 
 
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
     parser.add_argument("-i", "--initial_peers", type=str, nargs='+', required=True)
-    parser.add_argument("-p", "--n_processes", type=int, required=True)
+    parser.add_argument("-p", "--n_processes", type=str, required=True)
     parser.add_argument("--seq_len", type=int, default=128)
     parser.add_argument("--n_steps", type=int, default=100)
     parser.add_argument("-b", "--batch_size", type=int, required=True)
@@ -32,6 +32,11 @@ def main():
     else:
         logger.warning(f"Non-standard initial peers: {args.initial_peers}")
 
+    if args.n_processes == "n_gpus":
+        args.n_processes = torch.cuda.device_count()
+    else:
+        args.n_processes = int(args.n_processes)
+
     processes = [mp.Process(target=benchmark_forward, args=(i, args,)) for i in range(args.n_processes)]
     for proc in processes:
         proc.start()

+ 6 - 1
src/petals/cli/benchmark_inference.py

@@ -16,7 +16,7 @@ def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
     parser.add_argument("-i", "--initial_peers", type=str, nargs='+', required=True)
-    parser.add_argument("-p", "--n_processes", type=int, required=True)
+    parser.add_argument("-p", "--n_processes", type=str, required=True)
     parser.add_argument("-l", "--seq_len", type=int, required=True)
     args = parser.parse_args()
 
@@ -27,6 +27,11 @@ def main():
     else:
         logger.warning(f"Non-standard initial peers: {args.initial_peers}")
 
+    if args.n_processes == "n_gpus":
+        args.n_processes = torch.cuda.device_count()
+    else:
+        args.n_processes = int(args.n_processes)
+
     processes = [mp.Process(target=benchmark_inference, args=(i, args,)) for i in range(args.n_processes)]
     for proc in processes:
         proc.start()