test_optimizer.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import time
  2. import pytest
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import hivemind
  7. from hivemind.averaging.control import AveragingStage
  8. from hivemind.optim.experimental.grad_averager import GradientAverager
  9. @pytest.mark.forked
  10. def test_grad_averager():
  11. dht1 = hivemind.DHT(start=True)
  12. model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  13. averager1 = GradientAverager(
  14. model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
  15. )
  16. dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
  17. model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
  18. averager2 = GradientAverager(
  19. model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
  20. )
  21. control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
  22. control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
  23. for i in range(10):
  24. time.sleep(0.1)
  25. if i % 3 == 0:
  26. loss1 = F.mse_loss(model1.w, torch.ones(3))
  27. loss1.backward()
  28. averager1.accumulate_grads_(batch_size=2) # total: 4 times * 2 samples = 8
  29. model1.zero_grad()
  30. else:
  31. loss2 = F.mse_loss(model2.w, -torch.ones(3))
  32. loss2.backward()
  33. averager2.accumulate_grads_(batch_size=3) # total: 6 times * 3 samples = 18
  34. # note: we do not call zero grad here because reuse_grad_buffers=True
  35. assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
  36. peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
  37. assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
  38. ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
  39. assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
  40. assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
  41. ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
  42. assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
  43. averager1.step(control=control1, wait=False)
  44. averager2.step(control=control2, wait=False)
  45. for step in (control1, control2):
  46. step.result() # wait for all-reduce to finish
  47. peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
  48. peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
  49. ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
  50. with averager1.use_averaged_gradients():
  51. assert torch.allclose(model1.w.grad, ref_average)
  52. with averager2.use_averaged_gradients():
  53. assert torch.allclose(model2.w.grad, ref_average)
  54. # after no longer use_averaged_gradients
  55. assert not torch.allclose(model1.w.grad, ref_average)
  56. assert not torch.allclose(model2.w.grad, ref_average)