test_moe.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. import asyncio
  2. import ctypes
  3. import multiprocessing as mp
  4. import threading
  5. import time
  6. import numpy as np
  7. import pytest
  8. import torch
  9. from hivemind.dht import DHT
  10. from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
  11. from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
  12. from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
  13. from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
  14. from hivemind.moe.expert_uid import ExpertInfo
  15. from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
  16. from hivemind.moe.server.layers import name_to_block
  17. from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
  18. from hivemind.utils import BatchTensorDescriptor, MPFuture, get_dht_time
  19. @pytest.mark.forked
  20. def test_moe():
  21. all_expert_uids = [
  22. f"ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  23. ]
  24. with background_server(
  25. expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
  26. ) as server_peer_info:
  27. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  28. dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
  29. for i in range(3):
  30. out = dmoe(torch.randn(10, 16))
  31. out.sum().backward()
  32. @pytest.mark.forked
  33. def test_no_experts():
  34. all_expert_uids = [
  35. f"expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
  36. ]
  37. with background_server(
  38. expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
  39. ) as server_peer_info:
  40. dht = DHT(start=True, initial_peers=server_peer_info.addrs)
  41. dmoe = RemoteSwitchMixtureOfExperts(
  42. in_features=16,
  43. grid_size=(4, 4, 4),
  44. dht=dht,
  45. uid_prefix="expert.",
  46. forward_timeout=0.1,
  47. backward_timeout=0.1,
  48. allow_zero_outputs=True,
  49. )
  50. for i in range(3):
  51. out, balancing_loss = dmoe(torch.randn(10, 16))
  52. out.sum().backward()
  53. @pytest.mark.forked
  54. def test_call_many(hidden_dim=16):
  55. k_min = 1
  56. timeout_after_k_min = None
  57. backward_k_min = 1
  58. forward_timeout = None
  59. backward_timeout = None
  60. detect_anomalies = False
  61. allow_zero_outputs = False
  62. atol = 1e-5
  63. with background_server(
  64. num_experts=5,
  65. device="cpu",
  66. expert_cls="ffn",
  67. num_handlers=1,
  68. hidden_dim=hidden_dim,
  69. optim_cls=None,
  70. ) as server_peer_info:
  71. inputs = torch.randn(4, hidden_dim, requires_grad=True)
  72. inputs_clone = inputs.clone().detach().requires_grad_(True)
  73. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  74. e0, e1, e2, e3, e4 = create_remote_experts(
  75. [ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_info.peer_id) for i in range(5)],
  76. dht,
  77. )
  78. e5 = RemoteExpert(ExpertInfo(f"thisshouldnotexist", server_peer_info), None)
  79. mask, expert_outputs = _RemoteCallMany.apply(
  80. DUMMY,
  81. [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
  82. k_min,
  83. backward_k_min,
  84. timeout_after_k_min,
  85. forward_timeout,
  86. backward_timeout,
  87. detect_anomalies,
  88. allow_zero_outputs,
  89. e1.info,
  90. inputs,
  91. )
  92. assert mask.shape == (4, 3)
  93. assert expert_outputs.shape == (4, 3, hidden_dim)
  94. assert np.all(
  95. mask.data.numpy()
  96. == np.array([[True, True, True], [True, True, False], [True, False, True], [False, False, False]])
  97. ), f"Incorrect mask, {mask}"
  98. reference_outputs = torch.zeros_like(expert_outputs)
  99. reference_outputs[0, 0] = e0(inputs_clone[0:1])
  100. reference_outputs[0, 1] = e1(inputs_clone[0:1])
  101. reference_outputs[0, 2] = e2(inputs_clone[0:1])
  102. reference_outputs[1, 0] = e2(inputs_clone[1:2])
  103. reference_outputs[1, 1] = e4(inputs_clone[1:2])
  104. reference_outputs[2, 0] = e1(inputs_clone[2:3])
  105. reference_outputs[2, 2] = e3(inputs_clone[2:3])
  106. assert torch.allclose(expert_outputs, reference_outputs, atol=atol, rtol=0)
  107. proj = torch.randn(4, hidden_dim)
  108. loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  109. loss.backward()
  110. our_grad = inputs.grad.data.cpu().clone()
  111. reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  112. reference_loss.backward()
  113. reference_grad = inputs_clone.grad.data.cpu().clone()
  114. assert torch.allclose(our_grad, reference_grad, atol=atol, rtol=0)
  115. @pytest.mark.forked
  116. def test_remote_module_call(hidden_dim=16):
  117. with background_server(
  118. num_experts=1,
  119. device="cpu",
  120. expert_cls="ffn",
  121. num_handlers=1,
  122. hidden_dim=hidden_dim,
  123. optim_cls=None,
  124. ) as server_peer_info:
  125. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  126. real_expert, fake_expert = create_remote_experts(
  127. [
  128. ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
  129. ExpertInfo(uid="oiasfjiasjf", peer_id=server_peer_info.peer_id),
  130. ],
  131. dht=dht,
  132. )
  133. out1 = real_expert(torch.randn(1, hidden_dim))
  134. assert out1.shape == (1, hidden_dim)
  135. dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
  136. out3 = real_expert(dummy_x)
  137. assert out3.shape == (3, hidden_dim)
  138. out3_again = real_expert(dummy_x[1:])
  139. assert torch.allclose(out3_again, out3[1:], atol=1e-5, rtol=0)
  140. out3_again.norm().backward()
  141. assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
  142. try:
  143. real_expert(torch.randn(3, 11))
  144. except P2PHandlerError as e:
  145. assert str(11) in repr(e), "Exception must relay the remote server error (i.e. incorrect dimensions)"
  146. with pytest.raises(P2PHandlerError):
  147. fake_expert(dummy_x)
  148. # check that the server is still alive after processing a malformed request
  149. out3_yet_again = real_expert(dummy_x[1:])
  150. assert torch.allclose(out3_yet_again, out3[1:], atol=1e-5, rtol=0)
  151. @pytest.mark.forked
  152. def test_beam_search_correctness():
  153. all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
  154. dht = DHT(start=True)
  155. assert all(declare_experts(dht, all_expert_uids, expiration_time=get_dht_time() + 30))
  156. dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
  157. for _ in range(25):
  158. input = torch.randn(32)
  159. grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
  160. chosen_experts = dmoe.beam_search.find_best_experts(
  161. [tensor.detach().numpy() for tensor in grid_scores], beam_size=dmoe.k_best
  162. )
  163. chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores], [chosen_experts])[
  164. 0
  165. ]
  166. our_best_scores = list(chosen_scores.cpu().detach().numpy())
  167. # reference: independently find :beam_size: best experts with exhaustive search
  168. all_scores = dmoe.compute_expert_scores(
  169. [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
  170. [[RemoteExpert(ExpertInfo(uid, None), None) for uid in all_expert_uids]],
  171. )[0]
  172. true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
  173. assert np.allclose(true_best_scores, our_best_scores)
  174. @pytest.mark.forked
  175. def test_determinism(hidden_dim=16):
  176. atol = 1e-5
  177. xx = torch.randn(32, hidden_dim, requires_grad=True)
  178. mask = torch.randint(0, 1, (32, hidden_dim))
  179. with background_server(
  180. num_experts=1,
  181. device="cpu",
  182. expert_cls="det_dropout",
  183. num_handlers=1,
  184. hidden_dim=hidden_dim,
  185. optim_cls=None,
  186. ) as server_peer_info:
  187. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  188. expert = create_remote_experts(
  189. [ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id)],
  190. dht=dht,
  191. )[0]
  192. out = expert(xx, mask)
  193. out_rerun = expert(xx, mask)
  194. (grad,) = torch.autograd.grad(out.sum(), xx, retain_graph=True)
  195. (grad_rerun,) = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
  196. assert torch.allclose(out, out_rerun, atol=atol, rtol=0), "Dropout layer outputs are non-deterministic."
  197. assert torch.allclose(grad, grad_rerun, atol=atol, rtol=0), "Gradients are non-deterministic."
  198. @pytest.mark.forked
  199. def test_compute_expert_scores():
  200. try:
  201. dht = DHT(start=True)
  202. moe = RemoteMixtureOfExperts(
  203. dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1, uid_prefix="expert."
  204. )
  205. gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
  206. ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
  207. jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
  208. batch_experts = [
  209. [
  210. RemoteExpert(ExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
  211. for expert_i in range(len(ii[batch_i]))
  212. ]
  213. for batch_i in range(len(ii))
  214. ] # note: these experts do not exist on server, we use them only to test compute_expert_scores
  215. logits = moe.compute_expert_scores([gx, gy], batch_experts)
  216. torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
  217. assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
  218. for batch_i in range(len(ii)):
  219. for expert_i in range(len(ii[batch_i])):
  220. assert torch.allclose(
  221. logits[batch_i, expert_i], gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]
  222. ), "compute_expert_scores returned incorrect score"
  223. finally:
  224. dht.shutdown()
  225. @pytest.mark.forked
  226. def test_client_anomaly_detection():
  227. HID_DIM = 16
  228. experts = {}
  229. for i in range(4):
  230. expert = name_to_block["ffn"](HID_DIM)
  231. experts[f"expert.{i}"] = ModuleBackend(
  232. name=f"expert.{i}",
  233. module=expert,
  234. optimizer=torch.optim.Adam(expert.parameters()),
  235. args_schema=(BatchTensorDescriptor(HID_DIM),),
  236. outputs_schema=BatchTensorDescriptor(HID_DIM),
  237. max_batch_size=16,
  238. )
  239. experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan")
  240. dht = DHT(start=True)
  241. server = Server(dht, experts, num_connection_handlers=1)
  242. server.start()
  243. try:
  244. server.ready.wait()
  245. client_side_dht = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
  246. dmoe = RemoteMixtureOfExperts(
  247. in_features=16, grid_size=(3,), dht=client_side_dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
  248. )
  249. input = torch.randn(1, 16)
  250. input[0, 0] = float("nan")
  251. with pytest.raises(ValueError):
  252. dmoe(input)
  253. input[0, 0] = 0
  254. output = dmoe(input)
  255. inf_loss = float("inf") * output.sum()
  256. with pytest.raises(ValueError):
  257. inf_loss.backward()
  258. dmoe = RemoteMixtureOfExperts(
  259. in_features=16, grid_size=(4,), dht=client_side_dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
  260. )
  261. output = dmoe(input)
  262. assert output.isfinite().all()
  263. finally:
  264. server.shutdown()
  265. def _measure_coro_running_time(n_coros, elapsed_fut, counter):
  266. async def coro():
  267. await asyncio.sleep(0.1)
  268. counter.value += 1
  269. try:
  270. start_time = time.perf_counter()
  271. futures = [
  272. RemoteExpertWorker.run_coroutine(coro(), return_future=True) for _ in range(n_coros - 1)
  273. ] # Non-blocking calls
  274. RemoteExpertWorker.run_coroutine(coro(), return_future=False) # A blocking call
  275. for fut in futures:
  276. fut.result()
  277. elapsed_fut.set_result(time.perf_counter() - start_time)
  278. except Exception as e:
  279. elapsed_fut.set_exception(e)
  280. @pytest.mark.forked
  281. def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10):
  282. processes = []
  283. counter = mp.Value(ctypes.c_int64)
  284. for i in range(n_processes):
  285. elapsed_fut = MPFuture()
  286. factory = threading.Thread if i % 2 == 0 else mp.Process # Test both threads and processes
  287. proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter))
  288. proc.start()
  289. processes.append((proc, elapsed_fut))
  290. for proc, elapsed_fut in processes:
  291. # Ensure that the coroutines were run concurrently, not sequentially
  292. assert elapsed_fut.result() < 0.2
  293. proc.join()
  294. assert counter.value == n_processes * n_coros # Ensure all couroutines have finished