Ver código fonte

Merge remote-tracking branch 'origin/master' into TPU

Michael Diskin 4 anos atrás
pai
commit
a2347ece86

+ 0 - 2
.github/workflows/run-tests.yml

@@ -34,7 +34,6 @@ jobs:
         run: |
           cd tests
           pytest --durations=0 --durations-min=1.0 -v
-
   build_and_test_p2pd:
     runs-on: ubuntu-latest
     timeout-minutes: 10
@@ -61,7 +60,6 @@ jobs:
         run: |
           cd tests
           pytest -k "p2p" -v
-
   codecov_in_develop_mode:
 
     runs-on: ubuntu-latest

+ 1 - 0
.gitignore

@@ -54,6 +54,7 @@ coverage.xml
 .project
 .pydevproject
 .idea
+.vscode
 .ipynb_checkpoints
 
 # Rope

+ 7 - 0
examples/albert/arguments.py

@@ -34,6 +34,13 @@ class BaseTrainingArguments:
         default_factory=list,
         metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
     )
+    identity_path: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
+            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+        },
+    )
 
 
 @dataclass

+ 1 - 0
examples/albert/run_trainer.py

@@ -247,6 +247,7 @@ def main():
         use_ipfs=collaboration_args.use_ipfs,
         host_maddrs=collaboration_args.host_maddrs,
         announce_maddrs=collaboration_args.announce_maddrs,
+        identity_path=collaboration_args.identity_path,
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
 

+ 1 - 0
examples/albert/run_training_monitor.py

@@ -168,6 +168,7 @@ if __name__ == "__main__":
         use_ipfs=monitor_args.use_ipfs,
         host_maddrs=monitor_args.host_maddrs,
         announce_maddrs=monitor_args.announce_maddrs,
+        identity_path=monitor_args.identity_path,
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
 

+ 1 - 1
hivemind/__init__.py

@@ -19,4 +19,4 @@ from hivemind.optim import (
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 
-__version__ = "1.0.0.dev0"
+__version__ = "1.0.0dev0"

+ 1 - 1
hivemind/averaging/allreduce.py

@@ -8,7 +8,7 @@ from hivemind.averaging.partition import AllreduceException, TensorPartContainer
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
+from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 # flavour types

+ 8 - 3
hivemind/averaging/averager.py

@@ -96,7 +96,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         prefix: str,
         target_group_size: int,
         min_group_size: int = 2,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
         request_timeout: float = 3,
         averaging_alpha: float = 1.0,
@@ -117,7 +117,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         ), "bandwidth must be a non-negative float32"
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
+        assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
         super().__init__()
@@ -241,7 +241,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 self._ready.set_result(None)
 
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._matchmaking.request_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                         await task

+ 12 - 93
hivemind/averaging/key_manager.py

@@ -1,4 +1,3 @@
-import asyncio
 import random
 import re
 from typing import List, Optional, Tuple
