فهرست منبع

Merge branch 'master' into fix-colab

Alexander Borzunov 4 سال پیش
والد
کامیت
ddbcef8ad7

+ 0 - 2
.github/workflows/run-tests.yml

@@ -34,7 +34,6 @@ jobs:
         run: |
         run: |
           cd tests
           cd tests
           pytest --durations=0 --durations-min=1.0 -v
           pytest --durations=0 --durations-min=1.0 -v
-
   build_and_test_p2pd:
   build_and_test_p2pd:
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
     timeout-minutes: 10
     timeout-minutes: 10
@@ -61,7 +60,6 @@ jobs:
         run: |
         run: |
           cd tests
           cd tests
           pytest -k "p2p" -v
           pytest -k "p2p" -v
-
   codecov_in_develop_mode:
   codecov_in_develop_mode:
 
 
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest

+ 1 - 0
.gitignore

@@ -54,6 +54,7 @@ coverage.xml
 .project
 .project
 .pydevproject
 .pydevproject
 .idea
 .idea
+.vscode
 .ipynb_checkpoints
 .ipynb_checkpoints
 
 
 # Rope
 # Rope

+ 3 - 3
hivemind/averaging/allreduce.py

@@ -8,7 +8,7 @@ from hivemind.averaging.partition import AllreduceException, TensorPartContainer
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
+from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 
 # flavour types
 # flavour types
@@ -231,8 +231,8 @@ class AllReduceRunner(ServicerBase):
 
 
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        # In case of reporting the error, we expect the response stream to contain exactly one item
-        await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
+        # Coroutines are lazy, so we take the first item to start the couroutine's execution
+        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
 
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 1 - 1
hivemind/averaging/averager.py

@@ -593,7 +593,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         future.set_result((metadata, tensors))
                         future.set_result((metadata, tensors))
                         self.last_updated = get_dht_time()
                         self.last_updated = get_dht_time()
                         return
                         return
-                    except BaseException as e:
+                    except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
 
 
         finally:
         finally:

+ 1 - 1
hivemind/averaging/matchmaking.py

@@ -189,7 +189,7 @@ class Matchmaking:
                         gather=self.data_for_gather,
                         gather=self.data_for_gather,
                         group_key=self.group_key_manager.current_key,
                         group_key=self.group_key_manager.current_key,
                     )
                     )
-                ).__aiter__()
+                )
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
 
                 if message.code == averaging_pb2.ACCEPTED:
                 if message.code == averaging_pb2.ACCEPTED:

+ 9 - 0
hivemind/dht/__init__.py

@@ -61,6 +61,7 @@ class DHT(mp.Process):
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         *,
         *,
         start: bool,
         start: bool,
+        p2p: Optional[P2P] = None,
         daemon: bool = True,
         daemon: bool = True,
         num_workers: int = DEFAULT_NUM_WORKERS,
         num_workers: int = DEFAULT_NUM_WORKERS,
         record_validators: Iterable[RecordValidatorBase] = (),
         record_validators: Iterable[RecordValidatorBase] = (),
@@ -94,6 +95,8 @@ class DHT(mp.Process):
         self._client_mode = None
         self._client_mode = None
         self._p2p_replica = None
         self._p2p_replica = None
 
 
+        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
+
         if start:
         if start:
             self.run_in_background(await_ready=await_ready)
             self.run_in_background(await_ready=await_ready)
 
 
@@ -105,10 +108,16 @@ class DHT(mp.Process):
 
 
             async def _run():
             async def _run():
                 try:
                 try:
+                    if self._daemon_listen_maddr is not None:
+                        replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                    else:
+                        replicated_p2p = None
+
                     self._node = await DHTNode.create(
                     self._node = await DHTNode.create(
                         initial_peers=self.initial_peers,
                         initial_peers=self.initial_peers,
                         num_workers=self.num_workers,
                         num_workers=self.num_workers,
                         record_validator=self._record_validator,
                         record_validator=self._record_validator,
+                        p2p=replicated_p2p,
                         **self.kwargs,
                         **self.kwargs,
                     )
                     )
                 except Exception as e:
                 except Exception as e:

