test_dust_payment.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from __future__ import annotations
  2. from typing import AsyncIterator, Optional
  3. import pytest
  4. import torch
  5. from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
  6. from hivemind.proto.runtime_pb2 import ExpertRequest
  7. from hivemind.utils import MSGPackSerializer, amap_in_executor, iter_as_aiter, split_for_streaming
  8. from src.client.dust_bank import DustBankBase
  9. from src.client.dusty_block import DustyRemoteBlock
  10. class DustBankTest(DustBankBase):
  11. def __init__(self):
  12. self._p = {
  13. "rpc_single": 1,
  14. "rpc_stream": 2,
  15. }
  16. def get_dust(self, request: ExpertRequest, method_name: str) -> float:
  17. return self._p.get(method_name, -1)
  18. class HandlerStubTest:
  19. async def rpc_single(self, input: ExpertRequest, timeout: Optional[float] = None):
  20. return input
  21. async def rpc_stream(self, input: AsyncIterator[ExpertRequest], timeout: Optional[float] = None):
  22. async for i in input:
  23. yield i
  24. async def rpc_info(self, input: str, timeout: Optional[float] = None):
  25. return input
  26. class RemoteBlockTest(DustyRemoteBlock):
  27. @property
  28. def _stub(self):
  29. return HandlerStubTest()
  30. @pytest.mark.asyncio
  31. async def test_single():
  32. remote = RemoteBlockTest(DustBankTest(), None, None)
  33. stub = remote.stub
  34. input = torch.randn(1, 2)
  35. request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)])
  36. out: ExpertRequest = await stub.rpc_single(request)
  37. assert out.metadata != b""
  38. assert len(out.tensors) == 1
  39. assert torch.allclose(input, deserialize_torch_tensor(out.tensors[0]))
  40. meta = MSGPackSerializer.loads(out.metadata)
  41. assert isinstance(meta, dict)
  42. assert "__dust" in meta
  43. assert meta["__dust"] == 1
  44. @pytest.mark.asyncio
  45. async def test_stream():
  46. remote = RemoteBlockTest(DustBankTest(), None, None)
  47. stub = remote.stub
  48. input = torch.randn(2**21, 2)
  49. split = (p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=2**16))
  50. output_generator = await stub.rpc_stream(
  51. amap_in_executor(
  52. lambda tensor_part: ExpertRequest(uid="expert2", tensors=[tensor_part]),
  53. iter_as_aiter(split),
  54. ),
  55. )
  56. outputs_list = [part async for part in output_generator]
  57. assert len(outputs_list) == 2**5 * 8
  58. assert outputs_list[0].metadata != b""
  59. for i in range(1, len(outputs_list)):
  60. assert outputs_list[i].metadata == b""
  61. meta = MSGPackSerializer.loads(outputs_list[0].metadata)
  62. assert isinstance(meta, dict)
  63. assert "__dust" in meta
  64. assert meta["__dust"] == 2
  65. results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
  66. assert len(results) == 1
  67. assert torch.allclose(results[0], input)