Selaa lähdekoodia

test connection handler

justheuristic 3 vuotta sitten
vanhempi
commit
e1d9daef7b
1 muutettua tiedostoa jossa 170 lisäystä ja 0 poistoa
  1. 170 0
      tests/test_connection_handler.py

+ 170 - 0
tests/test_connection_handler.py

@@ -0,0 +1,170 @@
+from __future__ import annotations
+
+import asyncio
+from typing import Any, Dict
+
+import pytest
+import torch
+
+import hivemind
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.task_pool import TaskPool
+from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
+from hivemind.utils.tensor_descr import BatchTensorDescriptor
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_connection_handler():
+    handler = ConnectionHandler(
+        hivemind.DHT(start=True),
+        dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)),
+    )
+    handler.start()
+
+    client_dht = hivemind.DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
+    client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
+
+    inputs = torch.randn(1, 2)
+    inputs_long = torch.randn(2**21, 2)
+
+    # forward unary
+    response = await client_stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)])
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * 1)
+
+    response = await client_stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid="expert2", tensors=[serialize_torch_tensor(inputs)])
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * 2)
+
+    # forward streaming
+    split = (
+        p for t in [serialize_torch_tensor(inputs_long)] for p in split_for_streaming(t, chunk_size_bytes=2**16)
+    )
+    output_generator = await client_stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
+            iter_as_aiter(split),
+        ),
+    )
+    outputs_list = [part async for part in output_generator]
+    del output_generator
+    assert len(outputs_list) == 8  # message size divided by DEFAULT_MAX_MSG_SIZE
+
+    results = await combine_and_deserialize_from_streaming(
+        amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)), deserialize_torch_tensor
+    )
+    assert len(results) == 1
+    assert torch.allclose(results[0], inputs_long * 2)
+
+    # forward errors
+    with pytest.raises(P2PHandlerError):
+        # no such expert: fails with P2PHandlerError KeyError('expert3')
+        await client_stub.rpc_forward(
+            runtime_pb2.ExpertRequest(uid="expert3", tensors=[serialize_torch_tensor(inputs)])
+        )
+
+    with pytest.raises(P2PHandlerError):
+        # bad input shape: P2PHandlerError("AssertionError") raised by DummyPool.submit_task
+        await client_stub.rpc_forward(
+            runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(torch.arange(5))])
+        )
+
+    # backward unary
+    response = await client_stub.rpc_backward(
+        runtime_pb2.ExpertRequest(
+            uid="expert2", tensors=[serialize_torch_tensor(inputs * -1), serialize_torch_tensor(inputs)]
+        )
+    )
+    outputs = deserialize_torch_tensor(response.tensors[0])
+    assert len(response.tensors) == 1
+    assert torch.allclose(outputs, inputs * -2)
+
+    # backward streaming
+    split = (
+        p
+        for t in [serialize_torch_tensor(inputs_long * 3), serialize_torch_tensor(inputs_long * 0)]
+        for p in split_for_streaming(t, chunk_size_bytes=2**16)
+    )
+    output_generator = await client_stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert1", tensors=[tensor_part]),
+            iter_as_aiter(split),
+        ),
+    )
+    results = await combine_and_deserialize_from_streaming(
+        amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
+    )
+    assert len(results) == 1
+    assert torch.allclose(results[0], inputs_long * 3)
+
+    # backward errors
+    with pytest.raises(P2PHandlerError):
+        # bad input schema: fails with P2PHandlerError IndexError('tuple index out of range')
+        await client_stub.rpc_backward(runtime_pb2.ExpertRequest(uid="expert2", tensors=[]))
+
+    with pytest.raises(P2PHandlerError):
+        # backward fails: empty stream
+        output_generator = await client_stub.rpc_backward_stream(
+            amap_in_executor(
+                lambda tensor_part: runtime_pb2.ExpertRequest(uid="expert2", tensors=[tensor_part]),
+                iter_as_aiter([]),
+            ),
+        )
+        results = await combine_and_deserialize_from_streaming(
+            amap_in_executor(lambda r: r.tensors, output_generator), deserialize_torch_tensor
+        )
+        assert len(results) == 1
+        assert torch.allclose(results[0], inputs_long * 3)
+
+    # check that handler did not crash after failed request
+    await client_stub.rpc_forward(runtime_pb2.ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(inputs)]))
+
+    # info
+    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert1"))
+    assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert1")
+
+    response = await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert2"))
+    assert MSGPackSerializer.loads(response.serialized_info) == dict(name="expert2")
+
+    with pytest.raises(P2PHandlerError):
+        await client_stub.rpc_info(runtime_pb2.ExpertUID(uid="expert999"))
+
+    handler.terminate()
+    handler.join()
+
+
+class DummyPool(TaskPool):
+    def __init__(self, k: float):
+        self.k = k
+
+    async def submit_task(self, *inputs: torch.Tensor):
+        await asyncio.sleep(0.01)
+        print(type(inputs), inputs[0].shape)
+        assert inputs[0].shape[-1] == 2
+        return [inputs[0] * self.k]
+
+
+class DummyExpertBackend(ExpertBackend):
+    def __init__(self, name: str, k: float):
+        self.name = name
+        self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.grad_inputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))]
+        self.forward_pool = DummyPool(k)
+        self.backward_pool = DummyPool(k)
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(name=self.name)