Aleksandr Borzunov 2 лет назад
Родитель
Сommit
205eb2f2d8
1 измененных файлов с 5 добавлено и 4 удалено
  1. 5 4
      src/petals/cli/benchmark_inference.py

+ 5 - 4
src/petals/cli/benchmark_inference.py

@@ -15,6 +15,7 @@ logger = get_logger()
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
+    parser.add_argument("--initial_peers", type=str, nargs='+', default=["/ip4/185.244.175.92/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
     parser.add_argument("-p", "--n_processes", type=int, required=True)
     parser.add_argument("-l", "--seq_len", type=int, required=True)
     args = parser.parse_args()
@@ -29,7 +30,7 @@ def main():
 @torch.inference_mode()
 def benchmark_inference(process_idx, args):
     tokenizer = BloomTokenizerFast.from_pretrained(args.model)
-    model = DistributedBloomForCausalLM.from_pretrained(args.model)
+    model = DistributedBloomForCausalLM.from_pretrained(args.model, initial_peers=args.initial_peers)
     logger.info(f"Created model: {process_idx=} {model.device=}")
 
     result = ""
@@ -41,10 +42,10 @@ def benchmark_inference(process_idx, args):
             if step == 0:
                 start_time = perf_counter()
             else:
-                average_time = (perf_counter() - start_time) / step
-                logger.info(f"{process_idx=} {step=} {average_time=:.3f}")
+                speed = step / (perf_counter() - start_time)
+                logger.info(f"{process_idx=} {step=} {speed=:.3f}")
 
-    logger.info(f"Final result: {process_idx=} {average_time=:.3f}")
+    logger.info(f"Final result: {process_idx=} {speed=:.3f}")
 
 
 if __name__ == "__main__":