|
@@ -8,7 +8,7 @@ import numpy as np
|
|
import torch
|
|
import torch
|
|
import petals.client.sequential_autograd
|
|
import petals.client.sequential_autograd
|
|
from hivemind.utils.logging import get_logger
|
|
from hivemind.utils.logging import get_logger
|
|
-from petals import DistributedBloomForSequenceClassification
|
|
|
|
|
|
+from petals import DistributedBloomForSequenceClassification, DistributedBloomForCausalLM
|
|
from transformers import BloomTokenizerFast
|
|
from transformers import BloomTokenizerFast
|
|
|
|
|
|
logger = get_logger()
|
|
logger = get_logger()
|
|
@@ -19,6 +19,7 @@ petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
|
|
def main():
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
|
|
parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
|
|
|
|
+ parser.add_argument("--task", type=str, default="cls")
|
|
parser.add_argument("-i", "--initial_peers", type=str, nargs='+',
|
|
parser.add_argument("-i", "--initial_peers", type=str, nargs='+',
|
|
default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
|
|
default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
|
|
parser.add_argument("--n_processes", type=str, default="1")
|
|
parser.add_argument("--n_processes", type=str, default="1")
|
|
@@ -28,6 +29,8 @@ def main():
|
|
parser.add_argument("-b", "--batch_size", type=int, required=True)
|
|
parser.add_argument("-b", "--batch_size", type=int, required=True)
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
+ assert args.task in ["cls", "causal_lm"]
|
|
|
|
+
|
|
if args.n_processes == "n_gpus":
|
|
if args.n_processes == "n_gpus":
|
|
args.n_processes = torch.cuda.device_count()
|
|
args.n_processes = torch.cuda.device_count()
|
|
else:
|
|
else:
|
|
@@ -42,9 +45,14 @@ def main():
|
|
|
|
|
|
def benchmark_training(process_idx, args):
|
|
def benchmark_training(process_idx, args):
|
|
tokenizer = BloomTokenizerFast.from_pretrained(args.model)
|
|
tokenizer = BloomTokenizerFast.from_pretrained(args.model)
|
|
- model = DistributedBloomForSequenceClassification.from_pretrained(
|
|
|
|
- args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune",
|
|
|
|
- pre_seq_len=args.pre_seq_len, num_labels=2)
|
|
|
|
|
|
+ if args.task == "cls":
|
|
|
|
+ model = DistributedBloomForSequenceClassification.from_pretrained(
|
|
|
|
+ args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune",
|
|
|
|
+ pre_seq_len=args.pre_seq_len, num_labels=2)
|
|
|
|
+ elif args.task == "causal_lm":
|
|
|
|
+ model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
+ args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune",
|
|
|
|
+ pre_seq_len=args.pre_seq_len)
|
|
opt = torch.optim.Adam(model.parameters())
|
|
opt = torch.optim.Adam(model.parameters())
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
|
|
|
|
@@ -57,7 +65,10 @@ def benchmark_training(process_idx, args):
|
|
|
|
|
|
logger.info(f"{process_idx=} {step=} Forward")
|
|
logger.info(f"{process_idx=} {step=} Forward")
|
|
start_time = perf_counter()
|
|
start_time = perf_counter()
|
|
- outputs = model(input_ids, labels=labels)
|
|
|
|
|
|
+ if args.task == "cls":
|
|
|
|
+ outputs = model(input_ids, labels=labels)
|
|
|
|
+ else:
|
|
|
|
+ outputs = model(input_ids)
|
|
fwd_times.append(perf_counter() - start_time)
|
|
fwd_times.append(perf_counter() - start_time)
|
|
|
|
|
|
logger.info(f"{process_idx=} {step=} Backward")
|
|
logger.info(f"{process_idx=} {step=} Backward")
|