test_optimizer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. import ctypes
  2. import multiprocessing as mp
  3. import time
  4. from functools import partial
  5. import numpy as np
  6. import pytest
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import hivemind
  11. from hivemind.averaging.control import AveragingStage
  12. from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
  13. from hivemind.optim.optimizer import Optimizer
  14. from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
  15. from hivemind.optim.progress_tracker import ProgressTracker
  16. from hivemind.optim.state_averager import ZERO_GRAD_SET_TO_NONE_DEFAULT, TrainingStateAverager
  17. from hivemind.utils.crypto import RSAPrivateKey
  18. @pytest.mark.forked
  19. @pytest.mark.parametrize(
  20. "grad_averager_factory",
  21. [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
  22. )
  23. def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
  24. parameter_shape = (5, 5)
  25. dht1 = hivemind.DHT(start=True)
  26. model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
  27. averager1 = grad_averager_factory(
  28. model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
  29. )
  30. dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
  31. model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
  32. averager2 = grad_averager_factory(
  33. model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
  34. )
  35. control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
  36. control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
  37. for i in range(10):
  38. time.sleep(0.1)
  39. if i % 3 == 0:
  40. loss1 = F.mse_loss(model1.w, torch.ones(parameter_shape))
  41. loss1.backward()
  42. averager1.accumulate_grads_(batch_size=2) # total: 4 times * 2 samples = 8
  43. model1.zero_grad()
  44. else:
  45. loss2 = F.mse_loss(model2.w, -torch.ones(parameter_shape))
  46. loss2.backward()
  47. averager2.accumulate_grads_(batch_size=3) # total: 6 times * 3 samples = 18
  48. # note: we do not call zero grad here because reuse_grad_buffers=True
  49. assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
  50. peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
  51. assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
  52. ref_grads1 = torch.full(parameter_shape, -2 / np.prod(parameter_shape) * averager1.local_times_accumulated)
  53. assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
  54. assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
  55. ref_grads2 = torch.full(parameter_shape, 2 / np.prod(parameter_shape) * averager2.local_times_accumulated)
  56. assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
  57. averager1.step(control=control1, wait=False)
  58. averager2.step(control=control2, wait=False)
  59. for step in (control1, control2):
  60. step.result() # wait for all-reduce to finish
  61. peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
  62. peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
  63. ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
  64. with averager1.use_averaged_gradients():
  65. assert torch.allclose(model1.w.grad, ref_average)
  66. with averager2.use_averaged_gradients():
  67. assert torch.allclose(model2.w.grad, ref_average)
  68. # after no longer use_averaged_gradients
  69. if ZERO_GRAD_SET_TO_NONE_DEFAULT: # averager1 has reuse_grad_buffers=False
  70. assert model1.w.grad is None
  71. else:
  72. assert not torch.allclose(model1.w.grad, ref_average)
  73. assert not torch.allclose(model2.w.grad, ref_average) # averager2 has reuse_grad_buffers=True
  74. @pytest.mark.forked
  75. @pytest.mark.parametrize(
  76. "grad_averager_factory",
  77. [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
  78. )
  79. def test_grad_averager_wrong_shape(grad_averager_factory: GradientAveragerFactory):
  80. parameter_shape = (5, 5)
  81. model = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
  82. dht = hivemind.DHT(start=True)
  83. with pytest.raises(ValueError):
  84. grad_averager_factory(
  85. model.parameters(),
  86. dht=dht,
  87. prefix="test_fail",
  88. target_group_size=2,
  89. reuse_grad_buffers=False,
  90. start=True,
  91. averaged_grads=[torch.zeros(parameter_shape + (1,))],
  92. )
  93. @pytest.mark.forked
  94. @pytest.mark.parametrize(
  95. "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
  96. [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
  97. )
  98. def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
  99. dht1 = hivemind.DHT(start=True)
  100. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  101. torch.manual_seed(1337)
  102. torch.use_deterministic_algorithms(True)
  103. # note: use_deterministic_algorithms does not affect further tests because this test is forked
  104. model1 = nn.Linear(2, 3)
  105. model2 = nn.Linear(2, 3)
  106. extras1 = (torch.randn(2, 2), -torch.rand(1))
  107. extras2 = (-torch.randn(2, 2), torch.rand(1))
  108. common_kwargs = dict(
  109. optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
  110. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  111. sync_epoch_when_averaging=sync_epoch_when_averaging,
  112. average_opt_statistics=("exp_avg_sq",),
  113. offload_optimizer=offload_optimizer,
  114. reuse_tensors=reuse_tensors,
  115. target_group_size=2,
  116. prefix="my_exp",
  117. )
  118. avgr1 = TrainingStateAverager(
  119. dht=dht1, params=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
  120. )
  121. avgr2 = TrainingStateAverager(
  122. dht=dht2, params=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
  123. )
  124. x = torch.ones(2)
  125. for step in range(20):
  126. F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
  127. avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
  128. F.mse_loss(model2(x), -torch.ones(3)).backward()
  129. avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
  130. if ZERO_GRAD_SET_TO_NONE_DEFAULT:
  131. assert model1.weight.grad is None and model2.weight.grad is None, ".zero_grad() wasn't called"
  132. else:
  133. assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), ".zero_grad() wasn't called"
  134. assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
  135. assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
  136. assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
  137. stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  138. stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  139. assert not torch.allclose(stats1, stats2)
  140. avgr1.step(increment_epoch=True)
  141. avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  142. avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  143. avgr1.step(wait_for_delayed_updates=True)
  144. avgr2.step(wait_for_delayed_updates=True)
  145. assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
  146. assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  147. assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  148. assert avgr1.local_epoch == 2
  149. assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
  150. @pytest.mark.forked
  151. def test_load_state_from_peers():
  152. dht1 = hivemind.DHT(start=True)
  153. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  154. model1 = nn.Linear(2, 3)
  155. model2 = nn.Linear(2, 3)
  156. common_kwargs = dict(
  157. optimizer=partial(torch.optim.SGD, lr=0.1),
  158. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  159. target_group_size=2,
  160. prefix="my_exp",
  161. )
  162. avgr1 = TrainingStateAverager(
  163. dht=dht1,
  164. params=model1.parameters(),
  165. allow_state_sharing=False,
  166. start=True,
  167. **common_kwargs,
  168. )
  169. avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
  170. avgr2.local_epoch = 1337
  171. model2.weight.data[...] = 42
  172. time.sleep(0.1)
  173. avgr1.load_state_from_peers()
  174. assert avgr1.local_epoch == 1337
  175. assert torch.all(model1.weight == 42).item()
  176. assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
  177. @pytest.mark.forked
  178. def test_progress_tracker():
  179. # note to a curious reader: no, you cannot reduce the timings without compromising realism or stability
  180. prefix = "my_exp"
  181. target_batch_size = 256
  182. dht_root = hivemind.DHT(start=True)
  183. barrier = mp.Barrier(parties=5)
  184. delayed_start_evt = mp.Event()
  185. finished_evt = mp.Event()
  186. emas = mp.Array(ctypes.c_double, 5)
  187. def run_worker(index: int, batch_size: int, period: float, **kwargs):
  188. dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
  189. tracker = ProgressTracker(
  190. dht,
  191. prefix,
  192. target_batch_size,
  193. start=True,
  194. min_refresh_period=0.1,
  195. default_refresh_period=0.2,
  196. max_refresh_period=0.5,
  197. private_key=RSAPrivateKey(),
  198. **kwargs,
  199. )
  200. barrier.wait()
  201. if index == 4:
  202. delayed_start_evt.wait()
  203. local_epoch = 2 if index == 4 else 0
  204. samples_accumulated = 0
  205. while True:
  206. time.sleep(period)
  207. if finished_evt.is_set():
  208. break
  209. samples_accumulated += batch_size
  210. tracker.report_local_progress(local_epoch, samples_accumulated)
  211. if tracker.ready_to_update_epoch:
  212. if index == 4 and local_epoch >= 4:
  213. time.sleep(0.5)
  214. break
  215. with tracker.pause_updates():
  216. local_epoch = tracker.update_epoch(local_epoch + 1)
  217. samples_accumulated = 0
  218. emas[index] = tracker.performance_ema.samples_per_second
  219. tracker.shutdown()
  220. dht.shutdown()
  221. workers = [
  222. mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
  223. mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
  224. mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
  225. mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
  226. ]
  227. for worker in workers:
  228. worker.start()
  229. tracker = ProgressTracker(
  230. dht_root,
  231. prefix,
  232. target_batch_size,
  233. start=True,
  234. min_refresh_period=0.1,
  235. default_refresh_period=0.2,
  236. max_refresh_period=0.5,
  237. )
  238. barrier.wait()
  239. local_epoch = 0
  240. last_timestamp = hivemind.get_dht_time()
  241. step_time_deltas = []
  242. while local_epoch < 6:
  243. time.sleep(0.1)
  244. if tracker.ready_to_update_epoch:
  245. with tracker.pause_updates():
  246. local_epoch = tracker.update_epoch(local_epoch + 1)
  247. time_delta = hivemind.get_dht_time() - last_timestamp
  248. if local_epoch == 2:
  249. delayed_start_evt.set()
  250. last_timestamp = hivemind.get_dht_time()
  251. step_time_deltas.append(time_delta)
  252. finished_evt.set()
  253. for worker in workers:
  254. worker.join()
  255. tracker.shutdown()
  256. dht_root.shutdown()
  257. assert not tracker.is_alive()
  258. mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
  259. for i in (0, 1, 5): # Without the 4th worker (the fastest one)
  260. assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
  261. for i in (2, 3, 4): # With the 4th worker
  262. assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
  263. assert emas[1] < emas[2] < emas[3] < emas[4]
  264. assert tracker.performance_ema.samples_per_second < 1e-9
  265. @pytest.mark.forked
  266. @pytest.mark.parametrize(
  267. "use_local_updates, delay_state_averaging, delay_optimizer_step, delay_grad_averaging, reuse_grad_buffers",
  268. # fmt: off
  269. [
  270. (False, False, False, False, False),
  271. (False, True, False, False, False),
  272. (False, True, True, True, False),
  273. (False, False, False, False, True),
  274. (False, True, True, True, True),
  275. (False, True, True, False, True),
  276. (True, False, False, False, False),
  277. (True, True, False, False, False,),
  278. ],
  279. # fmt: on
  280. )
  281. def test_optimizer(
  282. use_local_updates: bool,
  283. delay_state_averaging: bool,
  284. delay_optimizer_step: bool,
  285. delay_grad_averaging: bool,
  286. reuse_grad_buffers: bool,
  287. ):
  288. _test_optimizer(
  289. use_local_updates=use_local_updates,
  290. delay_state_averaging=delay_state_averaging,
  291. delay_grad_averaging=delay_grad_averaging,
  292. delay_optimizer_step=delay_optimizer_step,
  293. reuse_grad_buffers=reuse_grad_buffers,
  294. )
  295. def _test_optimizer(
  296. num_peers: int = 1,
  297. num_clients: int = 0,
  298. target_batch_size: int = 32,
  299. total_epochs: int = 3,
  300. use_local_updates: bool = False,
  301. reuse_grad_buffers: bool = True,
  302. delay_state_averaging: bool = True,
  303. delay_grad_averaging: bool = True,
  304. delay_optimizer_step: bool = True,
  305. average_state_every: int = 1,
  306. ):
  307. dht = hivemind.DHT(start=True)
  308. features = torch.randn(100, 5)
  309. targets = features @ torch.randn(5, 1)
  310. optimizer = None
  311. total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
  312. def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
  313. nonlocal optimizer
  314. model = nn.Linear(5, 1)
  315. assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
  316. optimizer = Optimizer(
  317. run_id="test_run",
  318. target_batch_size=target_batch_size,
  319. batch_size_per_step=batch_size,
  320. params=model.parameters(),
  321. optimizer=partial(torch.optim.SGD, lr=0.1),
  322. scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
  323. dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
  324. tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
  325. averager_opts=dict(request_timeout=0.5),
  326. use_local_updates=use_local_updates,
  327. matchmaking_time=1.0,
  328. averaging_timeout=5.0,
  329. reuse_grad_buffers=reuse_grad_buffers,
  330. delay_state_averaging=delay_state_averaging,
  331. delay_grad_averaging=delay_grad_averaging,
  332. delay_optimizer_step=delay_optimizer_step,
  333. average_state_every=average_state_every,
  334. client_mode=client_mode,
  335. verbose=False,
  336. )
  337. optimizer.load_state_from_peers()
  338. prev_time = time.perf_counter()
  339. while optimizer.local_epoch < total_epochs:
  340. time.sleep(max(0.0, prev_time + batch_time - time.perf_counter()))
  341. batch = torch.randint(0, len(features), (batch_size,))
  342. loss = F.mse_loss(model(features[batch]), targets[batch])
  343. loss.backward()
  344. optimizer.step()
  345. total_samples_accumulated.value += batch_size
  346. if not reuse_grad_buffers:
  347. optimizer.zero_grad()
  348. prev_time = time.perf_counter()
  349. time.sleep(1.0)
  350. optimizer.shutdown()
  351. return optimizer
  352. peers = []
  353. for index in range(num_peers):
  354. peers.append(
  355. mp.Process(
  356. target=run_trainer,
  357. name=f"trainer-{index}",
  358. kwargs=dict(
  359. batch_size=4 + index,
  360. batch_time=0.3 + 0.2 * index,
  361. client_mode=(index >= num_peers - num_clients),
  362. ),
  363. )
  364. )
  365. for peer in peers[1:]:
  366. peer.start()
  367. peers[0].run()
  368. for peer in peers[1:]:
  369. peer.join()
  370. assert isinstance(optimizer, Optimizer)
  371. assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
  372. expected_samples_accumulated = target_batch_size * total_epochs
  373. assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
  374. assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
  375. assert not optimizer.state_averager.is_alive()
  376. assert not optimizer.tracker.is_alive()
  377. if not use_local_updates:
  378. assert not optimizer.grad_averager.is_alive()
  379. else:
  380. assert optimizer.grad_averager is None
  381. assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()