Browse Source

pass remote id to handler context

Denis Mazur 4 năm trước cách đây
mục cha
commit
864cb3e3ac

+ 7 - 6
hivemind/p2p/p2p_daemon.py

@@ -4,6 +4,7 @@ import secrets
 from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
+from google.protobuf.message import Message
 from importlib.resources import path
 from subprocess import Popen
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
@@ -171,8 +172,8 @@ class P2P:
     async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
         return await self._client.add_unary_handler(handle_name, handler)
 
-    async def unary_call(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
-        return await self._client.unary_call(peer_id, handle_name, data)
+    async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
+        return await self._client.call_unary_handler(peer_id, handle_name, data)
 
     async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
         for try_number in range(ping_n_attempts):
@@ -429,12 +430,12 @@ class P2P:
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         input_protobuf_type: type,
     ) -> None:
-        async def _unary_handler(request: bytes) -> bytes:
-            input_serialized = input_protobuf_type().FromString(request)
+        async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
+            input_serialized = input_protobuf_type.FromString(request)
             context = P2PContext(
                 handle_name=handle_name,
                 local_id=self.id,
-                # TODO: add remote id
+                remote_id=remote_id,
             )
 
             response = await handler(input_serialized, context)
@@ -470,7 +471,7 @@ class P2P:
         output_protobuf_type: type,
     ) -> Awaitable[TOutputProtobuf]:
         serialized_input = input.SerializeToString()
-        response = await self.unary_call(peer_id, handle_name, serialized_input)
+        response = await self.call_unary_handler(peer_id, handle_name, serialized_input)
         return output_protobuf_type().FromString(response)
 
     def iterate_protobuf_handler(

+ 4 - 3
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -66,7 +66,7 @@ class DaemonConnector:
         return reader, writer
 
 
-TUnaryHandler = Callable[[bytes], Awaitable[bytes]]
+TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
 CallID = uuid.UUID
 
 
@@ -132,7 +132,8 @@ class ControlClient:
         assert request.proto in self.unary_handlers
 
         try:
-            response_payload: bytes = self.unary_handlers[request.proto](request.data)
+            remote_id = PeerID(request.peer)
+            response_payload: bytes = self.unary_handlers[request.proto](request.data, remote_id)
             response = p2pd_pb.CallUnaryResponse(callId=request.callId, result=response_payload)
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(callId=request.callId, error=repr(e))
@@ -174,7 +175,7 @@ class ControlClient:
             raise ValueError(f"Handler for protocol {proto} already assigned")
         self.unary_handlers[proto] = handler
 
-    async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         call_id = uuid.uuid4()
         call_unary_req = p2pd_pb.CallUnaryRequest(
             peer=peer_id.to_bytes(),

+ 2 - 2
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -33,8 +33,8 @@ class Client:
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
         await self.control.add_unary_handler(proto, handler)
 
-    async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
-        return await self.control.unary_call(peer_id, proto, data)
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self.control.call_unary_handler(peer_id, proto, data)
 
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         """

+ 1 - 1
hivemind/proto/p2pd.proto

@@ -161,7 +161,7 @@ message PSRequest {
 
 
 message PSMessage {
-  optional bytes from = 1;
+  optional bytes from_id = 1;
   optional bytes data = 2;
   optional bytes seqno = 3;
   repeated string topicIDs = 4;