|
@@ -6,7 +6,9 @@ from typing import Any, Dict
|
|
|
import pytest
|
|
|
import torch
|
|
|
|
|
|
-import hivemind
|
|
|
+import logging
|
|
|
+
|
|
|
+from hivemind.dht import DHT
|
|
|
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
|
from hivemind.moe.server.expert_backend import ExpertBackend
|
|
@@ -17,43 +19,92 @@ from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
|
|
|
from hivemind.utils.serializer import MSGPackSerializer
|
|
|
from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
|
|
|
from hivemind.utils.tensor_descr import BatchTensorDescriptor
|
|
|
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
|
|
|
+LONG_INPUT_SIZE = 2**21
|
|
|
|
|
|
-@pytest.mark.forked
|
|
|
-@pytest.mark.asyncio
|
|
|
-async def test_connection_handler():
|
|
|
- handler = ConnectionHandler(
|
|
|
- hivemind.DHT(start=True),
|
|
|
- dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
|
|
|
- )
|
|
|
+
|
|
|
+class DummyPool(TaskPool):
|
|
|
+ def __init__(self, k: float):
|
|
|
+ self.k = k
|
|
|
+
|
|
|
+ async def submit_task(self, *inputs: torch.Tensor):
|
|
|
+ await asyncio.sleep(0.01)
|
|
|
+ if inputs[0].shape[-1] != 2:
|
|
|
+ raise ValueError("wrong input shape")
|
|
|
+ return [inputs[0] * self.k]
|
|
|
+
|
|
|
+
|
|
|
+class DummyExpertBackend(ExpertBackend):
|
|
|
+ def __init__(self, name: str, k: float):
|
|
|
+ self.name = name
|
|
|
+ self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
|
|
|
+ self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
|
|
|
+ self.forward_pool = DummyPool(k)
|
|
|
+ self.backward_pool = DummyPool(k)
|
|
|
+
|
|
|
+ def get_info(self) -> Dict[str, Any]:
|
|
|
+ """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
|
|
|
+ return dict(name=self.name)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture()
|
|
|
+async def stub():
|
|
|
+ server_dht = DHT(start=True)
|
|
|
+ experts = {
|
|
|
+ "expert1": DummyExpertBackend("expert1", k=1),
|
|
|
+ "expert2": DummyExpertBackend("expert2", k=2),
|
|
|
+ }
|
|
|
+ handler = ConnectionHandler(server_dht, experts)
|
|
|
handler.start()
|
|
|
|
|
|
- client_dht = hivemind.DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
|
|
|
- client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
|
|
|
+ client_dht = DHT(start=True, client_mode=True, initial_peers=server_dht.get_visible_maddrs())
|
|
|
+ p2p = await client_dht.replicate_p2p()
|
|
|
+ client_stub = ConnectionHandler.get_stub(p2p, server_dht.peer_id)
|
|
|
+ yield client_stub
|
|
|
+
|
|
|
+ handler.terminate()
|
|
|
+ handler.join()
|
|
|
|
|
|
- inputs = torch.randn(1, 2)
|
|
|
- inputs_long = torch.randn(2**21, 2)
|
|
|
|
|
|
- # forward unary
|
|
|
- response = await client_stub.rpc_forward(
|
|
|
- runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
|
|
|
+@pytest.fixture
|
|
|
+def small_input():
|
|
|
+ return torch.randn(1, 2)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def long_input():
|
|
|
+ input = torch.randn(LONG_INPUT_SIZE, 2)
|
|
|
+ n_chunks = (input.nelement() * input.element_size() + DEFAULT_MAX_MSG_SIZE - 1) // DEFAULT_MAX_MSG_SIZE
|
|
|
+ return input, n_chunks
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.forked
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_forward_unary(stub, small_input):
|
|
|
+ response = await stub.rpc_forward(
|
|
|
+ runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(small_input)])
|
|
|
)
|
|
|
outputs = deserialize_torch_tensor(response.tensors[0])
|
|
|
assert len(response.tensors) == 1
|
|
|
- assert torch.allclose(outputs, inputs * 1)
|
|
|
+ assert torch.allclose(outputs, small_input * 1)
|
|
|
|
|
|
- response = await client_stub.rpc_forward(
|
|
|
- runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
|
|
|
+ response = await stub.rpc_forward(
|
|
|
+ runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(small_input)])
|
|
|
)
|
|
|
outputs = deserialize_torch_tensor(response.tensors[0])
|
|
|
assert len(response.tensors) == 1
|
|
|
- assert torch.allclose(outputs, inputs * 2)
|
|
|
+ assert torch.allclose(outputs, small_input * 2)
|
|
|
+
|
|
|
|
|
|
- # forward streaming
|
|
|
+@pytest.mark.forked
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_forward_streaming(stub, long_input):
|
|
|
+ input, n_chunks = long_input
|
|
|
split = (
|
|
|
- p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
|
|
|
+ p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=DEFAULT_MAX_MSG_SIZE)
|
|
|
)
|
|
|
- output_generator = await client_stub.rpc_forward_stream(
|
|
|
+ output_generator = await stub.rpc_forward_stream(
|
|
|
amap_in_executor(
|
|
|
lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
|
|
|
iter_as_aiter(split),
|
|
@@ -61,44 +112,53 @@ async def test_connection_handler():
|
|
|
)
|
|
|
outputs_list = [part async for part in output_generator]
|
|
|
del output_generator
|
|
|
- assert len(outputs_list) == 8 # message size divided by DEFAULT_MAX_MSG_SIZE
|
|
|
+ assert len(outputs_list) == n_chunks
|
|
|
|
|
|
results = await combine_and_deserialize_from_streaming(
|
|
|
amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)), deserialize_torch_tensor
|
|
|
)
|
|
|
assert len(results) == 1
|
|
|
- assert torch.allclose(results[0], inputs_long * 2)
|
|
|
+ assert torch.allclose(results[0], input * 2)
|
|
|
|
|
|
- # forward errors
|
|
|
+
|
|
|
+@pytest.mark.forked
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_forward_errors(stub, small_input):
|
|
|
+ # no such expert: fails with P2PHandlerError KeyError('expert3')
|
|
|
with pytest.raises(P2PHandlerError):
|
|
|
- # no such expert: fails with P2PHandlerError KeyError('expert3')
|
|
|
- await client_stub.rpc_forward(
|
|
|
- runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
|
|
|
+ await stub.rpc_forward(
|
|
|
+ runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(small_input)])
|
|
|
)
|
|
|
|
|
|
+ # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
|
|
|
with pytest.raises(P2PHandlerError):
|
|
|
- # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
|
|
|
- await client_stub.rpc_forward(
|
|
|
+ await stub.rpc_forward(
|
|
|
runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
|
|
|
)
|
|
|
|
|
|
- # backward unary
|
|
|
- response = await client_stub.rpc_backward(
|
|
|
+@pytest.mark.forked
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_backward_unary(stub, small_input):
|
|
|
+ response = await stub.rpc_backward(
|
|
|
runtime_pb2.ExpertRequest(
|
|
|
- uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
|
|
|
+ uid="expert2", tensors=[serialize_torch_tensor(small_input * -1), serialize_torch_tensor(small_input)]
|
|
|
)
|
|
|
)
|
|
|
outputs = deserialize_torch_tensor(response.tensors[0])
|
|
|
assert len(response.tensors) == 1
|
|
|
- assert torch.allclose(outputs, inputs * -2)
|
|
|
+ assert torch.allclose(outputs, small_input * -2)
|
|
|
|
|
|
- # backward streaming
|
|
|
+
|
|
|
+@pytest.mark.forked
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_backward_streaming(stub, long_input):
|
|
|
+ input, _ = long_input
|
|
|
split = (
|
|
|
p
|
|
|
- for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
|
|
|
- for p in split_for_streaming(t, chunk_size_bytes=2**16)
|
|
|
+ for t in [serialize_torch_tensor(input * 3), serialize_torch_tensor(input * 0)]
|
|
|
+ for p in split_for_streaming(t, chunk_size_bytes=DEFAULT_MAX_MSG_SIZE)
|
|
|
)
|
|
|
- output_generator = await client_stub.rpc_backward_stream(
|
|
|
+ output_generator = await stub.rpc_backward_stream(
|
|
|
amap_in_executor(
|
|
|
lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
|
|
|
iter_as_aiter(split),
|
|
@@ -108,16 +168,20 @@ async def test_connection_handler():
|
|
|
amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
|
|
|
)
|
|
|
assert len(results) == 1
|
|
|
- assert torch.allclose(results[0], inputs_long * 3)
|
|
|
+ assert torch.allclose(results[0], input * 3)
|
|
|
|
|
|
- # backward errors
|
|
|
+
|
|
|
+@pytest.mark.forked
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_backward_errors(stub, small_input, long_input):
|
|
|
+ long, _ = long_input
|
|
|
+ # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
|
|
|
with pytest.raises(P2PHandlerError):
|
|
|
- # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
|
|
|
- await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
|
|
|
+ await stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
|
|
|
|
|
|
+ # backward fails: empty stream
|
|
|
with pytest.raises(P2PHandlerError):
|
|
|
- # backward fails: empty stream
|
|
|
- output_generator = await client_stub.rpc_backward_stream(
|
|
|
+ output_generator = await stub.rpc_backward_stream(
|
|
|
amap_in_executor(
|
|
|
lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
|
|
|
iter_as_aiter([]),
|
|
@@ -127,44 +191,20 @@ async def test_connection_handler():
|
|
|
amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
|
|
|
)
|
|
|
assert len(results) == 1
|
|
|
- assert torch.allclose(results[0], inputs_long * 3)
|
|
|
+ assert torch.allclose(results[0], long * 3)
|
|
|
|
|
|
# check that handler did not crash after failed request
|
|
|
- await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
|
|
|
+ await stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(small_input)]))
|
|
|
+
|
|
|
|
|
|
- # info
|
|
|
- response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
|
|
|
+@pytest.mark.forked
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_info(stub):
|
|
|
+ response = await stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
|
|
|
assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
|
|
|
|
|
|
- response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
|
|
|
+ response = await stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
|
|
|
assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
|
|
|
|
|
|
with pytest.raises(P2PHandlerError):
|
|
|
- await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
|
|
|
-
|
|
|
- handler.terminate()
|
|
|
- handler.join()
|
|
|
-
|
|
|
-
|
|
|
-class DummyPool(TaskPool):
|
|
|
- def __init__(self, k: float):
|
|
|
- self.k = k
|
|
|
-
|
|
|
- async def submit_task(self, *inputs: torch.Tensor):
|
|
|
- await asyncio.sleep(0.01)
|
|
|
- print(type(inputs), inputs[0].shape)
|
|
|
- assert inputs[0].shape[-1] == 2
|
|
|
- return [inputs[0] * self.k]
|
|
|
-
|
|
|
-
|
|
|
-class DummyExpertBackend(ExpertBackend):
|
|
|
- def __init__(self, name: str, k: float):
|
|
|
- self.name = name
|
|
|
- self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
|
|
|
- self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
|
|
|
- self.forward_pool = DummyPool(k)
|
|
|
- self.backward_pool = DummyPool(k)
|
|
|
-
|
|
|
- def get_info(self) -> Dict[str, Any]:
|
|
|
- """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
|
|
|
- return dict(name=self.name)
|
|
|
+ await stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
|