+ 1 - 1
hivemind/dht/protocol.py

@@ -81,7 +81,7 @@ class DHTProtocol(ServicerBase):
 
 
     def __init__(self, *, _initialized_with_create=False):
     def __init__(self, *, _initialized_with_create=False):
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
-        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
+        assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
         super().__init__()
         super().__init__()
 
 
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:

+ 78 - 24
hivemind/p2p/p2p_daemon.py

@@ -5,12 +5,14 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
 from importlib.resources import path
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
 
+from google.protobuf.message import Message
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind.hivemind_cli as cli
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import aiter, asingle
 from hivemind.utils.asyncio import aiter, asingle
@@ -27,7 +29,6 @@ class P2PContext(object):
     handle_name: str
     handle_name: str
     local_id: PeerID
     local_id: PeerID
     remote_id: PeerID = None
     remote_id: PeerID = None
-    remote_maddr: Multiaddr = None
 
 
 
 
 class P2P:
 class P2P:
@@ -65,6 +66,7 @@ class P2P:
 
 
     def __init__(self):
     def __init__(self):
         self.peer_id = None
         self.peer_id = None
+        self._client = None
         self._child = None
         self._child = None
         self._alive = False
         self._alive = False
         self._reader_task = None
         self._reader_task = None
@@ -90,6 +92,7 @@ class P2P:
         use_auto_relay: bool = False,
         use_auto_relay: bool = False,
         relay_hop_limit: int = 0,
         relay_hop_limit: int = 0,
         startup_timeout: float = 15,
         startup_timeout: float = 15,
+        idle_timeout: float = 30,
     ) -> "P2P":
     ) -> "P2P":
         """
         """
         Start a new p2pd process and connect to it.
         Start a new p2pd process and connect to it.
@@ -111,6 +114,8 @@ class P2P:
         :param use_auto_relay: enables autorelay
         :param use_auto_relay: enables autorelay
         :param relay_hop_limit: sets the hop limit for hop relays
         :param relay_hop_limit: sets the hop limit for hop relays
         :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
+        :param idle_timeout: kill daemon if client has been idle for a given number of
+                             seconds before opening persistent streams
         :return: a wrapper for the p2p daemon
         :return: a wrapper for the p2p daemon
         """
         """
 
 
@@ -150,6 +155,7 @@ class P2P:
             relayDiscovery=use_relay_discovery,
             relayDiscovery=use_relay_discovery,
             autoRelay=use_auto_relay,
             autoRelay=use_auto_relay,
             relayHopLimit=relay_hop_limit,
             relayHopLimit=relay_hop_limit,
+            idleTimeout=f"{idle_timeout}s",
             b=need_bootstrap,
             b=need_bootstrap,
             **process_kwargs,
             **process_kwargs,
         )
         )
@@ -167,7 +173,7 @@ class P2P:
             await self.shutdown()
             await self.shutdown()
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
         await self._ping_daemon()
         await self._ping_daemon()
         return self
         return self
 
 
@@ -189,7 +195,7 @@ class P2P:
         self._daemon_listen_maddr = daemon_listen_maddr
         self._daemon_listen_maddr = daemon_listen_maddr
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
 
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
 
 
         await self._ping_daemon()
         await self._ping_daemon()
         return self
         return self
@@ -258,7 +264,7 @@ class P2P:
 
 
     @staticmethod
     @staticmethod
     async def receive_protobuf(
     async def receive_protobuf(
-        input_protobuf_type: type, reader: asyncio.StreamReader
+        input_protobuf_type: Type[Message], reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
         if msg_type == P2P.MESSAGE_MARKER:
@@ -279,7 +285,7 @@ class P2P:
         self,
         self,
         name: str,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
         max_prefetch: int = 5,
     ) -> None:
     ) -> None:
         """
         """
