import os import pytest import torch from hivemind import RemoteExpert from hivemind.moe.server import background_server CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), 'test_utils', 'custom_networks.py') @pytest.mark.forked def test_custom_expert(hid_dim=16): with background_server( expert_cls='perceptron', num_experts=2, device='cpu', hidden_dim=hid_dim, num_handlers=2, no_dht=True, custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _): expert0 = RemoteExpert('expert.0', server_endpoint) expert1 = RemoteExpert('expert.1', server_endpoint) for batch_size in (1, 4): batch = torch.randn(batch_size, hid_dim) output0 = expert0(batch) output1 = expert1(batch) loss = output0.sum() loss.backward() loss = output1.sum() loss.backward() @pytest.mark.forked def test_multihead_expert(hid_dim=16): with background_server( expert_cls='multihead', num_experts=2, device='cpu', hidden_dim=hid_dim, num_handlers=2, no_dht=True, custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _): expert0 = RemoteExpert('expert.0', server_endpoint) expert1 = RemoteExpert('expert.1', server_endpoint) for batch_size in (1, 4): batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim), torch.randn(batch_size, 3 * hid_dim)) output0 = expert0(*batch) output1 = expert1(*batch) loss = output0.sum() loss.backward() loss = output1.sum() loss.backward()