浏览代码

pass remote id to handler context

Denis Mazur 4 年之前
父节点
当前提交
864cb3e3ac

+ 7 - 6
hivemind/p2p/p2p_daemon.py

@@ -4,6 +4,7 @@ import secrets
 from collections.abc import AsyncIterable as AsyncIterableABC
 from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
+from google.protobuf.message import Message
 from importlib.resources import path
 from importlib.resources import path
 from subprocess import Popen
 from subprocess import Popen
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
@@ -171,8 +172,8 @@ class P2P:
     async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
     async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
         return await self._client.add_unary_handler(handle_name, handler)
         return await self._client.add_unary_handler(handle_name, handler)
 
 
-    async def unary_call(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
-        return await self._client.unary_call(peer_id, handle_name, data)
+    async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
+        return await self._client.call_unary_handler(peer_id, handle_name, data)
 
 
     async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
     async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
         for try_number in range(ping_n_attempts):
         for try_number in range(ping_n_attempts):
@@ -429,12 +430,12 @@ class P2P:
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         input_protobuf_type: type,
         input_protobuf_type: type,
     ) -> None:
     ) -> None:
-        async def _unary_handler(request: bytes) -> bytes:
-            input_serialized = input_protobuf_type().FromString(request)
+        async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
+            input_serialized = input_protobuf_type.FromString(request)
             context = P2PContext(
             context = P2PContext(
                 handle_name=handle_name,
                 handle_name=handle_name,
                 local_id=self.id,
                 local_id=self.id,
-                # TODO: add remote id
+                remote_id=remote_id,
             )
             )
 
 
             response = await handler(input_serialized, context)
             response = await handler(input_serialized, context)
@@ -470,7 +471,7 @@ class P2P:
         output_protobuf_type: type,
         output_protobuf_type: type,
     ) -> Awaitable[TOutputProtobuf]:
     ) -> Awaitable[TOutputProtobuf]:
         serialized_input = input.SerializeToString()
         serialized_input = input.SerializeToString()
-        response = await self.unary_call(peer_id, handle_name, serialized_input)
+        response = await self.call_unary_handler(peer_id, handle_name, serialized_input)
         return output_protobuf_type().FromString(response)
         return output_protobuf_type().FromString(response)
 
 
     def iterate_protobuf_handler(
     def iterate_protobuf_handler(

+ 4 - 3
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -66,7 +66,7 @@ class DaemonConnector:
         return reader, writer
         return reader, writer
 
 
 
 
-TUnaryHandler = Callable[[bytes], Awaitable[bytes]]
+TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
 CallID = uuid.UUID
 CallID = uuid.UUID
 
 
 
 
@@ -132,7 +132,8 @@ class ControlClient:
         assert request.proto in self.unary_handlers
         assert request.proto in self.unary_handlers
 
 
         try:
         try:
-            response_payload: bytes = self.unary_handlers[request.proto](request.data)
+            remote_id = PeerID(request.peer)
+            response_payload: bytes = self.unary_handlers[request.proto](request.data, remote_id)
             response = p2pd_pb.CallUnaryResponse(callId=request.callId, result=response_payload)
             response = p2pd_pb.CallUnaryResponse(callId=request.callId, result=response_payload)
         except Exception as e:
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(callId=request.callId, error=repr(e))
             response = p2pd_pb.CallUnaryResponse(callId=request.callId, error=repr(e))
@@ -174,7 +175,7 @@ class ControlClient:
             raise ValueError(f"Handler for protocol {proto} already assigned")
             raise ValueError(f"Handler for protocol {proto} already assigned")
         self.unary_handlers[proto] = handler
         self.unary_handlers[proto] = handler
 
 
-    async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         call_id = uuid.uuid4()
         call_id = uuid.uuid4()
         call_unary_req = p2pd_pb.CallUnaryRequest(
         call_unary_req = p2pd_pb.CallUnaryRequest(
             peer=peer_id.to_bytes(),
             peer=peer_id.to_bytes(),

+ 2 - 2
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -33,8 +33,8 @@ class Client:
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
         await self.control.add_unary_handler(proto, handler)
         await self.control.add_unary_handler(proto, handler)
 
 
-    async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
-        return await self.control.unary_call(peer_id, proto, data)
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self.control.call_unary_handler(peer_id, proto, data)
 
 
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         """
         """

+ 1 - 1
hivemind/proto/p2pd.proto

@@ -161,7 +161,7 @@ message PSRequest {
 
 
 
 
 message PSMessage {
 message PSMessage {
-  optional bytes from = 1;
+  optional bytes from_id = 1;
   optional bytes data = 2;
   optional bytes data = 2;
   optional bytes seqno = 3;
   optional bytes seqno = 3;
   repeated string topicIDs = 4;
   repeated string topicIDs = 4;