|
@@ -15,8 +15,8 @@ logger = get_logger()
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
|
|
|
- parser.add_argument("-p", "--n_processes", type=int)
|
|
|
- parser.add_argument("-l", "--seq_len", type=int)
|
|
|
+ parser.add_argument("-p", "--n_processes", type=int, required=True)
|
|
|
+ parser.add_argument("-l", "--seq_len", type=int, required=True)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
processes = [mp.Process(target=benchmark_inference, args=(i, args,)) for i in range(args.n_processes)]
|