test_custom_experts.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. import pytest
  3. import torch
  4. from hivemind.dht import DHT
  5. from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
  6. from hivemind.moe.server import background_server
  7. CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
  8. @pytest.mark.forked
  9. def test_custom_expert(hid_dim=16):
  10. with background_server(
  11. expert_cls="perceptron",
  12. num_experts=2,
  13. device="cpu",
  14. hidden_dim=hid_dim,
  15. num_handlers=2,
  16. custom_module_path=CUSTOM_EXPERTS_PATH,
  17. ) as server_peer_info:
  18. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  19. expert0, expert1 = create_remote_experts(
  20. [
  21. RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
  22. RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
  23. ],
  24. dht=dht,
  25. )
  26. for batch_size in (1, 4):
  27. batch = torch.randn(batch_size, hid_dim)
  28. output0 = expert0(batch)
  29. output1 = expert1(batch)
  30. loss = output0.sum()
  31. loss.backward()
  32. loss = output1.sum()
  33. loss.backward()
  34. @pytest.mark.forked
  35. def test_multihead_expert(hid_dim=16):
  36. with background_server(
  37. expert_cls="multihead",
  38. num_experts=2,
  39. device="cpu",
  40. hidden_dim=hid_dim,
  41. num_handlers=2,
  42. custom_module_path=CUSTOM_EXPERTS_PATH,
  43. ) as server_peer_info:
  44. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  45. expert0, expert1 = create_remote_experts(
  46. [
  47. RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
  48. RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
  49. ],
  50. dht=dht,
  51. )
  52. for batch_size in (1, 4):
  53. batch = (
  54. torch.randn(batch_size, hid_dim),
  55. torch.randn(batch_size, 2 * hid_dim),
  56. torch.randn(batch_size, 3 * hid_dim),
  57. )
  58. output0 = expert0(*batch)
  59. output1 = expert1(*batch)
  60. loss = output0.sum()
  61. loss.backward()
  62. loss = output1.sum()
  63. loss.backward()