test_moe.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import asyncio
  2. import grpc
  3. import numpy as np
  4. import pytest
  5. import torch
  6. import hivemind
  7. from hivemind.client.expert import DUMMY
  8. from hivemind import background_server
  9. def test_moe():
  10. all_expert_uids = [f'ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}'
  11. for _ in range(20)]
  12. with background_server(expert_uids=all_expert_uids, device='cpu', expert_cls='ffn',
  13. num_handlers=1, hidden_dim=16) as (server_endpoint, dht_endpoint):
  14. dht = hivemind.DHT(start=True, expiration=999, initial_peers=[dht_endpoint])
  15. # declare expert uids. Server *should* declare them by itself, but it takes time.
  16. assert all(dht.declare_experts(all_expert_uids, endpoint=server_endpoint))
  17. dmoe = hivemind.RemoteMixtureOfExperts(
  18. in_features=16, grid_size=(32, 32, 32), dht=dht, k_best=3, uid_prefix='ffn')
  19. for i in range(10):
  20. out = dmoe(torch.randn(10, 16))
  21. out.sum().backward()
  22. def test_call_many():
  23. k_min = 1
  24. timeout_after_k_min = None
  25. backward_k_min = 1
  26. forward_timeout = None
  27. backward_timeout = None
  28. rtol = 1e-3
  29. atol = 1e-6
  30. with background_server(num_experts=5, device='cpu', expert_cls='ffn', num_handlers=8, hidden_dim=64,
  31. Optimizer=None, no_dht=True) as (server_endpoint, dht_endpoint):
  32. inputs = torch.randn(4, 64, requires_grad=True)
  33. inputs_clone = inputs.clone().detach().requires_grad_(True)
  34. e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]
  35. e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
  36. mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
  37. DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []],
  38. k_min, backward_k_min, timeout_after_k_min, forward_timeout, backward_timeout,
  39. asyncio.new_event_loop(), inputs
  40. )
  41. assert mask.shape == (4, 3)
  42. assert expert_outputs.shape == (4, 3, 64)
  43. assert np.all(mask.data.numpy() == np.array([[True, True, True],
  44. [True, True, False],
  45. [True, False, True],
  46. [False, False, False]])), f"Incorrect mask, {mask}"
  47. reference_outputs = torch.zeros_like(expert_outputs)
  48. reference_outputs[0, 0] = e0(inputs_clone[0:1])
  49. reference_outputs[0, 1] = e1(inputs_clone[0:1])
  50. reference_outputs[0, 2] = e2(inputs_clone[0:1])
  51. reference_outputs[1, 0] = e2(inputs_clone[1:2])
  52. reference_outputs[1, 1] = e4(inputs_clone[1:2])
  53. reference_outputs[2, 0] = e1(inputs_clone[2:3])
  54. reference_outputs[2, 2] = e3(inputs_clone[2:3])
  55. assert torch.allclose(expert_outputs, reference_outputs, rtol, atol)
  56. proj = torch.randn(4, 64)
  57. loss = (expert_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  58. loss.backward()
  59. our_grad = inputs.grad.data.cpu().clone()
  60. reference_loss = (reference_outputs[(0, 1, 1, 2), (0, 2, 1, 0)] * proj).sum()
  61. reference_loss.backward()
  62. reference_grad = inputs_clone.grad.data.cpu().clone()
  63. assert torch.allclose(our_grad, reference_grad, rtol, atol)
  64. def test_remote_module_call():
  65. with background_server(num_experts=1, device='cpu', expert_cls='ffn', num_handlers=1, hidden_dim=1024,
  66. Optimizer=None, no_dht=True) as (server_endpoint, dht_endpoint):
  67. real_expert = hivemind.RemoteExpert('expert.0', server_endpoint)
  68. fake_expert = hivemind.RemoteExpert('oiasfjiasjf', server_endpoint)
  69. out1 = real_expert(torch.randn(1, 1024))
  70. assert out1.shape == (1, 1024)
  71. dummy_x = torch.randn(3, 1024, requires_grad=True)
  72. out3 = real_expert(dummy_x)
  73. assert out3.shape == (3, 1024)
  74. out3_again = real_expert(dummy_x[1:])
  75. assert torch.allclose(out3_again, out3[1:])
  76. out3_again.norm().backward()
  77. assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
  78. with pytest.raises(grpc.RpcError):
  79. real_expert(torch.randn(3, 11))
  80. with pytest.raises(grpc.RpcError):
  81. fake_expert(dummy_x)
  82. def test_moe_beam_search():
  83. all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
  84. dht = hivemind.DHT(start=True, expiration=999)
  85. assert all(dht.declare_experts(all_expert_uids, endpoint='fake-endpoint'))
  86. dmoe = hivemind.RemoteMixtureOfExperts(
  87. in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn')
  88. for i in range(25):
  89. input = torch.randn(32)
  90. grid_scores = dmoe.proj(input).split_with_sizes(dmoe.grid_size, dim=-1)
  91. chosen_experts = dmoe.loop.run_until_complete(dmoe.beam_search(grid_scores, k_best=dmoe.k_best))
  92. chosen_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
  93. [chosen_experts])[0]
  94. all_scores = dmoe.compute_expert_scores([dim_scores[None] for dim_scores in grid_scores],
  95. [[hivemind.RemoteExpert(uid, '') for uid in all_expert_uids]])[0]
  96. true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[:len(chosen_experts)]
  97. our_best_scores = list(chosen_scores.cpu().detach().numpy())
  98. assert np.allclose(true_best_scores, our_best_scores)
  99. def test_determinism():
  100. rtol = 0
  101. atol = 1e-6
  102. xx = torch.randn(32, 1024, requires_grad=True)
  103. mask = torch.randint(0, 1, (32, 1024))
  104. with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1,
  105. Optimizer=None, no_dht=True) as (server_endpoint, dht_endpoint):
  106. expert = hivemind.RemoteExpert(uid=f'expert.0', endpoint=server_endpoint)
  107. out = expert(xx, mask)
  108. out_rerun = expert(xx, mask)
  109. grad, = torch.autograd.grad(out.sum(), xx, retain_graph=True)
  110. grad_rerun, = torch.autograd.grad(out_rerun.sum(), xx, retain_graph=True)
  111. assert torch.allclose(out, out_rerun, rtol, atol), "Dropout layer outputs are non-deterministic."
  112. assert torch.allclose(grad, grad_rerun, rtol, atol), "Gradients are non-deterministic."
  113. def test_compute_expert_scores():
  114. try:
  115. dht = hivemind.DHT(start=True)
  116. moe = hivemind.client.moe.RemoteMixtureOfExperts(
  117. dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
  118. uid_prefix='expert')
  119. gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
  120. ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
  121. jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
  122. batch_experts = [
  123. [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}', endpoint="[::]:1337")
  124. for expert_i in range(len(ii[batch_i]))]
  125. for batch_i in range(len(ii))
  126. ] # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
  127. logits = moe.compute_expert_scores([gx, gy], batch_experts)
  128. torch.softmax(logits, dim=-1).norm(dim=-1).mean().backward()
  129. assert gx.grad.norm().item() > 0 and gy.grad.norm().item(), "compute_expert_scores didn't backprop"
  130. for batch_i in range(len(ii)):
  131. for expert_i in range(len(ii[batch_i])):
  132. assert torch.allclose(logits[batch_i, expert_i],
  133. gx[batch_i, ii[batch_i][expert_i]] + gy[batch_i, jj[batch_i][expert_i]]), \
  134. "compute_expert_scores returned incorrect score"
  135. finally:
  136. dht.shutdown()