test_custom_experts.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import os
  2. import pytest
  3. import torch
  4. from hivemind import RemoteExpert
  5. from hivemind.moe.server import background_server
  6. CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
  7. @pytest.mark.forked
  8. def test_custom_expert(hid_dim=16):
  9. with background_server(
  10. expert_cls="perceptron",
  11. num_experts=2,
  12. device="cpu",
  13. hidden_dim=hid_dim,
  14. num_handlers=2,
  15. no_dht=True,
  16. custom_module_path=CUSTOM_EXPERTS_PATH,
  17. ) as (server_endpoint, _):
  18. expert0 = RemoteExpert("expert.0", server_endpoint)
  19. expert1 = RemoteExpert("expert.1", server_endpoint)
  20. for batch_size in (1, 4):
  21. batch = torch.randn(batch_size, hid_dim)
  22. output0 = expert0(batch)
  23. output1 = expert1(batch)
  24. loss = output0.sum()
  25. loss.backward()
  26. loss = output1.sum()
  27. loss.backward()
  28. @pytest.mark.forked
  29. def test_multihead_expert(hid_dim=16):
  30. with background_server(
  31. expert_cls="multihead",
  32. num_experts=2,
  33. device="cpu",
  34. hidden_dim=hid_dim,
  35. num_handlers=2,
  36. no_dht=True,
  37. custom_module_path=CUSTOM_EXPERTS_PATH,
  38. ) as (server_endpoint, _):
  39. expert0 = RemoteExpert("expert.0", server_endpoint)
  40. expert1 = RemoteExpert("expert.1", server_endpoint)
  41. for batch_size in (1, 4):
  42. batch = (
  43. torch.randn(batch_size, hid_dim),
  44. torch.randn(batch_size, 2 * hid_dim),
  45. torch.randn(batch_size, 3 * hid_dim),
  46. )
  47. output0 = expert0(*batch)
  48. output1 = expert1(*batch)
  49. loss = output0.sum()
  50. loss.backward()
  51. loss = output1.sum()
  52. loss.backward()