|
@@ -70,8 +70,8 @@ def benchmark_training(process_idx, args):
|
|
|
opt.zero_grad()
|
|
|
|
|
|
if step >= 1:
|
|
|
- fwd_speed = step / np.mean(fwd_times[1:]) * input_ids.numel()
|
|
|
- bwd_speed = step / np.mean(bwd_times[1:]) * input_ids.numel()
|
|
|
+ fwd_speed = input_ids.numel() / np.mean(fwd_times[1:])
|
|
|
+ bwd_speed = input_ids.numel() / np.mean(bwd_times[1:])
|
|
|
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
|
|
|
|
|
|
logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
|