test_moe.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import torch
  2. import tesseract
  3. def test_remote_module_call():
  4. """ Check that remote_module_call returns correct outputs and gradients if called directly """
  5. xx = torch.randn(32, 1024, requires_grad=True)
  6. logits = torch.randn(3, requires_grad=True)
  7. random_proj = torch.randn_like(xx)
  8. # TODO somehow start server on some port
  9. experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=6565) for i in range(8)]
  10. k_min = 1
  11. timeout_after_k_min = None
  12. backward_k_min = 1
  13. timeout_total = None
  14. backward_timeout = None
  15. moe_output, = tesseract.client.moe._RemoteMoECall.apply(
  16. logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
  17. [(None,), {}], xx)
  18. grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
  19. grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
  20. # reference outputs: call all experts manually and average their outputs with softmax probabilities
  21. probs = torch.softmax(logits, 0)
  22. outs = [expert(xx) for expert in experts[:3]]
  23. manual_output = sum(p * x for p, x in zip(probs, outs))
  24. grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
  25. grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
  26. grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
  27. assert torch.allclose(moe_output, manual_output), "_RemoteMoECall returned incorrect output"
  28. assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun), "Experts are non-deterministic. This test is only " \
  29. "valid for deterministic experts"
  30. assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
  31. assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"