test_moe.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import grpc
  2. import numpy as np
  3. import pytest
  4. import torch
  5. import hivemind
  6. from hivemind.moe.client.expert import DUMMY
  7. from hivemind.moe.server import background_server, declare_experts, layers
  8. @pytest.mark.forked
  9. def test_moe():
  10. all_expert_uids = [
  11. f"ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  12. ]
  13. with background_server(
  14. expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
  15. ) as (server_endpoint, dht_maddrs):
  16. dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
  17. dmoe = hivemind.RemoteMixtureOfExperts(
  18. in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn."
  19. )
  20. for i in range(3):
  21. out = dmoe(torch.randn(10, 16))
  22. out.sum().backward()
  23. @pytest.mark.forked
  24. def test_no_experts():
  25. all_expert_uids = [
  26. f"expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  27. ]
  28. with background_server(
  29. expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
  30. ) as (server_endpoint, dht_maddrs):
  31. dht = hivemind.DHT(start=True, initial_peers=dht_maddrs)
  32. dmoe = hivemind.RemoteSwitchMixtureOfExperts(
  33. in_features=16,
  34. grid_size=(4, 4, 4),
  35. dht=dht,
  36. uid_prefix="expert.",
  37. forward_timeout=0.1,
  38. backward_timeout=0.1,
  39. allow_zero_outputs=True,
  40. )
  41. for i in range(3):
  42. out, balancing_loss = dmoe(torch.randn(10, 16))
  43. out.sum().backward()
  44. @pytest.mark.forked
  45. def test_call_many(hidden_dim=16):
  46. k_min = 1
  47. timeout_after_k_min = None
  48. backward_k_min = 1
  49. forward_timeout = None
  50. backward_timeout = None
  51. detect_anomalies = False
  52. allow_zero_outputs = False
  53. atol = 1e-5
  54. with background_server(
  55. num_experts=5,
  56. device="cpu",
  57. expert_cls="ffn",
  58. num_handlers=1,
  59. hidden_dim=hidden_dim,
  60. optim_cls=None,
  61. no_dht=True,
  62. ) as (server_endpoint, _):
  63. inputs = torch.randn(4, hidden_dim, requires_grad=True)
  64. inputs_clone = inputs.clone().detach().requires_grad_(True)
  65. e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
  66. e5 = hivemind.RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
  67. mask, expert_outputs = hivemind.moe.client.moe._RemoteCallMany.apply(
  68. DUMMY,
  69. [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
  70. k_min,
  71. backward_k_min,
  72. timeout_after_k_min,
  73. forward_timeout,
  74. backward_timeout,
  75. detect_anomalies,
  76. allow_zero_outputs,
  77. e1.info,
  78. inputs,
  79. )
  80. assert mask.shape == (4, 3)
  81. assert expert_outputs.shape == (4, 3, hidden_dim)
  82. assert np.all(
  83. mask.data.numpy()
  84. == np.array([[True, True, True], [True, True, False], [True, False, True], [False, False, False]])
  85. ), f"Incorrect mask, {mask}"
  86. reference_outputs = torch.zeros_like(expert_outputs)
  87. reference_outputs[0, 0] = e0(inputs_clone[0:1])
  88. reference_outputs[0, 1] = e1(inputs_clone[0:1])
  89. reference_outputs[0, 2] = e2(inputs_clone[0:1])
  90. reference_outputs[1, 0] = e2(inputs_clone[1:2])
  91. reference_outputs[1, 1] = e4(inputs_clone[1:2])
  92. reference_outputs[2, 0] = e1(inputs_clone[2:3])
  93. reference_outputs[2, 2] = e3(inputs_clone[2:3])
  94. assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0)
  95. proj = torch.randn(4, hidden_dim)
  96. loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  97. loss.backward()
  98. our_grad = inputs.grad.data.cpu().clone()
  99. reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  100. reference_loss.backward()
  101. reference_grad = inputs_clone.grad.data.cpu().clone()
  102. assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
  103. @pytest.mark.forked
  104. def test_remote_module_call(hidden_dim=16):
  105. with background_server(
  106. num_experts=1,
  107. device="cpu",
  108. expert_cls="ffn",
  109. num_handlers=1,
  110. hidden_dim=hidden_dim,
  111. optim_cls=None,
  112. no_dht=True,
  113. ) as (server_endpoint, _):
  114. real_expert = hivemind.RemoteExpert("expert.0", server_endpoint)
  115. fake_expert = hivemind.RemoteExpert("oiasfjiasjf", server_endpoint)
  116. out1 = real_expert(torch.randn(1, hidden_dim))
  117. assert out1.shape == (1, hidden_dim)
  118. dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
  119. out3 = real_expert(dummy_x)
  120. assert out3.shape == (3, hidden_dim)
  121. out3_again = real_expert(dummy_x[1:])
  122. assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
  123. out3_again.norm().backward()
  124. assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
  125. with pytest.raises(grpc.RpcError):
  126. real_expert(torch.randn(3, 11))
  127. with pytest.raises(grpc.RpcError):
  128. fake_expert(dummy_x)
  129. @pytest.mark.forked
  130. def test_beam_search_correctness():
  131. all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
  132. dht = hivemind.DHT(start=True)
  133. assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
  134. dmoe = hivemind.RemoteMixtureOfExperts(
  135. in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn."
  136. )
  137. for i in range(25):
  138. input = torch.randn(32)
  139. grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
  140. chosen_experts = dmoe.beam_search.find_best_experts(
  141. [tensor.detach().numpy() for tensor in grid_scores], beam_size=dmoe.k_best
  142. )
  143. chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores], [chosen_experts])[
  144. 0
  145. ]
  146. our_best_scores = list(chosen_scores.cpu().detach().numpy())
  147. # reference: independently find :beam_size: best experts with exhaustive search
  148. all_scores = dmoe.compute_expert_scores(
  149. [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
  150. [[hivemind.RemoteExpert(uid, "") for uid in all_expert_uids]],
  151. )[0]
  152. true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
  153. assert np.allclose(true_best_scores, our_best_scores)
  154. @pytest.mark.forked
  155. def test_determinism(hidden_dim=16):
  156. atol = 1e-5
  157. xx = torch.randn(32, hidden_dim, requires_grad=True)
  158. mask = torch.randint(0, 1, (32, hidden_dim))
  159. with background_server(
  160. num_experts=1,
  161. device="cpu",
  162. expert_cls="det_dropout",
  163. num_handlers=1,
  164. hidden_dim=hidden_dim,
  165. optim_cls=None,
  166. no_dht=True,
  167. ) as (server_endpoint, _):
  168. expert = hivemind.RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
  169. out = expert(xx, mask)
  170. out_rerun = expert(xx, mask)
  171. (grad,) = torch.autograd.grad(out.sum(), xx, retain_graph=True)
  172. (grad_rerun,) = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
  173. assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
  174. assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
  175. @pytest.mark.forked
  176. def test_compute_expert_scores():
  177. try:
  178. dht = hivemind.DHT(start=True)
  179. moe = hivemind.moe.RemoteMixtureOfExperts(
  180. dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
  181. )
  182. gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
  183. ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
  184. jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
  185. batch_experts = [
  186. [
  187. hivemind.RemoteExpert(
  188. uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337"
  189. )
  190. for expert_i in range(len(ii[batch_i]))
  191. ]
  192. for batch_i in range(len(ii))
  193. ] # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
  194. logits = moe.compute_expert_scores([gx, gy], batch_experts)
  195. torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
  196. assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
  197. for batch_i in range(len(ii)):
  198. for expert_i in range(len(ii[batch_i])):
  199. assert torch.allclose(
  200. logits[batch_i, expert_i], gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]
  201. ), "compute_expert_scores returned incorrect score"
  202. finally:
  203. dht.shutdown()
  204. @pytest.mark.forked
  205. def test_client_anomaly_detection():
  206. HID_DIM = 16
  207. experts = {}
  208. for i in range(4):
  209. expert = layers.name_to_block["ffn"](HID_DIM)
  210. experts[f"expert.{i}"] = hivemind.ExpertBackend(
  211. name=f"expert.{i}",
  212. expert=expert,
  213. optimizer=torch.optim.Adam(expert.parameters()),
  214. args_schema=(hivemind.BatchTensorDescriptor(HID_DIM),),
  215. outputs_schema=hivemind.BatchTensorDescriptor(HID_DIM),
  216. max_batch_size=16,
  217. )
  218. experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
  219. dht = hivemind.DHT(start=True)
  220. server = hivemind.moe.Server(dht, experts, num_connection_handlers=1)
  221. server.start()
  222. try:
  223. server.ready.wait()
  224. dmoe = hivemind.RemoteMixtureOfExperts(
  225. in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
  226. )
  227. input = torch.randn(1, 16)
  228. input[0, 0] = float("nan")
  229. with pytest.raises(ValueError):
  230. dmoe(input)
  231. input[0, 0] = 0
  232. output = dmoe(input)
  233. inf_loss = float("inf") * output.sum()
  234. with pytest.raises(ValueError):
  235. inf_loss.backward()
  236. dmoe = hivemind.RemoteMixtureOfExperts(
  237. in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
  238. )
  239. output = dmoe(input)
  240. assert output.isfinite().all()
  241. finally:
  242. server.shutdown()