test_moe.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import grpc
  2. import numpy as np
  3. import pytest
  4. import torch
  5. import hivemind
  6. from hivemind import background_server
  7. from hivemind.client.expert import DUMMY
  8. from hivemind.server import layers
  9. @pytest.mark.forked
  10. def test_moe():
  11. all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
  12. for _ in range(10)]
  13. with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn', num_handlers=1,
  14. hidden_dim=16) as (server_endpoint, dht_endpoint):
  15. dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])
  16. dmoe = hivemind.RemoteMixtureOfExperts(
  17. in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix='ffn.')
  18. for i in range(3):
  19. out = dmoe(torch.randn(10, 16))
  20. out.sum().backward()
  21. @pytest.mark.forked
  22. def test_no_experts():
  23. all_expert_uids = [f'expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
  24. for _ in range(10)]
  25. with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='nop_delay', num_handlers=1,
  26. hidden_dim=16) as (server_endpoint, dht_endpoint):
  27. dht = hivemind.DHT(start=True, initial_peers=[dht_endpoint])
  28. dmoe = hivemind.RemoteSwitchMixtureOfExperts(
  29. in_features=16, grid_size=(4, 4, 4), dht=dht, uid_prefix='expert.', forward_timeout=0.1,
  30. backward_timeout=0.1, allow_zero_outputs=True)
  31. for i in range(3):
  32. out, balancing_loss = dmoe(torch.randn(10, 16))
  33. out.sum().backward()
  34. @pytest.mark.forked
  35. def test_call_many(hidden_dim=16):
  36. k_min = 1
  37. timeout_after_k_min = None
  38. backward_k_min = 1
  39. forward_timeout = None
  40. backward_timeout = None
  41. detect_anomalies = False
  42. allow_zero_outputs = False
  43. atol = 1e-5
  44. with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
  45. optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
  46. inputs = torch.randn(4, hidden_dim, requires_grad=True)
  47. inputs_clone = inputs.clone().detach().requires_grad_(True)
  48. e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]
  49. e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
  50. mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
  51. DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min, backward_k_min, timeout_after_k_min,
  52. forward_timeout, backward_timeout, detect_anomalies, allow_zero_outputs, e1.info, inputs
  53. )
  54. assert mask.shape == (4, 3)
  55. assert expert_outputs.shape == (4, 3, hidden_dim)
  56. assert np.all(mask.data.numpy() == np.array([[True, True, True],
  57. [True, True, False],
  58. [True, False, True],
  59. [False, False, False]])), f"Incorrect mask, {mask}"
  60. reference_outputs = torch.zeros_like(expert_outputs)
  61. reference_outputs[0, 0] = e0(inputs_clone[0:1])
  62. reference_outputs[0, 1] = e1(inputs_clone[0:1])
  63. reference_outputs[0, 2] = e2(inputs_clone[0:1])
  64. reference_outputs[1, 0] = e2(inputs_clone[1:2])
  65. reference_outputs[1, 1] = e4(inputs_clone[1:2])
  66. reference_outputs[2, 0] = e1(inputs_clone[2:3])
  67. reference_outputs[2, 2] = e3(inputs_clone[2:3])
  68. assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0)
  69. proj = torch.randn(4, hidden_dim)
  70. loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  71. loss.backward()
  72. our_grad = inputs.grad.data.cpu().clone()
  73. reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  74. reference_loss.backward()
  75. reference_grad = inputs_clone.grad.data.cpu().clone()
  76. assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
  77. @pytest.mark.forked
  78. def test_remote_module_call(hidden_dim=16):
  79. with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=hidden_dim,
  80. optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
  81. real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
  82. fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)
  83. out1 = real_expert(torch.randn(1, hidden_dim))
  84. assert out1.shape == (1, hidden_dim)
  85. dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
  86. out3 = real_expert(dummy_x)
  87. assert out3.shape == (3, hidden_dim)
  88. out3_again = real_expert(dummy_x[1:])
  89. assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
  90. out3_again.norm().backward()
  91. assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
  92. with pytest.raises(grpc.RpcError):
  93. real_expert(torch.randn(3, 11))
  94. with pytest.raises(grpc.RpcError):
  95. fake_expert(dummy_x)
  96. @pytest.mark.forked
  97. def test_beam_search_correctness():
  98. all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
  99. dht = hivemind.DHT(start=True)
  100. assert all(hivemind.declare_experts(dht, all_expert_uids, endpoint='fake-endpoint'))
  101. dmoe = hivemind.RemoteMixtureOfExperts(
  102. in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn.')
  103. for i in range(25):
  104. input = torch.randn(32)
  105. grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
  106. chosen_experts = dmoe.beam_search.find_best_experts([tensor.detach().numpy() for tensor in grid_scores],
  107. beam_size=dmoe.k_best)
  108. chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
  109. [chosen_experts])[0]
  110. our_best_scores = list(chosen_scores.cpu().detach().numpy())
  111. # reference: independently find :beam_size: best experts with exhaustive search
  112. all_scores = dmoe.compute_expert_scores([dim_scores.unsqueeze(0) for dim_scores in grid_scores],
  113. [[hivemind.RemoteExpert(uid, '') for uid in all_expert_uids]])[0]
  114. true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[:len(chosen_experts)]
  115. assert np.allclose(true_best_scores, our_best_scores)
  116. @pytest.mark.forked
  117. def test_determinism(hidden_dim=16):
  118. atol = 1e-5
  119. xx = torch.randn(32, hidden_dim, requires_grad=True)
  120. mask = torch.randint(0, 1, (32, hidden_dim))
  121. with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1, hidden_dim=hidden_dim,
  122. optim_cls=None, no_dht=True) as (server_endpoint, dht_endpoint):
  123. expert = hivemind.RemoteExpert(uid=f'expert.0', endpoint=server_endpoint)
  124. out = expert(xx, mask)
  125. out_rerun = expert(xx, mask)
  126. grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
  127. grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
  128. assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
  129. assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
  130. @pytest.mark.forked
  131. def test_compute_expert_scores():
  132. try:
  133. dht = hivemind.DHT(start=True)
  134. moe = hivemind.client.moe.RemoteMixtureOfExperts(
  135. dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
  136. uid_prefix='expert.')
  137. gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
  138. ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
  139. jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
  140. batch_experts = [
  141. [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}', endpoint="[::]:1337")
  142. for expert_i in range(len(ii[batch_i]))]
  143. for batch_i in range(len(ii))
  144. ] # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
  145. logits = moe.compute_expert_scores([gx, gy], batch_experts)
  146. torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
  147. assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
  148. for batch_i in range(len(ii)):
  149. for expert_i in range(len(ii[batch_i])):
  150. assert torch.allclose(logits[batch_i, expert_i],
  151. gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]), \
  152. "compute_expert_scores returned incorrect score"
  153. finally:
  154. dht.shutdown()
  155. @pytest.mark.forked
  156. def test_client_anomaly_detection():
  157. HID_DIM = 16
  158. experts = {}
  159. for i in range(4):
  160. expert = layers.name_to_block['ffn'](HID_DIM)
  161. experts[f'expert.{i}'] = hivemind.ExpertBackend(name=f'expert.{i}',
  162. expert=expert, optimizer=torch.optim.Adam(expert.parameters()),
  163. args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
  164. outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
  165. max_batch_size=16,
  166. )
  167. experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan')
  168. dht = hivemind.DHT(start=True)
  169. server = hivemind.Server(dht, experts, num_connection_handlers=1)
  170. server.start()
  171. try:
  172. server.ready.wait()
  173. dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix='expert.',
  174. detect_anomalies=True)
  175. input = torch.randn(1, 16)
  176. input[0, 0] = float('nan')
  177. with pytest.raises(ValueError):
  178. dmoe(input)
  179. input[0, 0] = 0
  180. output = dmoe(input)
  181. inf_loss = float('inf') * output.sum()
  182. with pytest.raises(ValueError):
  183. inf_loss.backward()
  184. dmoe = hivemind.RemoteMixtureOfExperts(in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix='expert.',
  185. detect_anomalies=True)
  186. output = dmoe(input)
  187. assert output.isfinite().all()
  188. finally:
  189. server.shutdown()