Denis Mazur пре 4 година
родитељ
комит
7d8eb40716
2 измењених фајлова са 9 додато и 9 уклоњено
  1. 0 0
      hivemind/moe/server/connection_handler_p2p.py
  2. 9 9
      hivemind/p2p/p2p_daemon.py

+ 0 - 0
hivemind/moe/server/connection_handler_p2p.py


+ 9 - 9
hivemind/p2p/p2p_daemon.py

@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
 from importlib.resources import path
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
 
 from google.protobuf.message import Message
 from google.protobuf.message import Message
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
@@ -263,7 +263,7 @@ class P2P:
 
 
     @staticmethod
     @staticmethod
     async def receive_protobuf(
     async def receive_protobuf(
-        input_protobuf_type: Message, reader: asyncio.StreamReader
+        input_protobuf_type: Type[Message], reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
         if msg_type == P2P.MESSAGE_MARKER:
@@ -284,7 +284,7 @@ class P2P:
         self,
         self,
         name: str,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: Message,
+        input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
         max_prefetch: int = 5,
     ) -> None:
     ) -> None:
         """
         """
@@ -347,7 +347,7 @@ class P2P:
         await self.add_binary_stream_handler(name, _handle_stream)
         await self.add_binary_stream_handler(name, _handle_stream)
 
 
     async def _iterate_protobuf_stream_handler(
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Message
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
     ) -> TOutputStream:
     ) -> TOutputStream:
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
 
@@ -379,7 +379,7 @@ class P2P:
         handler: Callable[
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
         ],
-        input_protobuf_type: Message,
+        input_protobuf_type: Type[Message],
         *,
         *,
         stream_input: bool = False,
         stream_input: bool = False,
         stream_output: bool = False,
         stream_output: bool = False,
@@ -411,7 +411,7 @@ class P2P:
         self,
         self,
         handle_name: str,
         handle_name: str,
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
-        input_protobuf_type: Message,
+        input_protobuf_type: Type[Message],
     ) -> None:
     ) -> None:
         """
         """
         Register a request-response (unary) handler. Unary requests and responses
         Register a request-response (unary) handler. Unary requests and responses
@@ -440,7 +440,7 @@ class P2P:
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: Message,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
     ) -> Awaitable[TOutputProtobuf]:
 
 
         if not isinstance(input, AsyncIterableABC):
         if not isinstance(input, AsyncIterableABC):
@@ -454,7 +454,7 @@ class P2P:
         peer_id: PeerID,
         peer_id: PeerID,
         handle_name: str,
         handle_name: str,
         input: TInputProtobuf,
         input: TInputProtobuf,
-        output_protobuf_type: Message,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
     ) -> Awaitable[TOutputProtobuf]:
         serialized_input = input.SerializeToString()
         serialized_input = input.SerializeToString()
         response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
         response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
@@ -465,7 +465,7 @@ class P2P:
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: Message,
+        output_protobuf_type: Type[Message],
     ) -> TOutputStream:
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)