test_moe.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import torch
  2. import tesseract
  3. from test_utils.run_server import background_server
  4. def test_remote_module_call():
  5. """ Check that remote_module_call returns correct outputs and gradients if called directly """
  6. num_experts = 8
  7. k_min = 1
  8. timeout_after_k_min = None
  9. backward_k_min = 1
  10. timeout_total = None
  11. backward_timeout = None
  12. xx = torch.randn(32, 1024, requires_grad=True)
  13. logits = torch.randn(3, requires_grad=True)
  14. random_proj = torch.randn_like(xx)
  15. with background_server(num_experts=num_experts, no_optimizer=True, no_network=True) as (localhost, server_port):
  16. experts = [tesseract.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
  17. moe_output, = tesseract.client.moe._RemoteMoECall.apply(
  18. logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
  19. [(None,), {}], xx)
  20. grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
  21. grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), logits, retain_graph=True)
  22. # reference outputs: call all experts manually and average their outputs with softmax probabilities
  23. probs = torch.softmax(logits, 0)
  24. outs = [expert(xx) for expert in experts[:3]]
  25. manual_output = sum(p * x for p, x in zip(probs, outs))
  26. grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
  27. grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
  28. grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
  29. assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun), "Experts are non-deterministic. This test is only " \
  30. "valid for deterministic experts"
  31. assert torch.allclose(moe_output, manual_output, rtol=1e-3, atol=1e-6), "_RemoteMoECall returned incorrect output"
  32. assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. input"
  33. assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol=1e-3, atol=1e-6), "incorrect gradient w.r.t. logits"
  34. if __name__ == '__main__':
  35. test_remote_module_call()