@@ -25,31 +24,17 @@ class GroupKeyManager:
     Utility class that declares and fetches averaging-related keys using a DHT
     """
 
-    RESERVED_KEY_FOR_NBITS = "::NBITS"
-
     def __init__(
         self,
         dht: DHT,
         prefix: str,
-        initial_group_bits: Optional[str],
+        initial_group_bits: str,
         target_group_size: int,
-        insufficient_size: Optional[int] = None,
-        excessive_size: Optional[int] = None,
-        nbits_expiration: float = 60,
-        nbits_rewrite_grace_period: float = 15,
     ):
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
-        if initial_group_bits is None:
-            search_result = dht.get(f"{prefix}.0b", latest=True)
-            initial_group_nbits = self.get_suggested_nbits(search_result) or 0
-            initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
+        assert all(bit in "01" for bit in initial_group_bits)
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
-        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
-        self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
-        self.excessive_size = excessive_size or target_group_size * 3
-        self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
-        self.suggested_nbits: Optional[int] = None
+        self.peer_id = dht.peer_id
 
     @property
     def current_key(self) -> GroupKey:
@@ -93,51 +78,16 @@ class GroupKeyManager:
         if result is None or not isinstance(result.value, dict):
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
-        averagers = [
-            (PeerID(key), looking_for_group.expiration_time)
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
-        ]
-        num_active_averagers = sum(
-            1
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and looking_for_group.value
-        )
-
-        suggested_nbits = self.get_suggested_nbits(result)
-        if (
-            suggested_nbits is not None
-            and suggested_nbits != len(self.group_bits)
-            and suggested_nbits != self.suggested_nbits
-        ):
-            self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
-        elif num_active_averagers >= self.excessive_size:
-            self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+        averagers = []
+        for key, looking_for_group in result.value.items():
+            try:
+                if only_active and not looking_for_group.value:
+                    continue
+                averagers.append((PeerID(key), looking_for_group.expiration_time))
+            except Exception as e:
+                logger.warning(f"Could not parse group key {key} ({looking_for_group}, exc={e})")
         return averagers
 
-    async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
-        """notify other peers that they can run averaging at this depth"""
-        return await self.dht.store(
-            key=group_key,
-            subkey=self.RESERVED_KEY_FOR_NBITS,
-            value=nbits,
-            expiration_time=expiration_time,
-            return_future=True,
-        )
-
-    @classmethod
-    def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
-        if (
-            isinstance(search_result, ValueWithExpiration)
-            and cls.RESERVED_KEY_FOR_NBITS in search_result.value
-            and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int)
-        ):
-            return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
-        else:
-            return None
-
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
@@ -148,37 +98,6 @@ class GroupKeyManager:
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
-        if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
-            asyncio.create_task(self.notify_stragglers())
-        if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
-            num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
-            self.group_bits = "".join((random.choice("01") for _ in range(num_extra_bits))) + self.group_bits
-            self.group_bits = self.group_bits[-self.suggested_nbits :]
-        self.suggested_nbits = None
-
     async def update_key_on_not_enough_peers(self):
         """this function is triggered whenever averager fails to assemble group within timeout"""
-        new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
-        if self.group_bits != prev_nbits:
-            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
-        self.suggested_nbits = None
-
-    async def notify_stragglers(self):
-        """Find averagers that have fewer nbits and redirect them to your current nbits"""
-        for nbits in reversed(range(1, len(self.group_bits) - 1)):
-            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
-            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
-
-            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
-                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
-                break
-
-        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
-        if (
-            isinstance(root_data, dict)
-            and root_data.get(self.RESERVED_KEY_FOR_NBITS, (None, -float("inf")))[1]
-            > get_dht_time() + self.nbits_grace_period
-        ):
-            return
-        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
+        pass  # to be implemented in subclasses

+ 30 - 42
hivemind/averaging/matchmaking.py

@@ -15,7 +15,7 @@ from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
-from hivemind.utils.asyncio import anext
+from hivemind.utils.asyncio import anext, cancel_and_wait
 
 logger = get_logger(__name__)
 
@@ -45,7 +45,7 @@ class Matchmaking:
         min_group_size: int,
         request_timeout: float,
         client_mode: bool,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
     ):
         assert "." not in prefix, "group prefix must be a string without ."
@@ -127,10 +127,9 @@ class Matchmaking:
                 raise
 
             finally:
-                if not request_leaders_task.done():
-                    request_leaders_task.cancel()
-                if not self.assembled_group.done():
-                    self.assembled_group.cancel()
+                await cancel_and_wait(request_leaders_task)
+                self.assembled_group.cancel()
+
                 while len(self.current_followers) > 0:
                     await self.follower_was_discarded.wait()
                     self.follower_was_discarded.clear()
@@ -229,7 +228,7 @@ class Matchmaking:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
         except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+            logger.exception(f"{self} - failed to request potential leader {leader}:")
             return None
 
         finally:
@@ -413,10 +412,9 @@ class PotentialLeaders:
             try:
                 yield self
             finally:
-                if not update_queue_task.done():
-                    update_queue_task.cancel()
-                if declare and not declare_averager_task.done():
-                    declare_averager_task.cancel()
+                await cancel_and_wait(update_queue_task)
+                if declare:
+                    await cancel_and_wait(declare_averager_task)
 
                 for field in (
                     self.past_attempts,
@@ -477,37 +475,31 @@ class PotentialLeaders:
         else:
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
 
-    async def _update_queue_periodically(self, key_manager: GroupKeyManager):
-        try:
-            DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
-            while get_dht_time() < self.search_end_time:
-                new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
-                self.max_assured_time = max(
-                    self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
-                )
+    async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
+        DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+        while get_dht_time() < self.search_end_time:
+            new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
+            self.max_assured_time = max(
+                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+            )
 
-                self.leader_queue.clear()
-                for peer, peer_expiration_time in new_peers:
-                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
-                        continue
-                    self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
-                    self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+            self.leader_queue.clear()
+            for peer, peer_expiration_time in new_peers:
+                if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
+                    continue
+                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
-                self.update_finished.set()
+            self.update_finished.set()
 
-                await asyncio.wait(
-                    {self.running.wait(), self.update_triggered.wait()},
-                    return_when=asyncio.ALL_COMPLETED,
-                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
-                )
-                self.update_triggered.clear()
-        except (concurrent.futures.CancelledError, asyncio.CancelledError):
-            return  # note: this is a compatibility layer for python3.7
-        except Exception as e:
-            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
-            raise
+            await asyncio.wait(
+                {self.running.wait(), self.update_triggered.wait()},
+                return_when=asyncio.ALL_COMPLETED,
+                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
+            )
+            self.update_triggered.clear()
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
+    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
             try:
                 while True:
@@ -521,10 +513,6 @@ class PotentialLeaders:
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
-            except (concurrent.futures.CancelledError, asyncio.CancelledError):
-                pass  # note: this is a compatibility layer for python3.7
-            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time

+ 19 - 12
hivemind/dht/__init__.py

@@ -27,7 +27,7 @@ from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
+from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_logger, switch_to_uvloop
 
 logger = get_logger(__name__)
 
@@ -61,6 +61,7 @@ class DHT(mp.Process):
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         *,
         start: bool,
+        p2p: Optional[P2P] = None,
         daemon: bool = True,
         num_workers: int = DEFAULT_NUM_WORKERS,
         record_validators: Iterable[RecordValidatorBase] = (),
@@ -94,6 +95,8 @@ class DHT(mp.Process):
         self._client_mode = None
         self._p2p_replica = None
 
+        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
+
         if start:
             self.run_in_background(await_ready=await_ready)
 
@@ -105,10 +108,16 @@ class DHT(mp.Process):
 
             async def _run():
                 try:
+                    if self._daemon_listen_maddr is not None:
+                        replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                    else:
+                        replicated_p2p = None
+
                     self._node = await DHTNode.create(
                         initial_peers=self.initial_peers,
                         num_workers=self.num_workers,
                         record_validator=self._record_validator,
+                        p2p=replicated_p2p,
                         **self.kwargs,
                     )
                 except Exception as e:
@@ -119,7 +128,12 @@ class DHT(mp.Process):
                 self._ready.set_result(None)
 
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._node.protocol.wait_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                         await task
@@ -247,18 +261,11 @@ class DHT(mp.Process):
     async def _run_coroutine(
         self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
     ):
-        main_task = asyncio.create_task(coro(self, self._node))
-        cancel_task = asyncio.create_task(await_cancelled(future))
         try:
-            await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
-            if future.cancelled():
-                main_task.cancel()
-            else:
-                future.set_result(await main_task)
+            future.set_result(await coro(self, self._node))
         except BaseException as e:
-            logger.exception(f"Caught an exception when running a coroutine: {e}")
-            if not future.done():
-                future.set_exception(e)
+            logger.exception("Caught an exception when running a coroutine:")
+            future.set_exception(e)
 
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
         if not self._ready.done():

+ 1 - 1
hivemind/dht/protocol.py

@@ -81,7 +81,7 @@ class DHTProtocol(ServicerBase):
 
     def __init__(self, *, _initialized_with_create=False):
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
-        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
+        assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
         super().__init__()
 
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:

+ 99 - 45
hivemind/p2p/p2p_daemon.py

@@ -5,12 +5,14 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
+from google.protobuf.message import Message
 from multiaddr import Multiaddr
 
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import aiter, asingle
@@ -27,7 +29,6 @@ class P2PContext(object):
     handle_name: str
     local_id: PeerID
     remote_id: PeerID = None
-    remote_maddr: Multiaddr = None
 
 
 class P2P:
@@ -65,6 +66,7 @@ class P2P:
 
     def __init__(self):
         self.peer_id = None
+        self._client = None
         self._child = None
         self._alive = False
         self._reader_task = None
@@ -74,43 +76,50 @@ class P2P:
     async def create(
         cls,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        use_ipfs: bool = False,
-        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        *,
         announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        quic: bool = False,
-        tls: bool = True,
+        auto_nat: bool = True,
         conn_manager: bool = True,
         dht_mode: str = "dht_server",
         force_reachability: Optional[str] = None,
+        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        identity_path: Optional[str] = None,
+        idle_timeout: float = 30,
         nat_port_map: bool = True,
-        auto_nat: bool = True,
+        quic: bool = False,
+        relay_hop_limit: int = 0,
+        startup_timeout: float = 15,
+        tls: bool = True,
+        use_auto_relay: bool = False,
+        use_ipfs: bool = False,
         use_relay: bool = True,
         use_relay_hop: bool = False,
         use_relay_discovery: bool = False,
-        use_auto_relay: bool = False,
-        relay_hop_limit: int = 0,
-        startup_timeout: float = 15,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
         :param initial_peers: List of bootstrap peers
-        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
-        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param auto_nat: Enables the AutoNAT service
         :param announce_maddrs: Visible multiaddrs that the peer will announce
-          for external connections from other p2p instances
-        :param quic: Enables the QUIC transport
-        :param tls: Enables TLS1.3 channel security protocol
+                                for external connections from other p2p instances
         :param conn_manager: Enables the Connection Manager
         :param dht_mode: DHT mode (dht_client/dht_server/dht)
         :param force_reachability: Force reachability mode (public/private)
+        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
+                              May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``.
+        :param idle_timeout: kill daemon if client has been idle for a given number of
+                             seconds before opening persistent streams
         :param nat_port_map: Enables NAT port mapping
-        :param auto_nat: Enables the AutoNAT service
+        :param quic: Enables the QUIC transport
+        :param relay_hop_limit: sets the hop limit for hop relays
+        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
+        :param tls: Enables TLS1.3 channel security protocol
+        :param use_auto_relay: enables autorelay
+        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
         :param use_relay: enables circuit relay
         :param use_relay_hop: enables hop for relay
         :param use_relay_discovery: enables passive discovery for relay
-        :param use_auto_relay: enables autorelay
-        :param relay_hop_limit: sets the hop limit for hop relays
-        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :return: a wrapper for the p2p daemon
         """
 
