test_custom_expert.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import os
  2. import pytest
  3. from typing import Optional
  4. import torch
  5. import hivemind
  6. from hivemind import RemoteExpert, background_server
  7. @pytest.mark.forked
  8. def test_custom_expert(port: Optional[int] = None, hid_dim=16):
  9. with background_server(
  10. expert_cls='perceptron', num_experts=2, device='cpu',
  11. hidden_dim=hid_dim, num_handlers=2, no_dht=True,
  12. custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
  13. expert0 = RemoteExpert('expert.0', server_endpoint)
  14. expert1 = RemoteExpert('expert.1', server_endpoint)
  15. for batch_size in (1, 4):
  16. batch = torch.randn(batch_size, hid_dim)
  17. output0 = expert0(batch)
  18. output1 = expert1(batch)
  19. loss = output0.sum()
  20. loss.backward()
  21. loss = output1.sum()
  22. loss.backward()
  23. @pytest.mark.forked
  24. def test_multihead_expert(port: Optional[int] = None, hid_dim=16):
  25. with background_server(
  26. expert_cls='multihead', num_experts=2, device='cpu',
  27. hidden_dim=hid_dim, num_handlers=2, no_dht=True,
  28. custom_module_path=os.path.join(os.path.dirname(__file__), 'custom_networks.py')) as (server_endpoint, _):
  29. expert0 = RemoteExpert('expert.0', server_endpoint)
  30. expert1 = RemoteExpert('expert.1', server_endpoint)
  31. for batch_size in (1, 4):
  32. batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim), torch.randn(batch_size, 3 * hid_dim))
  33. output0 = expert0(*batch)
  34. output1 = expert1(*batch)
  35. loss = output0.sum()
  36. loss.backward()
  37. loss = output1.sum()
  38. loss.backward()