Jelajahi Sumber

Use dtype float32

Aleksandr Borzunov 2 tahun lalu
induk
melakukan
52c1149751
1 mengubah file dengan 3 tambahan dan 3 penghapusan
  1. 3 3
      src/petals/cli/benchmark_forward.py

+ 3 - 3
src/petals/cli/benchmark_forward.py

@@ -17,8 +17,8 @@ def main():
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
     parser.add_argument("--initial_peers", type=str, nargs='+', default=["/ip4/185.244.175.92/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
     parser.add_argument("-p", "--n_processes", type=int, required=True)
-    parser.add_argument("-l", "--seq_len", type=int, default=128)
-    parser.add_argument("-s", "--n_steps", type=int, default=100)
+    parser.add_argument("--seq_len", type=int, default=128)
+    parser.add_argument("--n_steps", type=int, default=100)
     parser.add_argument("-b", "--batch_size", type=int, required=True)
     args = parser.parse_args()
 
@@ -32,7 +32,7 @@ def main():
 @torch.inference_mode()
 def benchmark_forward(process_idx, args):
     tokenizer = BloomTokenizerFast.from_pretrained(args.model)
-    model = DistributedBloomForCausalLM.from_pretrained(args.model, initial_peers=args.initial_peers)
+    model = DistributedBloomForCausalLM.from_pretrained(args.model, initial_peers=args.initial_peers, torch_dtype=torch.float32)
     logger.info(f"Created model: {process_idx=} {model.device=}")
 
     torch.manual_seed(42)