test_optimizer.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import time
  2. from functools import partial
  3. import numpy as np
  4. import pytest
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import hivemind
  9. from hivemind.averaging.control import AveragingStage
  10. from hivemind.optim.experimental.grad_averager import GradientAverager
  11. from hivemind.optim.experimental.state_averager import TrainingStateAverager
  12. @pytest.mark.forked
  13. def test_grad_averager():
  14. dht1 = hivemind.DHT(start=True)
  15. model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  16. averager1 = GradientAverager(
  17. model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
  18. )
  19. dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
  20. model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  21. averager2 = GradientAverager(
  22. model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
  23. )
  24. control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
  25. control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
  26. for i in range(10):
  27. time.sleep(0.1)
  28. if i % 3 == 0:
  29. loss1 = F.mse_loss(model1.w, torch.ones(3))
  30. loss1.backward()
  31. averager1.accumulate_grads_(batch_size=2) # total: 4 times * 2 samples = 8
  32. model1.zero_grad()
  33. else:
  34. loss2 = F.mse_loss(model2.w, -torch.ones(3))
  35. loss2.backward()
  36. averager2.accumulate_grads_(batch_size=3) # total: 6 times * 3 samples = 18
  37. # note: we do not call zero grad here because reuse_grad_buffers=True
  38. assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
  39. peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
  40. assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
  41. ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
  42. assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
  43. assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
  44. ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
  45. assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
  46. averager1.step(control=control1, wait=False)
  47. averager2.step(control=control2, wait=False)
  48. for step in (control1, control2):
  49. step.result() # wait for all-reduce to finish
  50. peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
  51. peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
  52. ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
  53. with averager1.use_averaged_gradients():
  54. assert torch.allclose(model1.w.grad, ref_average)
  55. with averager2.use_averaged_gradients():
  56. assert torch.allclose(model2.w.grad, ref_average)
  57. # after no longer use_averaged_gradients
  58. assert not torch.allclose(model1.w.grad, ref_average)
  59. assert not torch.allclose(model2.w.grad, ref_average)
  60. @pytest.mark.forked
  61. @pytest.mark.parametrize(
  62. "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
  63. [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
  64. )
  65. def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
  66. dht1 = hivemind.DHT(start=True)
  67. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  68. torch.manual_seed(1337)
  69. torch.use_deterministic_algorithms(True)
  70. # note: use_deterministic_algorithms does not affect further tests because this test is forked
  71. model1 = nn.Linear(2, 3)
  72. model2 = nn.Linear(2, 3)
  73. extras1 = (torch.randn(2, 2), -torch.rand(1))
  74. extras2 = (-torch.randn(2, 2), torch.rand(1))
  75. common_kwargs = dict(
  76. optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
  77. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  78. sync_epoch_when_averaging=sync_epoch_when_averaging,
  79. average_opt_statistics=("exp_avg_sq",),
  80. offload_optimizer=offload_optimizer,
  81. reuse_tensors=reuse_tensors,
  82. target_group_size=2,
  83. prefix="my_exp",
  84. )
  85. avgr1 = TrainingStateAverager(
  86. dht=dht1, param_groups=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
  87. )
  88. avgr2 = TrainingStateAverager(
  89. dht=dht2, param_groups=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
  90. )
  91. x = torch.ones(2)
  92. for step in range(20):
  93. F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
  94. avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
  95. F.mse_loss(model2(x), -torch.ones(3)).backward()
  96. avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
  97. assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
  98. assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
  99. assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
  100. assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
  101. stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  102. stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
  103. assert not torch.allclose(stats1, stats2)
  104. avgr1.step(increment_epoch=True)
  105. avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  106. avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
  107. avgr1.step(wait_for_delayed_update=True)
  108. avgr2.step(wait_for_delayed_update=True)
  109. assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
  110. assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  111. assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
  112. assert avgr1.local_epoch == 2
  113. assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
  114. @pytest.mark.forked
  115. def test_load_state_from_peers():
  116. dht1 = hivemind.DHT(start=True)
  117. dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
  118. model1 = nn.Linear(2, 3)
  119. model2 = nn.Linear(2, 3)
  120. common_kwargs = dict(
  121. optimizer=partial(torch.optim.SGD, lr=0.1),
  122. scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
  123. target_group_size=2,
  124. prefix="my_exp",
  125. )
  126. avgr1 = TrainingStateAverager(
  127. dht=dht1, param_groups=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
  128. )
  129. avgr2 = TrainingStateAverager(dht=dht2, param_groups=model2.parameters(), start=True, **common_kwargs)
  130. avgr2.local_epoch = 1337
  131. model2.weight.data[...] = 42
  132. time.sleep(0.1)
  133. avgr1.load_state_from_peers()
  134. assert avgr1.local_epoch == 1337
  135. assert torch.all(model1.weight == 42).item()
  136. assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)