test_p2p_servicer.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import asyncio
  2. from typing import AsyncIterator
  3. import pytest
  4. from hivemind.p2p import P2P, P2PContext, ServicerBase
  5. from hivemind.proto import test_pb2
  6. @pytest.fixture
  7. async def server_client():
  8. server = await P2P.create()
  9. client = await P2P.create(initial_peers=await server.get_visible_maddrs())
  10. yield server, client
  11. await asyncio.gather(server.shutdown(), client.shutdown())
  12. @pytest.mark.asyncio
  13. async def test_unary_unary(server_client):
  14. class ExampleServicer(ServicerBase):
  15. async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
  16. return test_pb2.TestResponse(number=request.number ** 2)
  17. server, client = server_client
  18. servicer = ExampleServicer()
  19. await servicer.add_p2p_handlers(server)
  20. stub = ExampleServicer.get_stub(client, server.peer_id)
  21. assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
  22. @pytest.mark.asyncio
  23. async def test_stream_unary(server_client):
  24. class ExampleServicer(ServicerBase):
  25. async def rpc_sum(
  26. self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
  27. ) -> test_pb2.TestResponse:
  28. result = 0
  29. async for item in stream:
  30. result += item.number
  31. return test_pb2.TestResponse(number=result)
  32. server, client = server_client
  33. servicer = ExampleServicer()
  34. await servicer.add_p2p_handlers(server)
  35. stub = ExampleServicer.get_stub(client, server.peer_id)
  36. async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
  37. for i in range(10):
  38. yield test_pb2.TestRequest(number=i)
  39. assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
  40. @pytest.mark.asyncio
  41. async def test_unary_stream(server_client):
  42. class ExampleServicer(ServicerBase):
  43. async def rpc_count(
  44. self, request: test_pb2.TestRequest, _context: P2PContext
  45. ) -> AsyncIterator[test_pb2.TestResponse]:
  46. for i in range(request.number):
  47. yield test_pb2.TestResponse(number=i)
  48. server, client = server_client
  49. servicer = ExampleServicer()
  50. await servicer.add_p2p_handlers(server)
  51. stub = ExampleServicer.get_stub(client, server.peer_id)
  52. i = 0
  53. async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
  54. assert item == test_pb2.TestResponse(number=i)
  55. i += 1
  56. assert i == 10
  57. @pytest.mark.asyncio
  58. async def test_stream_stream(server_client):
  59. class ExampleServicer(ServicerBase):
  60. async def rpc_powers(
  61. self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
  62. ) -> AsyncIterator[test_pb2.TestResponse]:
  63. async for item in stream:
  64. yield test_pb2.TestResponse(number=item.number ** 2)
  65. yield test_pb2.TestResponse(number=item.number ** 3)
  66. server, client = server_client
  67. servicer = ExampleServicer()
  68. await servicer.add_p2p_handlers(server)
  69. stub = ExampleServicer.get_stub(client, server.peer_id)
  70. async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
  71. for i in range(10):
  72. yield test_pb2.TestRequest(number=i)
  73. i = 0
  74. async for item in stub.rpc_powers(generate_requests()):
  75. if i % 2 == 0:
  76. assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
  77. else:
  78. assert item == test_pb2.TestResponse(number=(i // 2) ** 3)
  79. i += 1
  80. @pytest.mark.parametrize(
  81. "cancel_reason",
  82. ["close_connection", "close_generator"],
  83. )
  84. @pytest.mark.asyncio
  85. async def test_unary_stream_cancel(server_client, cancel_reason):
  86. handler_cancelled = False
  87. class ExampleServicer(ServicerBase):
  88. async def rpc_wait(
  89. self, request: test_pb2.TestRequest, _context: P2PContext
  90. ) -> AsyncIterator[test_pb2.TestResponse]:
  91. try:
  92. yield test_pb2.TestResponse(number=request.number + 1)
  93. await asyncio.sleep(2)
  94. yield test_pb2.TestResponse(number=request.number + 2)
  95. except asyncio.CancelledError:
  96. nonlocal handler_cancelled
  97. handler_cancelled = True
  98. raise
  99. server, client = server_client
  100. servicer = ExampleServicer()
  101. await servicer.add_p2p_handlers(server)
  102. if cancel_reason == "close_connection":
  103. _, reader, writer = await client.call_binary_stream_handler(server.peer_id, "ExampleServicer.rpc_wait")
  104. await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
  105. await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
  106. response, _ = await P2P.receive_protobuf(test_pb2.TestResponse, reader)
  107. assert response == test_pb2.TestResponse(number=11)
  108. await asyncio.sleep(0.25)
  109. writer.close()
  110. elif cancel_reason == "close_generator":
  111. stub = ExampleServicer.get_stub(client, server.peer_id)
  112. iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
  113. assert await iter.__anext__() == test_pb2.TestResponse(number=11)
  114. await asyncio.sleep(0.25)
  115. await iter.aclose()
  116. else:
  117. assert False, f"Unknown cancel_reason = `{cancel_reason}`"
  118. await asyncio.sleep(0.25)
  119. assert handler_cancelled