12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import os
- import pytest
- import torch
- from hivemind import RemoteExpert
- from hivemind.moe.server import background_server
- CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), 'test_utils', 'custom_networks.py')
- @pytest.mark.forked
- def test_custom_expert(hid_dim=16):
- with background_server(
- expert_cls='perceptron', num_experts=2, device='cpu',
- hidden_dim=hid_dim, num_handlers=2, no_dht=True,
- custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _):
- expert0 = RemoteExpert('expert.0', server_endpoint)
- expert1 = RemoteExpert('expert.1', server_endpoint)
- for batch_size in (1, 4):
- batch = torch.randn(batch_size, hid_dim)
- output0 = expert0(batch)
- output1 = expert1(batch)
- loss = output0.sum()
- loss.backward()
- loss = output1.sum()
- loss.backward()
- @pytest.mark.forked
- def test_multihead_expert(hid_dim=16):
- with background_server(
- expert_cls='multihead', num_experts=2, device='cpu',
- hidden_dim=hid_dim, num_handlers=2, no_dht=True,
- custom_module_path=CUSTOM_EXPERTS_PATH) as (server_endpoint, _):
- expert0 = RemoteExpert('expert.0', server_endpoint)
- expert1 = RemoteExpert('expert.1', server_endpoint)
- for batch_size in (1, 4):
- batch = (torch.randn(batch_size, hid_dim), torch.randn(batch_size, 2 * hid_dim),
- torch.randn(batch_size, 3 * hid_dim))
- output0 = expert0(*batch)
- output1 = expert1(*batch)
- loss = output0.sum()
- loss.backward()
- loss = output1.sum()
- loss.backward()
|