|
@@ -15,6 +15,7 @@ logger = get_logger()
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
|
|
|
+ parser.add_argument("--initial_peers", type=str, nargs='+', default=["/ip4/185.244.175.92/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
|
|
|
parser.add_argument("-p", "--n_processes", type=int, required=True)
|
|
|
parser.add_argument("-l", "--seq_len", type=int, required=True)
|
|
|
args = parser.parse_args()
|
|
@@ -29,7 +30,7 @@ def main():
|
|
|
@torch.inference_mode()
|
|
|
def benchmark_inference(process_idx, args):
|
|
|
tokenizer = BloomTokenizerFast.from_pretrained(args.model)
|
|
|
- model = DistributedBloomForCausalLM.from_pretrained(args.model)
|
|
|
+ model = DistributedBloomForCausalLM.from_pretrained(args.model, initial_peers=args.initial_peers)
|
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
|
|
|
|
result = ""
|
|
@@ -41,10 +42,10 @@ def benchmark_inference(process_idx, args):
|
|
|
if step == 0:
|
|
|
start_time = perf_counter()
|
|
|
else:
|
|
|
- average_time = (perf_counter() - start_time) / step
|
|
|
- logger.info(f"{process_idx=} {step=} {average_time=:.3f}")
|
|
|
+ speed = step / (perf_counter() - start_time)
|
|
|
+ logger.info(f"{process_idx=} {step=} {speed=:.3f}")
|
|
|
|
|
|
- logger.info(f"Final result: {process_idx=} {average_time=:.3f}")
|
|
|
+ logger.info(f"Final result: {process_idx=} {speed=:.3f}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|