test_optimizer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import ctypes
  2. import multiprocessing as mp
  3. import time
  4. from functools import partial
  5. from typing import Optional
  6. import numpy as np
  7. import pytest
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import hivemind
  12. from hivemind.averaging.control import AveragingStage
  13. from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
  14. from hivemind.optim.optimizer import Optimizer
  15. from hivemind.optim.power_ef_averager import PowerEFGradientAverager
  16. from hivemind.optim.progress_tracker import ProgressTracker
  17. from hivemind.optim.state_averager import TrainingStateAverager
  18. from hivemind.utils.crypto import RSAPrivateKey
  19. @pytest.mark.forked
  20. def test_grad_averager():
  21. dht1 = hivemind.DHT(start=True)
  22. model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  23. averager1 = GradientAverager(
  24. model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
  25. )
  26. dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
  27. model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  28. averager2 = GradientAverager(
  29. model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
  30. )
  31. control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
  32. control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
  33. for i in range(10):
  34. time.sleep(0.1)
  35. if i % 3 == 0:
  36. loss1 = F.mse_loss(model1.w, torch.ones(3))
  37. loss1.backward()
  38. averager1.accumulate_grads_(batch_size=2) # total: 4 times * 2 samples = 8
  39. model1.zero_grad()
  40. else:
  41. loss2 = F.mse_loss(model2.w, -torch.ones(3))
  42. loss2.backward()
  43. averager2.accumulate_grads_(batch_size=3) # total: 6 times * 3 samples = 18
  44. # note: we do not call zero grad here because reuse_grad_buffers=True
  45. assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
  46. peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
  47. assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
  48. ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
  49. assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
  50. assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
  51. ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
  52. assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
  53. averager1.step(control=control1, wait=False)
  54. averager2.step(control=control2, wait=False)
  55. for step in (control1, control2):
  56. step.result() # wait for all-reduce to finish
  57. peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
  58. peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
  59. ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
  60. with averager1.use_averaged_gradients():
  61. assert torch.allclose(model1.w.grad, ref_average)
  62. with averager2.use_averaged_gradients():
  63. assert torch.allclose(model2.w.grad, ref_average)
  64. # after no longer use_averaged_gradients
  65. assert not torch.allclose(model1.w.grad, ref_average)
  66. assert not torch.allclose(model2.w.grad, ref_average)
  67. @pytest.mark.forked
  68. @pytest.mark.parametrize(
  69. "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
  70. [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
  71. )
  72. def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
  73. dht1 = hivemind.DHT(start=True)
  74. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  75. torch.manual_seed(1337)
  76. torch.use_deterministic_algorithms(True)
  77. # note: use_deterministic_algorithms does not affect further tests because this test is forked
  78. model1 = nn.Linear(2, 3)
  79. model2 = nn.Linear(2, 3)
  80. extras1 = (torch.randn(2, 2), -torch.rand(1))
  81. extras2 = (-torch.randn(2, 2), torch.rand(1))
  82. common_kwargs = dict(
  83. optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
  84. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  85. sync_epoch_when_averaging=sync_epoch_when_averaging,
  86. average_opt_statistics=("exp_avg_sq",),
  87. offload_optimizer=offload_optimizer,
  88. reuse_tensors=reuse_tensors,
  89. target_group_size=2,
  90. prefix="my_exp",
  91. )
  92. avgr1 = TrainingStateAverager(
  93. dht=dht1, params=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
  94. )
  95. avgr2 = TrainingStateAverager(
  96. dht=dht2, params=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
  97. )
  98. x = torch.ones(2)
  99. for step in range(20):
  100. F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
  101. avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
  102. F.mse_loss(model2(x), -torch.ones(3)).backward()
  103. avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
  104. assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
  105. assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
  106. assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
  107. assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
  108. stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  109. stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  110. assert not torch.allclose(stats1, stats2)
  111. avgr1.step(increment_epoch=True)
  112. avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  113. avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  114. avgr1.step(wait_for_delayed_updates=True)
  115. avgr2.step(wait_for_delayed_updates=True)
  116. assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
  117. assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  118. assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  119. assert avgr1.local_epoch == 2
  120. assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
  121. @pytest.mark.forked
  122. @pytest.mark.parametrize("dpu", [True, False])
  123. def test_load_state_from_peers(dpu: bool):
  124. dht1 = hivemind.DHT(start=True)
  125. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  126. model1 = nn.Linear(2, 3)
  127. model2 = nn.Linear(2, 3)
  128. extras1 = (torch.randn(2, 2), -torch.rand(1))
  129. extras2 = (-torch.randn(2, 2), torch.rand(1))
  130. common_kwargs = dict(
  131. optimizer=partial(torch.optim.SGD, lr=0.1),
  132. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  133. offload_optimizer=dpu,
  134. reuse_tensors=dpu,
  135. target_group_size=2,
  136. prefix="my_exp",
  137. )
  138. avgr1 = TrainingStateAverager(
  139. dht=dht1,
  140. params=model1.parameters(),
  141. allow_state_sharing=False,
  142. start=True,
  143. extra_tensors=extras1,
  144. **common_kwargs,
  145. )
  146. avgr2 = TrainingStateAverager(
  147. dht=dht2, params=model2.parameters(), start=True, extra_tensors=extras2, **common_kwargs
  148. )
  149. avgr2.local_epoch = 1337
  150. model2.weight.data[...] = 42
  151. extras2[0][:] = 9999
  152. time.sleep(0.1)
  153. avgr1.load_state_from_peers()
  154. assert avgr1.local_epoch == 1337
  155. assert torch.all(model1.weight == 42).item()
  156. assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
  157. assert torch.all(extras1[0] == extras2[0]).item() and torch.all(extras1[0] == extras2[0]).item()
  158. assert torch.all(extras1[0] == 9999).item()
  159. @pytest.mark.forked
  160. def test_progress_tracker():
  161. # note to a curious reader: no, you cannot reduce the timings without compromising realism or stability
  162. prefix = "my_exp"
  163. target_batch_size = 256
  164. dht_root = hivemind.DHT(start=True)
  165. barrier = mp.Barrier(parties=5)
  166. delayed_start_evt = mp.Event()
  167. finished_evt = mp.Event()
  168. emas = mp.Array(ctypes.c_double, 5)
  169. def run_worker(index: int, batch_size: int, period: float, **kwargs):
  170. dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
  171. tracker = ProgressTracker(
  172. dht,
  173. prefix,
  174. target_batch_size,
  175. start=True,
  176. min_refresh_period=0.1,
  177. default_refresh_period=0.2,
  178. max_refresh_period=0.5,
  179. private_key=RSAPrivateKey(),
  180. **kwargs,
  181. )
  182. barrier.wait()
  183. if index == 4:
  184. delayed_start_evt.wait()
  185. local_epoch = 2 if index == 4 else 0
  186. samples_accumulated = 0
  187. while True:
  188. time.sleep(period)
  189. if finished_evt.is_set():
  190. break
  191. samples_accumulated += batch_size
  192. tracker.report_local_progress(local_epoch, samples_accumulated)
  193. if tracker.ready_to_update_epoch:
  194. if index == 4 and local_epoch >= 4:
  195. time.sleep(0.5)
  196. break
  197. with tracker.pause_updates():
  198. local_epoch = tracker.update_epoch(local_epoch + 1)
  199. samples_accumulated = 0
  200. emas[index] = tracker.performance_ema.samples_per_second
  201. tracker.shutdown()
  202. dht.shutdown()
  203. workers = [
  204. mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
  205. mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
  206. mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
  207. mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
  208. ]
  209. for worker in workers:
  210. worker.start()
  211. tracker = ProgressTracker(
  212. dht_root,
  213. prefix,
  214. target_batch_size,
  215. start=True,
  216. min_refresh_period=0.1,
  217. default_refresh_period=0.2,
  218. max_refresh_period=0.5,
  219. )
  220. barrier.wait()
  221. local_epoch = 0
  222. last_timestamp = hivemind.get_dht_time()
  223. step_time_deltas = []
  224. while local_epoch < 6:
  225. time.sleep(0.1)
  226. if tracker.ready_to_update_epoch:
  227. with tracker.pause_updates():
  228. local_epoch = tracker.update_epoch(local_epoch + 1)
  229. time_delta = hivemind.get_dht_time() - last_timestamp
  230. if local_epoch == 2:
  231. delayed_start_evt.set()
  232. last_timestamp = hivemind.get_dht_time()
  233. step_time_deltas.append(time_delta)
  234. finished_evt.set()
  235. for worker in workers:
  236. worker.join()
  237. tracker.shutdown()
  238. dht_root.shutdown()
  239. assert not tracker.is_alive()
  240. mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
  241. for i in (0, 1, 5): # Without the 4th worker (the fastest one)
  242. assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
  243. for i in (2, 3, 4): # With the 4th worker
  244. assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
  245. assert emas[1] < emas[2] < emas[3] < emas[4]
  246. assert tracker.performance_ema.samples_per_second < 1e-9
  247. @pytest.mark.forked
  248. @pytest.mark.parametrize(
  249. "grad_averager",
  250. [GradientAverager.get_factory(), PowerEFGradientAverager.get_factory(averager_rank=1)],
  251. )
  252. def test_optimizer(
  253. grad_averager: GradientAveragerFactory,
  254. num_peers: int = 1,
  255. num_clients: int = 0,
  256. target_batch_size: int = 32,
  257. total_epochs: int = 3,
  258. reuse_grad_buffers: bool = True,
  259. delay_grad_averaging: bool = True,
  260. delay_optimizer_step: bool = True,
  261. average_state_every: int = 1,
  262. ):
  263. dht = hivemind.DHT(start=True)
  264. features = torch.randn(100, 5)
  265. targets = features @ torch.randn(5, 1)
  266. optimizer = None
  267. total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
  268. def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
  269. nonlocal optimizer
  270. model = nn.Sequential(
  271. nn.Linear(5, 5),
  272. nn.ReLU(),
  273. nn.Linear(5, 1),
  274. )
  275. assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
  276. optimizer = Optimizer(
  277. run_id="test_run",
  278. target_batch_size=target_batch_size,
  279. batch_size_per_step=batch_size,
  280. params=model.parameters(),
  281. optimizer=partial(torch.optim.SGD, lr=0.1),
  282. scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
  283. dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
  284. tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
  285. averager_opts=dict(request_timeout=0.5),
  286. matchmaking_time=1.0,
  287. averaging_timeout=5.0,
  288. reuse_grad_buffers=reuse_grad_buffers,
  289. delay_grad_averaging=delay_grad_averaging,
  290. delay_optimizer_step=delay_optimizer_step,
  291. average_state_every=average_state_every,
  292. client_mode=client_mode,
  293. grad_averager=grad_averager,
  294. verbose=False,
  295. )
  296. optimizer.load_state_from_peers()
  297. prev_time = time.perf_counter()
  298. while optimizer.local_epoch < total_epochs:
  299. time.sleep(max(0.0, prev_time + batch_time - time.perf_counter()))
  300. batch = torch.randint(0, len(features), (batch_size,))
  301. loss = F.mse_loss(model(features[batch]), targets[batch])
  302. loss.backward()
  303. optimizer.step()
  304. total_samples_accumulated.value += batch_size
  305. if not reuse_grad_buffers:
  306. optimizer.zero_grad()
  307. prev_time = time.perf_counter()
  308. time.sleep(1.0)
  309. optimizer.shutdown()
  310. return optimizer
  311. peers = []
  312. for index in range(num_peers):
  313. peers.append(
  314. mp.Process(
  315. target=run_trainer,
  316. name=f"trainer-{index}",
  317. kwargs=dict(
  318. batch_size=4 + index,
  319. batch_time=0.3 + 0.2 * index,
  320. client_mode=(index >= num_peers - num_clients),
  321. ),
  322. )
  323. )
  324. for peer in peers[1:]:
  325. peer.start()
  326. peers[0].run()
  327. for peer in peers[1:]:
  328. peer.join()
  329. assert isinstance(optimizer, Optimizer)
  330. assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
  331. expected_samples_accumulated = target_batch_size * total_epochs
  332. assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
  333. assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
  334. assert not optimizer.state_averager.is_alive()
  335. assert not optimizer.grad_averager.is_alive()
  336. assert not optimizer.tracker.is_alive()
  337. assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()