Эх сурвалжийг харах

add test_compute_expert_scores

justheuristic 5 жил өмнө
parent
commit
2600d3c956
1 өөрчлөгдсөн 24 нэмэгдсэн , 2 устгасан
  1. 24 2
      tests/test_moe.py

+ 24 - 2
tests/test_moe.py

@@ -43,5 +43,27 @@ def test_remote_module_call():
     assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
 
 
-if __name__ == '__main__':
-    test_remote_module_call()
+def test_compute_expert_scores():
+    with background_server() as (server_addr, server_port, network_port):
+        dht = tesseract.TesseractNetwork(('localhost', network_port), port=tesseract.find_open_port(), start=True)
+        moe = tesseract.client.moe.RemoteMixtureOfExperts(
+            network=dht, in_features=1024, grid_size=[40], k_best=4, k_min=1, timeout_after_k_min=1,
+            uid_prefix='expert')
+        gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
+        ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
+        jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
+        grid_scores = [gx, gy]
+        batch_experts = [
+            [tesseract.RemoteExpert(uid=f'expert.{ii[b][e]}.{jj[b][e]}') for e in range(len(ii[b]))]
+            for b in range(len(ii))
+        ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
+        logits = moe.compute_expert_scores([gx, gy], batch_experts)
+        torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
+        assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores needs to backprop wrt grid scores"
+
+        for b in range(len(ii)):
+            for e in range(len(ii[b])):
+                print(end='.')
+                assert torch.allclose(logits[b, e], gx[b, ii[b][e]] + gy[b, jj[b][e]]), "compute_expert_scores returned incorrect score"
+
+        dht.shutdown()