|
@@ -0,0 +1,170 @@
|
|
|
|
+from __future__ import annotations
|
|
|
|
+
|
|
|
|
+import asyncio
|
|
|
|
+from typing import Any, Dict
|
|
|
|
+
|
|
|
|
+import pytest
|
|
|
|
+import torch
|
|
|
|
+
|
|
|
|
+import hivemind
|
|
|
|
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
|
+from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
|
|
+from hivemind.moe.server.expert_backend import ExpertBackend
|
|
|
|
+from hivemind.moe.server.task_pool import TaskPool
|
|
|
|
+from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
|
|
|
|
+from hivemind.proto import runtime_pb2
|
|
|
|
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
|
|
|
|
+from hivemind.utils.serializer import MSGPackSerializer
|
|
|
|
+from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
|
|
|
|
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@pytest.mark.forked
|
|
|
|
+@pytest.mark.asyncio
|
|
|
|
+async def test_connection_handler():
|
|
|
|
+ handler = ConnectionHandler(
|
|
|
|
+ hivemind.DHT(start=True),
|
|
|
|
+ dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
|
|
|
|
+ )
|
|
|
|
+ handler.start()
|
|
|
|
+
|
|
|
|
+ client_dht = hivemind.DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
|
|
|
|
+ client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
|
|
|
|
+
|
|
|
|
+ inputs = torch.randn(1, 2)
|
|
|
|
+ inputs_long = torch.randn(2**21, 2)
|
|
|
|
+
|
|
|
|
+ # forward unary
|
|
|
|
+ response = await client_stub.rpc_forward(
|
|
|
|
+ runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
|
|
|
|
+ )
|
|
|
|
+ outputs = deserialize_torch_tensor(response.tensors[0])
|
|
|
|
+ assert len(response.tensors) == 1
|
|
|
|
+ assert torch.allclose(outputs, inputs * 1)
|
|
|
|
+
|
|
|
|
+ response = await client_stub.rpc_forward(
|
|
|
|
+ runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
|
|
|
|
+ )
|
|
|
|
+ outputs = deserialize_torch_tensor(response.tensors[0])
|
|
|
|
+ assert len(response.tensors) == 1
|
|
|
|
+ assert torch.allclose(outputs, inputs * 2)
|
|
|
|
+
|
|
|
|
+ # forward streaming
|
|
|
|
+ split = (
|
|
|
|
+ p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
|
|
|
|
+ )
|
|
|
|
+ output_generator = await client_stub.rpc_forward_stream(
|
|
|
|
+ amap_in_executor(
|
|
|
|
+ lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
|
|
|
|
+ iter_as_aiter(split),
|
|
|
|
+ ),
|
|
|
|
+ )
|
|
|
|
+ outputs_list = [part async for part in output_generator]
|
|
|
|
+ del output_generator
|
|
|
|
+ assert len(outputs_list) == 8 # message size divided by DEFAULT_MAX_MSG_SIZE
|
|
|
|
+
|
|
|
|
+ results = await combine_and_deserialize_from_streaming(
|
|
|
|
+ amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)), deserialize_torch_tensor
|
|
|
|
+ )
|
|
|
|
+ assert len(results) == 1
|
|
|
|
+ assert torch.allclose(results[0], inputs_long * 2)
|
|
|
|
+
|
|
|
|
+ # forward errors
|
|
|
|
+ with pytest.raises(P2PHandlerError):
|
|
|
|
+ # no such expert: fails with P2PHandlerError KeyError('expert3')
|
|
|
|
+ await client_stub.rpc_forward(
|
|
|
|
+ runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ with pytest.raises(P2PHandlerError):
|
|
|
|
+ # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
|
|
|
|
+ await client_stub.rpc_forward(
|
|
|
|
+ runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # backward unary
|
|
|
|
+ response = await client_stub.rpc_backward(
|
|
|
|
+ runtime_pb2.ExpertRequest(
|
|
|
|
+ uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+ outputs = deserialize_torch_tensor(response.tensors[0])
|
|
|
|
+ assert len(response.tensors) == 1
|
|
|
|
+ assert torch.allclose(outputs, inputs * -2)
|
|
|
|
+
|
|
|
|
+ # backward streaming
|
|
|
|
+ split = (
|
|
|
|
+ p
|
|
|
|
+ for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
|
|
|
|
+ for p in split_for_streaming(t, chunk_size_bytes=2**16)
|
|
|
|
+ )
|
|
|
|
+ output_generator = await client_stub.rpc_backward_stream(
|
|
|
|
+ amap_in_executor(
|
|
|
|
+ lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
|
|
|
|
+ iter_as_aiter(split),
|
|
|
|
+ ),
|
|
|
|
+ )
|
|
|
|
+ results = await combine_and_deserialize_from_streaming(
|
|
|
|
+ amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
|
|
|
|
+ )
|
|
|
|
+ assert len(results) == 1
|
|
|
|
+ assert torch.allclose(results[0], inputs_long * 3)
|
|
|
|
+
|
|
|
|
+ # backward errors
|
|
|
|
+ with pytest.raises(P2PHandlerError):
|
|
|
|
+ # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
|
|
|
|
+ await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
|
|
|
|
+
|
|
|
|
+ with pytest.raises(P2PHandlerError):
|
|
|
|
+ # backward fails: empty stream
|
|
|
|
+ output_generator = await client_stub.rpc_backward_stream(
|
|
|
|
+ amap_in_executor(
|
|
|
|
+ lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
|
|
|
|
+ iter_as_aiter([]),
|
|
|
|
+ ),
|
|
|
|
+ )
|
|
|
|
+ results = await combine_and_deserialize_from_streaming(
|
|
|
|
+ amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
|
|
|
|
+ )
|
|
|
|
+ assert len(results) == 1
|
|
|
|
+ assert torch.allclose(results[0], inputs_long * 3)
|
|
|
|
+
|
|
|
|
+ # check that handler did not crash after failed request
|
|
|
|
+ await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
|
|
|
|
+
|
|
|
|
+ # info
|
|
|
|
+ response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
|
|
|
|
+ assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
|
|
|
|
+
|
|
|
|
+ response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
|
|
|
|
+ assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
|
|
|
|
+
|
|
|
|
+ with pytest.raises(P2PHandlerError):
|
|
|
|
+ await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
|
|
|
|
+
|
|
|
|
+ handler.terminate()
|
|
|
|
+ handler.join()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class DummyPool(TaskPool):
|
|
|
|
+ def __init__(self, k: float):
|
|
|
|
+ self.k = k
|
|
|
|
+
|
|
|
|
+ async def submit_task(self, *inputs: torch.Tensor):
|
|
|
|
+ await asyncio.sleep(0.01)
|
|
|
|
+ print(type(inputs), inputs[0].shape)
|
|
|
|
+ assert inputs[0].shape[-1] == 2
|
|
|
|
+ return [inputs[0] * self.k]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class DummyExpertBackend(ExpertBackend):
|
|
|
|
+ def __init__(self, name: str, k: float):
|
|
|
|
+ self.name = name
|
|
|
|
+ self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
|
|
|
|
+ self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
|
|
|
|
+ self.forward_pool = DummyPool(k)
|
|
|
|
+ self.backward_pool = DummyPool(k)
|
|
|
|
+
|
|
|
|
+ def get_info(self) -> Dict[str, Any]:
|
|
|
|
+ """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
|
|
|
|
+ return dict(name=self.name)
|