from typing import AsyncIterator, Optional import pytest import torch from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor from hivemind.proto.runtime_pb2 import ExpertRequest from hivemind.utils import MSGPackSerializer, amap_in_executor, iter_as_aiter, split_for_streaming from src.client.priority_block import DustyRemoteBlock from src.client.spending_policy import SpendingPolicyBase class SpendingPolicyTest(SpendingPolicyBase): def __init__(self): self._p = { "rpc_single": 1, "rpc_stream": 2, } def get_points(self, request: ExpertRequest, method_name: str) -> float: return self._p.get(method_name, -1) class HandlerStubTest: async def rpc_single(self, input: ExpertRequest, timeout: Optional[float] = None): return input async def rpc_stream(self, input: AsyncIterator[ExpertRequest], timeout: Optional[float] = None): return input async def rpc_info(self, input: str, timeout: Optional[float] = None): return input class RemoteBlockTest(DustyRemoteBlock): @property def _stub(self): return HandlerStubTest() @pytest.mark.forked @pytest.mark.asyncio async def test_single(): remote = RemoteBlockTest(SpendingPolicyTest(), None, None) stub = remote.stub input = torch.randn(1, 2) request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)]) print(stub) out: ExpertRequest = await stub.rpc_single(request) assert out.metadata != b"" assert len(out.tensors) == 1 assert torch.allclose(input, deserialize_torch_tensor(out.tensors[0])) meta = MSGPackSerializer.loads(out.metadata) assert isinstance(meta, dict) assert "__dust" in meta assert meta["__dust"] == 1 @pytest.mark.forked @pytest.mark.asyncio async def test_stream(): remote = RemoteBlockTest(SpendingPolicyTest(), None, None) stub = remote.stub input = torch.randn(2**21, 2) split = (p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=2**16)) output_generator = await stub.rpc_stream( amap_in_executor( lambda tensor_part: ExpertRequest(uid="expert2", tensors=[tensor_part]), iter_as_aiter(split), ), ) outputs_list = [part async for part in output_generator] assert len(outputs_list) == 2**5 * 8 assert outputs_list[0].metadata != b"" for i in range(1, len(outputs_list)): assert outputs_list[i].metadata == b"" meta = MSGPackSerializer.loads(outputs_list[0].metadata) assert isinstance(meta, dict) assert "__dust" in meta assert meta["__dust"] == 2 results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list))) assert len(results) == 1 assert torch.allclose(results[0], input) @pytest.mark.forked @pytest.mark.asyncio async def test_no_wrapper(): remote = RemoteBlockTest(SpendingPolicyTest(), None, None) stub = remote.stub test = await stub.rpc_info("Test") assert test == "Test"