test_moe.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import time
  2. import numpy as np
  3. import pytest
  4. import torch
  5. from hivemind.dht import DHT
  6. from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
  7. from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
  8. from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
  9. from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
  10. from hivemind.moe.server.layers import name_to_block
  11. from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
  12. from hivemind.utils.tensor_descr import BatchTensorDescriptor
  13. @pytest.mark.forked
  14. def test_moe():
  15. all_expert_uids = [
  16. f"ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  17. ]
  18. with background_server(
  19. expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
  20. ) as server_peer_info:
  21. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  22. dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
  23. for i in range(3):
  24. out = dmoe(torch.randn(10, 16))
  25. out.sum().backward()
  26. @pytest.mark.forked
  27. def test_no_experts():
  28. all_expert_uids = [
  29. f"expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  30. ]
  31. with background_server(
  32. expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
  33. ) as server_peer_info:
  34. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  35. dmoe = RemoteSwitchMixtureOfExperts(
  36. in_features=16,
  37. grid_size=(4, 4, 4),
  38. dht=dht,
  39. uid_prefix="expert.",
  40. forward_timeout=0.1,
  41. backward_timeout=0.1,
  42. allow_zero_outputs=True,
  43. )
  44. for i in range(3):
  45. out, balancing_loss = dmoe(torch.randn(10, 16))
  46. out.sum().backward()
  47. @pytest.mark.forked
  48. def test_call_many(hidden_dim=16):
  49. k_min = 1
  50. timeout_after_k_min = None
  51. backward_k_min = 1
  52. forward_timeout = None
  53. backward_timeout = None
  54. detect_anomalies = False
  55. allow_zero_outputs = False
  56. atol = 1e-5
  57. with background_server(
  58. num_experts=5,
  59. device="cpu",
  60. expert_cls="ffn",
  61. num_handlers=1,
  62. hidden_dim=hidden_dim,
  63. optim_cls=None,
  64. ) as server_peer_info:
  65. inputs = torch.randn(4, hidden_dim, requires_grad=True)
  66. inputs_clone = inputs.clone().detach().requires_grad_(True)
  67. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  68. e0, e1, e2, e3, e4 = RemoteExpertWorker.spawn_experts(
  69. [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
  70. dht,
  71. )
  72. e5 = RemoteExpert(RemoteExpertInfo(f"thisshouldnotexist", server_peer_info), None)
  73. mask, expert_outputs = _RemoteCallMany.apply(
  74. DUMMY,
  75. [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
  76. k_min,
  77. backward_k_min,
  78. timeout_after_k_min,
  79. forward_timeout,
  80. backward_timeout,
  81. detect_anomalies,
  82. allow_zero_outputs,
  83. e1.info,
  84. inputs,
  85. )
  86. assert mask.shape == (4, 3)
  87. assert expert_outputs.shape == (4, 3, hidden_dim)
  88. assert np.all(
  89. mask.data.numpy()
  90. == np.array([[True, True, True], [True, True, False], [True, False, True], [False, False, False]])
  91. ), f"Incorrect mask, {mask}"
  92. reference_outputs = torch.zeros_like(expert_outputs)
  93. reference_outputs[0, 0] = e0(inputs_clone[0:1])
  94. reference_outputs[0, 1] = e1(inputs_clone[0:1])
  95. reference_outputs[0, 2] = e2(inputs_clone[0:1])
  96. reference_outputs[1, 0] = e2(inputs_clone[1:2])
  97. reference_outputs[1, 1] = e4(inputs_clone[1:2])
  98. reference_outputs[2, 0] = e1(inputs_clone[2:3])
  99. reference_outputs[2, 2] = e3(inputs_clone[2:3])
  100. assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0)
  101. proj = torch.randn(4, hidden_dim)
  102. loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  103. loss.backward()
  104. our_grad = inputs.grad.data.cpu().clone()
  105. reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  106. reference_loss.backward()
  107. reference_grad = inputs_clone.grad.data.cpu().clone()
  108. assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
  109. @pytest.mark.forked
  110. def test_remote_module_call(hidden_dim=16):
  111. with background_server(
  112. num_experts=1,
  113. device="cpu",
  114. expert_cls="ffn",
  115. num_handlers=1,
  116. hidden_dim=hidden_dim,
  117. optim_cls=None,
  118. ) as server_peer_info:
  119. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  120. real_expert, fake_expert = RemoteExpertWorker.spawn_experts(
  121. [
  122. RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
  123. RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
  124. ],
  125. dht=dht,
  126. )
  127. out1 = real_expert(torch.randn(1, hidden_dim))
  128. assert out1.shape == (1, hidden_dim)
  129. dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
  130. out3 = real_expert(dummy_x)
  131. assert out3.shape == (3, hidden_dim)
  132. out3_again = real_expert(dummy_x[1:])
  133. assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
  134. out3_again.norm().backward()
  135. assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
  136. with pytest.raises(P2PDaemonError):
  137. real_expert(torch.randn(3, 11))
  138. with pytest.raises(P2PDaemonError):
  139. fake_expert(dummy_x)
  140. @pytest.mark.forked
  141. def test_beam_search_correctness():
  142. all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
  143. dht = DHT(start=True)
  144. assert all(declare_experts(dht, all_expert_uids, dht.peer_id))
  145. dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
  146. for _ in range(25):
  147. input = torch.randn(32)
  148. grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
  149. chosen_experts = dmoe.beam_search.find_best_experts(
  150. [tensor.detach().numpy() for tensor in grid_scores], beam_size=dmoe.k_best
  151. )
  152. chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores], [chosen_experts])[
  153. 0
  154. ]
  155. our_best_scores = list(chosen_scores.cpu().detach().numpy())
  156. # reference: independently find :beam_size: best experts with exhaustive search
  157. all_scores = dmoe.compute_expert_scores(
  158. [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
  159. [[RemoteExpert(RemoteExpertInfo(uid, None), None) for uid in all_expert_uids]],
  160. )[0]
  161. true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
  162. assert np.allclose(true_best_scores, our_best_scores)
  163. @pytest.mark.forked
  164. def test_determinism(hidden_dim=16):
  165. atol = 1e-5
  166. xx = torch.randn(32, hidden_dim, requires_grad=True)
  167. mask = torch.randint(0, 1, (32, hidden_dim))
  168. with background_server(
  169. num_experts=1,
  170. device="cpu",
  171. expert_cls="det_dropout",
  172. num_handlers=1,
  173. hidden_dim=hidden_dim,
  174. optim_cls=None,
  175. ) as server_peer_info:
  176. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  177. expert = RemoteExpertWorker.spawn_experts(
  178. [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
  179. dht=dht,
  180. )[0]
  181. out = expert(xx, mask)
  182. out_rerun = expert(xx, mask)
  183. (grad,) = torch.autograd.grad(out.sum(), xx, retain_graph=True)
  184. (grad_rerun,) = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
  185. assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
  186. assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
  187. @pytest.mark.forked
  188. def test_compute_expert_scores():
  189. try:
  190. dht = DHT(start=True)
  191. moe = RemoteMixtureOfExperts(
  192. dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
  193. )
  194. gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
  195. ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
  196. jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
  197. batch_experts = [
  198. [
  199. RemoteExpert(RemoteExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
  200. for expert_i in range(len(ii[batch_i]))
  201. ]
  202. for batch_i in range(len(ii))
  203. ] # note: these experts do not exist on server, we use them only to test compute_expert_scores
  204. logits = moe.compute_expert_scores([gx, gy], batch_experts)
  205. torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
  206. assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
  207. for batch_i in range(len(ii)):
  208. for expert_i in range(len(ii[batch_i])):
  209. assert torch.allclose(
  210. logits[batch_i, expert_i], gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]
  211. ), "compute_expert_scores returned incorrect score"
  212. finally:
  213. dht.shutdown()
  214. @pytest.mark.forked
  215. def test_client_anomaly_detection():
  216. HID_DIM = 16
  217. experts = {}
  218. for i in range(4):
  219. expert = name_to_block["ffn"](HID_DIM)
  220. experts[f"expert.{i}"] = ExpertBackend(
  221. name=f"expert.{i}",
  222. expert=expert,
  223. optimizer=torch.optim.Adam(expert.parameters()),
  224. args_schema=(BatchTensorDescriptor(HID_DIM),),
  225. outputs_schema=BatchTensorDescriptor(HID_DIM),
  226. max_batch_size=16,
  227. )
  228. experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan")
  229. dht = DHT(start=True)
  230. server = Server(dht, experts, num_connection_handlers=1)
  231. server.start()
  232. try:
  233. server.ready.wait()
  234. dht_experts = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
  235. dmoe = RemoteMixtureOfExperts(
  236. in_features=16, grid_size=(3,), dht=dht_experts, k_best=3, uid_prefix="expert.", detect_anomalies=True
  237. )
  238. input = torch.randn(1, 16)
  239. input[0, 0] = float("nan")
  240. with pytest.raises(ValueError):
  241. dmoe(input)
  242. input[0, 0] = 0
  243. output = dmoe(input)
  244. inf_loss = float("inf") * output.sum()
  245. with pytest.raises(ValueError):
  246. inf_loss.backward()
  247. dmoe = RemoteMixtureOfExperts(
  248. in_features=16, grid_size=(4,), dht=dht_experts, k_best=4, uid_prefix="expert.", detect_anomalies=True
  249. )
  250. output = dmoe(input)
  251. assert output.isfinite().all()
  252. finally:
  253. server.shutdown()