test_moe.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import grpc
  2. import numpy as np
  3. import pytest
  4. import torch
  5. from hivemind.dht import DHT
  6. from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
  7. from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
  8. from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
  9. from hivemind.moe.server.layers import name_to_block
  10. from hivemind.utils.tensor_descr import BatchTensorDescriptor
  11. @pytest.mark.forked
  12. def test_moe():
  13. all_expert_uids = [
  14. f"ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  15. ]
  16. with background_server(
  17. expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
  18. ) as (server_endpoint, dht_maddrs):
  19. dht = DHT(start=True, initial_peers=dht_maddrs)
  20. dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
  21. for i in range(3):
  22. out = dmoe(torch.randn(10, 16))
  23. out.sum().backward()
  24. @pytest.mark.forked
  25. def test_no_experts():
  26. all_expert_uids = [
  27. f"expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  28. ]
  29. with background_server(
  30. expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
  31. ) as (server_endpoint, dht_maddrs):
  32. dht = DHT(start=True, initial_peers=dht_maddrs)
  33. dmoe = RemoteSwitchMixtureOfExperts(
  34. in_features=16,
  35. grid_size=(4, 4, 4),
  36. dht=dht,
  37. uid_prefix="expert.",
  38. forward_timeout=0.1,
  39. backward_timeout=0.1,
  40. allow_zero_outputs=True,
  41. )
  42. for i in range(3):
  43. out, balancing_loss = dmoe(torch.randn(10, 16))
  44. out.sum().backward()
  45. @pytest.mark.forked
  46. def test_call_many(hidden_dim=16):
  47. k_min = 1
  48. timeout_after_k_min = None
  49. backward_k_min = 1
  50. forward_timeout = None
  51. backward_timeout = None
  52. detect_anomalies = False
  53. allow_zero_outputs = False
  54. atol = 1e-5
  55. with background_server(
  56. num_experts=5,
  57. device="cpu",
  58. expert_cls="ffn",
  59. num_handlers=1,
  60. hidden_dim=hidden_dim,
  61. optim_cls=None,
  62. no_dht=True,
  63. ) as (server_endpoint, _):
  64. inputs = torch.randn(4, hidden_dim, requires_grad=True)
  65. inputs_clone = inputs.clone().detach().requires_grad_(True)
  66. e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
  67. e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
  68. mask, expert_outputs = _RemoteCallMany.apply(
  69. DUMMY,
  70. [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
  71. k_min,
  72. backward_k_min,
  73. timeout_after_k_min,
  74. forward_timeout,
  75. backward_timeout,
  76. detect_anomalies,
  77. allow_zero_outputs,
  78. e1.info,
  79. inputs,
  80. )
  81. assert mask.shape == (4, 3)
  82. assert expert_outputs.shape == (4, 3, hidden_dim)
  83. assert np.all(
  84. mask.data.numpy()
  85. == np.array([[True, True, True], [True, True, False], [True, False, True], [False, False, False]])
  86. ), f"Incorrect mask, {mask}"
  87. reference_outputs = torch.zeros_like(expert_outputs)
  88. reference_outputs[0, 0] = e0(inputs_clone[0:1])
  89. reference_outputs[0, 1] = e1(inputs_clone[0:1])
  90. reference_outputs[0, 2] = e2(inputs_clone[0:1])
  91. reference_outputs[1, 0] = e2(inputs_clone[1:2])
  92. reference_outputs[1, 1] = e4(inputs_clone[1:2])
  93. reference_outputs[2, 0] = e1(inputs_clone[2:3])
  94. reference_outputs[2, 2] = e3(inputs_clone[2:3])
  95. assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0)
  96. proj = torch.randn(4, hidden_dim)
  97. loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  98. loss.backward()
  99. our_grad = inputs.grad.data.cpu().clone()
  100. reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  101. reference_loss.backward()
  102. reference_grad = inputs_clone.grad.data.cpu().clone()
  103. assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
  104. @pytest.mark.forked
  105. def test_remote_module_call(hidden_dim=16):
  106. with background_server(
  107. num_experts=1,
  108. device="cpu",
  109. expert_cls="ffn",
  110. num_handlers=1,
  111. hidden_dim=hidden_dim,
  112. optim_cls=None,
  113. no_dht=True,
  114. ) as (server_endpoint, _):
  115. real_expert = RemoteExpert("expert.0", server_endpoint)
  116. fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
  117. out1 = real_expert(torch.randn(1, hidden_dim))
  118. assert out1.shape == (1, hidden_dim)
  119. dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
  120. out3 = real_expert(dummy_x)
  121. assert out3.shape == (3, hidden_dim)
  122. out3_again = real_expert(dummy_x[1:])
  123. assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
  124. out3_again.norm().backward()
  125. assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
  126. with pytest.raises(grpc.RpcError):
  127. real_expert(torch.randn(3, 11))
  128. with pytest.raises(grpc.RpcError):
  129. fake_expert(dummy_x)
  130. @pytest.mark.forked
  131. def test_beam_search_correctness():
  132. all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
  133. dht = DHT(start=True)
  134. assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
  135. dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
  136. for i in range(25):
  137. input = torch.randn(32)
  138. grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
  139. chosen_experts = dmoe.beam_search.find_best_experts(
  140. [tensor.detach().numpy() for tensor in grid_scores], beam_size=dmoe.k_best
  141. )
  142. chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores], [chosen_experts])[
  143. 0
  144. ]
  145. our_best_scores = list(chosen_scores.cpu().detach().numpy())
  146. # reference: independently find :beam_size: best experts with exhaustive search
  147. all_scores = dmoe.compute_expert_scores(
  148. [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
  149. [[RemoteExpert(uid, "") for uid in all_expert_uids]],
  150. )[0]
  151. true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
  152. assert np.allclose(true_best_scores, our_best_scores)
  153. @pytest.mark.forked
  154. def test_determinism(hidden_dim=16):
  155. atol = 1e-5
  156. xx = torch.randn(32, hidden_dim, requires_grad=True)
  157. mask = torch.randint(0, 1, (32, hidden_dim))
  158. with background_server(
  159. num_experts=1,
  160. device="cpu",
  161. expert_cls="det_dropout",
  162. num_handlers=1,
  163. hidden_dim=hidden_dim,
  164. optim_cls=None,
  165. no_dht=True,
  166. ) as (server_endpoint, _):
  167. expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
  168. out = expert(xx, mask)
  169. out_rerun = expert(xx, mask)
  170. (grad,) = torch.autograd.grad(out.sum(), xx, retain_graph=True)
  171. (grad_rerun,) = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
  172. assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
  173. assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
  174. @pytest.mark.forked
  175. def test_compute_expert_scores():
  176. try:
  177. dht = DHT(start=True)
  178. moe = RemoteMixtureOfExperts(
  179. dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
  180. )
  181. gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
  182. ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
  183. jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
  184. batch_experts = [
  185. [
  186. RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
  187. for expert_i in range(len(ii[batch_i]))
  188. ]
  189. for batch_i in range(len(ii))
  190. ] # note: these experts do not exist on server, we use them only to test compute_expert_scores
  191. logits = moe.compute_expert_scores([gx, gy], batch_experts)
  192. torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
  193. assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
  194. for batch_i in range(len(ii)):
  195. for expert_i in range(len(ii[batch_i])):
  196. assert torch.allclose(
  197. logits[batch_i, expert_i], gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]
  198. ), "compute_expert_scores returned incorrect score"
  199. finally:
  200. dht.shutdown()
  201. @pytest.mark.forked
  202. def test_client_anomaly_detection():
  203. HID_DIM = 16
  204. experts = {}
  205. for i in range(4):
  206. expert = name_to_block["ffn"](HID_DIM)
  207. experts[f"expert.{i}"] = ExpertBackend(
  208. name=f"expert.{i}",
  209. expert=expert,
  210. optimizer=torch.optim.Adam(expert.parameters()),
  211. args_schema=(BatchTensorDescriptor(HID_DIM),),
  212. outputs_schema=BatchTensorDescriptor(HID_DIM),
  213. max_batch_size=16,
  214. )
  215. experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
  216. dht = DHT(start=True)
  217. server = Server(dht, experts, num_connection_handlers=1)
  218. server.start()
  219. try:
  220. server.ready.wait()
  221. dmoe = RemoteMixtureOfExperts(
  222. in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
  223. )
  224. input = torch.randn(1, 16)
  225. input[0, 0] = float("nan")
  226. with pytest.raises(ValueError):
  227. dmoe(input)
  228. input[0, 0] = 0
  229. output = dmoe(input)
  230. inf_loss = float("inf") * output.sum()
  231. with pytest.raises(ValueError):
  232. inf_loss.backward()
  233. dmoe = RemoteMixtureOfExperts(
  234. in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
  235. )
  236. output = dmoe(input)
  237. assert output.isfinite().all()
  238. finally:
  239. server.shutdown()