test_custom_experts.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import os
  2. import pytest
  3. import torch
  4. from hivemind import RemoteExpert, background_server
  5. CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), 'test_utils', 'custom_networks.py')
  6. @pytest.mark.forked
  7. def test_custom_expert(hid_dim=16):
  8. with background_server(
  9. expert_cls='perceptron', num_experts=2, device='cpu',
  10. hidden_dim=hid_dim, num_handlers=2, no_dht=True,
  11. custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _):
  12. expert0 = RemoteExpert('expert.0', server_endpoint)
  13. expert1 = RemoteExpert('expert.1', server_endpoint)
  14. for batch_size in (1, 4):
  15. batch = torch.randn(batch_size, hid_dim)
  16. output0 = expert0(batch)
  17. output1 = expert1(batch)
  18. loss = output0.sum()
  19. loss.backward()
  20. loss = output1.sum()
  21. loss.backward()
  22. @pytest.mark.forked
  23. def test_multihead_expert(hid_dim=16):
  24. with background_server(
  25. expert_cls='multihead', num_experts=2, device='cpu',
  26. hidden_dim=hid_dim, num_handlers=2, no_dht=True,
  27. custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _):
  28. expert0 = RemoteExpert('expert.0', server_endpoint)
  29. expert1 = RemoteExpert('expert.1', server_endpoint)
  30. for batch_size in (1, 4):
  31. batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim),
  32. torch.randn(batch_size, 3 * hid_dim))
  33. output0 = expert0(*batch)
  34. output1 = expert1(*batch)
  35. loss = output0.sum()
  36. loss.backward()
  37. loss = output1.sum()
  38. loss.backward()