test_p2p_servicer.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import asyncio
  2. from typing import AsyncIterator
  3. import pytest
  4. from hivemind.p2p import P2P, P2PContext, P2PDaemonError, ServicerBase
  5. from hivemind.proto import test_pb2
  6. from hivemind.utils.asyncio import anext
  7. @pytest.fixture
  8. async def server_client():
  9. server = await P2P.create()
  10. client = await P2P.create(initial_peers=await server.get_visible_maddrs())
  11. yield server, client
  12. await asyncio.gather(server.shutdown(), client.shutdown())
  13. class UnaryUnaryServicer(ServicerBase):
  14. async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
  15. return test_pb2.TestResponse(number=request.number**2)
  16. @pytest.mark.asyncio
  17. async def test_unary_unary(server_client):
  18. server, client = server_client
  19. servicer = UnaryUnaryServicer()
  20. await servicer.add_p2p_handlers(server)
  21. stub = UnaryUnaryServicer.get_stub(client, server.peer_id)
  22. assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
  23. class StreamUnaryServicer(ServicerBase):
  24. async def rpc_sum(
  25. self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
  26. ) -> test_pb2.TestResponse:
  27. result = 0
  28. async for item in stream:
  29. result += item.number
  30. return test_pb2.TestResponse(number=result)
  31. @pytest.mark.asyncio
  32. async def test_stream_unary(server_client):
  33. server, client = server_client
  34. servicer = StreamUnaryServicer()
  35. await servicer.add_p2p_handlers(server)
  36. stub = StreamUnaryServicer.get_stub(client, server.peer_id)
  37. async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
  38. for i in range(10):
  39. yield test_pb2.TestRequest(number=i)
  40. assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
  41. class UnaryStreamServicer(ServicerBase):
  42. async def rpc_count(
  43. self, request: test_pb2.TestRequest, _context: P2PContext
  44. ) -> AsyncIterator[test_pb2.TestResponse]:
  45. for i in range(request.number):
  46. yield test_pb2.TestResponse(number=i)
  47. @pytest.mark.asyncio
  48. async def test_unary_stream(server_client):
  49. server, client = server_client
  50. servicer = UnaryStreamServicer()
  51. await servicer.add_p2p_handlers(server)
  52. stub = UnaryStreamServicer.get_stub(client, server.peer_id)
  53. stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
  54. assert [item.number async for item in stream] == list(range(10))
  55. class StreamStreamServicer(ServicerBase):
  56. async def rpc_powers(
  57. self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
  58. ) -> AsyncIterator[test_pb2.TestResponse]:
  59. async for item in stream:
  60. yield test_pb2.TestResponse(number=item.number**2)
  61. yield test_pb2.TestResponse(number=item.number**3)
  62. @pytest.mark.asyncio
  63. async def test_stream_stream(server_client):
  64. server, client = server_client
  65. servicer = StreamStreamServicer()
  66. await servicer.add_p2p_handlers(server)
  67. stub = StreamStreamServicer.get_stub(client, server.peer_id)
  68. async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
  69. for i in range(10):
  70. yield test_pb2.TestRequest(number=i)
  71. stream = await stub.rpc_powers(generate_requests())
  72. i = 0
  73. async for item in stream:
  74. if i % 2 == 0:
  75. assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
  76. else:
  77. assert item == test_pb2.TestResponse(number=(i // 2) ** 3)
  78. i += 1
  79. @pytest.mark.parametrize(
  80. "cancel_reason",
  81. ["close_connection", "close_generator"],
  82. )
  83. @pytest.mark.asyncio
  84. async def test_unary_stream_cancel(server_client, cancel_reason):
  85. handler_cancelled = False
  86. class ExampleServicer(ServicerBase):
  87. async def rpc_wait(
  88. self, request: test_pb2.TestRequest, _context: P2PContext
  89. ) -> AsyncIterator[test_pb2.TestResponse]:
  90. try:
  91. yield test_pb2.TestResponse(number=request.number + 1)
  92. await asyncio.sleep(2)
  93. yield test_pb2.TestResponse(number=request.number + 2)
  94. except asyncio.CancelledError:
  95. nonlocal handler_cancelled
  96. handler_cancelled = True
  97. raise
  98. server, client = server_client
  99. servicer = ExampleServicer()
  100. await servicer.add_p2p_handlers(server)
  101. if cancel_reason == "close_connection":
  102. _, reader, writer = await client.call_binary_stream_handler(server.peer_id, "ExampleServicer.rpc_wait")
  103. await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
  104. await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
  105. response, _ = await P2P.receive_protobuf(test_pb2.TestResponse, reader)
  106. assert response == test_pb2.TestResponse(number=11)
  107. await asyncio.sleep(0.25)
  108. writer.close()
  109. elif cancel_reason == "close_generator":
  110. stub = ExampleServicer.get_stub(client, server.peer_id)
  111. iter = await stub.rpc_wait(test_pb2.TestRequest(number=10))
  112. assert await anext(iter) == test_pb2.TestResponse(number=11)
  113. await asyncio.sleep(0.25)
  114. await iter.aclose()
  115. else:
  116. assert False, f"Unknown cancel_reason = `{cancel_reason}`"
  117. await asyncio.sleep(0.25)
  118. assert handler_cancelled
  119. @pytest.mark.asyncio
  120. async def test_removing_unary_handlers(server_client):
  121. server1, client = server_client
  122. server2 = await P2P.replicate(server1.daemon_listen_maddr)
  123. servicer = UnaryUnaryServicer()
  124. stub = UnaryUnaryServicer.get_stub(client, server1.peer_id)
  125. for server in [server1, server2, server1]:
  126. await servicer.add_p2p_handlers(server)
  127. assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
  128. await servicer.remove_p2p_handlers(server)
  129. with pytest.raises((P2PDaemonError, ConnectionError)):
  130. await stub.rpc_square(test_pb2.TestRequest(number=10))
  131. await asyncio.gather(server2.shutdown())
  132. @pytest.mark.asyncio
  133. async def test_removing_stream_handlers(server_client):
  134. server1, client = server_client
  135. server2 = await P2P.replicate(server1.daemon_listen_maddr)
  136. servicer = UnaryStreamServicer()
  137. stub = UnaryStreamServicer.get_stub(client, server1.peer_id)
  138. for server in [server1, server2, server1]:
  139. await servicer.add_p2p_handlers(server)
  140. stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
  141. assert [item.number async for item in stream] == list(range(10))
  142. await servicer.remove_p2p_handlers(server)
  143. with pytest.raises((P2PDaemonError, ConnectionError)):
  144. stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
  145. outputs = [item.number async for item in stream]
  146. if not outputs:
  147. raise P2PDaemonError("Daemon has reset the connection")
  148. await asyncio.gather(server2.shutdown())