|
@@ -1,17 +1,27 @@
|
|
|
import asyncio
|
|
|
-import contextlib
|
|
|
import copy
|
|
|
from pathlib import Path
|
|
|
import pickle
|
|
|
-import socket
|
|
|
import subprocess
|
|
|
import typing as tp
|
|
|
import warnings
|
|
|
|
|
|
+import google.protobuf
|
|
|
from multiaddr import Multiaddr
|
|
|
import p2pclient
|
|
|
from libp2p.peer.id import ID
|
|
|
|
|
|
+from hivemind.utils.networking import find_open_port
|
|
|
+
|
|
|
+
|
|
|
+class P2PContext(object):
|
|
|
+ def __init__(self, ours_id, ours_port, handle_name):
|
|
|
+ self.peer_id = None
|
|
|
+ self.peer_addr = None
|
|
|
+ self.ours_id = ours_id
|
|
|
+ self.ours_port = ours_port
|
|
|
+ self.handle_name = handle_name
|
|
|
+
|
|
|
|
|
|
class P2P(object):
|
|
|
"""
|
|
@@ -26,11 +36,16 @@ class P2P(object):
|
|
|
HEADER_LEN = 8
|
|
|
BYTEORDER = 'big'
|
|
|
|
|
|
+ class IncompleteRead(Exception):
|
|
|
+ pass
|
|
|
+
|
|
|
+ class InterruptedError(Exception):
|
|
|
+ pass
|
|
|
+
|
|
|
def __init__(self):
|
|
|
self._child = None
|
|
|
self._listen_task = None
|
|
|
self._server_stopped = asyncio.Event()
|
|
|
- self._buffer = bytearray()
|
|
|
|
|
|
@classmethod
|
|
|
async def create(cls, *args, quic=1, tls=1, conn_manager=1, dht_client=1,
|
|
@@ -89,50 +104,108 @@ class P2P(object):
|
|
|
self._daemon_listen_port = find_open_port()
|
|
|
|
|
|
@staticmethod
|
|
|
- async def send_data(data, stream):
|
|
|
- byte_str = pickle.dumps(data)
|
|
|
+ async def send_raw_data(byte_str, stream):
|
|
|
request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
|
|
|
await stream.send_all(request)
|
|
|
|
|
|
- class IncompleteRead(Exception):
|
|
|
- pass
|
|
|
+ @staticmethod
|
|
|
+ async def send_data(data, stream):
|
|
|
+ await P2P.send_raw_data(pickle.dumps(data), stream)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ async def send_protobuf(protobuf, out_proto_type, stream):
|
|
|
+ if type(protobuf) != out_proto_type:
|
|
|
+ error = TypeError('Unary handler returned protobuf of wrong type.')
|
|
|
+ await P2P.send_raw_data(pickle.dumps(error), stream)
|
|
|
+ raise error
|
|
|
+ await P2P.send_raw_data(protobuf.SerializeToString(), stream)
|
|
|
|
|
|
- async def _receive_exactly(self, stream, n_bytes, max_bytes=1 << 16):
|
|
|
- while len(self._buffer) < n_bytes:
|
|
|
- data = await stream.receive_some(max_bytes)
|
|
|
+ @staticmethod
|
|
|
+ async def receive_exactly(stream, n_bytes, max_bytes=1 << 16):
|
|
|
+ buffer = bytearray()
|
|
|
+ while len(buffer) < n_bytes:
|
|
|
+ data = await stream.receive_some(min(max_bytes, n_bytes - len(buffer)))
|
|
|
if len(data) == 0:
|
|
|
raise P2P.IncompleteRead()
|
|
|
- self._buffer.extend(data)
|
|
|
-
|
|
|
- result = self._buffer[:n_bytes]
|
|
|
- self._buffer = self._buffer[n_bytes:]
|
|
|
- return bytes(result)
|
|
|
+ buffer.extend(data)
|
|
|
+ return bytes(buffer)
|
|
|
|
|
|
- async def receive_data(self, stream, max_bytes=(1 < 16)):
|
|
|
- header = await self._receive_exactly(stream, P2P.HEADER_LEN)
|
|
|
+ @staticmethod
|
|
|
+ async def receive_raw_data(stream):
|
|
|
+ header = await P2P.receive_exactly(stream, P2P.HEADER_LEN)
|
|
|
content_length = int.from_bytes(header, P2P.BYTEORDER)
|
|
|
- data = await self._receive_exactly(stream, content_length)
|
|
|
- return pickle.loads(data)
|
|
|
+ data = await P2P.receive_exactly(stream, content_length)
|
|
|
+ return data
|
|
|
|
|
|
- def _handle_stream(self, handle):
|
|
|
+ @staticmethod
|
|
|
+ async def receive_data(stream):
|
|
|
+ return pickle.loads(await P2P.receive_raw_data(stream))
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ async def receive_protobuf(in_proto_type, stream):
|
|
|
+ protobuf = in_proto_type()
|
|
|
+ protobuf.ParseFromString(await P2P.receive_raw_data(stream))
|
|
|
+ return protobuf
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _handle_stream(handle):
|
|
|
async def do_handle_stream(stream_info, stream):
|
|
|
try:
|
|
|
- request = await self.receive_data(stream)
|
|
|
+ request = await P2P.receive_data(stream)
|
|
|
except P2P.IncompleteRead:
|
|
|
warnings.warn("Incomplete read while receiving request from peer", RuntimeWarning)
|
|
|
+ await stream.close()
|
|
|
return
|
|
|
- finally:
|
|
|
- stream.close()
|
|
|
try:
|
|
|
result = handle(request)
|
|
|
- await self.send_data(result, stream)
|
|
|
+ await P2P.send_data(result, stream)
|
|
|
except Exception as exc:
|
|
|
- await self.send_data(exc, stream)
|
|
|
+ await P2P.send_data(exc, stream)
|
|
|
finally:
|
|
|
await stream.close()
|
|
|
|
|
|
return do_handle_stream
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def _handle_unary_stream(handle, context, in_proto_type, out_proto_type):
|
|
|
+ async def watchdog(stream):
|
|
|
+ await stream.receive_some(max_bytes=1)
|
|
|
+ raise P2P.InterruptedError()
|
|
|
+
|
|
|
+ async def do_handle_unary_stream(stream_info, stream):
|
|
|
+ try:
|
|
|
+ try:
|
|
|
+ request = await P2P.receive_protobuf(in_proto_type, stream)
|
|
|
+ except P2P.IncompleteRead:
|
|
|
+ warnings.warn("Incomplete read while receiving request from peer",
|
|
|
+ RuntimeWarning)
|
|
|
+ return
|
|
|
+ except google.protobuf.message.DecodeError as error:
|
|
|
+ warnings.warn(repr(error), RuntimeWarning)
|
|
|
+ return
|
|
|
+
|
|
|
+ context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr
|
|
|
+ done, pending = await asyncio.wait([watchdog(stream), handle(request, context)],
|
|
|
+ return_when=asyncio.FIRST_COMPLETED)
|
|
|
+ try:
|
|
|
+ result = done.pop().result()
|
|
|
+ await P2P.send_protobuf(result, out_proto_type, stream)
|
|
|
+ except P2P.InterruptedError:
|
|
|
+ pass
|
|
|
+ except Exception as exc:
|
|
|
+ await P2P.send_data(exc, stream)
|
|
|
+ finally:
|
|
|
+ pending_task = pending.pop()
|
|
|
+ pending_task.cancel()
|
|
|
+ try:
|
|
|
+ await pending_task
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ pass
|
|
|
+ finally:
|
|
|
+ await stream.close()
|
|
|
+
|
|
|
+ return do_handle_unary_stream
|
|
|
+
|
|
|
def start_listening(self):
|
|
|
async def listen():
|
|
|
async with self._client.listen():
|
|
@@ -153,15 +226,21 @@ class P2P(object):
|
|
|
async def add_stream_handler(self, name, handle):
|
|
|
if self._listen_task is None:
|
|
|
self.start_listening()
|
|
|
+ await self._client.stream_handler(name, P2P._handle_stream(handle))
|
|
|
|
|
|
- await self._client.stream_handler(name, self._handle_stream(handle))
|
|
|
+ async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
|
|
|
+ if self._listen_task is None:
|
|
|
+ self.start_listening()
|
|
|
+ context = P2PContext(ours_id=self.id, ours_port=self._host_port, handle_name=name)
|
|
|
+ await self._client.stream_handler(
|
|
|
+ name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
|
|
|
|
|
|
async def call_peer_handler(self, peer_id, handler_name, input_data):
|
|
|
libp2p_peer_id = ID.from_base58(peer_id)
|
|
|
stream_info, stream = await self._client.stream_open(libp2p_peer_id, (handler_name,))
|
|
|
try:
|
|
|
- await self.send_data(input_data, stream)
|
|
|
- return await self.receive_data(stream)
|
|
|
+ await P2P.send_data(input_data, stream)
|
|
|
+ return await P2P.receive_data(stream)
|
|
|
finally:
|
|
|
await stream.close()
|
|
|
|
|
@@ -183,15 +262,3 @@ class P2P(object):
|
|
|
for key, value in kwargs.items()
|
|
|
)
|
|
|
return proc_args
|
|
|
-
|
|
|
-
|
|
|
-def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM),
|
|
|
- opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
|
|
|
- """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
|
|
|
- try:
|
|
|
- with contextlib.closing(socket.socket(*params)) as sock:
|
|
|
- sock.bind(('', 0))
|
|
|
- sock.setsockopt(*opt)
|
|
|
- return sock.getsockname()[1]
|
|
|
- except Exception:
|
|
|
- raise
|