123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- from typing import AsyncIterator, Optional
- import pytest
- import torch
- from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
- from hivemind.proto.runtime_pb2 import ExpertRequest
- from hivemind.utils import MSGPackSerializer, amap_in_executor, iter_as_aiter, split_for_streaming
- from src.client.priority_block import DustyRemoteBlock
- from src.client.spending_policy import SpendingPolicyBase
- class SpendingPolicyTest(SpendingPolicyBase):
- def __init__(self):
- self._p = {
- "rpc_single": 1,
- "rpc_stream": 2,
- }
- def get_points(self, request: ExpertRequest, method_name: str) -> float:
- return self._p.get(method_name, -1)
- class HandlerStubTest:
- async def rpc_single(self, input: ExpertRequest, timeout: Optional[float] = None):
- return input
- async def rpc_stream(self, input: AsyncIterator[ExpertRequest], timeout: Optional[float] = None):
- return input
- async def rpc_info(self, input: str, timeout: Optional[float] = None):
- return input
- class RemoteBlockTest(DustyRemoteBlock):
- @property
- def _stub(self):
- return HandlerStubTest()
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_single():
- remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
- stub = remote.stub
- input = torch.randn(1, 2)
- request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)])
- print(stub)
- out: ExpertRequest = await stub.rpc_single(request)
- assert out.metadata != b""
- assert len(out.tensors) == 1
- assert torch.allclose(input, deserialize_torch_tensor(out.tensors[0]))
- meta = MSGPackSerializer.loads(out.metadata)
- assert isinstance(meta, dict)
- assert "__dust" in meta
- assert meta["__dust"] == 1
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_stream():
- remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
- stub = remote.stub
- input = torch.randn(2**21, 2)
- split = (p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=2**16))
- output_generator = await stub.rpc_stream(
- amap_in_executor(
- lambda tensor_part: ExpertRequest(uid="expert2", tensors=[tensor_part]),
- iter_as_aiter(split),
- ),
- )
- outputs_list = [part async for part in output_generator]
- assert len(outputs_list) == 2**5 * 8
- assert outputs_list[0].metadata != b""
- for i in range(1, len(outputs_list)):
- assert outputs_list[i].metadata == b""
- meta = MSGPackSerializer.loads(outputs_list[0].metadata)
- assert isinstance(meta, dict)
- assert "__dust" in meta
- assert meta["__dust"] == 2
- results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
- assert len(results) == 1
- assert torch.allclose(results[0], input)
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_no_wrapper():
- remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
- stub = remote.stub
- test = await stub.rpc_info("Test")
- assert test == "Test"
|