test_optimizer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import ctypes
  2. import multiprocessing as mp
  3. import time
  4. from functools import partial
  5. import numpy as np
  6. import pytest
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import hivemind
  11. from hivemind.averaging.control import AveragingStage
  12. from hivemind.optim.experimental.grad_averager import GradientAverager
  13. from hivemind.optim.experimental.progress_tracker import ProgressTracker
  14. from hivemind.optim.experimental.state_averager import TrainingStateAverager
  15. from hivemind.utils.crypto import RSAPrivateKey
  16. @pytest.mark.forked
  17. def test_grad_averager():
  18. dht1 = hivemind.DHT(start=True)
  19. model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  20. averager1 = GradientAverager(
  21. model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
  22. )
  23. dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
  24. model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  25. averager2 = GradientAverager(
  26. model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
  27. )
  28. control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
  29. control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
  30. for i in range(10):
  31. time.sleep(0.1)
  32. if i % 3 == 0:
  33. loss1 = F.mse_loss(model1.w, torch.ones(3))
  34. loss1.backward()
  35. averager1.accumulate_grads_(batch_size=2) # total: 4 times * 2 samples = 8
  36. model1.zero_grad()
  37. else:
  38. loss2 = F.mse_loss(model2.w, -torch.ones(3))
  39. loss2.backward()
  40. averager2.accumulate_grads_(batch_size=3) # total: 6 times * 3 samples = 18
  41. # note: we do not call zero grad here because reuse_grad_buffers=True
  42. assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
  43. peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
  44. assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
  45. ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
  46. assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
  47. assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
  48. ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
  49. assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
  50. averager1.step(control=control1, wait=False)
  51. averager2.step(control=control2, wait=False)
  52. for step in (control1, control2):
  53. step.result() # wait for all-reduce to finish
  54. peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
  55. peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
  56. ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
  57. with averager1.use_averaged_gradients():
  58. assert torch.allclose(model1.w.grad, ref_average)
  59. with averager2.use_averaged_gradients():
  60. assert torch.allclose(model2.w.grad, ref_average)
  61. # after no longer use_averaged_gradients
  62. assert not torch.allclose(model1.w.grad, ref_average)
  63. assert not torch.allclose(model2.w.grad, ref_average)
  64. @pytest.mark.forked
  65. @pytest.mark.parametrize(
  66. "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
  67. [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
  68. )
  69. def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
  70. dht1 = hivemind.DHT(start=True)
  71. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  72. torch.manual_seed(1337)
  73. torch.use_deterministic_algorithms(True)
  74. # note: use_deterministic_algorithms does not affect further tests because this test is forked
  75. model1 = nn.Linear(2, 3)
  76. model2 = nn.Linear(2, 3)
  77. extras1 = (torch.randn(2, 2), -torch.rand(1))
  78. extras2 = (-torch.randn(2, 2), torch.rand(1))
  79. common_kwargs = dict(
  80. optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
  81. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  82. sync_epoch_when_averaging=sync_epoch_when_averaging,
  83. average_opt_statistics=("exp_avg_sq",),
  84. offload_optimizer=offload_optimizer,
  85. reuse_tensors=reuse_tensors,
  86. target_group_size=2,
  87. prefix="my_exp",
  88. )
  89. avgr1 = TrainingStateAverager(
  90. dht=dht1, params=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
  91. )
  92. avgr2 = TrainingStateAverager(
  93. dht=dht2, params=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
  94. )
  95. x = torch.ones(2)
  96. for step in range(20):
  97. F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
  98. avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
  99. F.mse_loss(model2(x), -torch.ones(3)).backward()
  100. avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
  101. assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
  102. assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
  103. assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
  104. assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
  105. stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  106. stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  107. assert not torch.allclose(stats1, stats2)
  108. avgr1.step(increment_epoch=True)
  109. avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  110. avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  111. avgr1.step(wait_for_delayed_update=True)
  112. avgr2.step(wait_for_delayed_update=True)
  113. assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
  114. assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  115. assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  116. assert avgr1.local_epoch == 2
  117. assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
  118. @pytest.mark.forked
  119. def test_load_state_from_peers():
  120. dht1 = hivemind.DHT(start=True)
  121. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  122. model1 = nn.Linear(2, 3)
  123. model2 = nn.Linear(2, 3)
  124. common_kwargs = dict(
  125. optimizer=partial(torch.optim.SGD, lr=0.1),
  126. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  127. target_group_size=2,
  128. prefix="my_exp",
  129. )
  130. avgr1 = TrainingStateAverager(
  131. dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
  132. )
  133. avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
  134. avgr2.local_epoch = 1337
  135. model2.weight.data[...] = 42
  136. time.sleep(0.1)
  137. avgr1.load_state_from_peers()
  138. assert avgr1.local_epoch == 1337
  139. assert torch.all(model1.weight == 42).item()
  140. assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
  141. @pytest.mark.forked
  142. def test_progress_tracker():
  143. # note to a curious reader: no, you cannot reduce the timings without compromising realism or stability
  144. prefix = "my_exp"
  145. target_batch_size = 256
  146. dht_root = hivemind.DHT(start=True)
  147. barrier = mp.Barrier(parties=5)
  148. delayed_start_evt = mp.Event()
  149. finished_evt = mp.Event()
  150. emas = mp.Array(ctypes.c_double, 5)
  151. def run_worker(index: int, batch_size: int, period: float, **kwargs):
  152. dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
  153. tracker = ProgressTracker(
  154. dht,
  155. prefix,
  156. target_batch_size,
  157. start=True,
  158. min_refresh_period=0.1,
  159. default_refresh_period=0.2,
  160. max_refresh_period=0.5,
  161. private_key=RSAPrivateKey(),
  162. **kwargs,
  163. )
  164. barrier.wait()
  165. if index == 4:
  166. delayed_start_evt.wait()
  167. local_epoch = 2 if index == 4 else 0
  168. samples_accumulated = 0
  169. while True:
  170. time.sleep(period)
  171. if finished_evt.is_set():
  172. break
  173. samples_accumulated += batch_size
  174. tracker.report_local_progress(local_epoch, samples_accumulated)
  175. if tracker.ready_to_update_epoch:
  176. with tracker.pause_updates():
  177. local_epoch = tracker.update_epoch(local_epoch + 1)
  178. samples_accumulated = 0
  179. if index == 4 and local_epoch >= 5:
  180. time.sleep(0.5)
  181. break
  182. emas[index] = tracker.performance_ema.samples_per_second
  183. tracker.shutdown()
  184. dht.shutdown()
  185. workers = [
  186. mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
  187. mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
  188. mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
  189. mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
  190. ]
  191. for worker in workers:
  192. worker.start()
  193. tracker = ProgressTracker(
  194. dht_root,
  195. prefix,
  196. target_batch_size,
  197. start=True,
  198. min_refresh_period=0.1,
  199. default_refresh_period=0.2,
  200. max_refresh_period=0.5,
  201. )
  202. barrier.wait()
  203. current_step = 0
  204. last_timestamp = hivemind.get_dht_time()
  205. step_time_deltas = []
  206. while current_step < 6:
  207. time.sleep(0.1)
  208. if tracker.global_progress.epoch > current_step:
  209. time_delta = hivemind.get_dht_time() - last_timestamp
  210. current_step = tracker.global_progress.epoch
  211. if current_step == 2:
  212. delayed_start_evt.set()
  213. last_timestamp = hivemind.get_dht_time()
  214. step_time_deltas.append(time_delta)
  215. finished_evt.set()
  216. for worker in workers:
  217. worker.join()
  218. tracker.shutdown()
  219. dht_root.shutdown()
  220. assert not tracker.is_alive()
  221. mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
  222. for i in (0, 1, 5): # Without the 4th worker (the fastest one)
  223. assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
  224. for i in (2, 3, 4): # With the 4th worker
  225. assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
  226. assert emas[1] < emas[2] < emas[3] < emas[4]
  227. assert tracker.performance_ema.samples_per_second < 1e-9