|
@@ -4,6 +4,7 @@ import argparse
|
|
|
import multiprocessing as mp
|
|
|
from time import perf_counter
|
|
|
|
|
|
+import numpy as np
|
|
|
import torch
|
|
|
import petals.client.sequential_autograd
|
|
|
from hivemind.utils.logging import get_logger
|
|
@@ -48,27 +49,32 @@ def benchmark_training(process_idx, args):
|
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
|
|
|
|
torch.manual_seed(42)
|
|
|
+ fwd_times = []
|
|
|
+ bwd_times = []
|
|
|
for step in range(args.n_steps):
|
|
|
input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
|
|
|
labels = torch.randint(0, 2, size=[args.batch_size])
|
|
|
|
|
|
logger.info(f"{process_idx=} {step=} Forward")
|
|
|
+ start_time = perf_counter()
|
|
|
outputs = model(input_ids, labels=labels)
|
|
|
+ fwd_times.append(perf_counter() - start_time)
|
|
|
|
|
|
logger.info(f"{process_idx=} {step=} Backward")
|
|
|
+ start_time = perf_counter()
|
|
|
outputs.loss.backward()
|
|
|
+ bwd_times.append(perf_counter() - start_time)
|
|
|
|
|
|
logger.info(f"{process_idx=} {step=} Optimizer step")
|
|
|
opt.step()
|
|
|
opt.zero_grad()
|
|
|
|
|
|
- if step == 0:
|
|
|
- start_time = perf_counter()
|
|
|
- else:
|
|
|
- speed = step / (perf_counter() - start_time) * input_ids.numel()
|
|
|
- logger.info(f"{process_idx=} {step=} {speed=:.3f}")
|
|
|
+ if step >= 1:
|
|
|
+ fwd_speed = step / np.mean(fwd_times[1:]) * input_ids.numel()
|
|
|
+ bwd_speed = step / np.mean(bwd_times[1:]) * input_ids.numel()
|
|
|
+ logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
|
|
|
|
|
|
- logger.info(f"Final result: {process_idx=} {speed=:.3f}")
|
|
|
+ logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|