test_custom_expert.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. import pytest
  3. import torch
  4. from hivemind import RemoteExpert, background_server
  5. @pytest.mark.forked
  6. def test_custom_expert(hid_dim=16):
  7. with background_server(
  8. expert_cls='perceptron', num_experts=2, device='cpu',
  9. hidden_dim=hid_dim, num_handlers=2, no_dht=True,
  10. custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
  11. expert0 = RemoteExpert('expert.0', server_endpoint)
  12. expert1 = RemoteExpert('expert.1', server_endpoint)
  13. for batch_size in (1, 4):
  14. batch = torch.randn(batch_size, hid_dim)
  15. output0 = expert0(batch)
  16. output1 = expert1(batch)
  17. loss = output0.sum()
  18. loss.backward()
  19. loss = output1.sum()
  20. loss.backward()
  21. @pytest.mark.forked
  22. def test_multihead_expert(hid_dim=16):
  23. with background_server(
  24. expert_cls='multihead', num_experts=2, device='cpu',
  25. hidden_dim=hid_dim, num_handlers=2, no_dht=True,
  26. custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
  27. expert0 = RemoteExpert('expert.0', server_endpoint)
  28. expert1 = RemoteExpert('expert.1', server_endpoint)
  29. for batch_size in (1, 4):
  30. batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim),
  31. torch.randn(batch_size, 3 * hid_dim))
  32. output0 = expert0(*batch)
  33. output1 = expert1(*batch)
  34. loss = output0.sum()
  35. loss.backward()
  36. loss = output1.sum()
  37. loss.backward()