test_moe.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import asyncio
  2. import grpc
  3. import numpy as np
  4. import pytest
  5. import torch
  6. import hivemind
  7. from hivemind.client.expert import DUMMY
  8. from test_utils.run_server import background_server
  9. def test_moe():
  10. all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
  11. for _ in range(20)]
  12. with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
  13. num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
  14. dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
  15. # declare expert uids. Server *should* declare them by itself, but it takes time.
  16. assert all(dht.declare_experts(all_expert_uids, endpoint=server_endpoint))
  17. dmoe = hivemind.RemoteMixtureOfExperts(
  18. in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn')
  19. for i in range(10):
  20. out = dmoe(torch.randn(10, 16))
  21. out.sum().backward()
  22. def test_call_many():
  23. k_min = 1
  24. timeout_after_k_min = None
  25. backward_k_min = 1
  26. forward_timeout = None
  27. backward_timeout = None
  28. rtol = 1e-3
  29. atol = 1e-6
  30. with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
  31. no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
  32. inputs = torch.randn(4, 64, requires_grad=True)
  33. inputs_clone = inputs.clone().detach().requires_grad_(True)
  34. e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]
  35. e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
  36. mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
  37. DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
  38. k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout,
  39. asyncio.new_event_loop(), inputs
  40. )
  41. assert mask.shape == (4, 3)
  42. assert expert_outputs.shape == (4, 3, 64)
  43. assert np.all(mask.data.numpy() == np.array([[True, True, True],
  44. [True, True, False],
  45. [True, False, True],
  46. [False, False, False]])), f"Incorrect mask, {mask}"
  47. reference_outputs = torch.zeros_like(expert_outputs)
  48. reference_outputs[0, 0] = e0(inputs_clone[0:1])
  49. reference_outputs[0, 1] = e1(inputs_clone[0:1])
  50. reference_outputs[0, 2] = e2(inputs_clone[0:1])
  51. reference_outputs[1, 0] = e2(inputs_clone[1:2])
  52. reference_outputs[1, 1] = e4(inputs_clone[1:2])
  53. reference_outputs[2, 0] = e1(inputs_clone[2:3])
  54. reference_outputs[2, 2] = e3(inputs_clone[2:3])
  55. assert torch.allclose(expert_outputs, reference_outputs, rtol, atol)
  56. proj = torch.randn(4, 64)
  57. loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  58. loss.backward()
  59. our_grad = inputs.grad.data.cpu().clone()
  60. reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  61. reference_loss.backward()
  62. reference_grad = inputs_clone.grad.data.cpu().clone()
  63. assert torch.allclose(our_grad, reference_grad, rtol, atol)
  64. def test_remote_module_call():
  65. with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=1024,
  66. no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
  67. real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
  68. fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)
  69. out1 = real_expert(torch.randn(1, 1024))
  70. assert out1.shape == (1, 1024)
  71. dummy_x = torch.randn(3, 1024, requires_grad=True)
  72. out3 = real_expert(dummy_x)
  73. assert out3.shape == (3, 1024)
  74. out3_again = real_expert(dummy_x[1:])
  75. assert torch.allclose(out3_again, out3[1:])
  76. out3_again.norm().backward()
  77. assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
  78. with pytest.raises(grpc.RpcError):
  79. real_expert(torch.randn(3, 11))
  80. with pytest.raises(grpc.RpcError):
  81. fake_expert(dummy_x)
  82. def test_moe_beam_search():
  83. all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
  84. dht = hivemind.DHT(start=True, expiration=999)
  85. assert all(dht.declare_experts(all_expert_uids, endpoint='fake-endpoint'))
  86. dmoe = hivemind.RemoteMixtureOfExperts(
  87. in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn')
  88. for i in range(25):
  89. input = torch.randn(32)
  90. grid_scores = dmoe.proj(input).split_with_sizes(dmoe.grid_size, dim=-1)
  91. chosen_experts = dmoe.loop.run_until_complete(dmoe.beam_search(grid_scores, k_best=dmoe.k_best))
  92. chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
  93. [chosen_experts])[0]
  94. all_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
  95. [[hivemind.RemoteExpert(uid, '') for uid in all_expert_uids]])[0]
  96. true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[:len(chosen_experts)]
  97. our_best_scores = list(chosen_scores.cpu().detach().numpy())
  98. assert np.allclose(true_best_scores, our_best_scores)
  99. def test_determinism():
  100. rtol = 0
  101. atol = 1e-6
  102. xx = torch.randn(32, 1024, requires_grad=True)
  103. mask = torch.randint(0, 1, (32, 1024))
  104. with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1,
  105. no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
  106. expert = hivemind.RemoteExpert(uid=f'expert.0', endpoint=server_endpoint)
  107. out = expert(xx, mask)
  108. out_rerun = expert(xx, mask)
  109. grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
  110. grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
  111. assert torch.allclose(out, out_rerun, rtol, atol), "Dropout layer outputs are non-deterministic."
  112. assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
  113. def test_compute_expert_scores():
  114. try:
  115. dht = hivemind.DHT(start=True)
  116. moe = hivemind.client.moe.RemoteMixtureOfExperts(
  117. dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
  118. uid_prefix='expert')
  119. gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
  120. ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
  121. jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
  122. batch_experts = [
  123. [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}', endpoint="[::]:1337")
  124. for expert_i in range(len(ii[batch_i]))]
  125. for batch_i in range(len(ii))
  126. ] # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
  127. logits = moe.compute_expert_scores([gx, gy], batch_experts)
  128. torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
  129. assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
  130. for batch_i in range(len(ii)):
  131. for expert_i in range(len(ii[batch_i])):
  132. assert torch.allclose(logits[batch_i, expert_i],
  133. gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]), \
  134. "compute_expert_scores returned incorrect score"
  135. finally:
  136. dht.shutdown()