benchmark_optimizer.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import multiprocessing as mp
  2. import random
  3. import time
  4. from dataclasses import dataclass
  5. from functools import partial
  6. from typing import Callable
  7. import numpy as np
  8. import torch
  9. import torchvision
  10. from torch import nn as nn
  11. from torch.nn import functional as F
  12. from torch.utils.data import Dataset
  13. import hivemind
  14. from hivemind.optim.experimental.optimizer import Optimizer
  15. from hivemind.utils.crypto import RSAPrivateKey
  16. @dataclass(frozen=True)
  17. class TrainingArguments:
  18. seed: int = 42
  19. prefix: str = "my_exp"
  20. num_peers: int = 8
  21. num_clients: int = 3
  22. target_batch_size: int = 128
  23. reuse_grad_buffers: bool = True
  24. lr_base: float = 0.1
  25. lr_gamma: int = 0.1
  26. lr_step_size: int = 10
  27. max_epoch: int = 25
  28. batch_size_min: int = 2
  29. batch_size_max: int = 16
  30. batch_time_min: float = 1.0
  31. batch_time_max: float = 4.5
  32. batch_time_std: float = 0.5
  33. matchmaking_time: float = 5.0
  34. max_refresh_period: float = 5.0
  35. averaging_timeout: float = 15.0
  36. winddown_time: float = 5.0
  37. verbose: bool = True
  38. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  39. make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
  40. make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
  41. nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
  42. )
  43. def _run_training_with_swarm(args: TrainingArguments):
  44. random.seed(args.seed)
  45. torch.manual_seed(args.seed)
  46. torch.set_num_threads(1)
  47. dht = hivemind.DHT(start=True)
  48. train_dataset = args.make_dataset()
  49. num_features = np.prod(train_dataset.data[0].shape)
  50. num_classes = len(train_dataset.classes)
  51. X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
  52. X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
  53. y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
  54. del train_dataset
  55. def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
  56. model = args.make_model(num_features, num_classes).to(args.device)
  57. assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
  58. optimizer = Optimizer(
  59. prefix=args.prefix,
  60. target_batch_size=args.target_batch_size,
  61. params=model.parameters(),
  62. optimizer=partial(torch.optim.SGD, lr=args.lr_base),
  63. scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
  64. dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
  65. tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
  66. matchmaking_time=args.matchmaking_time,
  67. averaging_timeout=args.averaging_timeout,
  68. reuse_grad_buffers=args.reuse_grad_buffers,
  69. client_mode=client_mode,
  70. verbose=verbose,
  71. )
  72. prev_time = time.perf_counter()
  73. while optimizer.local_epoch < args.max_epoch:
  74. time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
  75. batch = torch.randint(0, len(X_train), (batch_size,))
  76. loss = F.cross_entropy(model(X_train[batch]), y_train[batch])
  77. loss.backward()
  78. optimizer.step(batch_size=batch_size)
  79. if not args.reuse_grad_buffers:
  80. optimizer.zero_grad()
  81. prev_time = time.perf_counter()
  82. time.sleep(args.winddown_time)
  83. optimizer.shutdown()
  84. peers = []
  85. for index in range(args.num_peers):
  86. batch_size = random.randint(args.batch_size_min, args.batch_size_max)
  87. batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
  88. peers.append(
  89. mp.Process(
  90. target=run_trainer,
  91. name=f"trainer-{index}",
  92. kwargs=dict(
  93. batch_size=batch_size,
  94. batch_time=batch_time,
  95. client_mode=(index >= args.num_peers - args.num_clients),
  96. verbose=args.verbose and (index == 0),
  97. ),
  98. )
  99. )
  100. for peer in peers[1:]:
  101. peer.start()
  102. peers[0].run()
  103. for peer in peers[1:]:
  104. peer.join()