import time import pytest import torch import torch.nn as nn import torch.nn.functional as F import hivemind from hivemind.averaging.control import AveragingStage from hivemind.optim.experimental.grad_averager import GradientAverager @pytest.mark.forked def test_grad_averager(): dht1 = hivemind.DHT(start=True) model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))}) averager1 = GradientAverager( model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True ) dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs()) model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))}) averager2 = GradientAverager( model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True ) control1 = averager1.schedule_step(hivemind.get_dht_time() + 5) control2 = averager2.schedule_step(hivemind.get_dht_time() + 5) for i in range(10): time.sleep(0.1) if i % 3 == 0: loss1 = F.mse_loss(model1.w, torch.ones(3)) loss1.backward() averager1.accumulate_grads_(batch_size=2) # total: 4 times * 2 samples = 8 model1.zero_grad() else: loss2 = F.mse_loss(model2.w, -torch.ones(3)) loss2.backward() averager2.accumulate_grads_(batch_size=3) # total: 6 times * 3 samples = 18 # note: we do not call zero grad here because reuse_grad_buffers=True assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6 assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated) assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1) assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated) assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2) averager1.step(control=control1, wait=False) averager2.step(control=control2, wait=False) for step in (control1, control2): step.result() # wait for all-reduce to finish peer1_weight = peer1_samples / (peer1_samples + peer2_samples) peer2_weight = peer2_samples / (peer1_samples + peer2_samples) ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times) with averager1.use_averaged_gradients(): assert torch.allclose(model1.w.grad, ref_average) with averager2.use_averaged_gradients(): assert torch.allclose(model2.w.grad, ref_average) # after no longer use_averaged_gradients assert not torch.allclose(model1.w.grad, ref_average) assert not torch.allclose(model2.w.grad, ref_average)