test_custom_experts.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import os
  2. import pytest
  3. import torch
  4. from hivemind import RemoteExpert
  5. from hivemind.dht import DHT
  6. from hivemind.moe.client.expert import RemoteExpertInfo, RemoteExpertWorker
  7. from hivemind.moe.server import background_server
  8. CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
  9. @pytest.mark.forked
  10. def test_custom_expert(hid_dim=16):
  11. with background_server(
  12. expert_cls="perceptron",
  13. num_experts=2,
  14. device="cpu",
  15. hidden_dim=hid_dim,
  16. num_handlers=2,
  17. custom_module_path=CUSTOM_EXPERTS_PATH,
  18. ) as server_peer_info:
  19. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  20. expert0, expert1 = RemoteExpertWorker.spawn_experts(
  21. [
  22. RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
  23. RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
  24. ],
  25. dht=dht,
  26. )
  27. for batch_size in (1, 4):
  28. batch = torch.randn(batch_size, hid_dim)
  29. output0 = expert0(batch)
  30. output1 = expert1(batch)
  31. loss = output0.sum()
  32. loss.backward()
  33. loss = output1.sum()
  34. loss.backward()
  35. @pytest.mark.forked
  36. def test_multihead_expert(hid_dim=16):
  37. with background_server(
  38. expert_cls="multihead",
  39. num_experts=2,
  40. device="cpu",
  41. hidden_dim=hid_dim,
  42. num_handlers=2,
  43. custom_module_path=CUSTOM_EXPERTS_PATH,
  44. ) as server_peer_info:
  45. dht = DHT(initial_peers=server_peer_info.addrs, start=True)
  46. expert0, expert1 = RemoteExpertWorker.spawn_experts(
  47. [
  48. RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
  49. RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
  50. ],
  51. dht=dht,
  52. )
  53. for batch_size in (1, 4):
  54. batch = (
  55. torch.randn(batch_size, hid_dim),
  56. torch.randn(batch_size, 2 * hid_dim),
  57. torch.randn(batch_size, 3 * hid_dim),
  58. )
  59. output0 = expert0(*batch)
  60. output1 = expert1(*batch)
  61. loss = output0.sum()
  62. loss.backward()
  63. loss = output1.sum()
  64. loss.backward()