test_p2p_servicer.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import asyncio
  2. from typing import AsyncIterator
  3. import pytest
  4. from hivemind.p2p import P2P, P2PContext, ServicerBase
  5. from hivemind.proto import test_pb2
  6. from hivemind.utils.asyncio import anext
  7. @pytest.fixture
  8. async def server_client():
  9. server = await P2P.create()
  10. client = await P2P.create(initial_peers=await server.get_visible_maddrs())
  11. yield server, client
  12. await asyncio.gather(server.shutdown(), client.shutdown())
  13. @pytest.mark.asyncio
  14. async def test_unary_unary(server_client):
  15. class ExampleServicer(ServicerBase):
  16. async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
  17. return test_pb2.TestResponse(number=request.number**2)
  18. server, client = server_client
  19. servicer = ExampleServicer()
  20. await servicer.add_p2p_handlers(server)
  21. stub = ExampleServicer.get_stub(client, server.peer_id)
  22. assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
  23. @pytest.mark.asyncio
  24. async def test_stream_unary(server_client):
  25. class ExampleServicer(ServicerBase):
  26. async def rpc_sum(
  27. self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
  28. ) -> test_pb2.TestResponse:
  29. result = 0
  30. async for item in stream:
  31. result += item.number
  32. return test_pb2.TestResponse(number=result)
  33. server, client = server_client
  34. servicer = ExampleServicer()
  35. await servicer.add_p2p_handlers(server)
  36. stub = ExampleServicer.get_stub(client, server.peer_id)
  37. async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
  38. for i in range(10):
  39. yield test_pb2.TestRequest(number=i)
  40. assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
  41. @pytest.mark.asyncio
  42. async def test_unary_stream(server_client):
  43. class ExampleServicer(ServicerBase):
  44. async def rpc_count(
  45. self, request: test_pb2.TestRequest, _context: P2PContext
  46. ) -> AsyncIterator[test_pb2.TestResponse]:
  47. for i in range(request.number):
  48. yield test_pb2.TestResponse(number=i)
  49. server, client = server_client
  50. servicer = ExampleServicer()
  51. await servicer.add_p2p_handlers(server)
  52. stub = ExampleServicer.get_stub(client, server.peer_id)
  53. stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
  54. i = 0
  55. async for item in stream:
  56. assert item == test_pb2.TestResponse(number=i)
  57. i += 1
  58. assert i == 10
  59. @pytest.mark.asyncio
  60. async def test_stream_stream(server_client):
  61. class ExampleServicer(ServicerBase):
  62. async def rpc_powers(
  63. self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
  64. ) -> AsyncIterator[test_pb2.TestResponse]:
  65. async for item in stream:
  66. yield test_pb2.TestResponse(number=item.number**2)
  67. yield test_pb2.TestResponse(number=item.number**3)
  68. server, client = server_client
  69. servicer = ExampleServicer()
  70. await servicer.add_p2p_handlers(server)
  71. stub = ExampleServicer.get_stub(client, server.peer_id)
  72. async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
  73. for i in range(10):
  74. yield test_pb2.TestRequest(number=i)
  75. stream = await stub.rpc_powers(generate_requests())
  76. i = 0
  77. async for item in stream:
  78. if i % 2 == 0:
  79. assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
  80. else:
  81. assert item == test_pb2.TestResponse(number=(i // 2) ** 3)
  82. i += 1
  83. @pytest.mark.parametrize(
  84. "cancel_reason",
  85. ["close_connection", "close_generator"],
  86. )
  87. @pytest.mark.asyncio
  88. async def test_unary_stream_cancel(server_client, cancel_reason):
  89. handler_cancelled = False
  90. class ExampleServicer(ServicerBase):
  91. async def rpc_wait(
  92. self, request: test_pb2.TestRequest, _context: P2PContext
  93. ) -> AsyncIterator[test_pb2.TestResponse]:
  94. try:
  95. yield test_pb2.TestResponse(number=request.number + 1)
  96. await asyncio.sleep(2)
  97. yield test_pb2.TestResponse(number=request.number + 2)
  98. except asyncio.CancelledError:
  99. nonlocal handler_cancelled
  100. handler_cancelled = True
  101. raise
  102. server, client = server_client
  103. servicer = ExampleServicer()
  104. await servicer.add_p2p_handlers(server)
  105. if cancel_reason == "close_connection":
  106. _, reader, writer = await client.call_binary_stream_handler(server.peer_id, "ExampleServicer.rpc_wait")
  107. await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
  108. await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
  109. response, _ = await P2P.receive_protobuf(test_pb2.TestResponse, reader)
  110. assert response == test_pb2.TestResponse(number=11)
  111. await asyncio.sleep(0.25)
  112. writer.close()
  113. elif cancel_reason == "close_generator":
  114. stub = ExampleServicer.get_stub(client, server.peer_id)
  115. iter = await stub.rpc_wait(test_pb2.TestRequest(number=10))
  116. assert await anext(iter) == test_pb2.TestResponse(number=11)
  117. await asyncio.sleep(0.25)
  118. await iter.aclose()
  119. else:
  120. assert False, f"Unknown cancel_reason = `{cancel_reason}`"
  121. await asyncio.sleep(0.25)
  122. assert handler_cancelled