|
@@ -3,7 +3,7 @@ from typing import AsyncIterator
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
-from hivemind.p2p import P2P, P2PContext, ServicerBase
|
|
|
+from hivemind.p2p import P2P, P2PContext, P2PDaemonError, ServicerBase
|
|
|
from hivemind.proto import test_pb2
|
|
|
from hivemind.utils.asyncio import anext
|
|
|
|
|
@@ -17,35 +17,37 @@ async def server_client():
|
|
|
await asyncio.gather(server.shutdown(), client.shutdown())
|
|
|
|
|
|
|
|
|
+class UnaryUnaryServicer(ServicerBase):
|
|
|
+ async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
|
|
|
+ return test_pb2.TestResponse(number=request.number**2)
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_unary_unary(server_client):
|
|
|
- class ExampleServicer(ServicerBase):
|
|
|
- async def rpc_square(self, request: test_pb2.TestRequest, _context: P2PContext) -> test_pb2.TestResponse:
|
|
|
- return test_pb2.TestResponse(number=request.number**2)
|
|
|
-
|
|
|
server, client = server_client
|
|
|
- servicer = ExampleServicer()
|
|
|
+ servicer = UnaryUnaryServicer()
|
|
|
await servicer.add_p2p_handlers(server)
|
|
|
- stub = ExampleServicer.get_stub(client, server.peer_id)
|
|
|
+ stub = UnaryUnaryServicer.get_stub(client, server.peer_id)
|
|
|
|
|
|
assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
|
|
|
|
|
|
|
|
|
+class StreamUnaryServicer(ServicerBase):
|
|
|
+ async def rpc_sum(
|
|
|
+ self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
|
|
|
+ ) -> test_pb2.TestResponse:
|
|
|
+ result = 0
|
|
|
+ async for item in stream:
|
|
|
+ result += item.number
|
|
|
+ return test_pb2.TestResponse(number=result)
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_stream_unary(server_client):
|
|
|
- class ExampleServicer(ServicerBase):
|
|
|
- async def rpc_sum(
|
|
|
- self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
|
|
|
- ) -> test_pb2.TestResponse:
|
|
|
- result = 0
|
|
|
- async for item in stream:
|
|
|
- result += item.number
|
|
|
- return test_pb2.TestResponse(number=result)
|
|
|
-
|
|
|
server, client = server_client
|
|
|
- servicer = ExampleServicer()
|
|
|
+ servicer = StreamUnaryServicer()
|
|
|
await servicer.add_p2p_handlers(server)
|
|
|
- stub = ExampleServicer.get_stub(client, server.peer_id)
|
|
|
+ stub = StreamUnaryServicer.get_stub(client, server.peer_id)
|
|
|
|
|
|
async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
|
|
|
for i in range(10):
|
|
@@ -54,42 +56,40 @@ async def test_stream_unary(server_client):
|
|
|
assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
|
|
|
|
|
|
|
|
|
+class UnaryStreamServicer(ServicerBase):
|
|
|
+ async def rpc_count(
|
|
|
+ self, request: test_pb2.TestRequest, _context: P2PContext
|
|
|
+ ) -> AsyncIterator[test_pb2.TestResponse]:
|
|
|
+ for i in range(request.number):
|
|
|
+ yield test_pb2.TestResponse(number=i)
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_unary_stream(server_client):
|
|
|
- class ExampleServicer(ServicerBase):
|
|
|
- async def rpc_count(
|
|
|
- self, request: test_pb2.TestRequest, _context: P2PContext
|
|
|
- ) -> AsyncIterator[test_pb2.TestResponse]:
|
|
|
- for i in range(request.number):
|
|
|
- yield test_pb2.TestResponse(number=i)
|
|
|
-
|
|
|
server, client = server_client
|
|
|
- servicer = ExampleServicer()
|
|
|
+ servicer = UnaryStreamServicer()
|
|
|
await servicer.add_p2p_handlers(server)
|
|
|
- stub = ExampleServicer.get_stub(client, server.peer_id)
|
|
|
+ stub = UnaryStreamServicer.get_stub(client, server.peer_id)
|
|
|
|
|
|
stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
|
|
|
- i = 0
|
|
|
- async for item in stream:
|
|
|
- assert item == test_pb2.TestResponse(number=i)
|
|
|
- i += 1
|
|
|
- assert i == 10
|
|
|
+ assert [item.number async for item in stream] == list(range(10))
|
|
|
+
|
|
|
+
|
|
|
+class StreamStreamServicer(ServicerBase):
|
|
|
+ async def rpc_powers(
|
|
|
+ self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
|
|
|
+ ) -> AsyncIterator[test_pb2.TestResponse]:
|
|
|
+ async for item in stream:
|
|
|
+ yield test_pb2.TestResponse(number=item.number**2)
|
|
|
+ yield test_pb2.TestResponse(number=item.number**3)
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_stream_stream(server_client):
|
|
|
- class ExampleServicer(ServicerBase):
|
|
|
- async def rpc_powers(
|
|
|
- self, stream: AsyncIterator[test_pb2.TestRequest], _context: P2PContext
|
|
|
- ) -> AsyncIterator[test_pb2.TestResponse]:
|
|
|
- async for item in stream:
|
|
|
- yield test_pb2.TestResponse(number=item.number**2)
|
|
|
- yield test_pb2.TestResponse(number=item.number**3)
|
|
|
-
|
|
|
server, client = server_client
|
|
|
- servicer = ExampleServicer()
|
|
|
+ servicer = StreamStreamServicer()
|
|
|
await servicer.add_p2p_handlers(server)
|
|
|
- stub = ExampleServicer.get_stub(client, server.peer_id)
|
|
|
+ stub = StreamStreamServicer.get_stub(client, server.peer_id)
|
|
|
|
|
|
async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
|
|
|
for i in range(10):
|
|
@@ -153,3 +153,43 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
|
|
|
|
|
|
await asyncio.sleep(0.25)
|
|
|
assert handler_cancelled
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_removing_unary_handlers(server_client):
|
|
|
+ server1, client = server_client
|
|
|
+ server2 = await P2P.replicate(server1.daemon_listen_maddr)
|
|
|
+ servicer = UnaryUnaryServicer()
|
|
|
+ stub = UnaryUnaryServicer.get_stub(client, server1.peer_id)
|
|
|
+
|
|
|
+ for server in [server1, server2, server1]:
|
|
|
+ await servicer.add_p2p_handlers(server)
|
|
|
+ assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
|
|
|
+
|
|
|
+ await servicer.remove_p2p_handlers(server)
|
|
|
+ with pytest.raises((P2PDaemonError, ConnectionError)):
|
|
|
+ await stub.rpc_square(test_pb2.TestRequest(number=10))
|
|
|
+
|
|
|
+ await asyncio.gather(server2.shutdown())
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_removing_stream_handlers(server_client):
|
|
|
+ server1, client = server_client
|
|
|
+ server2 = await P2P.replicate(server1.daemon_listen_maddr)
|
|
|
+ servicer = UnaryStreamServicer()
|
|
|
+ stub = UnaryStreamServicer.get_stub(client, server1.peer_id)
|
|
|
+
|
|
|
+ for server in [server1, server2, server1]:
|
|
|
+ await servicer.add_p2p_handlers(server)
|
|
|
+ stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
|
|
|
+ assert [item.number async for item in stream] == list(range(10))
|
|
|
+
|
|
|
+ await servicer.remove_p2p_handlers(server)
|
|
|
+ with pytest.raises((P2PDaemonError, ConnectionError)):
|
|
|
+ stream = await stub.rpc_count(test_pb2.TestRequest(number=10))
|
|
|
+ outputs = [item.number async for item in stream]
|
|
|
+ if not outputs:
|
|
|
+ raise P2PDaemonError("Daemon has reset the connection")
|
|
|
+
|
|
|
+ await asyncio.gather(server2.shutdown())
|