123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385 |
- import ctypes
- import multiprocessing as mp
- import time
- from functools import partial
- import numpy as np
- import pytest
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import hivemind
- from hivemind.averaging.control import AveragingStage
- from hivemind.optim.experimental.grad_averager import GradientAverager
- from hivemind.optim.experimental.optimizer import Optimizer
- from hivemind.optim.experimental.progress_tracker import ProgressTracker
- from hivemind.optim.experimental.state_averager import TrainingStateAverager
- from hivemind.utils.crypto import RSAPrivateKey
- @pytest.mark.forked
- def test_grad_averager():
- dht1 = hivemind.DHT(start=True)
- model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
- averager1 = GradientAverager(
- model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
- )
- dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
- model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
- averager2 = GradientAverager(
- model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
- )
- control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
- control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
- for i in range(10):
- time.sleep(0.1)
- if i % 3 == 0:
- loss1 = F.mse_loss(model1.w, torch.ones(3))
- loss1.backward()
- averager1.accumulate_grads_(batch_size=2) # total: 4 times * 2 samples = 8
- model1.zero_grad()
- else:
- loss2 = F.mse_loss(model2.w, -torch.ones(3))
- loss2.backward()
- averager2.accumulate_grads_(batch_size=3) # total: 6 times * 3 samples = 18
- # note: we do not call zero grad here because reuse_grad_buffers=True
- assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
- peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
- assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
- ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
- assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
- assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
- ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
- assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
- averager1.step(control=control1, wait=False)
- averager2.step(control=control2, wait=False)
- for step in (control1, control2):
- step.result() # wait for all-reduce to finish
- peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
- peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
- ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
- with averager1.use_averaged_gradients():
- assert torch.allclose(model1.w.grad, ref_average)
- with averager2.use_averaged_gradients():
- assert torch.allclose(model2.w.grad, ref_average)
- # after no longer use_averaged_gradients
- assert not torch.allclose(model1.w.grad, ref_average)
- assert not torch.allclose(model2.w.grad, ref_average)
- @pytest.mark.forked
- @pytest.mark.parametrize(
- "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
- [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
- )
- def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
- dht1 = hivemind.DHT(start=True)
- dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
- torch.manual_seed(1337)
- torch.use_deterministic_algorithms(True)
- # note: use_deterministic_algorithms does not affect further tests because this test is forked
- model1 = nn.Linear(2, 3)
- model2 = nn.Linear(2, 3)
- extras1 = (torch.randn(2, 2), -torch.rand(1))
- extras2 = (-torch.randn(2, 2), torch.rand(1))
- common_kwargs = dict(
- optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
- scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
- sync_epoch_when_averaging=sync_epoch_when_averaging,
- average_opt_statistics=("exp_avg_sq",),
- offload_optimizer=offload_optimizer,
- reuse_tensors=reuse_tensors,
- target_group_size=2,
- prefix="my_exp",
- )
- avgr1 = TrainingStateAverager(
- dht=dht1, params=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
- )
- avgr2 = TrainingStateAverager(
- dht=dht2, params=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
- )
- x = torch.ones(2)
- for step in range(20):
- F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
- avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
- F.mse_loss(model2(x), -torch.ones(3)).backward()
- avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
- assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
- assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
- assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
- assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
- stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
- stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
- assert not torch.allclose(stats1, stats2)
- avgr1.step(increment_epoch=True)
- avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
- avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
- avgr1.step(wait_for_delayed_updates=True)
- avgr2.step(wait_for_delayed_updates=True)
- assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
- assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
- assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
- assert avgr1.local_epoch == 2
- assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
- @pytest.mark.forked
- def test_load_state_from_peers():
- dht1 = hivemind.DHT(start=True)
- dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
- model1 = nn.Linear(2, 3)
- model2 = nn.Linear(2, 3)
- common_kwargs = dict(
- optimizer=partial(torch.optim.SGD, lr=0.1),
- scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
- target_group_size=2,
- prefix="my_exp",
- )
- avgr1 = TrainingStateAverager(
- dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
- )
- avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
- avgr2.local_epoch = 1337
- model2.weight.data[...] = 42
- time.sleep(0.1)
- avgr1.load_state_from_peers()
- assert avgr1.local_epoch == 1337
- assert torch.all(model1.weight == 42).item()
- assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
- @pytest.mark.forked
- def test_progress_tracker():
- # note to a curious reader: no, you cannot reduce the timings without compromising realism or stability
- prefix = "my_exp"
- target_batch_size = 256
- dht_root = hivemind.DHT(start=True)
- barrier = mp.Barrier(parties=5)
- delayed_start_evt = mp.Event()
- finished_evt = mp.Event()
- emas = mp.Array(ctypes.c_double, 5)
- def run_worker(index: int, batch_size: int, period: float, **kwargs):
- dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
- tracker = ProgressTracker(
- dht,
- prefix,
- target_batch_size,
- start=True,
- min_refresh_period=0.1,
- default_refresh_period=0.2,
- max_refresh_period=0.5,
- private_key=RSAPrivateKey(),
- **kwargs,
- )
- barrier.wait()
- if index == 4:
- delayed_start_evt.wait()
- local_epoch = 2 if index == 4 else 0
- samples_accumulated = 0
- while True:
- time.sleep(period)
- if finished_evt.is_set():
- break
- samples_accumulated += batch_size
- tracker.report_local_progress(local_epoch, samples_accumulated)
- if tracker.ready_to_update_epoch:
- if index == 4 and local_epoch >= 4:
- time.sleep(0.5)
- break
- with tracker.pause_updates():
- local_epoch = tracker.update_epoch(local_epoch + 1)
- samples_accumulated = 0
- emas[index] = tracker.performance_ema.samples_per_second
- tracker.shutdown()
- dht.shutdown()
- workers = [
- mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
- mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
- mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
- mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
- ]
- for worker in workers:
- worker.start()
- tracker = ProgressTracker(
- dht_root,
- prefix,
- target_batch_size,
- start=True,
- min_refresh_period=0.1,
- default_refresh_period=0.2,
- max_refresh_period=0.5,
- )
- barrier.wait()
- local_epoch = 0
- last_timestamp = hivemind.get_dht_time()
- step_time_deltas = []
- while local_epoch < 6:
- time.sleep(0.1)
- if tracker.ready_to_update_epoch:
- with tracker.pause_updates():
- local_epoch = tracker.update_epoch(local_epoch + 1)
- time_delta = hivemind.get_dht_time() - last_timestamp
- if local_epoch == 2:
- delayed_start_evt.set()
- last_timestamp = hivemind.get_dht_time()
- step_time_deltas.append(time_delta)
- finished_evt.set()
- for worker in workers:
- worker.join()
- tracker.shutdown()
- dht_root.shutdown()
- assert not tracker.is_alive()
- mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
- for i in (0, 1, 5): # Without the 4th worker (the fastest one)
- assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
- for i in (2, 3, 4): # With the 4th worker
- assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
- assert emas[1] < emas[2] < emas[3] < emas[4]
- assert tracker.performance_ema.samples_per_second < 1e-9
- @pytest.mark.forked
- def test_optimizer(
- num_peers: int = 1,
- num_clients: int = 0,
- target_batch_size: int = 32,
- total_epochs: int = 3,
- reuse_grad_buffers: bool = True,
- delay_grad_averaging: bool = True,
- delay_optimizer_step: bool = True,
- average_state_every: int = 1,
- ):
- dht = hivemind.DHT(start=True)
- features = torch.randn(100, 5)
- targets = features @ torch.randn(5, 1)
- optimizer = None
- total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
- def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
- nonlocal optimizer
- model = nn.Linear(5, 1)
- assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
- optimizer = Optimizer(
- run_id="test_run",
- target_batch_size=target_batch_size,
- batch_size_per_step=batch_size,
- params=model.parameters(),
- optimizer=partial(torch.optim.SGD, lr=0.1),
- scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
- dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
- tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
- averager_opts=dict(min_matchmaking_time=1.0, request_timeout=0.5),
- matchmaking_time=1.0,
- averaging_timeout=5.0,
- reuse_grad_buffers=reuse_grad_buffers,
- delay_grad_averaging=delay_grad_averaging,
- delay_optimizer_step=delay_optimizer_step,
- average_state_every=average_state_every,
- client_mode=client_mode,
- verbose=False,
- )
- optimizer.load_state_from_peers()
- prev_time = time.perf_counter()
- while optimizer.local_epoch < total_epochs:
- time.sleep(max(0.0, prev_time + batch_time - time.perf_counter()))
- batch = torch.randint(0, len(features), (batch_size,))
- loss = F.mse_loss(model(features[batch]), targets[batch])
- loss.backward()
- optimizer.step()
- total_samples_accumulated.value += batch_size
- if not reuse_grad_buffers:
- optimizer.zero_grad()
- prev_time = time.perf_counter()
- time.sleep(1.0)
- optimizer.shutdown()
- return optimizer
- peers = []
- for index in range(num_peers):
- peers.append(
- mp.Process(
- target=run_trainer,
- name=f"trainer-{index}",
- kwargs=dict(
- batch_size=4 + index,
- batch_time=0.3 + 0.2 * index,
- client_mode=(index >= num_peers - num_clients),
- ),
- )
- )
- for peer in peers[1:]:
- peer.start()
- peers[0].run()
- for peer in peers[1:]:
- peer.join()
- assert isinstance(optimizer, Optimizer)
- assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
- expected_samples_accumulated = target_batch_size * total_epochs
- assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
- assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
- assert not optimizer.state_averager.is_alive()
- assert not optimizer.grad_averager.is_alive()
- assert not optimizer.tracker.is_alive()
- assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()
|