@@ -136,21 +145,24 @@ class P2P:
         ]:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
+        if identity_path is not None:
+            process_kwargs["id"] = identity_path
 
         proc_args = self._make_process_args(
             str(p2pd_path),
-            listen=self._daemon_listen_maddr,
-            quic=quic,
-            tls=tls,
+            autoRelay=use_auto_relay,
+            autonat=auto_nat,
+            b=need_bootstrap,
             connManager=conn_manager,
+            idleTimeout=f"{idle_timeout}s",
+            listen=self._daemon_listen_maddr,
             natPortMap=nat_port_map,
-            autonat=auto_nat,
+            quic=quic,
             relay=use_relay,
-            relayHop=use_relay_hop,
             relayDiscovery=use_relay_discovery,
-            autoRelay=use_auto_relay,
+            relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
-            b=need_bootstrap,
+            tls=tls,
             **process_kwargs,
         )
 
@@ -167,7 +179,7 @@ class P2P:
             await self.shutdown()
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
         await self._ping_daemon()
         return self
 
@@ -189,7 +201,7 @@ class P2P:
         self._daemon_listen_maddr = daemon_listen_maddr
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
 
         await self._ping_daemon()
         return self
@@ -258,7 +270,7 @@ class P2P:
 
     @staticmethod
     async def receive_protobuf(
-        input_protobuf_type: type, reader: asyncio.StreamReader
+        input_protobuf_type: Type[Message], reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
@@ -279,7 +291,7 @@ class P2P:
         self,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
     ) -> None:
         """
@@ -297,7 +309,6 @@ class P2P:
                 handle_name=name,
                 local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
-                remote_maddr=stream_info.addr,
             )
             requests = asyncio.Queue(max_prefetch)
 
@@ -349,7 +360,7 @@ class P2P:
         await self.add_binary_stream_handler(name, _handle_stream)
 
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
     ) -> TOutputStream:
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
@@ -381,15 +392,22 @@ class P2P:
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         *,
         stream_input: bool = False,
+        stream_output: bool = False,
     ) -> None:
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
                              (not just ``TInputProtobuf``) as input.
+        :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
+                              (not ``Awaitable[TOutputProtobuf]``).
         """
 