@@ -297,7 +303,6 @@ class P2P:
                 handle_name=name,
                 handle_name=name,
                 local_id=self.peer_id,
                 local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
                 remote_id=stream_info.peer_id,
-                remote_maddr=stream_info.addr,
             )
             )
             requests = asyncio.Queue(max_prefetch)
             requests = asyncio.Queue(max_prefetch)
 
 
@@ -311,11 +316,17 @@ class P2P:
             async def _process_stream() -> None:
             async def _process_stream() -> None:
                 try:
                 try:
                     async for response in handler(_read_stream(), context):
                     async for response in handler(_read_stream(), context):
-                        await P2P.send_protobuf(response, writer)
+                        try:
+                            await P2P.send_protobuf(response, writer)
+                        except Exception:
+                            # The connection is unexpectedly closed by the caller or broken.
+                            # The loglevel is DEBUG since the actual error will be reported on the caller
+                            logger.debug("Exception while sending response:", exc_info=True)
+                            break
                 except Exception as e:
                 except Exception as e:
-                    logger.warning("Exception while processing stream and sending responses:", exc_info=True)
-                    # Sometimes `e` is a connection error, so we won't be able to report the error to the caller
+                    logger.warning("Handler failed with the exception:", exc_info=True)
                     with suppress(Exception):
                     with suppress(Exception):
+                        # Sometimes `e` is a connection error, so it is okay if we fail to report `e` to the caller
                         await P2P.send_protobuf(RPCError(message=str(e)), writer)
                         await P2P.send_protobuf(RPCError(message=str(e)), writer)
 
 
             with closing(writer):
             with closing(writer):
@@ -343,7 +354,7 @@ class P2P:
         await self.add_binary_stream_handler(name, _handle_stream)
         await self.add_binary_stream_handler(name, _handle_stream)
 
 
     async def _iterate_protobuf_stream_handler(
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
     ) -> TOutputStream:
     ) -> TOutputStream:
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
 
@@ -375,15 +386,22 @@ class P2P:
         handler: Callable[
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
         ],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         *,
         *,
         stream_input: bool = False,
         stream_input: bool = False,
