Explorar o código

add a unit test that fails (but shouldn't)

Artem Chumachenko %!s(int64=3) %!d(string=hai) anos
pai
achega
5244d8a96f
Modificáronse 1 ficheiros con 14 adicións e 3 borrados
  1. 14 3
      tests/test_optimizer.py

+ 14 - 3
tests/test_optimizer.py

@@ -147,34 +147,45 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
 
 
 @pytest.mark.forked
-def test_load_state_from_peers():
+@pytest.mark.parametrize("dpu", [True, False])
+def test_load_state_from_peers(dpu: bool):
     dht1 = hivemind.DHT(start=True)
     dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
 
     model1 = nn.Linear(2, 3)
     model2 = nn.Linear(2, 3)
 
+    extras1 = (torch.randn(2, 2), -torch.rand(1))
+    extras2 = (-torch.randn(2, 2), torch.rand(1))
+
     common_kwargs = dict(
         optimizer=partial(torch.optim.SGD, lr=0.1),
         scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        offload_optimizer=dpu,
+        reuse_tensors=dpu,
         target_group_size=2,
         prefix="my_exp",
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, extra_tensors=extras1,
+        **common_kwargs
     )
 
-    avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
+    avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True,  extra_tensors=extras2,
+                                  **common_kwargs)
 
     avgr2.local_epoch = 1337
     model2.weight.data[...] = 42
+    extras2[0][:] = 9999
     time.sleep(0.1)
 
     avgr1.load_state_from_peers()
     assert avgr1.local_epoch == 1337
     assert torch.all(model1.weight == 42).item()
     assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
+    assert torch.all(extras1[0] == extras2[0]).item() and torch.all(extras1[0] == extras2[0]).item()
+    assert torch.all(extras1[0] == 9999).item()
 
 
 @pytest.mark.forked