test_p2p_servicer.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import asyncio
  2. from typing import AsyncIterator
  3. import pytest
  4. from hivemind.p2p import P2P, P2PContext, ServicerBase
  5. from hivemind.proto import test_pb2
  6. @pytest.fixture
  7. async def server_client():
  8. server = await P2P.create()
  9. client = await P2P.create(initial_peers=await server.get_visible_maddrs())
  10. yield server, client
  11. await asyncio.gather(server.shutdown(), client.shutdown())
  12. @pytest.mark.asyncio
  13. async def test_unary_unary(server_client):
  14. class ExampleServicer(ServicerBase):
  15. async def rpc_square(self, request: test_pb2.TestRequest, _: P2PContext) -> test_pb2.TestResponse:
  16. return test_pb2.TestResponse(number=request.number ** 2)
  17. server, client = server_client
  18. servicer = ExampleServicer()
  19. await servicer.add_p2p_handlers(server)
  20. stub = servicer.get_stub(client, server.id)
  21. assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
  22. @pytest.mark.asyncio
  23. async def test_stream_unary(server_client):
  24. class ExampleServicer(ServicerBase):
  25. async def rpc_sum(self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext) -> test_pb2.TestResponse:
  26. result = 0
  27. async for item in request:
  28. result += item.number
  29. return test_pb2.TestResponse(number=result)
  30. server, client = server_client
  31. servicer = ExampleServicer()
  32. await servicer.add_p2p_handlers(server)
  33. stub = servicer.get_stub(client, server.id)
  34. async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
  35. for i in range(10):
  36. yield test_pb2.TestRequest(number=i)
  37. assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
  38. @pytest.mark.asyncio
  39. async def test_unary_stream(server_client):
  40. class ExampleServicer(ServicerBase):
  41. async def rpc_count(
  42. self, request: test_pb2.TestRequest, _: P2PContext
  43. ) -> AsyncIterator[test_pb2.TestResponse]:
  44. for i in range(request.number):
  45. yield test_pb2.TestResponse(number=i)
  46. server, client = server_client
  47. servicer = ExampleServicer()
  48. await servicer.add_p2p_handlers(server)
  49. stub = servicer.get_stub(client, server.id)
  50. i = 0
  51. async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
  52. assert item == test_pb2.TestResponse(number=i)
  53. i += 1
  54. assert i == 10
  55. @pytest.mark.asyncio
  56. async def test_stream_stream(server_client):
  57. class ExampleServicer(ServicerBase):
  58. async def rpc_powers(
  59. self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext
  60. ) -> AsyncIterator[test_pb2.TestResponse]:
  61. async for item in request:
  62. yield test_pb2.TestResponse(number=item.number ** 2)
  63. yield test_pb2.TestResponse(number=item.number ** 3)
  64. server, client = server_client
  65. servicer = ExampleServicer()
  66. await servicer.add_p2p_handlers(server)
  67. stub = servicer.get_stub(client, server.id)
  68. async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
  69. for i in range(10):
  70. yield test_pb2.TestRequest(number=i)
  71. i = 0
  72. async for item in stub.rpc_powers(generate_requests()):
  73. if i % 2 == 0:
  74. assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
  75. else:
  76. assert item == test_pb2.TestResponse(number=(i // 2) ** 3)
  77. i += 1
  78. @pytest.mark.parametrize(
  79. "cancel_reason",
  80. ["close_connection", "close_generator"],
  81. )
  82. @pytest.mark.asyncio
  83. async def test_unary_stream_cancel(server_client, cancel_reason):
  84. handler_cancelled = False
  85. class ExampleServicer(ServicerBase):
  86. async def rpc_wait(self, request: test_pb2.TestRequest, _: P2PContext) -> AsyncIterator[test_pb2.TestResponse]:
  87. try:
  88. yield test_pb2.TestResponse(number=request.number + 1)
  89. await asyncio.sleep(2)
  90. yield test_pb2.TestResponse(number=request.number + 2)
  91. except asyncio.CancelledError:
  92. nonlocal handler_cancelled
  93. handler_cancelled = True
  94. raise
  95. server, client = server_client
  96. servicer = ExampleServicer()
  97. await servicer.add_p2p_handlers(server)
  98. if cancel_reason == "close_connection":
  99. _, reader, writer = await client.call_binary_stream_handler(server.id, "ExampleServicer.rpc_wait")
  100. await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
  101. await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
  102. response, _ = await P2P.receive_protobuf(test_pb2.TestResponse, reader)
  103. assert response == test_pb2.TestResponse(number=11)
  104. await asyncio.sleep(0.25)
  105. writer.close()
  106. elif cancel_reason == "close_generator":
  107. stub = servicer.get_stub(client, server.id)
  108. iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
  109. assert await iter.__anext__() == test_pb2.TestResponse(number=11)
  110. await asyncio.sleep(0.25)
  111. await iter.aclose()
  112. else:
  113. assert False, f"Unknown cancel_reason = `{cancel_reason}`"
  114. await asyncio.sleep(0.25)
  115. assert handler_cancelled