+        stream_output: bool = False,
     ) -> None:
     ) -> None:
         """
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
                              (not just ``TInputProtobuf``) as input.
                              (not just ``TInputProtobuf``) as input.
+        :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
+                              (not ``Awaitable[TOutputProtobuf]``).
         """
         """
 
 
+        if not stream_input and not stream_output:
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            return
+
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
             input = requests if stream_input else await asingle(requests)
             input = requests if stream_input else await asingle(requests)
             output = handler(input, context)
             output = handler(input, context)
@@ -396,23 +414,65 @@ class P2P:
 
 
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
 
 
+    async def _add_protobuf_unary_handler(
+        self,
+        handle_name: str,
+        handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
+        input_protobuf_type: Type[Message],
+    ) -> None:
+        """
+        Register a request-response (unary) handler. Unary requests and responses
+        are sent through persistent multiplexed connections to the daemon for the
+        sake of reducing the number of open files.
+        :param handle_name: name of the handler (protocol id)
+        :param handler: function handling the unary requests
+        :param input_protobuf_type: protobuf type of the request
+        """
+
+        async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
+            input_serialized = input_protobuf_type.FromString(request)
+            context = P2PContext(
+                handle_name=handle_name,
+                local_id=self.peer_id,
+                remote_id=remote_id,
+            )
+
+            response = await handler(input_serialized, context)
+            return response.SerializeToString()
+
+        await self._client.add_unary_handler(handle_name, _unary_handler)
+
     async def call_protobuf_handler(
     async def call_protobuf_handler(
         self,
         self,
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
     ) -> Awaitable[TOutputProtobuf]:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+
+        if not isinstance(input, AsyncIterableABC):
+            return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
+
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
         return await asingle(responses)
 
 
+    async def _call_unary_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        handle_name: str,
+        input: TInputProtobuf,
+        output_protobuf_type: Type[Message],
+    ) -> Awaitable[TOutputProtobuf]:
+        serialized_input = input.SerializeToString()
+        response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
+        return output_protobuf_type.FromString(response)
+
     def iterate_protobuf_handler(
     def iterate_protobuf_handler(
         self,
         self,
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> TOutputStream:
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
@@ -447,6 +507,8 @@ class P2P:
             await self._child.wait()
             await self._child.wait()
 
 
     def _terminate(self) -> None:
     def _terminate(self) -> None:
+        if self._client is not None:
+            self._client.close()
         if self._listen_task is not None:
         if self._listen_task is not None:
             self._listen_task.cancel()
             self._listen_task.cancel()
         if self._reader_task is not None:
         if self._reader_task is not None:
@@ -495,11 +557,3 @@ class P2P:
 
 
         if not ready.done():
         if not ready.done():
             ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
             ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
-
-
-class P2PDaemonError(RuntimeError):
-    pass
-
-
-class P2PHandlerError(Exception):
-    pass

+ 189 - 3
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -5,8 +5,9 @@ Author: Kevin Mai-Husan Chia
 """
 """
 
 
 import asyncio
 import asyncio
-from contextlib import asynccontextmanager
-from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
+from contextlib import asynccontextmanager, closing
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
+from uuid import UUID, uuid4
 
 
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
 
 
@@ -54,17 +55,75 @@ class DaemonConnector:
         else:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
 
+    async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        """
+        Open connection to daemon and upgrade it to a persistent one
+        """
+        reader, writer = await self.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
+        await write_pbmsg(writer, req)
+
+        response = p2pd_pb.Response()
+        await read_pbmsg_safe(reader, response)
+
+        if response.type == "ERROR":
+            raise P2PDaemonError(response.error.msg)
+
+        return reader, writer
+
+
+TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
+CallID = UUID
+
 
 
 class ControlClient:
 class ControlClient:
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
 
 
     def __init__(
     def __init__(
-        self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+        self,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        *,
+        _initialized_with_create=False,
     ) -> None:
     ) -> None:
+        assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
+
         self.listen_maddr = listen_maddr
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
         self.handlers: Dict[str, StreamHandler] = {}
 
 
+        self.unary_handlers: Dict[str, TUnaryHandler] = {}
+
+        self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
+        self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self._handler_tasks: Dict[CallID, asyncio.Task] = {}
+
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
+
+    @classmethod
+    async def create(
+        cls,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        use_persistent_conn: bool = True,
+    ) -> "ControlClient":
+        control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
+
+        if use_persistent_conn:
+            await control._ensure_persistent_conn()
+
+        return control
+
+    def close(self) -> None:
+        if self._read_task is not None:
+            self._read_task.cancel()
+        if self._write_task is not None:
+            self._write_task.cancel()
+
+    def __del__(self):
+        self.close()
+
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
         await read_pbmsg_safe(reader, pb_stream_info)
@@ -93,6 +152,121 @@ class ControlClient:
         async with server:
         async with server:
             yield self
             yield self
 
 
+    async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
+        while True:
+            resp = p2pd_pb.PersistentConnectionResponse()
+            try:
+                await read_pbmsg_safe(reader, resp)
+            except asyncio.IncompleteReadError:
+                break
+
+            call_id = UUID(bytes=resp.callId)
+
+            if resp.HasField("callUnaryResponse"):
+                if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
+                    self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
+                elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
+                    remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
+                    self._pending_calls[call_id].set_exception(remote_exc)
+                else:
+                    logger.debug(f"Received unexpected unary call: {resp}")
+
+            elif resp.HasField("requestHandling"):
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self._handler_tasks[call_id] = handler_task
+
+            elif call_id in self._handler_tasks and resp.HasField("cancel"):
+                self._handler_tasks[call_id].cancel()
+
+            elif call_id in self._pending_calls and resp.HasField("daemonError"):
+                daemon_exc = P2PDaemonError(resp.daemonError.message)
+                self._pending_calls[call_id].set_exception(daemon_exc)
+
+            elif call_id in self._pending_calls:
+                self._pending_calls[call_id].set_result(None)
+
+            else:
+                logger.warning(f"Received unexpected response from daemon: {resp}")
+
+    async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        with closing(writer):
+            while True:
+                msg = await self._pending_messages.get()
+                await write_pbmsg(writer, msg)
+
+    async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest):
+        if request.proto not in self.unary_handlers:
+            logger.warning(f"Protocol {request.proto} not supported")
+            return
+
+        try:
+            remote_id = PeerID(request.peer)
+            response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
+            response = p2pd_pb.CallUnaryResponse(response=response_payload)
+
+        except Exception as e:
+            response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
+
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                unaryResponse=response,
+            )
+        )
+        self._handler_tasks.pop(call_id)
+
+    async def _cancel_unary_call(self, call_id: UUID):
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                cancel=p2pd_pb.Cancel(),
+            ),
+        )
+
+    async def _ensure_persistent_conn(self):
+        reader, writer = await self.daemon_connector.open_persistent_connection()
+
+        self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+        self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        call_id = uuid4()
+
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+
+        if self.unary_handlers.get(proto):
+            raise P2PDaemonError(f"Handler for protocol {proto} already registered")
+        self.unary_handlers[proto] = handler
+
+        self._pending_calls[call_id] = asyncio.Future()
+        await self._pending_messages.put(req)
+        await self._pending_calls[call_id]
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        call_id = uuid4()
+        call_unary_req = p2pd_pb.CallUnaryRequest(
+            peer=peer_id.to_bytes(),
+            proto=proto,
+            data=data,
+        )
+        req = p2pd_pb.PersistentConnectionRequest(
+            callId=call_id.bytes,
+            callUnary=call_unary_req,
+        )
+
+        try:
+            self._pending_calls[call_id] = asyncio.Future()
+            await self._pending_messages.put(req)
+            return await self._pending_calls[call_id]
+
+        except asyncio.CancelledError:
+            await self._cancel_unary_call(call_id)
+            raise
+
+        finally:
+            self._pending_calls.pop(call_id, None)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         reader, writer = await self.daemon_connector.open_connection()
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
@@ -179,3 +353,15 @@ class ControlClient:
 
 
         # if success, add the handler to the dict
         # if success, add the handler to the dict
         self.handlers[proto] = handler_cb
         self.handlers[proto] = handler_cb
+
+
+class P2PHandlerError(Exception):
+    """
+    Raised if remote handled a request with an exception
+    """
+
+
+class P2PDaemonError(Exception):
+    """
+    Raised if daemon failed to handle request
+    """

+ 3 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -131,6 +131,9 @@ class PeerInfo:
     def __str__(self):
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
 
+    def __repr__(self):
+        return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})"
+
 
 
 class InvalidAddrError(ValueError):
 class InvalidAddrError(ValueError):
     pass
     pass

+ 24 - 3
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,16 +10,31 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler, TUnaryHandler
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
 
 
 class Client:
 class Client:
     control: ControlClient
     control: ControlClient
 
 
-    def __init__(self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> None:
+    def __init__(self, *, _initialized_with_create=False) -> None:
+        assert _initialized_with_create, "Please use Client.create coroutine to spawn new client instances"
+        self.control = None
+
+    @classmethod
+    async def create(cls, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> "Client":
+        client = cls(_initialized_with_create=True)
+
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        self.control = ControlClient(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+
+        return client
+
+    def close(self) -> None:
+        self.control.close()
+
+    def __del__(self):
+        self.close()
 
 
     @asynccontextmanager
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["Client"]:
     async def listen(self) -> AsyncIterator["Client"]:
@@ -30,6 +45,12 @@ class Client:
         async with self.control.listen():
         async with self.control.listen():
             yield self
             yield self
 
 
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        await self.control.add_unary_handler(proto, handler)
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self.control.call_unary_handler(peer_id, proto, data)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         """
         """
         Get current node peer id and list of addresses
         Get current node peer id and list of addresses

+ 13 - 7
hivemind/p2p/servicer.py

@@ -125,13 +125,19 @@ class ServicerBase:
         self._collect_rpc_handlers()
         self._collect_rpc_handlers()
 
 
         servicer = self if wrapper is None else wrapper
         servicer = self if wrapper is None else wrapper
-        for handler in self._rpc_handlers:
-            await p2p.add_protobuf_handler(
-                self._get_handle_name(namespace, handler.method_name),
-                getattr(servicer, handler.method_name),
-                handler.request_type,
-                stream_input=handler.stream_input,
-            )
+
+        await asyncio.gather(
+            *[
+                p2p.add_protobuf_handler(
+                    self._get_handle_name(namespace, handler.method_name),
+                    getattr(servicer, handler.method_name),
+                    handler.request_type,
+                    stream_input=handler.stream_input,
+                    stream_output=handler.stream_output,
+                )
+                for handler in self._rpc_handlers
+            ]
+        )
 
 
     @classmethod
     @classmethod
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:

+ 59 - 10
hivemind/proto/p2pd.proto

@@ -8,15 +8,17 @@ package p2pclient.p2pd.pb;
 
 
 message Request {
 message Request {
   enum Type {
   enum Type {
-    IDENTIFY       = 0;
-    CONNECT        = 1;
-    STREAM_OPEN    = 2;
-    STREAM_HANDLER = 3;
-    DHT            = 4;
-    LIST_PEERS     = 5;
-    CONNMANAGER    = 6;
-    DISCONNECT     = 7;
-    PUBSUB         = 8;
+    IDENTIFY                 = 0;
+    CONNECT                  = 1;
+    STREAM_OPEN              = 2;
+    STREAM_HANDLER           = 3;
+    DHT                      = 4;
+    LIST_PEERS               = 5;
+    CONNMANAGER              = 6;
+    DISCONNECT               = 7;      
+    PUBSUB                   = 8;
+
+    PERSISTENT_CONN_UPGRADE  = 9;
   }
   }
 
 
   required Type type = 1;
   required Type type = 1;
@@ -45,6 +47,29 @@ message Response {
   optional PSResponse pubsub = 7;
   optional PSResponse pubsub = 7;
 }
 }
 
 
+message PersistentConnectionRequest {
+  required bytes callId = 1;
+
+  oneof message {
+    AddUnaryHandlerRequest addUnaryHandler = 2;
+    CallUnaryRequest  callUnary = 3;
+    CallUnaryResponse unaryResponse = 4;
+    Cancel cancel = 5;
+  }
+}
+
+message PersistentConnectionResponse {
+  required bytes callId = 1;
+
+  oneof message {
+    CallUnaryResponse callUnaryResponse = 2;
+    CallUnaryRequest requestHandling = 3;
+    DaemonError daemonError = 4;
+    Cancel cancel = 5;
+  }
+}
+
+
 message IdentifyResponse {
 message IdentifyResponse {
   required bytes id = 1;
   required bytes id = 1;
   repeated bytes addrs = 2;
   repeated bytes addrs = 2;
@@ -148,7 +173,7 @@ message PSRequest {
 }
 }
 
 
 message PSMessage {
 message PSMessage {
-  optional bytes from_id = 1;
+  optional bytes from = 1;
   optional bytes data = 2;
   optional bytes data = 2;
   optional bytes seqno = 3;
   optional bytes seqno = 3;
   repeated string topicIDs = 4;
   repeated string topicIDs = 4;
@@ -161,6 +186,30 @@ message PSResponse {
   repeated bytes peerIDs = 2;
   repeated bytes peerIDs = 2;
 }
 }
 
 
+message CallUnaryRequest {
+  required bytes peer = 1;
+  required string proto = 2;
+  required bytes data = 3;
+}
+
+message CallUnaryResponse {
+  oneof result {
+    bytes response = 1;
+    bytes error = 2;
+  }
+}
+
+message AddUnaryHandlerRequest {
+  required string proto = 1;
+}
+
+message DaemonError {
+  optional string message = 1;
+}
+
+message Cancel {
+}
+
 message RPCError {
 message RPCError {
   optional string message = 1;
   optional string message = 1;
 }
 }

+ 8 - 1
hivemind/utils/asyncio.py

@@ -59,7 +59,7 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]
 
 
 
 
 async def asingle(aiter: AsyncIterable[T]) -> T:
 async def asingle(aiter: AsyncIterable[T]) -> T:
-    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`."""
+    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises ``ValueError``."""
     count = 0
     count = 0
     async for item in aiter:
     async for item in aiter:
         count += 1
         count += 1
@@ -70,6 +70,13 @@ async def asingle(aiter: AsyncIterable[T]) -> T:
     return item
     return item
 
 
 
 
+async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
+    """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
+    async for item in aiter:
+        return item
+    return default
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
     try:
         await awaitable
         await awaitable