+        if not stream_input and not stream_output:
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            return
+
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
             input = requests if stream_input else await asingle(requests)
             output = handler(input, context)
@@ -402,23 +420,65 @@ class P2P:
 
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
 
+    async def _add_protobuf_unary_handler(
+        self,
+        handle_name: str,
+        handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
+        input_protobuf_type: Type[Message],
+    ) -> None:
+        """
+        Register a request-response (unary) handler. Unary requests and responses
+        are sent through persistent multiplexed connections to the daemon for the
+        sake of reducing the number of open files.
+        :param handle_name: name of the handler (protocol id)
+        :param handler: function handling the unary requests
+        :param input_protobuf_type: protobuf type of the request
+        """
+
+        async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
+            input_serialized = input_protobuf_type.FromString(request)
+            context = P2PContext(
+                handle_name=handle_name,
+                local_id=self.peer_id,
+                remote_id=remote_id,
+            )
+
+            response = await handler(input_serialized, context)
+            return response.SerializeToString()
+
+        await self._client.add_unary_handler(handle_name, _unary_handler)
+
     async def call_protobuf_handler(
         self,
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+
+        if not isinstance(input, AsyncIterableABC):
+            return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
+
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
 
+    async def _call_unary_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        handle_name: str,
+        input: TInputProtobuf,
+        output_protobuf_type: Type[Message],
+    ) -> Awaitable[TOutputProtobuf]:
+        serialized_input = input.SerializeToString()
+        response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
+        return output_protobuf_type.FromString(response)
+
     def iterate_protobuf_handler(
         self,
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
@@ -453,6 +513,8 @@ class P2P:
             await self._child.wait()
 
     def _terminate(self) -> None:
+        if self._client is not None:
+            self._client.close()
         if self._listen_task is not None:
             self._listen_task.cancel()
         if self._reader_task is not None:
@@ -501,11 +563,3 @@ class P2P:
 
         if not ready.done():
             ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
-
-
-class P2PDaemonError(RuntimeError):
-    pass
-
-
-class P2PHandlerError(Exception):
-    pass

+ 189 - 3
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -5,8 +5,9 @@ Author: Kevin Mai-Husan Chia
 """
 
 import asyncio
-from contextlib import asynccontextmanager
-from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
+from contextlib import asynccontextmanager, closing
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
+from uuid import UUID, uuid4
 
 from multiaddr import Multiaddr, protocols
 
@@ -54,17 +55,75 @@ class DaemonConnector:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
+    async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        """
+        Open connection to daemon and upgrade it to a persistent one
+        """
+        reader, writer = await self.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
+        await write_pbmsg(writer, req)
+
+        response = p2pd_pb.Response()
+        await read_pbmsg_safe(reader, response)
+
+        if response.type == "ERROR":
+            raise P2PDaemonError(response.error.msg)
+
+        return reader, writer
+
+
+TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
+CallID = UUID
+
 
 class ControlClient:
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
 
     def __init__(
-        self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+        self,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        *,
+        _initialized_with_create=False,
     ) -> None:
+        assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
+
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
 
+        self.unary_handlers: Dict[str, TUnaryHandler] = {}
+
+        self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
+        self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self._handler_tasks: Dict[CallID, asyncio.Task] = {}
+
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
+
+    @classmethod
+    async def create(
+        cls,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        use_persistent_conn: bool = True,
+    ) -> "ControlClient":
+        control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
+
+        if use_persistent_conn:
+            await control._ensure_persistent_conn()
+
+        return control
+
+    def close(self) -> None:
+        if self._read_task is not None:
+            self._read_task.cancel()
+        if self._write_task is not None:
+            self._write_task.cancel()
+
+    def __del__(self):
+        self.close()
+
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
@@ -93,6 +152,121 @@ class ControlClient:
         async with server:
             yield self
 
+    async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
+        while True:
+            resp = p2pd_pb.PersistentConnectionResponse()
+            try:
+                await read_pbmsg_safe(reader, resp)
+            except asyncio.IncompleteReadError:
+                break
+
+            call_id = UUID(bytes=resp.callId)
+
+            if resp.HasField("callUnaryResponse"):
+                if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
+                    self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
+                elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
+                    remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
+                    self._pending_calls[call_id].set_exception(remote_exc)
+                else:
+                    logger.debug(f"Received unexpected unary call: {resp}")
+
+            elif resp.HasField("requestHandling"):
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self._handler_tasks[call_id] = handler_task
+
+            elif call_id in self._handler_tasks and resp.HasField("cancel"):
+                self._handler_tasks[call_id].cancel()
+
+            elif call_id in self._pending_calls and resp.HasField("daemonError"):
+                daemon_exc = P2PDaemonError(resp.daemonError.message)
+                self._pending_calls[call_id].set_exception(daemon_exc)
+
+            elif call_id in self._pending_calls:
+                self._pending_calls[call_id].set_result(None)
+
+            else:
+                logger.debug(f"Received unexpected response from daemon: {resp}")
+
+    async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        with closing(writer):
+            while True:
+                msg = await self._pending_messages.get()
+                await write_pbmsg(writer, msg)
+
+    async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest):
+        if request.proto not in self.unary_handlers:
+            logger.warning(f"Protocol {request.proto} not supported")
+            return
+
+        try:
+            remote_id = PeerID(request.peer)
+            response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
+            response = p2pd_pb.CallUnaryResponse(response=response_payload)
+
+        except Exception as e:
+            response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
+
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                unaryResponse=response,
+            )
+        )
+        self._handler_tasks.pop(call_id)
+
+    async def _cancel_unary_call(self, call_id: UUID):
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                cancel=p2pd_pb.Cancel(),
+            ),
+        )
+
+    async def _ensure_persistent_conn(self):
+        reader, writer = await self.daemon_connector.open_persistent_connection()
+
+        self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+        self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        call_id = uuid4()
+
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+
+        if self.unary_handlers.get(proto):
+            raise P2PDaemonError(f"Handler for protocol {proto} already registered")
+        self.unary_handlers[proto] = handler
+
+        self._pending_calls[call_id] = asyncio.Future()
+        await self._pending_messages.put(req)
+        await self._pending_calls[call_id]
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        call_id = uuid4()
+        call_unary_req = p2pd_pb.CallUnaryRequest(
+            peer=peer_id.to_bytes(),
+            proto=proto,
+            data=data,
+        )
+        req = p2pd_pb.PersistentConnectionRequest(
+            callId=call_id.bytes,
+            callUnary=call_unary_req,
+        )
+
+        try:
+            self._pending_calls[call_id] = asyncio.Future()
+            await self._pending_messages.put(req)
+            return await self._pending_calls[call_id]
+
+        except asyncio.CancelledError:
+            await self._cancel_unary_call(call_id)
+            raise
+
+        finally:
+            self._pending_calls.pop(call_id, None)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
@@ -179,3 +353,15 @@ class ControlClient:
 
         # if success, add the handler to the dict
         self.handlers[proto] = handler_cb
+
+
+class P2PHandlerError(Exception):
+    """
+    Raised if remote handled a request with an exception
+    """
+
+
+class P2PDaemonError(Exception):
+    """
+    Raised if daemon failed to handle request
+    """

+ 3 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -131,6 +131,9 @@ class PeerInfo:
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
+    def __repr__(self):
+        return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})"
+
 
 class InvalidAddrError(ValueError):
     pass

+ 24 - 3
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,16 +10,31 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 from multiaddr import Multiaddr
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler, TUnaryHandler
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
 class Client:
     control: ControlClient
 
-    def __init__(self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> None:
+    def __init__(self, *, _initialized_with_create=False) -> None:
+        assert _initialized_with_create, "Please use Client.create coroutine to spawn new client instances"
+        self.control = None
+
+    @classmethod
+    async def create(cls, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> "Client":
+        client = cls(_initialized_with_create=True)
+
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        self.control = ControlClient(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+
+        return client
+
+    def close(self) -> None:
+        self.control.close()
+
+    def __del__(self):
+        self.close()
 
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["Client"]:
@@ -30,6 +45,12 @@ class Client:
         async with self.control.listen():
             yield self
 
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        await self.control.add_unary_handler(proto, handler)
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self.control.call_unary_handler(peer_id, proto, data)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         """
         Get current node peer id and list of addresses

+ 13 - 7
hivemind/p2p/servicer.py

@@ -125,13 +125,19 @@ class ServicerBase:
         self._collect_rpc_handlers()
 
         servicer = self if wrapper is None else wrapper
-        for handler in self._rpc_handlers:
-            await p2p.add_protobuf_handler(
-                self._get_handle_name(namespace, handler.method_name),
-                getattr(servicer, handler.method_name),
-                handler.request_type,
-                stream_input=handler.stream_input,
-            )
+
+        await asyncio.gather(
+            *[
+                p2p.add_protobuf_handler(
+                    self._get_handle_name(namespace, handler.method_name),
+                    getattr(servicer, handler.method_name),
+                    handler.request_type,
+                    stream_input=handler.stream_input,
+                    stream_output=handler.stream_output,
+                )
+                for handler in self._rpc_handlers
+            ]
+        )
 
     @classmethod
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:

+ 59 - 10
hivemind/proto/p2pd.proto

@@ -8,15 +8,17 @@ package p2pclient.p2pd.pb;
 
 message Request {
   enum Type {
-    IDENTIFY       = 0;
-    CONNECT        = 1;
-    STREAM_OPEN    = 2;
-    STREAM_HANDLER = 3;
-    DHT            = 4;
-    LIST_PEERS     = 5;
-    CONNMANAGER    = 6;
-    DISCONNECT     = 7;
-    PUBSUB         = 8;
+    IDENTIFY                 = 0;
+    CONNECT                  = 1;
+    STREAM_OPEN              = 2;
+    STREAM_HANDLER           = 3;
+    DHT                      = 4;
+    LIST_PEERS               = 5;
+    CONNMANAGER              = 6;
+    DISCONNECT               = 7;      
+    PUBSUB                   = 8;
+
+    PERSISTENT_CONN_UPGRADE  = 9;
   }
 
   required Type type = 1;
@@ -45,6 +47,29 @@ message Response {
   optional PSResponse pubsub = 7;
 }
 
+message PersistentConnectionRequest {
+  required bytes callId = 1;
+
+  oneof message {
+    AddUnaryHandlerRequest addUnaryHandler = 2;
+    CallUnaryRequest  callUnary = 3;
+    CallUnaryResponse unaryResponse = 4;
+    Cancel cancel = 5;
+  }
+}
+
+message PersistentConnectionResponse {
+  required bytes callId = 1;
+
+  oneof message {
+    CallUnaryResponse callUnaryResponse = 2;
+    CallUnaryRequest requestHandling = 3;
+    DaemonError daemonError = 4;
+    Cancel cancel = 5;
+  }
+}
+
+
 message IdentifyResponse {
   required bytes id = 1;
   repeated bytes addrs = 2;
@@ -148,7 +173,7 @@ message PSRequest {
 }
 
 message PSMessage {
-  optional bytes from_id = 1;
+  optional bytes from = 1;
   optional bytes data = 2;
   optional bytes seqno = 3;
   repeated string topicIDs = 4;
@@ -161,6 +186,30 @@ message PSResponse {
   repeated bytes peerIDs = 2;
 }
 
+message CallUnaryRequest {
+  required bytes peer = 1;
+  required string proto = 2;
+  required bytes data = 3;
+}
+
+message CallUnaryResponse {
+  oneof result {
+    bytes response = 1;
+    bytes error = 2;
+  }
+}
+
+message AddUnaryHandlerRequest {
+  required string proto = 1;
+}
+
+message DaemonError {
+  optional string message = 1;
+}
+
+message Cancel {
+}
+
 message RPCError {
   optional string message = 1;
 }

+ 16 - 1
hivemind/utils/asyncio.py

@@ -1,4 +1,5 @@
 import asyncio
+import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 
@@ -81,12 +82,26 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable
         return False
-    except asyncio.CancelledError:
+    except (asyncio.CancelledError, concurrent.futures.CancelledError):
+        # In Python 3.7, awaiting a cancelled asyncio.Future raises concurrent.futures.CancelledError
+        # instead of asyncio.CancelledError
         return True
     except BaseException:
+        logger.exception(f"Exception in {awaitable}:")
         return False
 
 
+async def cancel_and_wait(awaitable: Awaitable) -> bool:
+    """
+    Cancels ``awaitable`` and waits for its cancellation.
+    In case of ``asyncio.Task``, helps to avoid ``Task was destroyed but it is pending!`` errors.
+    In case of ``asyncio.Future``, equal to ``future.cancel()``.
+    """
+
+    awaitable.cancel()
+    return await await_cancelled(awaitable)
+
+
 async def amap_in_executor(
     func: Callable[..., T],
     *iterables: AsyncIterable,

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -53,7 +53,7 @@ class SharedBytes:
         """Create another shared byte value, represented as a scalar uint8 tensor"""
         with cls._lock:
             if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
-                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16))
                 cls._pid = os.getpid()
                 cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
                 cls._index = 0

+ 2 - 2
setup.py

@@ -14,8 +14,8 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.1"
-P2PD_CHECKSUM = "15292b880c6b31f5b3c36084b3acc17f"
+P2PD_VERSION = "v0.3.4"
+P2PD_CHECKSUM = "194dca06116fdd36bc4b681d18f3b9cb"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 
 here = os.path.abspath(os.path.dirname(__file__))

+ 5 - 2
tests/test_averaging.py

@@ -10,6 +10,7 @@ import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.partition import AllreduceException
 from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 
@@ -363,9 +364,11 @@ def test_too_few_peers():
         )
         for i, dht in enumerate(dht_instances)
     ]
-    step_futures = [averager.step(wait=False) for averager in averagers]
+    step_futures = [averager.step(wait=False, timeout=2) for averager in averagers]
+
     for future in step_futures:
-        assert len(future.result()) == 2
+        with pytest.raises(AllreduceException):
+            future.result()
 
     for process in averagers + dht_instances:
         process.shutdown()

+ 2 - 2
tests/test_dht.py

@@ -14,7 +14,7 @@ from test_utils.dht_swarms import launch_dht_instances
 
 @pytest.mark.asyncio
 async def test_startup_error():
-    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         hivemind.DHT(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
             start=True,
@@ -118,7 +118,7 @@ async def test_dht_get_visible_maddrs():
 
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
-    dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
+    dht = hivemind.DHT(start=True, p2p=p2p)
 
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()

+ 29 - 4
tests/test_p2p_daemon.py

@@ -10,7 +10,7 @@ import pytest
 from multiaddr import Multiaddr
 
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
-from hivemind.proto import dht_pb2
+from hivemind.proto import dht_pb2, test_pb2
 from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 
@@ -36,7 +36,7 @@ async def test_daemon_killed_on_del():
 
 @pytest.mark.asyncio
 async def test_startup_error_message():
-    with pytest.raises(P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         await P2P.create(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
         )
@@ -63,9 +63,9 @@ async def test_transports(host_maddrs: List[Multiaddr]):
     await client.wait_for_at_least_n_peers(1)
 
     peers = await client.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
     peers = await server.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
 
 
 @pytest.mark.asyncio
@@ -83,6 +83,31 @@ async def test_daemon_replica_does_not_affect_primary():
     assert not is_process_running(child_pid)
 
 
+@pytest.mark.asyncio
+async def test_unary_handler_edge_cases():
+    p2p = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p.daemon_listen_maddr)
+
+    async def square_handler(data: test_pb2.TestRequest, context):
+        return test_pb2.TestResponse(number=data.number ** 2)
+
+    await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler
+    with pytest.raises(P2PDaemonError):
+        await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler from replicated p2p
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try dialing yourself
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.call_protobuf_handler(
+            p2p.peer_id, "square", test_pb2.TestRequest(number=41), test_pb2.TestResponse
+        )
+
+
 @pytest.mark.parametrize(
     "should_cancel,replicate",
     [

+ 13 - 6
tests/test_p2p_daemon_bindings.py

@@ -199,24 +199,31 @@ def test_parse_conn_protocol_invalid(maddr_str):
 
 
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_client_ctor_control_maddr(control_maddr_str):
+@pytest.mark.asyncio
+async def test_client_create_control_maddr(control_maddr_str):
     c = DaemonConnector(Multiaddr(control_maddr_str))
     assert c.control_maddr == Multiaddr(control_maddr_str)
 
 
-def test_client_ctor_default_control_maddr():
+def test_client_create_default_control_maddr():
     c = DaemonConnector()
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
 
 
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_control_client_ctor_listen_maddr(listen_maddr_str):
-    c = ControlClient(daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str))
+@pytest.mark.asyncio
+async def test_control_client_create_listen_maddr(listen_maddr_str):
+    c = await ControlClient.create(
+        daemon_connector=DaemonConnector(),
+        listen_maddr=Multiaddr(listen_maddr_str),
+        use_persistent_conn=False,
+    )
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
 
 
-def test_control_client_ctor_default_listen_maddr():
-    c = ControlClient(daemon_connector=DaemonConnector())
+@pytest.mark.asyncio
+async def test_control_client_create_default_listen_maddr():
+    c = await ControlClient.create(daemon_connector=DaemonConnector(), use_persistent_conn=False)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
 
 

+ 43 - 1
tests/test_util_modules.py

@@ -13,7 +13,17 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
-from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext, asingle, azip
+from hivemind.utils.asyncio import (
+    achain,
+    aenumerate,
+    afirst,
+    aiter,
+    amap_in_executor,
+    anext,
+    asingle,
+    azip,
+    cancel_and_wait,
+)
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 
@@ -509,3 +519,35 @@ async def test_asyncio_utils():
     assert await afirst(aiter()) is None
     assert await afirst(aiter(), -1) == -1
     assert await afirst(aiter(1, 2, 3)) == 1
+
+
+@pytest.mark.asyncio
+async def test_cancel_and_wait():
+    finished_gracefully = False
+
+    async def coro_with_finalizer():
+        nonlocal finished_gracefully
+
+        try:
+            await asyncio.Event().wait()
+        except asyncio.CancelledError:
+            await asyncio.sleep(0.05)
+            finished_gracefully = True
+            raise
+
+    task = asyncio.create_task(coro_with_finalizer())
+    await asyncio.sleep(0.05)
+    assert await cancel_and_wait(task)
+    assert finished_gracefully
+
+    async def coro_with_result():
+        return 777
+
+    async def coro_with_error():
+        raise ValueError("error")
+
+    task_with_result = asyncio.create_task(coro_with_result())
+    task_with_error = asyncio.create_task(coro_with_error())
+    await asyncio.sleep(0.05)
+    assert not await cancel_and_wait(task_with_result)
+    assert not await cancel_and_wait(task_with_error)

+ 1 - 1
tests/test_utils/p2p_daemon.py

@@ -157,7 +157,7 @@ async def _make_p2pd_pair(
     )
     # wait for daemon ready
     await p2pd.wait_until_ready()
-    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    client = await Client.create(control_maddr=control_maddr, listen_maddr=listen_maddr)
     try:
         async with client.listen():
             yield DaemonTuple(daemon=p2pd, client=client)