test_custom_experts.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import os
  2. import pytest
  3. import torch
  4. from hivemind.dht import DHT
  5. from hivemind.moe.client.expert import create_remote_experts
  6. from hivemind.moe.expert_uid import ExpertInfo
  7. from hivemind.moe.server import background_server
  8. CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
  9. @pytest.mark.forked
  10. def test_custom_expert(hid_dim=16):
  11. with background_server(
  12. expert_cls="perceptron",
  13. num_experts=2,
  14. device="cpu",
  15. hidden_dim=hid_dim,
  16. num_handlers=2,
  17. custom_module_path=CUSTOM_EXPERTS_PATH,
  18. ) as server_peer_info:
  19. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  20. expert0, expert1 = create_remote_experts(
  21. [
  22. ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
  23. ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
  24. ],
  25. dht=dht,
  26. )
  27. for batch_size in (1, 4):
  28. batch = torch.randn(batch_size, hid_dim)
  29. output0 = expert0(batch)
  30. output1 = expert1(batch)
  31. loss = output0.sum()
  32. loss.backward()
  33. loss = output1.sum()
  34. loss.backward()
  35. @pytest.mark.forked
  36. def test_multihead_expert(hid_dim=16):
  37. with background_server(
  38. expert_cls="multihead",
  39. num_experts=2,
  40. device="cpu",
  41. hidden_dim=hid_dim,
  42. num_handlers=2,
  43. custom_module_path=CUSTOM_EXPERTS_PATH,
  44. ) as server_peer_info:
  45. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  46. expert0, expert1 = create_remote_experts(
  47. [
  48. ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
  49. ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
  50. ],
  51. dht=dht,
  52. )
  53. for batch_size in (1, 4):
  54. batch = (
  55. torch.randn(batch_size, hid_dim),
  56. torch.randn(batch_size, 2 * hid_dim),
  57. torch.randn(batch_size, 3 * hid_dim),
  58. )
  59. output0 = expert0(*batch)
  60. output1 = expert1(*batch)
  61. loss = output0.sum()
  62. loss.backward()
  63. loss = output1.sum()
  64. loss.backward()