|
@@ -12,14 +12,14 @@ from transformers import BloomTokenizerFast
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
-# petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
|
|
|
+petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
|
|
|
parser.add_argument("-i", "--initial_peers", type=str, nargs='+', required=True)
|
|
|
- parser.add_argument("-p", "--n_processes", type=int, required=True)
|
|
|
+ parser.add_argument("-p", "--n_processes", type=str, required=True)
|
|
|
parser.add_argument("--seq_len", type=int, default=128)
|
|
|
parser.add_argument("--n_steps", type=int, default=100)
|
|
|
parser.add_argument("-b", "--batch_size", type=int, required=True)
|
|
@@ -32,6 +32,11 @@ def main():
|
|
|
else:
|
|
|
logger.warning(f"Non-standard initial peers: {args.initial_peers}")
|
|
|
|
|
|
+ if args.n_processes == "n_gpus":
|
|
|
+ args.n_processes = torch.cuda.device_count()
|
|
|
+ else:
|
|
|
+ args.n_processes = int(args.n_processes)
|
|
|
+
|
|
|
processes = [mp.Process(target=benchmark_forward, args=(i, args,)) for i in range(args.n_processes)]
|
|
|
for proc in processes:
|
|
|
proc.start()
|