|
@@ -1,13 +1,16 @@
|
|
-import grpc
|
|
|
|
|
|
+import time
|
|
|
|
+
|
|
import numpy as np
|
|
import numpy as np
|
|
import pytest
|
|
import pytest
|
|
import torch
|
|
import torch
|
|
|
|
|
|
from hivemind.dht import DHT
|
|
from hivemind.dht import DHT
|
|
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
|
|
|
|
-from hivemind.moe.client.moe import DUMMY, _RemoteCallMany
|
|
|
|
|
|
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
|
|
|
|
+from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
|
|
|
|
+from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
|
|
from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
|
|
from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
|
|
from hivemind.moe.server.layers import name_to_block
|
|
from hivemind.moe.server.layers import name_to_block
|
|
|
|
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
|
|
from hivemind.utils.tensor_descr import BatchTensorDescriptor
|
|
from hivemind.utils.tensor_descr import BatchTensorDescriptor
|
|
|
|
|
|
|
|
|
|
@@ -18,8 +21,8 @@ def test_moe():
|
|
]
|
|
]
|
|
with background_server(
|
|
with background_server(
|
|
expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
|
|
expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
|
|
- ) as (server_endpoint, dht_maddrs):
|
|
|
|
- dht = DHT(start=True, initial_peers=dht_maddrs)
|
|
|
|
|
|
+ ) as server_peer_info:
|
|
|
|
+ dht = DHT(start=True, initial_peers=server_peer_info.addrs)
|
|
|
|
|
|
dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
|
|
dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
|
|
|
|
|
|
@@ -35,9 +38,8 @@ def test_no_experts():
|
|
]
|
|
]
|
|
with background_server(
|
|
with background_server(
|
|
expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
|
|
expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
|
|
- ) as (server_endpoint, dht_maddrs):
|
|
|
|
- dht = DHT(start=True, initial_peers=dht_maddrs)
|
|
|
|
-
|
|
|
|
|
|
+ ) as server_peer_info:
|
|
|
|
+ dht = DHT(start=True, initial_peers=server_peer_info.addrs)
|
|
dmoe = RemoteSwitchMixtureOfExperts(
|
|
dmoe = RemoteSwitchMixtureOfExperts(
|
|
in_features=16,
|
|
in_features=16,
|
|
grid_size=(4, 4, 4),
|
|
grid_size=(4, 4, 4),
|
|
@@ -71,12 +73,16 @@ def test_call_many(hidden_dim=16):
|
|
num_handlers=1,
|
|
num_handlers=1,
|
|
hidden_dim=hidden_dim,
|
|
hidden_dim=hidden_dim,
|
|
optim_cls=None,
|
|
optim_cls=None,
|
|
- no_dht=True,
|
|
|
|
- ) as (server_endpoint, _):
|
|
|
|
|
|
+ ) as server_peer_info:
|
|
inputs = torch.randn(4, hidden_dim, requires_grad=True)
|
|
inputs = torch.randn(4, hidden_dim, requires_grad=True)
|
|
inputs_clone = inputs.clone().detach().requires_grad_(True)
|
|
inputs_clone = inputs.clone().detach().requires_grad_(True)
|
|
- e0, e1, e2, e3, e4 = [RemoteExpert(f"expert.{i}", server_endpoint) for i in range(5)]
|
|
|
|
- e5 = RemoteExpert(f"thisshouldnotexist", "127.0.0.1:80")
|
|
|
|
|
|
+
|
|
|
|
+ dht = DHT(initial_peers=server_peer_info.addrs, start=True)
|
|
|
|
+ e0, e1, e2, e3, e4 = RemoteExpertWorker.spawn_experts(
|
|
|
|
+ [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
|
|
|
|
+ dht,
|
|
|
|
+ )
|
|
|
|
+ e5 = RemoteExpert(RemoteExpertInfo(f"thisshouldnotexist", server_peer_info), None)
|
|
|
|
|
|
mask, expert_outputs = _RemoteCallMany.apply(
|
|
mask, expert_outputs = _RemoteCallMany.apply(
|
|
DUMMY,
|
|
DUMMY,
|
|
@@ -129,11 +135,15 @@ def test_remote_module_call(hidden_dim=16):
|
|
num_handlers=1,
|
|
num_handlers=1,
|
|
hidden_dim=hidden_dim,
|
|
hidden_dim=hidden_dim,
|
|
optim_cls=None,
|
|
optim_cls=None,
|
|
- no_dht=True,
|
|
|
|
- ) as (server_endpoint, _):
|
|
|
|
- real_expert = RemoteExpert("expert.0", server_endpoint)
|
|
|
|
- fake_expert = RemoteExpert("oiasfjiasjf", server_endpoint)
|
|
|
|
-
|
|
|
|
|
|
+ ) as server_peer_info:
|
|
|
|
+ dht = DHT(initial_peers=server_peer_info.addrs, start=True)
|
|
|
|
+ real_expert, fake_expert = RemoteExpertWorker.spawn_experts(
|
|
|
|
+ [
|
|
|
|
+ RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
|
|
|
|
+ RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
|
|
|
|
+ ],
|
|
|
|
+ dht=dht,
|
|
|
|
+ )
|
|
out1 = real_expert(torch.randn(1, hidden_dim))
|
|
out1 = real_expert(torch.randn(1, hidden_dim))
|
|
assert out1.shape == (1, hidden_dim)
|
|
assert out1.shape == (1, hidden_dim)
|
|
dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
|
|
dummy_x = torch.randn(3, hidden_dim, requires_grad=True)
|
|
@@ -144,9 +154,9 @@ def test_remote_module_call(hidden_dim=16):
|
|
out3_again.norm().backward()
|
|
out3_again.norm().backward()
|
|
assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
|
|
assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
|
|
|
|
|
|
- with pytest.raises(grpc.RpcError):
|
|
|
|
|
|
+ with pytest.raises(P2PDaemonError):
|
|
real_expert(torch.randn(3, 11))
|
|
real_expert(torch.randn(3, 11))
|
|
- with pytest.raises(grpc.RpcError):
|
|
|
|
|
|
+ with pytest.raises(P2PDaemonError):
|
|
fake_expert(dummy_x)
|
|
fake_expert(dummy_x)
|
|
|
|
|
|
|
|
|
|
@@ -154,11 +164,11 @@ def test_remote_module_call(hidden_dim=16):
|
|
def test_beam_search_correctness():
|
|
def test_beam_search_correctness():
|
|
all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
|
|
all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)]
|
|
dht = DHT(start=True)
|
|
dht = DHT(start=True)
|
|
- assert all(declare_experts(dht, all_expert_uids, endpoint="fake-endpoint"))
|
|
|
|
|
|
+ assert all(declare_experts(dht, all_expert_uids, dht.peer_id))
|
|
|
|
|
|
dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
|
|
dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.")
|
|
|
|
|
|
- for i in range(25):
|
|
|
|
|
|
+ for _ in range(25):
|
|
input = torch.randn(32)
|
|
input = torch.randn(32)
|
|
grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
|
|
grid_scores = dmoe.proj(input).split_with_sizes(dmoe.beam_search.grid_size, dim=-1)
|
|
|
|
|
|
@@ -173,7 +183,7 @@ def test_beam_search_correctness():
|
|
# reference: independently find :beam_size: best experts with exhaustive search
|
|
# reference: independently find :beam_size: best experts with exhaustive search
|
|
all_scores = dmoe.compute_expert_scores(
|
|
all_scores = dmoe.compute_expert_scores(
|
|
[dim_scores.unsqueeze(0) for dim_scores in grid_scores],
|
|
[dim_scores.unsqueeze(0) for dim_scores in grid_scores],
|
|
- [[RemoteExpert(uid, "") for uid in all_expert_uids]],
|
|
|
|
|
|
+ [[RemoteExpert(RemoteExpertInfo(uid, None), None) for uid in all_expert_uids]],
|
|
)[0]
|
|
)[0]
|
|
true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
|
|
true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
|
|
|
|
|
|
@@ -194,9 +204,12 @@ def test_determinism(hidden_dim=16):
|
|
num_handlers=1,
|
|
num_handlers=1,
|
|
hidden_dim=hidden_dim,
|
|
hidden_dim=hidden_dim,
|
|
optim_cls=None,
|
|
optim_cls=None,
|
|
- no_dht=True,
|
|
|
|
- ) as (server_endpoint, _):
|
|
|
|
- expert = RemoteExpert(uid=f"expert.0", endpoint=server_endpoint)
|
|
|
|
|
|
+ ) as server_peer_info:
|
|
|
|
+ dht = DHT(initial_peers=server_peer_info.addrs, start=True)
|
|
|
|
+ expert = RemoteExpertWorker.spawn_experts(
|
|
|
|
+ [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
|
|
|
|
+ dht=dht,
|
|
|
|
+ )[0]
|
|
|
|
|
|
out = expert(xx, mask)
|
|
out = expert(xx, mask)
|
|
out_rerun = expert(xx, mask)
|
|
out_rerun = expert(xx, mask)
|
|
@@ -220,7 +233,7 @@ def test_compute_expert_scores():
|
|
jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
|
|
jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
|
|
batch_experts = [
|
|
batch_experts = [
|
|
[
|
|
[
|
|
- RemoteExpert(uid=f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", endpoint="[::]:1337")
|
|
|
|
|
|
+ RemoteExpert(RemoteExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
|
|
for expert_i in range(len(ii[batch_i]))
|
|
for expert_i in range(len(ii[batch_i]))
|
|
]
|
|
]
|
|
for batch_i in range(len(ii))
|
|
for batch_i in range(len(ii))
|
|
@@ -261,9 +274,10 @@ def test_client_anomaly_detection():
|
|
server.start()
|
|
server.start()
|
|
try:
|
|
try:
|
|
server.ready.wait()
|
|
server.ready.wait()
|
|
|
|
+ dht_experts = DHT(initial_peers=dht.get_visible_maddrs(), start=True)
|
|
|
|
|
|
dmoe = RemoteMixtureOfExperts(
|
|
dmoe = RemoteMixtureOfExperts(
|
|
- in_features=16, grid_size=(3,), dht=dht, k_best=3, uid_prefix="expert.", detect_anomalies=True
|
|
|
|
|
|
+ in_features=16, grid_size=(3,), dht=dht_experts, k_best=3, uid_prefix="expert.", detect_anomalies=True
|
|
)
|
|
)
|
|
|
|
|
|
input = torch.randn(1, 16)
|
|
input = torch.randn(1, 16)
|
|
@@ -280,7 +294,7 @@ def test_client_anomaly_detection():
|
|
inf_loss.backward()
|
|
inf_loss.backward()
|
|
|
|
|
|
dmoe = RemoteMixtureOfExperts(
|
|
dmoe = RemoteMixtureOfExperts(
|
|
- in_features=16, grid_size=(4,), dht=dht, k_best=4, uid_prefix="expert.", detect_anomalies=True
|
|
|
|
|
|
+ in_features=16, grid_size=(4,), dht=dht_experts, k_best=4, uid_prefix="expert.", detect_anomalies=True
|
|
)
|
|
)
|
|
output = dmoe(input)
|
|
output = dmoe(input)
|
|
assert output.isfinite().all()
|
|
assert output.isfinite().all()
|