1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- import os
- import pytest
- import torch
- from hivemind import RemoteExpert
- from hivemind.moe.server import background_server
- CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
- @pytest.mark.forked
- def test_custom_expert(hid_dim=16):
- with background_server(
- expert_cls="perceptron",
- num_experts=2,
- device="cpu",
- hidden_dim=hid_dim,
- num_handlers=2,
- no_dht=True,
- custom_module_path=CUSTOM_EXPERTS_PATH,
- ) as (server_endpoint, _):
- expert0 = RemoteExpert("expert.0", server_endpoint)
- expert1 = RemoteExpert("expert.1", server_endpoint)
- for batch_size in (1, 4):
- batch = torch.randn(batch_size, hid_dim)
- output0 = expert0(batch)
- output1 = expert1(batch)
- loss = output0.sum()
- loss.backward()
- loss = output1.sum()
- loss.backward()
- @pytest.mark.forked
- def test_multihead_expert(hid_dim=16):
- with background_server(
- expert_cls="multihead",
- num_experts=2,
- device="cpu",
- hidden_dim=hid_dim,
- num_handlers=2,
- no_dht=True,
- custom_module_path=CUSTOM_EXPERTS_PATH,
- ) as (server_endpoint, _):
- expert0 = RemoteExpert("expert.0", server_endpoint)
- expert1 = RemoteExpert("expert.1", server_endpoint)
- for batch_size in (1, 4):
- batch = (
- torch.randn(batch_size, hid_dim),
- torch.randn(batch_size, 2 * hid_dim),
- torch.randn(batch_size, 3 * hid_dim),
- )
- output0 = expert0(*batch)
- output1 = expert1(*batch)
- loss = output0.sum()
- loss.backward()
- loss = output1.sum()
- loss.backward()
|