test_moe.py 4.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import torch
  2. import hivemind
  3. from test_utils.run_server import background_server
  4. def test_remote_module_call():
  5. """ Check that remote_module_call returns correct outputs and gradients if called directly """
  6. num_experts = 8
  7. k_min = 1
  8. timeout_after_k_min = None
  9. backward_k_min = 1
  10. timeout_total = None
  11. backward_timeout = None
  12. rtol = 1e-3
  13. atol = 1e-6
  14. xx = torch.randn(32, 1024, requires_grad=True)
  15. logits = torch.randn(3, requires_grad=True)
  16. random_proj = torch.randn_like(xx)
  17. with background_server(num_experts=num_experts, device='cpu',
  18. no_optimizer=True, no_dht=True) as (localhost, server_port, dht_port):
  19. experts = [hivemind.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
  20. moe_output, = hivemind.client.moe._RemoteMoECall.apply(
  21. logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
  22. [(None,), {}], xx)
  23. grad_xx_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), xx, retain_graph=True)
  24. grad_logits_moe, = torch.autograd.grad(torch.sum(random_proj * moe_output), logits, retain_graph=True)
  25. # reference outputs: call all experts manually and average their outputs with softmax probabilities
  26. probs = torch.softmax(logits, 0)
  27. outs = [expert(xx) for expert in experts[:3]]
  28. manual_output = sum(p * x for p, x in zip(probs, outs))
  29. grad_xx_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
  30. grad_xx_manual_rerun, = torch.autograd.grad(torch.sum(random_proj * manual_output), xx, retain_graph=True)
  31. grad_logits_manual, = torch.autograd.grad(torch.sum(random_proj * manual_output), logits, retain_graph=True)
  32. assert torch.allclose(grad_xx_manual, grad_xx_manual_rerun, rtol, atol), "Experts are non-deterministic. The test" \
  33. " is only valid for deterministic experts"
  34. assert torch.allclose(moe_output, manual_output, rtol, atol), "_RemoteMoECall returned incorrect output"
  35. assert torch.allclose(grad_xx_moe, grad_xx_manual, rtol, atol), "incorrect gradient w.r.t. input"
  36. assert torch.allclose(grad_logits_moe, grad_logits_manual, rtol, atol), "incorrect gradient w.r.t. logits"
  37. def test_determinism():
  38. rtol = 0
  39. atol = 1e-6
  40. xx = torch.randn(32, 1024, requires_grad=True)
  41. mask = torch.randint(0, 1, (32, 1024))
  42. with background_server(num_experts=1, device='cpu', expert_cls='det_dropout',
  43. no_optimizer=True, no_dht=True) as (interface, server_port, dht_port):
  44. expert = hivemind.RemoteExpert(uid=f'expert.0', port=server_port)
  45. out = expert(xx, mask)
  46. out_rerun = expert(xx, mask)
  47. grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
  48. grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
  49. assert torch.allclose(out, out_rerun, rtol, atol), "Dropout layer outputs are non-deterministic."
  50. assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
  51. def test_compute_expert_scores():
  52. try:
  53. dht = hivemind.DHTNode(port=hivemind.find_open_port(), start=True)
  54. moe = hivemind.client.moe.RemoteMixtureOfExperts(
  55. dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
  56. uid_prefix='expert')
  57. gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
  58. ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
  59. jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
  60. batch_experts = [
  61. [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}')
  62. for expert_i in range(len(ii[batch_i]))]
  63. for batch_i in range(len(ii))
  64. ] # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
  65. logits = moe.compute_expert_scores([gx, gy], batch_experts)
  66. torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
  67. assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
  68. for batch_i in range(len(ii)):
  69. for expert_i in range(len(ii[batch_i])):
  70. assert torch.allclose(logits[batch_i, expert_i],
  71. gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]), \
  72. "compute_expert_scores returned incorrect score"
  73. finally:
  74. dht.shutdown()