|
@@ -20,9 +20,10 @@ def main():
|
|
|
parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
|
|
|
parser.add_argument("-i", "--initial_peers", type=str, nargs='+',
|
|
|
default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
|
|
|
- parser.add_argument("-p", "--n_processes", type=str, default="1")
|
|
|
+ parser.add_argument("--n_processes", type=str, default="1")
|
|
|
parser.add_argument("--seq_len", type=int, default=128)
|
|
|
- parser.add_argument("--n_steps", type=int, default=100)
|
|
|
+ parser.add_argument("--pre_seq_len", type=int, default=16)
|
|
|
+ parser.add_argument("--n_steps", type=int, default=10)
|
|
|
parser.add_argument("-b", "--batch_size", type=int, required=True)
|
|
|
args = parser.parse_args()
|
|
|
|
|
@@ -41,8 +42,9 @@ def main():
|
|
|
def benchmark_training(process_idx, args):
|
|
|
tokenizer = BloomTokenizerFast.from_pretrained(args.model)
|
|
|
model = DistributedBloomForSequenceClassification.from_pretrained(
|
|
|
- args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune", pre_seq_len=16, num_labels=2)
|
|
|
- optimizer = torch.optim.Adam(model.parameters())
|
|
|
+ args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune",
|
|
|
+ pre_seq_len=args.pre_seq_len, num_labels=2)
|
|
|
+ opt = torch.optim.Adam(model.parameters())
|
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
|
|
|
|
torch.manual_seed(42)
|