Bläddra i källkod

Hardcode more initial peers

Aleksandr Borzunov 2 år sedan
förälder
incheckning
cf87264199
2 ändrade filer med 16 tillägg och 2 borttagningar
  1. 8 1
      src/petals/cli/benchmark_forward.py
  2. 8 1
      src/petals/cli/benchmark_inference.py

+ 8 - 1
src/petals/cli/benchmark_forward.py

@@ -15,13 +15,20 @@ logger = get_logger()
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
-    parser.add_argument("--initial_peers", type=str, nargs='+', default=["/ip4/185.244.175.92/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
+    parser.add_argument("-i", "--initial_peers", type=str, nargs='+', required=True)
     parser.add_argument("-p", "--n_processes", type=int, required=True)
     parser.add_argument("--seq_len", type=int, default=128)
     parser.add_argument("--n_steps", type=int, default=100)
     parser.add_argument("-b", "--batch_size", type=int, required=True)
     args = parser.parse_args()
 
+    if args.initial_peers == ["3090"]:
+        args.initial_peers = ["/ip4/185.244.175.92/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"]
+    elif args.initial_peers == ["a100"]:
+        args.initial_peers = ["/ip4/127.0.0.1/tcp/38355/p2p/QmU3wFRRW1XUbByqXqk9sbA3wiYQBp1Lpa32doxt1RvKRv"]
+    else:
+        logger.warning(f"Non-standard initial peers: {args.initial_peers}")
+
     processes = [mp.Process(target=benchmark_forward, args=(i, args,)) for i in range(args.n_processes)]
     for proc in processes:
         proc.start()

+ 8 - 1
src/petals/cli/benchmark_inference.py

@@ -15,11 +15,18 @@ logger = get_logger()
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
-    parser.add_argument("--initial_peers", type=str, nargs='+', default=["/ip4/185.244.175.92/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
+    parser.add_argument("-i", "--initial_peers", type=str, nargs='+', required=True)
     parser.add_argument("-p", "--n_processes", type=int, required=True)
     parser.add_argument("-l", "--seq_len", type=int, required=True)
     args = parser.parse_args()
 
+    if args.initial_peers == ["3090"]:
+        args.initial_peers = ["/ip4/185.244.175.92/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"]
+    elif args.initial_peers == ["a100"]:
+        args.initial_peers = ["/ip4/127.0.0.1/tcp/38355/p2p/QmU3wFRRW1XUbByqXqk9sbA3wiYQBp1Lpa32doxt1RvKRv"]
+    else:
+        logger.warning(f"Non-standard initial peers: {args.initial_peers}")
+
     processes = [mp.Process(target=benchmark_inference, args=(i, args,)) for i in range(args.n_processes)]
     for proc in processes:
         proc.start()