test_optimizer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. import ctypes
  2. import multiprocessing as mp
  3. import random
  4. import time
  5. from dataclasses import dataclass
  6. from functools import partial
  7. from typing import Callable
  8. import numpy as np
  9. import pytest
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. import torchvision
  14. from torch.utils.data import Dataset
  15. import hivemind
  16. from hivemind.averaging.control import AveragingStage
  17. from hivemind.optim.experimental.grad_averager import GradientAverager
  18. from hivemind.optim.experimental.optimizer import Optimizer
  19. from hivemind.optim.experimental.progress_tracker import ProgressTracker
  20. from hivemind.optim.experimental.state_averager import TrainingStateAverager
  21. from hivemind.utils.crypto import RSAPrivateKey
  22. @pytest.mark.forked
  23. def test_grad_averager():
  24. dht1 = hivemind.DHT(start=True)
  25. model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  26. averager1 = GradientAverager(
  27. model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
  28. )
  29. dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
  30. model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  31. averager2 = GradientAverager(
  32. model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
  33. )
  34. control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
  35. control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
  36. for i in range(10):
  37. time.sleep(0.1)
  38. if i % 3 == 0:
  39. loss1 = F.mse_loss(model1.w, torch.ones(3))
  40. loss1.backward()
  41. averager1.accumulate_grads_(batch_size=2) # total: 4 times * 2 samples = 8
  42. model1.zero_grad()
  43. else:
  44. loss2 = F.mse_loss(model2.w, -torch.ones(3))
  45. loss2.backward()
  46. averager2.accumulate_grads_(batch_size=3) # total: 6 times * 3 samples = 18
  47. # note: we do not call zero grad here because reuse_grad_buffers=True
  48. assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
  49. peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
  50. assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
  51. ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
  52. assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
  53. assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
  54. ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
  55. assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
  56. averager1.step(control=control1, wait=False)
  57. averager2.step(control=control2, wait=False)
  58. for step in (control1, control2):
  59. step.result() # wait for all-reduce to finish
  60. peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
  61. peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
  62. ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
  63. with averager1.use_averaged_gradients():
  64. assert torch.allclose(model1.w.grad, ref_average)
  65. with averager2.use_averaged_gradients():
  66. assert torch.allclose(model2.w.grad, ref_average)
  67. # after no longer use_averaged_gradients
  68. assert not torch.allclose(model1.w.grad, ref_average)
  69. assert not torch.allclose(model2.w.grad, ref_average)
  70. @pytest.mark.forked
  71. @pytest.mark.parametrize(
  72. "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
  73. [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
  74. )
  75. def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
  76. dht1 = hivemind.DHT(start=True)
  77. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  78. torch.manual_seed(1337)
  79. torch.use_deterministic_algorithms(True)
  80. # note: use_deterministic_algorithms does not affect further tests because this test is forked
  81. model1 = nn.Linear(2, 3)
  82. model2 = nn.Linear(2, 3)
  83. extras1 = (torch.randn(2, 2), -torch.rand(1))
  84. extras2 = (-torch.randn(2, 2), torch.rand(1))
  85. common_kwargs = dict(
  86. optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
  87. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  88. sync_epoch_when_averaging=sync_epoch_when_averaging,
  89. average_opt_statistics=("exp_avg_sq",),
  90. offload_optimizer=offload_optimizer,
  91. reuse_tensors=reuse_tensors,
  92. target_group_size=2,
  93. prefix="my_exp",
  94. )
  95. avgr1 = TrainingStateAverager(
  96. dht=dht1, param_groups=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
  97. )
  98. avgr2 = TrainingStateAverager(
  99. dht=dht2, param_groups=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
  100. )
  101. x = torch.ones(2)
  102. for step in range(20):
  103. F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
  104. avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
  105. F.mse_loss(model2(x), -torch.ones(3)).backward()
  106. avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
  107. assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
  108. assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
  109. assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
  110. assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
  111. stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  112. stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  113. assert not torch.allclose(stats1, stats2)
  114. avgr1.step(increment_epoch=True)
  115. avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  116. avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  117. avgr1.step(wait_for_delayed_update=True)
  118. avgr2.step(wait_for_delayed_update=True)
  119. assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
  120. assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  121. assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  122. assert avgr1.local_epoch == 2
  123. assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
  124. @pytest.mark.forked
  125. def test_load_state_from_peers():
  126. dht1 = hivemind.DHT(start=True)
  127. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  128. model1 = nn.Linear(2, 3)
  129. model2 = nn.Linear(2, 3)
  130. common_kwargs = dict(
  131. optimizer=partial(torch.optim.SGD, lr=0.1),
  132. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  133. target_group_size=2,
  134. prefix="my_exp",
  135. )
  136. avgr1 = TrainingStateAverager(
  137. dht=dht1, param_groups=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
  138. )
  139. avgr2 = TrainingStateAverager(dht=dht2, param_groups=model2.parameters(), start=True, **common_kwargs)
  140. avgr2.local_epoch = 1337
  141. model2.weight.data[...] = 42
  142. time.sleep(0.1)
  143. avgr1.load_state_from_peers()
  144. assert avgr1.local_epoch == 1337
  145. assert torch.all(model1.weight == 42).item()
  146. assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
  147. @pytest.mark.forked
  148. def test_progress_tracker():
  149. # note to a curious reader: no, you cannot reduce the timings without compromising realism or stability
  150. prefix = "my_exp"
  151. target_batch_size = 256
  152. dht_root = hivemind.DHT(start=True)
  153. barrier = mp.Barrier(parties=5)
  154. delayed_start_evt = mp.Event()
  155. finished_evt = mp.Event()
  156. emas = mp.Array(ctypes.c_double, 5)
  157. def run_worker(index: int, batch_size: int, period: float, **kwargs):
  158. dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
  159. tracker = ProgressTracker(
  160. dht,
  161. prefix,
  162. target_batch_size,
  163. start=True,
  164. min_refresh_period=0.1,
  165. default_refresh_period=0.2,
  166. max_refresh_period=0.5,
  167. private_key=RSAPrivateKey(),
  168. **kwargs,
  169. )
  170. barrier.wait()
  171. if index == 4:
  172. delayed_start_evt.wait()
  173. local_epoch = 2 if index == 4 else 0
  174. samples_accumulated = 0
  175. while True:
  176. time.sleep(period)
  177. if finished_evt.is_set():
  178. break
  179. samples_accumulated += batch_size
  180. tracker.report_local_progress(local_epoch, samples_accumulated)
  181. if tracker.ready_to_update_epoch:
  182. with tracker.pause_updates():
  183. local_epoch = tracker.update_epoch(local_epoch + 1)
  184. samples_accumulated = 0
  185. if index == 4 and local_epoch >= 5:
  186. time.sleep(0.5)
  187. break
  188. emas[index] = tracker.performance_ema.samples_per_second
  189. tracker.shutdown()
  190. dht.shutdown()
  191. workers = [
  192. mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
  193. mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
  194. mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
  195. mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
  196. ]
  197. for worker in workers:
  198. worker.start()
  199. tracker = ProgressTracker(
  200. dht_root,
  201. prefix,
  202. target_batch_size,
  203. start=True,
  204. min_refresh_period=0.1,
  205. default_refresh_period=0.2,
  206. max_refresh_period=0.5,
  207. )
  208. barrier.wait()
  209. current_step = 0
  210. last_timestamp = hivemind.get_dht_time()
  211. step_time_deltas = []
  212. while current_step < 6:
  213. time.sleep(0.1)
  214. if tracker.global_progress.epoch > current_step:
  215. time_delta = hivemind.get_dht_time() - last_timestamp
  216. current_step = tracker.global_progress.epoch
  217. if current_step == 2:
  218. delayed_start_evt.set()
  219. last_timestamp = hivemind.get_dht_time()
  220. step_time_deltas.append(time_delta)
  221. finished_evt.set()
  222. for worker in workers:
  223. worker.join()
  224. tracker.shutdown()
  225. dht_root.shutdown()
  226. assert not tracker.is_alive()
  227. mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
  228. for i in (0, 1, 5): # Without the 4th worker (the fastest one)
  229. assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
  230. for i in (2, 3, 4): # With the 4th worker
  231. assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
  232. assert emas[1] < emas[2] < emas[3] < emas[4]
  233. assert tracker.performance_ema.samples_per_second < 1e-9
  234. @dataclass(frozen=True)
  235. class TrainingArguments:
  236. seed: int = 42
  237. prefix: str = "my_exp"
  238. num_peers: int = 8
  239. num_clients: int = 3
  240. target_batch_size: int = 128
  241. reuse_grad_buffers: bool = True
  242. lr_base: float = 0.1
  243. lr_gamma: int = 0.1
  244. lr_step_size: int = 10
  245. max_epoch: int = 25
  246. batch_size_min: int = 2
  247. batch_size_max: int = 16
  248. batch_time_min: float = 1.0
  249. batch_time_max: float = 4.5
  250. batch_time_std: float = 0.5
  251. matchmaking_time: float = 5.0
  252. max_refresh_period: float = 5.0
  253. averaging_timeout: float = 15.0
  254. winddown_time: float = 5.0
  255. verbose: bool = True
  256. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  257. make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
  258. make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
  259. nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
  260. )
  261. def _run_training_with_swarm(args: TrainingArguments):
  262. random.seed(args.seed)
  263. torch.manual_seed(args.seed)
  264. torch.set_num_threads(1)
  265. dht = hivemind.DHT(start=True)
  266. train_dataset = args.make_dataset()
  267. num_features = np.prod(train_dataset.data[0].shape)
  268. num_classes = len(train_dataset.classes)
  269. X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
  270. X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
  271. y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
  272. del train_dataset
  273. def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
  274. model = args.make_model(num_features, num_classes).to(args.device)
  275. assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
  276. optimizer = Optimizer(
  277. prefix=args.prefix,
  278. target_batch_size=args.target_batch_size,
  279. param_groups=model.parameters(),
  280. optimizer=partial(torch.optim.SGD, lr=args.lr_base),
  281. scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
  282. dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
  283. tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
  284. matchmaking_time=args.matchmaking_time,
  285. averaging_timeout=args.averaging_timeout,
  286. reuse_grad_buffers=args.reuse_grad_buffers,
  287. client_mode=client_mode,
  288. verbose=verbose,
  289. )
  290. prev_time = time.perf_counter()
  291. while optimizer.local_epoch < args.max_epoch:
  292. time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
  293. batch = torch.randint(0, len(X_train), (batch_size,))
  294. loss = F.cross_entropy(model(X_train[batch]), y_train[batch])
  295. loss.backward()
  296. optimizer.step(batch_size=batch_size)
  297. if not args.reuse_grad_buffers:
  298. optimizer.zero_grad()
  299. prev_time = time.perf_counter()
  300. time.sleep(args.winddown_time)
  301. optimizer.shutdown()
  302. peers = []
  303. for index in range(args.num_peers):
  304. batch_size = random.randint(args.batch_size_min, args.batch_size_max)
  305. batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
  306. peers.append(
  307. mp.Process(
  308. target=run_trainer,
  309. name=f"trainer-{index}",
  310. kwargs=dict(
  311. batch_size=batch_size,
  312. batch_time=batch_time,
  313. client_mode=(index >= args.num_peers - args.num_clients),
  314. verbose=args.verbose and (index == 0),
  315. ),
  316. )
  317. )
  318. for peer in peers[1:]:
  319. peer.start()
  320. peers[0].run()
  321. for peer in peers[1:]:
  322. peer.join()
  323. @pytest.mark.forked
  324. def test_training_short():
  325. _run_training_with_swarm(
  326. TrainingArguments(
  327. num_peers=3,
  328. num_clients=1,
  329. target_batch_size=32,
  330. max_epoch=3,
  331. lr_step_size=1,
  332. )
  333. )