benchmark_optimizer.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import multiprocessing as mp
  2. import random
  3. import time
  4. from contextlib import nullcontext
  5. from dataclasses import dataclass
  6. from functools import partial
  7. from typing import Callable
  8. import numpy as np
  9. import torch
  10. import torchvision
  11. from torch import nn as nn
  12. from torch.nn import functional as F
  13. from torch.utils.data import Dataset
  14. import hivemind
  15. from hivemind.optim.experimental.optimizer import Optimizer
  16. from hivemind.utils.crypto import RSAPrivateKey
  17. @dataclass(frozen=True)
  18. class TrainingArguments:
  19. seed: int = 42
  20. prefix: str = "my_exp"
  21. num_peers: int = 8
  22. num_clients: int = 3
  23. target_batch_size: int = 128
  24. reuse_grad_buffers: bool = True
  25. delay_optimizer_step: bool = False
  26. use_amp: bool = True
  27. lr_base: float = 0.1
  28. lr_gamma: int = 0.1
  29. lr_step_size: int = 10
  30. max_epoch: int = 25
  31. batch_size_min: int = 2
  32. batch_size_max: int = 16
  33. batch_time_min: float = 1.0
  34. batch_time_max: float = 4.5
  35. batch_time_std: float = 0.5
  36. matchmaking_time: float = 5.0
  37. max_refresh_period: float = 5.0
  38. averaging_timeout: float = 15.0
  39. winddown_time: float = 5.0
  40. verbose: bool = True
  41. device: str = "cpu"
  42. make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
  43. make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
  44. nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
  45. )
  46. def benchmark_optimizer(args: TrainingArguments):
  47. random.seed(args.seed)
  48. torch.manual_seed(args.seed)
  49. torch.set_num_threads(1)
  50. dht = hivemind.DHT(start=True)
  51. train_dataset = args.make_dataset()
  52. num_features = np.prod(train_dataset.data[0].shape)
  53. num_classes = len(train_dataset.classes)
  54. X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
  55. X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
  56. y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
  57. del train_dataset
  58. def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
  59. model = args.make_model(num_features, num_classes).to(args.device)
  60. assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
  61. optimizer = Optimizer(
  62. prefix=args.prefix,
  63. target_batch_size=args.target_batch_size,
  64. batch_size_per_step=batch_size,
  65. params=model.parameters(),
  66. optimizer=partial(torch.optim.SGD, lr=args.lr_base),
  67. scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
  68. dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
  69. tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
  70. matchmaking_time=args.matchmaking_time,
  71. averaging_timeout=args.averaging_timeout,
  72. reuse_grad_buffers=args.reuse_grad_buffers,
  73. delay_optimizer_step=args.delay_optimizer_step,
  74. client_mode=client_mode,
  75. verbose=verbose,
  76. )
  77. if args.reuse_grad_buffers:
  78. grad_scaler = hivemind.GradScaler()
  79. else:
  80. grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
  81. prev_time = time.perf_counter()
  82. while optimizer.local_epoch < args.max_epoch:
  83. time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
  84. batch = torch.randint(0, len(X_train), (batch_size,))
  85. with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
  86. loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
  87. grad_scaler.scale(loss).backward()
  88. grad_scaler.unscale_(optimizer)
  89. if args.use_amp:
  90. grad_scaler.step(optimizer)
  91. else:
  92. optimizer.step()
  93. grad_scaler.update()
  94. if not args.reuse_grad_buffers:
  95. optimizer.zero_grad()
  96. prev_time = time.perf_counter()
  97. time.sleep(args.winddown_time)
  98. optimizer.shutdown()
  99. peers = []
  100. for index in range(args.num_peers):
  101. batch_size = random.randint(args.batch_size_min, args.batch_size_max)
  102. batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
  103. peers.append(
  104. mp.Process(
  105. target=run_trainer,
  106. name=f"trainer-{index}",
  107. daemon=False,
  108. kwargs=dict(
  109. batch_size=batch_size,
  110. batch_time=batch_time,
  111. client_mode=(index >= args.num_peers - args.num_clients),
  112. verbose=args.verbose and (index == 0),
  113. ),
  114. )
  115. )
  116. try:
  117. for peer in peers[1:]:
  118. peer.start()
  119. peers[0].run()
  120. for peer in peers[1:]:
  121. peer.join()
  122. finally:
  123. for peer in peers[1:]:
  124. peer.kill()