test_point_system.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from typing import AsyncIterator, Optional
  2. import pytest
  3. import torch
  4. from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
  5. from hivemind.proto.runtime_pb2 import ExpertRequest
  6. from hivemind.utils import MSGPackSerializer, amap_in_executor, iter_as_aiter, split_for_streaming
  7. from src.client.priority_block import DustyRemoteBlock
  8. from src.client.spending_policy import SpendingPolicyBase
  9. class SpendingPolicyTest(SpendingPolicyBase):
  10. def __init__(self):
  11. self._p = {
  12. "rpc_single": 1,
  13. "rpc_stream": 2,
  14. }
  15. def get_points(self, request: ExpertRequest, method_name: str) -> float:
  16. return self._p.get(method_name, -1)
  17. class HandlerStubTest:
  18. async def rpc_single(self, input: ExpertRequest, timeout: Optional[float] = None):
  19. return input
  20. async def rpc_stream(self, input: AsyncIterator[ExpertRequest], timeout: Optional[float] = None):
  21. return input
  22. async def rpc_info(self, input: str, timeout: Optional[float] = None):
  23. return input
  24. class RemoteBlockTest(DustyRemoteBlock):
  25. @property
  26. def _stub(self):
  27. return HandlerStubTest()
  28. @pytest.mark.forked
  29. @pytest.mark.asyncio
  30. async def test_single():
  31. remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
  32. stub = remote.stub
  33. input = torch.randn(1, 2)
  34. request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)])
  35. print(stub)
  36. out: ExpertRequest = await stub.rpc_single(request)
  37. assert out.metadata != b""
  38. assert len(out.tensors) == 1
  39. assert torch.allclose(input, deserialize_torch_tensor(out.tensors[0]))
  40. meta = MSGPackSerializer.loads(out.metadata)
  41. assert isinstance(meta, dict)
  42. assert "__dust" in meta
  43. assert meta["__dust"] == 1
  44. @pytest.mark.forked
  45. @pytest.mark.asyncio
  46. async def test_stream():
  47. remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
  48. stub = remote.stub
  49. input = torch.randn(2**21, 2)
  50. split = (p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=2**16))
  51. output_generator = await stub.rpc_stream(
  52. amap_in_executor(
  53. lambda tensor_part: ExpertRequest(uid="expert2", tensors=[tensor_part]),
  54. iter_as_aiter(split),
  55. ),
  56. )
  57. outputs_list = [part async for part in output_generator]
  58. assert len(outputs_list) == 2**5 * 8
  59. assert outputs_list[0].metadata != b""
  60. for i in range(1, len(outputs_list)):
  61. assert outputs_list[i].metadata == b""
  62. meta = MSGPackSerializer.loads(outputs_list[0].metadata)
  63. assert isinstance(meta, dict)
  64. assert "__dust" in meta
  65. assert meta["__dust"] == 2
  66. results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
  67. assert len(results) == 1
  68. assert torch.allclose(results[0], input)
  69. @pytest.mark.forked
  70. @pytest.mark.asyncio
  71. async def test_no_wrapper():
  72. remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
  73. stub = remote.stub
  74. test = await stub.rpc_info("Test")
  75. assert test == "Test"