|
@@ -149,21 +149,16 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
-def test_load_state_from_peers(dpu: bool = False):
|
|
|
+def test_load_state_from_peers():
|
|
|
dht1 = hivemind.DHT(start=True)
|
|
|
dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
|
|
|
|
|
|
model1 = nn.Linear(2, 3)
|
|
|
model2 = nn.Linear(2, 3)
|
|
|
|
|
|
- extras1 = (torch.randn(2, 2), -torch.rand(1))
|
|
|
- extras2 = (-torch.randn(2, 2), torch.rand(1))
|
|
|
-
|
|
|
common_kwargs = dict(
|
|
|
optimizer=partial(torch.optim.SGD, lr=0.1),
|
|
|
scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
|
|
|
- offload_optimizer=dpu,
|
|
|
- reuse_tensors=dpu,
|
|
|
target_group_size=2,
|
|
|
prefix="my_exp",
|
|
|
)
|
|
@@ -173,25 +168,19 @@ def test_load_state_from_peers(dpu: bool = False):
|
|
|
params=model1.parameters(),
|
|
|
allow_state_sharing=False,
|
|
|
start=True,
|
|
|
- extra_tensors=extras1,
|
|
|
**common_kwargs,
|
|
|
)
|
|
|
|
|
|
- avgr2 = TrainingStateAverager(
|
|
|
- dht=dht2, params=model2.parameters(), start=True, extra_tensors=extras2, **common_kwargs
|
|
|
- )
|
|
|
+ avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
|
|
|
|
|
|
avgr2.local_epoch = 1337
|
|
|
model2.weight.data[...] = 42
|
|
|
- extras2[0][:] = 9999
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
avgr1.load_state_from_peers()
|
|
|
assert avgr1.local_epoch == 1337
|
|
|
assert torch.all(model1.weight == 42).item()
|
|
|
assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
|
|
|
- assert torch.all(extras1[0] == extras2[0]).item() and torch.all(extras1[0] == extras2[0]).item()
|
|
|
- assert torch.all(extras1[0] == 9999).item()
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|