test_moe.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. import numpy as np
  2. import pytest
  3. import torch
  4. from hivemind.dht import DHT
  5. from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
  6. from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
  7. from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
  8. from hivemind.moe.expert_uid import ExpertInfo
  9. from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
  10. from hivemind.moe.server.layers import name_to_block
  11. from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
  12. from hivemind.utils import BatchTensorDescriptor, get_dht_time
  13. @pytest.mark.forked
  14. def test_moe():
  15. all_expert_uids = [
  16. f"ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  17. ]
  18. with background_server(
  19. expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
  20. ) as server_peer_info:
  21. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  22. dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
  23. for i in range(3):
  24. out = dmoe(torch.randn(10, 16))
  25. out.sum().backward()
  26. @pytest.mark.forked
  27. def test_no_experts():
  28. all_expert_uids = [
  29. f"expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  30. ]
  31. with background_server(
  32. expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
  33. ) as server_peer_info:
  34. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  35. dmoe = RemoteSwitchMixtureOfExperts(
  36. in_features=16,
  37. grid_size=(4, 4, 4),
  38. dht=dht,
  39. uid_prefix="expert.",
  40. forward_timeout=0.1,
  41. backward_timeout=0.1,
  42. allow_zero_outputs=True,
  43. )
  44. for i in range(3):
  45. out, balancing_loss = dmoe(torch.randn(10, 16))
  46. out.sum().backward()
  47. @pytest.mark.forked
  48. def test_call_many(hidden_dim=16):
  49. k_min = 1
  50. timeout_after_k_min = None
  51. backward_k_min = 1
  52. forward_timeout = None
  53. backward_timeout = None
  54. detect_anomalies = False
  55. allow_zero_outputs = False
  56. atol = 1e-5
  57. with background_server(
  58. num_experts=5,
  59. device="cpu",
  60. expert_cls="ffn",
  61. num_handlers=1,
  62. hidden_dim=hidden_dim,
  63. optim_cls=None,
  64. ) as server_peer_info:
  65. inputs = torch.randn(4, hidden_dim, requires_grad=True)
  66. inputs_clone = inputs.clone().detach().requires_grad_(True)
  67. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  68. e0, e1, e2, e3, e4 = create_remote_experts(
  69. [ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_info.peer_id) for i in range(5)],
  70. dht,
  71. )
  72. e5 = RemoteExpert(ExpertInfo(f"thisshouldnotexist", server_peer_info), None)
  73. mask, expert_outputs = _RemoteCallMany.apply(
  74. DUMMY,
  75. [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
  76. k_min,
  77. backward_k_min,
  78. timeout_after_k_min,
  79. forward_timeout,
  80. backward_timeout,
  81. detect_anomalies,
  82. allow_zero_outputs,
  83. e1.info,
  84. inputs,
  85. )
  86. assert mask.shape == (4, 3)
  87. assert expert_outputs.shape == (4, 3, hidden_dim)
  88. assert np.all(
  89. mask.data.numpy()
  90. == np.array([[True, True, True], [True, True, False], [True, False, True], [False, False, False]])
  91. ), f"Incorrect mask, {mask}"
  92. reference_outputs = torch.zeros_like(expert_outputs)
  93. reference_outputs[0, 0] = e0(inputs_clone[0:1])
  94. reference_outputs[0, 1] = e1(inputs_clone[0:1])
  95. reference_outputs[0, 2] = e2(inputs_clone[0:1])
  96. reference_outputs[1, 0] = e2(inputs_clone[1:2])
  97. reference_outputs[1, 1] = e4(inputs_clone[1:2])
  98. reference_outputs[2, 0] = e1(inputs_clone[2:3])
  99. reference_outputs[2, 2] = e3(inputs_clone[2:3])
  100. assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0)
  101. proj = torch.randn(4, hidden_dim)
  102. loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  103. loss.backward()
  104. our_grad = inputs.grad.data.cpu().clone()
  105. reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  106. reference_loss.backward()
  107. reference_grad = inputs_clone.grad.data.cpu().clone()
  108. assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
  109. @pytest.mark.forked
  110. def test_remote_module_call(hidden_dim=16):
  111. with background_server(
  112. num_experts=1,
  113. device="cpu",
  114. expert_cls="ffn",
  115. num_handlers=1,
  116. hidden_dim=hidden_dim,
  117. optim_cls=None,
  118. ) as server_peer_info:
  119. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  120. real_expert, fake_expert = create_remote_experts(
  121. [
  122. ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
  123. ExpertInfo(uid="oiasfjiasjf", peer_id=server_peer_info.peer_id),
  124. ],
  125. dht=dht,
  126. )
  127. out1 = real_expert(torch.randn(1, hidden_dim))
  128. assert out1.shape == (1, hidden_dim)
  129. dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
  130. out3 = real_expert(dummy_x)
  131. assert out3.shape == (3, hidden_dim)
  132. out3_again = real_expert(dummy_x[1:])
  133. assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
  134. out3_again.norm().backward()
  135. assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
  136. try:
  137. real_expert(torch.randn(3, 11))
  138. except P2PHandlerError as e:
  139. assert str(11) in repr(e), "Exception must relay the remote server error (i.e. incorrect dimensions)"
  140. with pytest.raises(P2PHandlerError):
  141. fake_expert(dummy_x)
  142. # check that the server is still alive after processing a malformed request
  143. out3_yet_again = real_expert(dummy_x[1:])
  144. assert torch.allclose(out3_yet_again, out3[1:], atol=1e-5, rtol=0)
  145. @pytest.mark.forked
  146. def test_beam_search_correctness():
  147. all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
  148. dht = DHT(start=True)
  149. assert all(declare_experts(dht, all_expert_uids, expiration_time=get_dht_time() + 30))
  150. dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
  151. for _ in range(25):
  152. input = torch.randn(32)
  153. grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
  154. chosen_experts = dmoe.beam_search.find_best_experts(
  155. [tensor.detach().numpy() for tensor in grid_scores], beam_size=dmoe.k_best
  156. )
  157. chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores], [chosen_experts])[
  158. 0
  159. ]
  160. our_best_scores = list(chosen_scores.cpu().detach().numpy())
  161. # reference: independently find :beam_size: best experts with exhaustive search
  162. all_scores = dmoe.compute_expert_scores(
  163. [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
  164. [[RemoteExpert(ExpertInfo(uid, None), None) for uid in all_expert_uids]],
  165. )[0]
  166. true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
  167. assert np.allclose(true_best_scores, our_best_scores)
  168. @pytest.mark.forked
  169. def test_determinism(hidden_dim=16):
  170. atol = 1e-5
  171. xx = torch.randn(32, hidden_dim, requires_grad=True)
  172. mask = torch.randint(0, 1, (32, hidden_dim))
  173. with background_server(
  174. num_experts=1,
  175. device="cpu",
  176. expert_cls="det_dropout",
  177. num_handlers=1,
  178. hidden_dim=hidden_dim,
  179. optim_cls=None,
  180. ) as server_peer_info:
  181. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  182. expert = create_remote_experts(
  183. [ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id)],
  184. dht=dht,
  185. )[0]
  186. out = expert(xx, mask)
  187. out_rerun = expert(xx, mask)
  188. (grad,) = torch.autograd.grad(out.sum(), xx, retain_graph=True)
  189. (grad_rerun,) = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
  190. assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
  191. assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
  192. @pytest.mark.forked
  193. def test_compute_expert_scores():
  194. try:
  195. dht = DHT(start=True)
  196. moe = RemoteMixtureOfExperts(
  197. dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
  198. )
  199. gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
  200. ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
  201. jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
  202. batch_experts = [
  203. [
  204. RemoteExpert(ExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
  205. for expert_i in range(len(ii[batch_i]))
  206. ]
  207. for batch_i in range(len(ii))
  208. ] # note: these experts do not exist on server, we use them only to test compute_expert_scores
  209. logits = moe.compute_expert_scores([gx, gy], batch_experts)
  210. torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
  211. assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
  212. for batch_i in range(len(ii)):
  213. for expert_i in range(len(ii[batch_i])):
  214. assert torch.allclose(
  215. logits[batch_i, expert_i], gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]
  216. ), "compute_expert_scores returned incorrect score"
  217. finally:
  218. dht.shutdown()
  219. @pytest.mark.forked
  220. def test_client_anomaly_detection():
  221. HID_DIM = 16
  222. experts = {}
  223. for i in range(4):
  224. expert = name_to_block["ffn"](HID_DIM)
  225. experts[f"expert.{i}"] = ModuleBackend(
  226. name=f"expert.{i}",
  227. module=expert,
  228. optimizer=torch.optim.Adam(expert.parameters()),
  229. args_schema=(BatchTensorDescriptor(HID_DIM),),
  230. outputs_schema=BatchTensorDescriptor(HID_DIM),
  231. max_batch_size=16,
  232. )
  233. experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan")
  234. dht = DHT(start=True)
  235. server = Server(dht, experts, num_connection_handlers=1)
  236. server.start()
  237. try:
  238. server.ready.wait()
  239. client_side_dht = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
  240. dmoe = RemoteMixtureOfExperts(
  241. in_features=16, grid_size=(3,), dht=client_side_dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
  242. )
  243. input = torch.randn(1, 16)
  244. input[0, 0] = float("nan")
  245. with pytest.raises(ValueError):
  246. dmoe(input)
  247. input[0, 0] = 0
  248. output = dmoe(input)
  249. inf_loss = float("inf") * output.sum()
  250. with pytest.raises(ValueError):
  251. inf_loss.backward()
  252. dmoe = RemoteMixtureOfExperts(
  253. in_features=16, grid_size=(4,), dht=client_side_dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
  254. )
  255. output = dmoe(input)
  256. assert output.isfinite().all()
  257. finally:
  258. server.shutdown()