浏览代码

fix suggesstions

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Denis Mazur 4 年之前
父节点
当前提交
2292d62119
共有 3 个文件被更改,包括 12 次插入13 次删除
  1. 1 2
      hivemind/p2p/p2p_daemon.py
  2. 7 7
      hivemind/p2p/p2p_daemon_bindings/control.py
  3. 4 4
      tests/test_p2p_daemon_bindings.py

+ 1 - 2
hivemind/p2p/p2p_daemon.py

@@ -446,8 +446,7 @@ class P2P:
         if not isinstance(input, AsyncIterableABC):
         if not isinstance(input, AsyncIterableABC):
             return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
             return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
 
 
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
         return await asingle(responses)
 
 
     async def _call_unary_protobuf_handler(
     async def _call_unary_protobuf_handler(

+ 7 - 7
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -5,7 +5,7 @@ Author: Kevin Mai-Husan Chia
 """
 """
 
 
 import asyncio
 import asyncio
-import uuid
+from uuid import UUID, uuid4
 from contextlib import asynccontextmanager, closing
 from contextlib import asynccontextmanager, closing
 from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
 from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
 
 
@@ -67,7 +67,7 @@ class DaemonConnector:
 
 
 
 
 TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
 TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
-CallID = uuid.UUID
+CallID = UUID
 
 
 
 
 class ControlClient:
 class ControlClient:
@@ -149,7 +149,7 @@ class ControlClient:
             resp = p2pd_pb.PersistentConnectionResponse()
             resp = p2pd_pb.PersistentConnectionResponse()
             await read_pbmsg_safe(reader, resp)
             await read_pbmsg_safe(reader, resp)
 
 
-            call_id = uuid.UUID(bytes=resp.callId)
+            call_id = UUID(bytes=resp.callId)
 
 
             if resp.HasField("callUnaryResponse"):
             if resp.HasField("callUnaryResponse"):
                 if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
                 if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
@@ -158,7 +158,7 @@ class ControlClient:
                     remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
                     remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
                     self._pending_calls[call_id].set_exception(remote_exc)
                     self._pending_calls[call_id].set_exception(remote_exc)
                 else:
                 else:
-                    logger.debug("received unexpected unary call")
+                    logger.debug("received unexpected unary call:", resp)
 
 
             elif resp.HasField("requestHandling"):
             elif resp.HasField("requestHandling"):
                 handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
                 handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
@@ -173,7 +173,7 @@ class ControlClient:
                 msg = await self._pending_messages.get()
                 msg = await self._pending_messages.get()
                 await write_pbmsg(writer, msg)
                 await write_pbmsg(writer, msg)
 
 
-    async def _handle_persistent_request(self, call_id: uuid.UUID, request: p2pd_pb.CallUnaryRequest):
+    async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest):
         if request.proto not in self.unary_handlers:
         if request.proto not in self.unary_handlers:
             logger.warning(f"Protocol {request.proto} not supported")
             logger.warning(f"Protocol {request.proto} not supported")
             return
             return
@@ -194,7 +194,7 @@ class ControlClient:
         )
         )
         self._handler_tasks.pop(call_id)
         self._handler_tasks.pop(call_id)
 
 
-    async def _cancel_unary_call(self, call_id: uuid.UUID):
+    async def _cancel_unary_call(self, call_id: UUID):
         await self._pending_messages.put(
         await self._pending_messages.put(
             p2pd_pb.PersistentConnectionRequest(
             p2pd_pb.PersistentConnectionRequest(
                 callId=call_id.bytes,
                 callId=call_id.bytes,
@@ -211,7 +211,7 @@ class ControlClient:
         self._is_persistent_conn_open = True
         self._is_persistent_conn_open = True
 
 
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
-        call_id = uuid.uuid4()
+        call_id = uuid4()
 
 
         add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
         add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
         req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
         req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)

+ 4 - 4
tests/test_p2p_daemon_bindings.py

@@ -194,19 +194,19 @@ def test_parse_conn_protocol_invalid(maddr_str):
 
 
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_client_ctor_control_maddr(control_maddr_str):
+async def test_client_create_control_maddr(control_maddr_str):
     c = DaemonConnector(Multiaddr(control_maddr_str))
     c = DaemonConnector(Multiaddr(control_maddr_str))
     assert c.control_maddr == Multiaddr(control_maddr_str)
     assert c.control_maddr == Multiaddr(control_maddr_str)
 
 
 
 
-def test_client_ctor_default_control_maddr():
+def test_client_create_default_control_maddr():
     c = DaemonConnector()
     c = DaemonConnector()
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
 
 
 
 
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_control_client_ctor_listen_maddr(listen_maddr_str):
+async def test_control_client_create_listen_maddr(listen_maddr_str):
     c = await ControlClient.create(
     c = await ControlClient.create(
         daemon_connector=DaemonConnector(),
         daemon_connector=DaemonConnector(),
         listen_maddr=Multiaddr(listen_maddr_str),
         listen_maddr=Multiaddr(listen_maddr_str),
@@ -216,7 +216,7 @@ async def test_control_client_ctor_listen_maddr(listen_maddr_str):
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_control_client_ctor_default_listen_maddr():
+async def test_control_client_create_default_listen_maddr():
     c = await ControlClient.create(daemon_connector=DaemonConnector(), use_persistent_conn=False)
     c = await ControlClient.create(daemon_connector=DaemonConnector(), use_persistent_conn=False)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)