Aleksandr Borzunov 2 tahun lalu
induk
melakukan
78dedd5e58
1 mengubah file dengan 6 tambahan dan 4 penghapusan
  1. 6 4
      src/petals/cli/benchmark_training.py

+ 6 - 4
src/petals/cli/benchmark_training.py

@@ -20,9 +20,10 @@ def main():
     parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
     parser.add_argument("-i", "--initial_peers", type=str, nargs='+',
         default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
-    parser.add_argument("-p", "--n_processes", type=str, default="1")
+    parser.add_argument("--n_processes", type=str, default="1")
     parser.add_argument("--seq_len", type=int, default=128)
-    parser.add_argument("--n_steps", type=int, default=100)
+    parser.add_argument("--pre_seq_len", type=int, default=16)
+    parser.add_argument("--n_steps", type=int, default=10)
     parser.add_argument("-b", "--batch_size", type=int, required=True)
     args = parser.parse_args()
 
@@ -41,8 +42,9 @@ def main():
 def benchmark_training(process_idx, args):
     tokenizer = BloomTokenizerFast.from_pretrained(args.model)
     model = DistributedBloomForSequenceClassification.from_pretrained(
-        args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune", pre_seq_len=16, num_labels=2)
-    optimizer = torch.optim.Adam(model.parameters())
+        args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune",
+        pre_seq_len=args.pre_seq_len, num_labels=2)
+    opt = torch.optim.Adam(model.parameters())
     logger.info(f"Created model: {process_idx=} {model.device=}")
 
     torch.manual_seed(42)