Denis Mazur 4 жил өмнө
parent
commit
7d8eb40716

+ 0 - 0
hivemind/moe/server/connection_handler_p2p.py


+ 9 - 9
hivemind/p2p/p2p_daemon.py

@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
 from google.protobuf.message import Message
 from multiaddr import Multiaddr
@@ -263,7 +263,7 @@ class P2P:
 
     @staticmethod
     async def receive_protobuf(
-        input_protobuf_type: Message, reader: asyncio.StreamReader
+        input_protobuf_type: Type[Message], reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
@@ -284,7 +284,7 @@ class P2P:
         self,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: Message,
+        input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
     ) -> None:
         """
@@ -347,7 +347,7 @@ class P2P:
         await self.add_binary_stream_handler(name, _handle_stream)
 
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Message
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
     ) -> TOutputStream:
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
@@ -379,7 +379,7 @@ class P2P:
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
-        input_protobuf_type: Message,
+        input_protobuf_type: Type[Message],
         *,
         stream_input: bool = False,
         stream_output: bool = False,
@@ -411,7 +411,7 @@ class P2P:
         self,
         handle_name: str,
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
-        input_protobuf_type: Message,
+        input_protobuf_type: Type[Message],
     ) -> None:
         """
         Register a request-response (unary) handler. Unary requests and responses
@@ -440,7 +440,7 @@ class P2P:
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: Message,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
 
         if not isinstance(input, AsyncIterableABC):
@@ -454,7 +454,7 @@ class P2P:
         peer_id: PeerID,
         handle_name: str,
         input: TInputProtobuf,
-        output_protobuf_type: Message,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
         serialized_input = input.SerializeToString()
         response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
@@ -465,7 +465,7 @@ class P2P:
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: Message,
+        output_protobuf_type: Type[Message],
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)