Browse Source

add basic moe correctness test

justheuristic 5 years ago
parent
commit
be3119b12e
1 changed files with 38 additions and 0 deletions
  1. 38 0
      tests/test_moe.py

+ 38 - 0
tests/test_moe.py

@@ -0,0 +1,38 @@
+import torch
+import tesseract
+
+
+def test_remote_module_call():
+    """ Check that remote_module_call returns correct outputs and gradients if called directly """
+    xx = torch.randn(32, 1024, requires_grad=True)
+    logits = torch.randn(3, requires_grad=True)
+    random_proj = torch.randn_like(xx)
+    # TODO somehow start server on some port
+    experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=6565) for i in range(8)]
+
+    k_min = 1
+    timeout_after_k_min = None
+    backward_k_min = 1
+    timeout_total = None
+    backward_timeout = None
+    moe_output, = tesseract.client.moe._RemoteMoECall.apply(
+        logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
+        [(None,), {}], xx)
+
+    grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
+    grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
+
+    # reference outputs: call all experts manually and average their outputs with softmax probabilities
+    probs = torch.softmax(logits, 0)
+    outs = [expert(xx) for expert in experts[:3]]
+    manual_output = sum(p * x for p, x in zip(probs, outs))
+    grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
+    grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
+    grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
+
+    assert torch.allclose(moe_output, manual_output), "_RemoteMoECall returned incorrect output"
+    assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun), "Experts are non-deterministic. This test is only " \
+                                                                 "valid for deterministic experts"
+    assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
+    assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"
+