+ 2 - 2
setup.py

@@ -14,8 +14,8 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 from setuptools.command.develop import develop
 
 
-P2PD_VERSION = "v0.3.1"
-P2PD_CHECKSUM = "15292b880c6b31f5b3c36084b3acc17f"
+P2PD_VERSION = "v0.3.4"
+P2PD_CHECKSUM = "194dca06116fdd36bc4b681d18f3b9cb"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 
 
 here = os.path.abspath(os.path.dirname(__file__))
 here = os.path.abspath(os.path.dirname(__file__))

+ 2 - 2
tests/test_dht.py

@@ -14,7 +14,7 @@ from test_utils.dht_swarms import launch_dht_instances
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_startup_error():
 async def test_startup_error():
-    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         hivemind.DHT(
         hivemind.DHT(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
             start=True,
             start=True,
@@ -118,7 +118,7 @@ async def test_dht_get_visible_maddrs():
 
 
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
-    dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
+    dht = hivemind.DHT(start=True, p2p=p2p)
 
 
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()
     dht.shutdown()

+ 29 - 4
tests/test_p2p_daemon.py

@@ -10,7 +10,7 @@ import pytest
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
-from hivemind.proto import dht_pb2
+from hivemind.proto import dht_pb2, test_pb2
 from hivemind.utils.networking import get_free_port
 from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
 
 
@@ -36,7 +36,7 @@ async def test_daemon_killed_on_del():
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_startup_error_message():
 async def test_startup_error_message():
-    with pytest.raises(P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         await P2P.create(
         await P2P.create(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
         )
         )
@@ -63,9 +63,9 @@ async def test_transports(host_maddrs: List[Multiaddr]):
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
     peers = await client.list_peers()
     peers = await client.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
     peers = await server.list_peers()
     peers = await server.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -83,6 +83,31 @@ async def test_daemon_replica_does_not_affect_primary():
     assert not is_process_running(child_pid)
     assert not is_process_running(child_pid)
 
 
 
 
+@pytest.mark.asyncio
+async def test_unary_handler_edge_cases():
+    p2p = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p.daemon_listen_maddr)
+
+    async def square_handler(data: test_pb2.TestRequest, context):
+        return test_pb2.TestResponse(number=data.number ** 2)
+
+    await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler
+    with pytest.raises(P2PDaemonError):
+        await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler from replicated p2p
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try dialing yourself
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.call_protobuf_handler(
+            p2p.peer_id, "square", test_pb2.TestRequest(number=41), test_pb2.TestResponse
+        )
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "should_cancel,replicate",
     "should_cancel,replicate",
     [
     [

+ 13 - 6
tests/test_p2p_daemon_bindings.py

@@ -199,24 +199,31 @@ def test_parse_conn_protocol_invalid(maddr_str):
 
 
 
 
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_client_ctor_control_maddr(control_maddr_str):
+@pytest.mark.asyncio
+async def test_client_create_control_maddr(control_maddr_str):
     c = DaemonConnector(Multiaddr(control_maddr_str))
     c = DaemonConnector(Multiaddr(control_maddr_str))
     assert c.control_maddr == Multiaddr(control_maddr_str)
     assert c.control_maddr == Multiaddr(control_maddr_str)
 
 
 
 
-def test_client_ctor_default_control_maddr():
+def test_client_create_default_control_maddr():
     c = DaemonConnector()
     c = DaemonConnector()
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
 
 
 
 
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_control_client_ctor_listen_maddr(listen_maddr_str):
-    c = ControlClient(daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str))
+@pytest.mark.asyncio
+async def test_control_client_create_listen_maddr(listen_maddr_str):
+    c = await ControlClient.create(
+        daemon_connector=DaemonConnector(),
+        listen_maddr=Multiaddr(listen_maddr_str),
+        use_persistent_conn=False,
+    )
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
 
 
 
 
-def test_control_client_ctor_default_listen_maddr():
-    c = ControlClient(daemon_connector=DaemonConnector())
+@pytest.mark.asyncio
+async def test_control_client_create_default_listen_maddr():
+    c = await ControlClient.create(daemon_connector=DaemonConnector(), use_persistent_conn=False)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
 
 
 
 

+ 3 - 2
tests/test_p2p_servicer.py

@@ -5,6 +5,7 @@ import pytest
 
 
 from hivemind.p2p import P2P, P2PContext, ServicerBase
 from hivemind.p2p import P2P, P2PContext, ServicerBase
 from hivemind.proto import test_pb2
 from hivemind.proto import test_pb2
+from hivemind.utils.asyncio import anext
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -139,9 +140,9 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
         writer.close()
         writer.close()
     elif cancel_reason == "close_generator":
     elif cancel_reason == "close_generator":
         stub = ExampleServicer.get_stub(client, server.peer_id)
         stub = ExampleServicer.get_stub(client, server.peer_id)
-        iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
+        iter = stub.rpc_wait(test_pb2.TestRequest(number=10))
 
 
-        assert await iter.__anext__() == test_pb2.TestResponse(number=11)
+        assert await anext(iter) == test_pb2.TestResponse(number=11)
         await asyncio.sleep(0.25)
         await asyncio.sleep(0.25)
 
 
         await iter.aclose()
         await iter.aclose()

+ 12 - 1
tests/test_util_modules.py

@@ -13,7 +13,7 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
+from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext, asingle, azip
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 from hivemind.utils.mpfuture import InvalidStateError
 
 
@@ -498,3 +498,14 @@ async def test_asyncio_utils():
         await anext(iterator)
         await anext(iterator)
 
 
     assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
     assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
+
+    assert await asingle(aiter(1)) == 1
+    with pytest.raises(ValueError):
+        await asingle(aiter())
+    with pytest.raises(ValueError):
+        await asingle(aiter(1, 2, 3))
+
+    assert await afirst(aiter(1)) == 1
+    assert await afirst(aiter()) is None
+    assert await afirst(aiter(), -1) == -1
+    assert await afirst(aiter(1, 2, 3)) == 1

+ 1 - 1
tests/test_utils/p2p_daemon.py

@@ -157,7 +157,7 @@ async def _make_p2pd_pair(
     )
     )
     # wait for daemon ready
     # wait for daemon ready
     await p2pd.wait_until_ready()
     await p2pd.wait_until_ready()
-    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    client = await Client.create(control_maddr=control_maddr, listen_maddr=listen_maddr)
     try:
     try:
         async with client.listen():
         async with client.listen():
             yield DaemonTuple(daemon=p2pd, client=client)
             yield DaemonTuple(daemon=p2pd, client=client)