Browse Source

Add new task

Aleksandr Borzunov 2 years ago
parent
commit
337f139662
1 changed files with 16 additions and 5 deletions
  1. 16 5
      src/petals/cli/benchmark_training.py

+ 16 - 5
src/petals/cli/benchmark_training.py

@@ -8,7 +8,7 @@ import numpy as np
 import torch
 import petals.client.sequential_autograd
 from hivemind.utils.logging import get_logger
-from petals import DistributedBloomForSequenceClassification
+from petals import DistributedBloomForSequenceClassification, DistributedBloomForCausalLM
 from transformers import BloomTokenizerFast
 
 logger = get_logger()
@@ -19,6 +19,7 @@ petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
+    parser.add_argument("--task", type=str, default="cls")
     parser.add_argument("-i", "--initial_peers", type=str, nargs='+',
         default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
     parser.add_argument("--n_processes", type=str, default="1")
@@ -28,6 +29,8 @@ def main():
     parser.add_argument("-b", "--batch_size", type=int, required=True)
     args = parser.parse_args()
 
+    assert args.task in ["cls", "causal_lm"]
+
     if args.n_processes == "n_gpus":
         args.n_processes = torch.cuda.device_count()
     else:
@@ -42,9 +45,14 @@ def main():
 
 def benchmark_training(process_idx, args):
     tokenizer = BloomTokenizerFast.from_pretrained(args.model)
-    model = DistributedBloomForSequenceClassification.from_pretrained(
-        args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune",
-        pre_seq_len=args.pre_seq_len, num_labels=2)
+    if args.task == "cls":
+        model = DistributedBloomForSequenceClassification.from_pretrained(
+            args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune",
+            pre_seq_len=args.pre_seq_len, num_labels=2)
+    elif args.task == "causal_lm":
+        model = DistributedBloomForCausalLM.from_pretrained(
+            args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune",
+            pre_seq_len=args.pre_seq_len)
     opt = torch.optim.Adam(model.parameters())
     logger.info(f"Created model: {process_idx=} {model.device=}")
 
@@ -57,7 +65,10 @@ def benchmark_training(process_idx, args):
 
         logger.info(f"{process_idx=} {step=} Forward")
         start_time = perf_counter()
-        outputs = model(input_ids, labels=labels)
+        if args.task == "cls":
+            outputs = model(input_ids, labels=labels)
+        else:
+            outputs = model(input_ids)
         fwd_times.append(perf_counter() - start_time)
 
         logger.info(f"{process_idx=} {step=} Backward")