test_custom_experts.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import pytest
  3. import torch
  4. from hivemind import RemoteExpert
  5. from hivemind.moe.server import background_server
  6. CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), 'test_utils', 'custom_networks.py')
  7. @pytest.mark.forked
  8. def test_custom_expert(hid_dim=16):
  9. with background_server(
  10. expert_cls='perceptron', num_experts=2, device='cpu',
  11. hidden_dim=hid_dim, num_handlers=2, no_dht=True,
  12. custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _):
  13. expert0 = RemoteExpert('expert.0', server_endpoint)
  14. expert1 = RemoteExpert('expert.1', server_endpoint)
  15. for batch_size in (1, 4):
  16. batch = torch.randn(batch_size, hid_dim)
  17. output0 = expert0(batch)
  18. output1 = expert1(batch)
  19. loss = output0.sum()
  20. loss.backward()
  21. loss = output1.sum()
  22. loss.backward()
  23. @pytest.mark.forked
  24. def test_multihead_expert(hid_dim=16):
  25. with background_server(
  26. expert_cls='multihead', num_experts=2, device='cpu',
  27. hidden_dim=hid_dim, num_handlers=2, no_dht=True,
  28. custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _):
  29. expert0 = RemoteExpert('expert.0', server_endpoint)
  30. expert1 = RemoteExpert('expert.1', server_endpoint)
  31. for batch_size in (1, 4):
  32. batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim),
  33. torch.randn(batch_size, 3 * hid_dim))
  34. output0 = expert0(*batch)
  35. output1 = expert1(*batch)
  36. loss = output0.sum()
  37. loss.backward()
  38. loss = output1.sum()
  39. loss.backward()