|
@@ -28,7 +28,6 @@ class P2PContext(object):
|
|
|
handle_name: str
|
|
|
local_id: PeerID
|
|
|
remote_id: PeerID = None
|
|
|
- remote_maddr: Multiaddr = None
|
|
|
|
|
|
|
|
|
class P2P:
|
|
@@ -169,11 +168,11 @@ class P2P:
|
|
|
|
|
|
return self
|
|
|
|
|
|
- async def add_unary_handler(self, proto: str, handler: p2pclient.TUnaryHandler):
|
|
|
- return await self._client.add_unary_handler(proto, handler)
|
|
|
+ async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
|
|
|
+ return await self._client.add_unary_handler(handle_name, handler)
|
|
|
|
|
|
- async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
|
|
|
- return await self._client.unary_call(peer_id, proto, data)
|
|
|
+ async def unary_call(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
|
|
|
+ return await self._client.unary_call(peer_id, handle_name, data)
|
|
|
|
|
|
async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
|
|
|
for try_number in range(ping_n_attempts):
|
|
@@ -323,7 +322,6 @@ class P2P:
|
|
|
handle_name=name,
|
|
|
local_id=self.id,
|
|
|
remote_id=stream_info.peer_id,
|
|
|
- remote_maddr=stream_info.addr,
|
|
|
)
|
|
|
requests = asyncio.Queue(max_prefetch)
|
|
|
|
|
@@ -408,16 +406,12 @@ class P2P:
|
|
|
(not just ``TInputProtobuf``) as input.
|
|
|
"""
|
|
|
|
|
|
- async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
|
|
|
- if stream_input:
|
|
|
- input = requests
|
|
|
- else:
|
|
|
- count = 0
|
|
|
- async for input in requests:
|
|
|
- count += 1
|
|
|
- if count != 1:
|
|
|
- raise ValueError(f"Got {count} requests for handler {name} instead of one")
|
|
|
+ if not stream_input:
|
|
|
+ await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
|
|
|
+ return
|
|
|
|
|
|
+ async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
|
|
|
+ input = requests
|
|
|
output = handler(input, context)
|
|
|
|
|
|
if isinstance(output, AsyncIterableABC):
|
|
@@ -428,6 +422,27 @@ class P2P:
|
|
|
|
|
|
await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
|
|
|
|
|
|
+ # only registers request-response handlers
|
|
|
+ async def _add_protobuf_unary_handler(
|
|
|
+ self,
|
|
|
+ handle_name: str,
|
|
|
+ handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
|
|
|
+ input_protobuf_type: type,
|
|
|
+ ) -> None:
|
|
|
+
|
|
|
+ async def _unary_handler(request: bytes) -> bytes:
|
|
|
+ input_serialized = input_protobuf_type().FromString(request)
|
|
|
+ context = P2PContext(
|
|
|
+ handle_name=handle_name,
|
|
|
+ local_id=self.id,
|
|
|
+ # TODO: add remote id
|
|
|
+ )
|
|
|
+
|
|
|
+ response = await handler(input_serialized, context)
|
|
|
+ return response.SerializeToString()
|
|
|
+
|
|
|
+ await self.add_unary_handler(handle_name, _unary_handler)
|
|
|
+
|
|
|
async def call_protobuf_handler(
|
|
|
self,
|
|
|
peer_id: PeerID,
|
|
@@ -435,8 +450,11 @@ class P2P:
|
|
|
input: Union[TInputProtobuf, TInputStream],
|
|
|
output_protobuf_type: type,
|
|
|
) -> Awaitable[TOutputProtobuf]:
|
|
|
- requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
|
|
|
- responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
|
|
|
+
|
|
|
+ if not isinstance(input, AsyncIterableABC):
|
|
|
+ return self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
|
|
|
+
|
|
|
+ responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
|
|
|
|
|
|
count = 0
|
|
|
async for response in responses:
|
|
@@ -445,6 +463,18 @@ class P2P:
|
|
|
raise ValueError(f"Got {count} responses from handler {name} instead of one")
|
|
|
return response
|
|
|
|
|
|
+ async def _call_unary_protobuf_handler(
|
|
|
+ self,
|
|
|
+ peer_id: PeerID,
|
|
|
+ handle_name: str,
|
|
|
+ input: TInputProtobuf,
|
|
|
+ output_protobuf_type: type,
|
|
|
+ ) -> Awaitable[TOutputProtobuf]:
|
|
|
+ serialized_input = input.SerializeToString()
|
|
|
+ response = await self.unary_call(peer_id, handle_name, serialized_input)
|
|
|
+ return output_protobuf_type().FromString(response)
|
|
|
+
|
|
|
+
|
|
|
def iterate_protobuf_handler(
|
|
|
self,
|
|
|
peer_id: PeerID,
|