Kaynağa Gözat

shutdown dht in any casy

justheuristic 5 yıl önce
ebeveyn
işleme
d9f93cb18d
1 değiştirilmiş dosya ile 20 ekleme ve 20 silme
  1. 20 20
      tests/test_moe.py

+ 20 - 20
tests/test_moe.py

@@ -45,24 +45,24 @@ def test_remote_module_call():
 
 def test_compute_expert_scores():
     with background_server(device='cpu') as (server_addr, server_port, network_port):
-        dht = tesseract.TesseractNetwork(('localhost', network_port), port=tesseract.find_open_port(), start=True)
-        moe = tesseract.client.moe.RemoteMixtureOfExperts(
-            network=dht, in_features=1024, grid_size=[40], k_best=4, k_min=1, timeout_after_k_min=1,
-            uid_prefix='expert')
-        gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
-        ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
-        jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
-        grid_scores = [gx, gy]
-        batch_experts = [
-            [tesseract.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}') for e in range(len(ii[b]))]
-            for b in range(len(ii))
-        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
-        logits = moe.compute_expert_scores([gx, gy], batch_experts)
-        torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
-        assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores needs to backprop wrt grid scores"
+        try:
+            dht = tesseract.TesseractNetwork(('localhost', network_port), port=tesseract.find_open_port(), start=True)
+            moe = tesseract.client.moe.RemoteMixtureOfExperts(
+                network=dht, in_features=1024, grid_size=[40], k_best=4, k_min=1, timeout_after_k_min=1,
+                uid_prefix='expert')
+            gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
+            ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
+            jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
+            batch_experts = [
+                [tesseract.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}') for e in range(len(ii[b]))]
+                for b in range(len(ii))
+            ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
+            logits = moe.compute_expert_scores([gx, gy], batch_experts)
+            torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
+            assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
 
-        for b in range(len(ii)):
-            for e in range(len(ii[b])):
-                assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
-
-        dht.shutdown()
+            for b in range(len(ii)):
+                for e in range(len(ii[b])):
+                    assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
+        finally:
+            dht.shutdown()