Browse Source

Merge branch 'master' of github.com:learning-at-home/hivemind into server-p2p

Denis Mazur 4 years ago
parent
commit
4739d33cfc

+ 0 - 2
.github/workflows/run-tests.yml

@@ -34,7 +34,6 @@ jobs:
         run: |
           cd tests
           pytest --durations=0 --durations-min=1.0 -v
-
   build_and_test_p2pd:
     runs-on: ubuntu-latest
     timeout-minutes: 10
@@ -61,7 +60,6 @@ jobs:
         run: |
           cd tests
           pytest -k "p2p" -v
-
   codecov_in_develop_mode:
 
     runs-on: ubuntu-latest

+ 1 - 0
.gitignore

@@ -54,6 +54,7 @@ coverage.xml
 .project
 .pydevproject
 .idea
+.vscode
 .ipynb_checkpoints
 
 # Rope

+ 10 - 4
examples/albert/arguments.py

@@ -13,7 +13,7 @@ class BaseTrainingArguments:
         default_factory=list,
         metadata={
             "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
-            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
+            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY"
         },
     )
     use_ipfs: bool = field(
@@ -24,17 +24,23 @@ class BaseTrainingArguments:
         },
     )
     host_maddrs: List[str] = field(
-        default_factory=lambda: ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
+        default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"],
         metadata={
             "help": "Multiaddrs to listen for external connections from other p2p instances. "
-            "Defaults to all IPv4 interfaces with TCP and QUIC (over UDP) protocols: "
-            "/ip4/0.0.0.0/tcp/0 /ip4/0.0.0.0/udp/0/quic"
+            "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0"
         },
     )
     announce_maddrs: List[str] = field(
         default_factory=list,
         metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
     )
+    identity_path: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
+            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+        },
+    )
 
 
 @dataclass

+ 1 - 0
examples/albert/run_trainer.py

@@ -248,6 +248,7 @@ def main():
         use_ipfs=collaboration_args.use_ipfs,
         host_maddrs=collaboration_args.host_maddrs,
         announce_maddrs=collaboration_args.announce_maddrs,
+        identity_path=collaboration_args.identity_path,
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
 

+ 2 - 1
examples/albert/run_training_monitor.py

@@ -156,7 +156,7 @@ if __name__ == "__main__":
         address = request.text
         logger.info(f"Received public IP address of this machine: {address}")
         version = ip_address(address).version
-        monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0", f"/ip{version}/{address}/udp/0/quic"]
+        monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]
 
     experiment_prefix = monitor_args.experiment_prefix
     validators, local_public_key = utils.make_validators(experiment_prefix)
@@ -168,6 +168,7 @@ if __name__ == "__main__":
         use_ipfs=monitor_args.use_ipfs,
         host_maddrs=monitor_args.host_maddrs,
         announce_maddrs=monitor_args.announce_maddrs,
+        identity_path=monitor_args.identity_path,
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
 

+ 1 - 1
hivemind/__init__.py

@@ -19,4 +19,4 @@ from hivemind.optim import (
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 
-__version__ = "1.0.0.dev0"
+__version__ = "1.0.0dev0"

+ 3 - 3
hivemind/averaging/allreduce.py

@@ -8,7 +8,7 @@ from hivemind.averaging.partition import AllreduceException, TensorPartContainer
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
+from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 # flavour types
@@ -231,8 +231,8 @@ class AllReduceRunner(ServicerBase):
 
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        # In case of reporting the error, we expect the response stream to contain exactly one item
-        await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
+        # Coroutines are lazy, so we take the first item to start the couroutine's execution
+        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 9 - 4
hivemind/averaging/averager.py

@@ -96,7 +96,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         prefix: str,
         target_group_size: int,
         min_group_size: int = 2,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
         request_timeout: float = 3,
         averaging_alpha: float = 1.0,
@@ -117,7 +117,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         ), "bandwidth must be a non-negative float32"
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
+        assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
         super().__init__()
@@ -241,7 +241,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 self._ready.set_result(None)
 
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._matchmaking.request_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                         await task
@@ -593,7 +598,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         future.set_result((metadata, tensors))
                         self.last_updated = get_dht_time()
                         return
-                    except BaseException as e:
+                    except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
 
         finally:

+ 12 - 93
hivemind/averaging/key_manager.py

@@ -1,4 +1,3 @@
-import asyncio
 import random
 import re
 from typing import List, Optional, Tuple
@@ -25,31 +24,17 @@ class GroupKeyManager:
     Utility class that declares and fetches averaging-related keys using a DHT
     """
 
-    RESERVED_KEY_FOR_NBITS = "::NBITS"
-
     def __init__(
         self,
         dht: DHT,
         prefix: str,
-        initial_group_bits: Optional[str],
+        initial_group_bits: str,
         target_group_size: int,
-        insufficient_size: Optional[int] = None,
-        excessive_size: Optional[int] = None,
-        nbits_expiration: float = 60,
-        nbits_rewrite_grace_period: float = 15,
     ):
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
-        if initial_group_bits is None:
-            search_result = dht.get(f"{prefix}.0b", latest=True)
-            initial_group_nbits = self.get_suggested_nbits(search_result) or 0
-            initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
+        assert all(bit in "01" for bit in initial_group_bits)
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
-        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
-        self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
-        self.excessive_size = excessive_size or target_group_size * 3
-        self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
-        self.suggested_nbits: Optional[int] = None
+        self.peer_id = dht.peer_id
 
     @property
     def current_key(self) -> GroupKey:
@@ -93,51 +78,16 @@ class GroupKeyManager:
         if result is None or not isinstance(result.value, dict):
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
-        averagers = [
-            (PeerID(key), looking_for_group.expiration_time)
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
-        ]
-        num_active_averagers = sum(
-            1
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and looking_for_group.value
-        )
-
-        suggested_nbits = self.get_suggested_nbits(result)
-        if (
-            suggested_nbits is not None
-            and suggested_nbits != len(self.group_bits)
-            and suggested_nbits != self.suggested_nbits
-        ):
-            self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
-        elif num_active_averagers >= self.excessive_size:
-            self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+        averagers = []
+        for key, looking_for_group in result.value.items():
+            try:
+                if only_active and not looking_for_group.value:
+                    continue
+                averagers.append((PeerID(key), looking_for_group.expiration_time))
+            except Exception as e:
+                logger.warning(f"Could not parse group key {key} ({looking_for_group}, exc={e})")
         return averagers
 
-    async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
-        """notify other peers that they can run averaging at this depth"""
-        return await self.dht.store(
-            key=group_key,
-            subkey=self.RESERVED_KEY_FOR_NBITS,
-            value=nbits,
-            expiration_time=expiration_time,
-            return_future=True,
-        )
-
-    @classmethod
-    def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
-        if (
-            isinstance(search_result, ValueWithExpiration)
-            and cls.RESERVED_KEY_FOR_NBITS in search_result.value
-            and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int)
-        ):
-            return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
-        else:
-            return None
-
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
@@ -148,37 +98,6 @@ class GroupKeyManager:
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
-        if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
-            asyncio.create_task(self.notify_stragglers())
-        if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
-            num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
-            self.group_bits = "".join((random.choice("01") for _ in range(num_extra_bits))) + self.group_bits
-            self.group_bits = self.group_bits[-self.suggested_nbits :]
-        self.suggested_nbits = None
-
     async def update_key_on_not_enough_peers(self):
         """this function is triggered whenever averager fails to assemble group within timeout"""
-        new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
-        if self.group_bits != prev_nbits:
-            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
-        self.suggested_nbits = None
-
-    async def notify_stragglers(self):
-        """Find averagers that have fewer nbits and redirect them to your current nbits"""
-        for nbits in reversed(range(1, len(self.group_bits) - 1)):
-            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
-            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
-
-            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
-                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
-                break
-
-        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
-        if (
-            isinstance(root_data, dict)
-            and root_data.get(self.RESERVED_KEY_FOR_NBITS, (None, -float("inf")))[1]
-            > get_dht_time() + self.nbits_grace_period
-        ):
-            return
-        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
+        pass  # to be implemented in subclasses

+ 31 - 43
hivemind/averaging/matchmaking.py

@@ -15,7 +15,7 @@ from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
-from hivemind.utils.asyncio import anext
+from hivemind.utils.asyncio import anext, cancel_and_wait
 
 logger = get_logger(__name__)
 
@@ -45,7 +45,7 @@ class Matchmaking:
         min_group_size: int,
         request_timeout: float,
         client_mode: bool,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
     ):
         assert "." not in prefix, "group prefix must be a string without ."
@@ -127,10 +127,9 @@ class Matchmaking:
                 raise
 
             finally:
-                if not request_leaders_task.done():
-                    request_leaders_task.cancel()
-                if not self.assembled_group.done():
-                    self.assembled_group.cancel()
+                await cancel_and_wait(request_leaders_task)
+                self.assembled_group.cancel()
+
                 while len(self.current_followers) > 0:
                     await self.follower_was_discarded.wait()
                     self.follower_was_discarded.clear()
@@ -189,7 +188,7 @@ class Matchmaking:
                         gather=self.data_for_gather,
                         group_key=self.group_key_manager.current_key,
                     )
-                ).__aiter__()
+                )
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
@@ -229,7 +228,7 @@ class Matchmaking:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
         except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+            logger.exception(f"{self} - failed to request potential leader {leader}:")
             return None
 
         finally:
@@ -413,10 +412,9 @@ class PotentialLeaders:
             try:
                 yield self
             finally:
-                if not update_queue_task.done():
-                    update_queue_task.cancel()
-                if declare and not declare_averager_task.done():
-                    declare_averager_task.cancel()
+                await cancel_and_wait(update_queue_task)
+                if declare:
+                    await cancel_and_wait(declare_averager_task)
 
                 for field in (
                     self.past_attempts,
@@ -477,37 +475,31 @@ class PotentialLeaders:
         else:
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
 
-    async def _update_queue_periodically(self, key_manager: GroupKeyManager):
-        try:
-            DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
-            while get_dht_time() < self.search_end_time:
-                new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
-                self.max_assured_time = max(
-                    self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
-                )
+    async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
+        DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+        while get_dht_time() < self.search_end_time:
+            new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
+            self.max_assured_time = max(
+                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+            )
 
-                self.leader_queue.clear()
-                for peer, peer_expiration_time in new_peers:
-                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
-                        continue
-                    self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
-                    self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+            self.leader_queue.clear()
+            for peer, peer_expiration_time in new_peers:
+                if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
+                    continue
+                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
-                self.update_finished.set()
+            self.update_finished.set()
 
-                await asyncio.wait(
-                    {self.running.wait(), self.update_triggered.wait()},
-                    return_when=asyncio.ALL_COMPLETED,
-                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
-                )
-                self.update_triggered.clear()
-        except (concurrent.futures.CancelledError, asyncio.CancelledError):
-            return  # note: this is a compatibility layer for python3.7
-        except Exception as e:
-            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
-            raise
+            await asyncio.wait(
+                {self.running.wait(), self.update_triggered.wait()},
+                return_when=asyncio.ALL_COMPLETED,
+                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
+            )
+            self.update_triggered.clear()
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
+    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
             try:
                 while True:
@@ -521,10 +513,6 @@ class PotentialLeaders:
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
-            except (concurrent.futures.CancelledError, asyncio.CancelledError):
-                pass  # note: this is a compatibility layer for python3.7
-            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time

+ 19 - 12
hivemind/dht/__init__.py

@@ -27,7 +27,7 @@ from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
+from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_logger, switch_to_uvloop
 
 logger = get_logger(__name__)
 
@@ -61,6 +61,7 @@ class DHT(mp.Process):
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         *,
         start: bool,
+        p2p: Optional[P2P] = None,
         daemon: bool = True,
         num_workers: int = DEFAULT_NUM_WORKERS,
         record_validators: Iterable[RecordValidatorBase] = (),
@@ -94,6 +95,8 @@ class DHT(mp.Process):
         self._client_mode = None
         self._p2p_replica = None
 
+        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
+
         if start:
             self.run_in_background(await_ready=await_ready)
 
@@ -105,10 +108,16 @@ class DHT(mp.Process):
 
             async def _run():
                 try:
+                    if self._daemon_listen_maddr is not None:
+                        replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                    else:
+                        replicated_p2p = None
+
                     self._node = await DHTNode.create(
                         initial_peers=self.initial_peers,
                         num_workers=self.num_workers,
                         record_validator=self._record_validator,
+                        p2p=replicated_p2p,
                         **self.kwargs,
                     )
                 except Exception as e:
@@ -119,7 +128,12 @@ class DHT(mp.Process):
                 self._ready.set_result(None)
 
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._node.protocol.wait_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                         await task
@@ -247,18 +261,11 @@ class DHT(mp.Process):
     async def _run_coroutine(
         self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
     ):
-        main_task = asyncio.create_task(coro(self, self._node))
-        cancel_task = asyncio.create_task(await_cancelled(future))
         try:
-            await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
-            if future.cancelled():
-                main_task.cancel()
-            else:
-                future.set_result(await main_task)
+            future.set_result(await coro(self, self._node))
         except BaseException as e:
-            logger.exception(f"Caught an exception when running a coroutine: {e}")
-            if not future.done():
-                future.set_exception(e)
+            logger.exception("Caught an exception when running a coroutine:")
+            future.set_exception(e)
 
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
         if not self._ready.done():

+ 9 - 4
hivemind/dht/node.py

@@ -121,7 +121,7 @@ class DHTNode:
         client_mode: bool = False,
         record_validator: Optional[RecordValidatorBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
-        validate: bool = True,
+        ensure_bootstrap_success: bool = True,
         strict: bool = True,
         **kwargs,
     ) -> DHTNode:
@@ -156,7 +156,8 @@ class DHTNode:
         :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
         :param blacklist_time: excludes non-responsive peers from search for this many seconds (set 0 to disable)
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
-        :param validate: if True, use initial peers to validate that this node is accessible and synchronized
+        :param ensure_bootstrap_success: raise an error if node could not connect to initial peers (or vice versa)
+           If False, print a warning instead. It is recommended to keep this flag unless you know what you're doing.
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
         :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citizen"
           if True, this node will refuse any incoming requests, effectively being only a client
@@ -220,7 +221,7 @@ class DHTNode:
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             start_time = get_dht_time()
             ping_tasks = set(
-                asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
+                asyncio.create_task(self.protocol.call_ping(peer, validate=ensure_bootstrap_success, strict=strict))
                 for peer in initial_peers
             )
             finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
@@ -235,7 +236,11 @@ class DHTNode:
                 finished_pings |= finished_in_time
 
             if not finished_pings or all(ping.result() is None for ping in finished_pings):
-                logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
+                message = "DHTNode bootstrap failed: none of the initial_peers responded to a ping."
+                if ensure_bootstrap_success:
+                    raise RuntimeError(f"{message} (set ensure_bootstrap_success=False to ignore)")
+                else:
+                    logger.warning(message)
 
             if strict:
                 for task in asyncio.as_completed(finished_pings):

+ 1 - 1
hivemind/dht/protocol.py

@@ -81,7 +81,7 @@ class DHTProtocol(ServicerBase):
 
     def __init__(self, *, _initialized_with_create=False):
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
-        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
+        assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
         super().__init__()
 
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:

+ 108 - 48
hivemind/p2p/p2p_daemon.py

@@ -5,12 +5,14 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
+from google.protobuf.message import Message
 from multiaddr import Multiaddr
 
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import aiter, asingle
@@ -27,7 +29,6 @@ class P2PContext(object):
     handle_name: str
     local_id: PeerID
     remote_id: PeerID = None
-    remote_maddr: Multiaddr = None
 
 
 class P2P:
@@ -65,6 +66,7 @@ class P2P:
 
     def __init__(self):
         self.peer_id = None
+        self._client = None
         self._child = None
         self._alive = False
         self._reader_task = None
@@ -74,43 +76,50 @@ class P2P:
     async def create(
         cls,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        use_ipfs: bool = False,
-        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        *,
         announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        quic: bool = True,
-        tls: bool = True,
+        auto_nat: bool = True,
         conn_manager: bool = True,
         dht_mode: str = "dht_server",
         force_reachability: Optional[str] = None,
+        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        identity_path: Optional[str] = None,
+        idle_timeout: float = 30,
         nat_port_map: bool = True,
-        auto_nat: bool = True,
+        quic: bool = False,
+        relay_hop_limit: int = 0,
+        startup_timeout: float = 15,
+        tls: bool = True,
+        use_auto_relay: bool = False,
+        use_ipfs: bool = False,
         use_relay: bool = True,
         use_relay_hop: bool = False,
         use_relay_discovery: bool = False,
-        use_auto_relay: bool = False,
-        relay_hop_limit: int = 0,
-        startup_timeout: float = 15,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
         :param initial_peers: List of bootstrap peers
-        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
-        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param auto_nat: Enables the AutoNAT service
         :param announce_maddrs: Visible multiaddrs that the peer will announce
-          for external connections from other p2p instances
-        :param quic: Enables the QUIC transport
-        :param tls: Enables TLS1.3 channel security protocol
+                                for external connections from other p2p instances
         :param conn_manager: Enables the Connection Manager
         :param dht_mode: DHT mode (dht_client/dht_server/dht)
         :param force_reachability: Force reachability mode (public/private)
+        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
+                              May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``.
+        :param idle_timeout: kill daemon if client has been idle for a given number of
+                             seconds before opening persistent streams
         :param nat_port_map: Enables NAT port mapping
-        :param auto_nat: Enables the AutoNAT service
+        :param quic: Enables the QUIC transport
+        :param relay_hop_limit: sets the hop limit for hop relays
+        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
+        :param tls: Enables TLS1.3 channel security protocol
+        :param use_auto_relay: enables autorelay
+        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
         :param use_relay: enables circuit relay
         :param use_relay_hop: enables hop for relay
         :param use_relay_discovery: enables passive discovery for relay
-        :param use_auto_relay: enables autorelay
-        :param relay_hop_limit: sets the hop limit for hop relays
-        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :return: a wrapper for the p2p daemon
         """
 
@@ -136,21 +145,24 @@ class P2P:
         ]:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
+        if identity_path is not None:
+            process_kwargs["id"] = identity_path
 
         proc_args = self._make_process_args(
             str(p2pd_path),
-            listen=self._daemon_listen_maddr,
-            quic=quic,
-            tls=tls,
+            autoRelay=use_auto_relay,
+            autonat=auto_nat,
+            b=need_bootstrap,
             connManager=conn_manager,
+            idleTimeout=f"{idle_timeout}s",
+            listen=self._daemon_listen_maddr,
             natPortMap=nat_port_map,
-            autonat=auto_nat,
+            quic=quic,
             relay=use_relay,
-            relayHop=use_relay_hop,
             relayDiscovery=use_relay_discovery,
-            autoRelay=use_auto_relay,
+            relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
-            b=need_bootstrap,
+            tls=tls,
             **process_kwargs,
         )
 
@@ -167,7 +179,7 @@ class P2P:
             await self.shutdown()
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
         await self._ping_daemon()
         return self
 
@@ -189,7 +201,7 @@ class P2P:
         self._daemon_listen_maddr = daemon_listen_maddr
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
 
         await self._ping_daemon()
         return self
@@ -258,7 +270,7 @@ class P2P:
 
     @staticmethod
     async def receive_protobuf(
-        input_protobuf_type: type, reader: asyncio.StreamReader
+        input_protobuf_type: Type[Message], reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
@@ -279,7 +291,7 @@ class P2P:
         self,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
     ) -> None:
         """
@@ -297,7 +309,6 @@ class P2P:
                 handle_name=name,
                 local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
-                remote_maddr=stream_info.addr,
             )
             requests = asyncio.Queue(max_prefetch)
 
@@ -311,11 +322,17 @@ class P2P:
             async def _process_stream() -> None:
                 try:
                     async for response in handler(_read_stream(), context):
-                        await P2P.send_protobuf(response, writer)
+                        try:
+                            await P2P.send_protobuf(response, writer)
+                        except Exception:
+                            # The connection is unexpectedly closed by the caller or broken.
+                            # The loglevel is DEBUG since the actual error will be reported on the caller
+                            logger.debug("Exception while sending response:", exc_info=True)
+                            break
                 except Exception as e:
-                    logger.warning("Exception while processing stream and sending responses:", exc_info=True)
-                    # Sometimes `e` is a connection error, so we won't be able to report the error to the caller
+                    logger.warning("Handler failed with the exception:", exc_info=True)
                     with suppress(Exception):
+                        # Sometimes `e` is a connection error, so it is okay if we fail to report `e` to the caller
                         await P2P.send_protobuf(RPCError(message=str(e)), writer)
 
             with closing(writer):
@@ -343,7 +360,7 @@ class P2P:
         await self.add_binary_stream_handler(name, _handle_stream)
 
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
     ) -> TOutputStream:
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
@@ -375,15 +392,22 @@ class P2P:
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         *,
         stream_input: bool = False,
+        stream_output: bool = False,
     ) -> None:
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
                              (not just ``TInputProtobuf``) as input.
+        :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
+                              (not ``Awaitable[TOutputProtobuf]``).
         """
 
+        if not stream_input and not stream_output:
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            return
+
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
             input = requests if stream_input else await asingle(requests)
             output = handler(input, context)
@@ -396,23 +420,65 @@ class P2P:
 
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
 
+    async def _add_protobuf_unary_handler(
+        self,
+        handle_name: str,
+        handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
+        input_protobuf_type: Type[Message],
+    ) -> None:
+        """
+        Register a request-response (unary) handler. Unary requests and responses
+        are sent through persistent multiplexed connections to the daemon for the
+        sake of reducing the number of open files.
+        :param handle_name: name of the handler (protocol id)
+        :param handler: function handling the unary requests
+        :param input_protobuf_type: protobuf type of the request
+        """
+
+        async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
+            input_serialized = input_protobuf_type.FromString(request)
+            context = P2PContext(
+                handle_name=handle_name,
+                local_id=self.peer_id,
+                remote_id=remote_id,
+            )
+
+            response = await handler(input_serialized, context)
+            return response.SerializeToString()
+
+        await self._client.add_unary_handler(handle_name, _unary_handler)
+
     async def call_protobuf_handler(
         self,
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+
+        if not isinstance(input, AsyncIterableABC):
+            return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
+
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
 
+    async def _call_unary_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        handle_name: str,
+        input: TInputProtobuf,
+        output_protobuf_type: Type[Message],
+    ) -> Awaitable[TOutputProtobuf]:
+        serialized_input = input.SerializeToString()
+        response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
+        return output_protobuf_type.FromString(response)
+
     def iterate_protobuf_handler(
         self,
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
@@ -447,6 +513,8 @@ class P2P:
             await self._child.wait()
 
     def _terminate(self) -> None:
+        if self._client is not None:
+            self._client.close()
         if self._listen_task is not None:
             self._listen_task.cancel()
         if self._reader_task is not None:
@@ -495,11 +563,3 @@ class P2P:
 
         if not ready.done():
             ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
-
-
-class P2PDaemonError(RuntimeError):
-    pass
-
-
-class P2PHandlerError(Exception):
-    pass

+ 189 - 3
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -5,8 +5,9 @@ Author: Kevin Mai-Husan Chia
 """
 
 import asyncio
-from contextlib import asynccontextmanager
-from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
+from contextlib import asynccontextmanager, closing
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
+from uuid import UUID, uuid4
 
 from multiaddr import Multiaddr, protocols
 
@@ -54,17 +55,75 @@ class DaemonConnector:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
+    async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        """
+        Open connection to daemon and upgrade it to a persistent one
+        """
+        reader, writer = await self.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
+        await write_pbmsg(writer, req)
+
+        response = p2pd_pb.Response()
+        await read_pbmsg_safe(reader, response)
+
+        if response.type == "ERROR":
+            raise P2PDaemonError(response.error.msg)
+
+        return reader, writer
+
+
+TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
+CallID = UUID
+
 
 class ControlClient:
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
 
     def __init__(
-        self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+        self,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        *,
+        _initialized_with_create=False,
     ) -> None:
+        assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
+
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
 
+        self.unary_handlers: Dict[str, TUnaryHandler] = {}
+
+        self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
+        self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self._handler_tasks: Dict[CallID, asyncio.Task] = {}
+
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
+
+    @classmethod
+    async def create(
+        cls,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        use_persistent_conn: bool = True,
+    ) -> "ControlClient":
+        control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
+
+        if use_persistent_conn:
+            await control._ensure_persistent_conn()
+
+        return control
+
+    def close(self) -> None:
+        if self._read_task is not None:
+            self._read_task.cancel()
+        if self._write_task is not None:
+            self._write_task.cancel()
+
+    def __del__(self):
+        self.close()
+
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
@@ -93,6 +152,121 @@ class ControlClient:
         async with server:
             yield self
 
+    async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
+        while True:
+            resp = p2pd_pb.PersistentConnectionResponse()
+            try:
+                await read_pbmsg_safe(reader, resp)
+            except asyncio.IncompleteReadError:
+                break
+
+            call_id = UUID(bytes=resp.callId)
+
+            if resp.HasField("callUnaryResponse"):
+                if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
+                    self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
+                elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
+                    remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
+                    self._pending_calls[call_id].set_exception(remote_exc)
+                else:
+                    logger.debug(f"Received unexpected unary call: {resp}")
+
+            elif resp.HasField("requestHandling"):
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self._handler_tasks[call_id] = handler_task
+
+            elif call_id in self._handler_tasks and resp.HasField("cancel"):
+                self._handler_tasks[call_id].cancel()
+
+            elif call_id in self._pending_calls and resp.HasField("daemonError"):
+                daemon_exc = P2PDaemonError(resp.daemonError.message)
+                self._pending_calls[call_id].set_exception(daemon_exc)
+
+            elif call_id in self._pending_calls:
+                self._pending_calls[call_id].set_result(None)
+
+            else:
+                logger.debug(f"Received unexpected response from daemon: {resp}")
+
+    async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        with closing(writer):
+            while True:
+                msg = await self._pending_messages.get()
+                await write_pbmsg(writer, msg)
+
+    async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest):
+        if request.proto not in self.unary_handlers:
+            logger.warning(f"Protocol {request.proto} not supported")
+            return
+
+        try:
+            remote_id = PeerID(request.peer)
+            response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
+            response = p2pd_pb.CallUnaryResponse(response=response_payload)
+
+        except Exception as e:
+            response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
+
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                unaryResponse=response,
+            )
+        )
+        self._handler_tasks.pop(call_id)
+
+    async def _cancel_unary_call(self, call_id: UUID):
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                cancel=p2pd_pb.Cancel(),
+            ),
+        )
+
+    async def _ensure_persistent_conn(self):
+        reader, writer = await self.daemon_connector.open_persistent_connection()
+
+        self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+        self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        call_id = uuid4()
+
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+
+        if self.unary_handlers.get(proto):
+            raise P2PDaemonError(f"Handler for protocol {proto} already registered")
+        self.unary_handlers[proto] = handler
+
+        self._pending_calls[call_id] = asyncio.Future()
+        await self._pending_messages.put(req)
+        await self._pending_calls[call_id]
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        call_id = uuid4()
+        call_unary_req = p2pd_pb.CallUnaryRequest(
+            peer=peer_id.to_bytes(),
+            proto=proto,
+            data=data,
+        )
+        req = p2pd_pb.PersistentConnectionRequest(
+            callId=call_id.bytes,
+            callUnary=call_unary_req,
+        )
+
+        try:
+            self._pending_calls[call_id] = asyncio.Future()
+            await self._pending_messages.put(req)
+            return await self._pending_calls[call_id]
+
+        except asyncio.CancelledError:
+            await self._cancel_unary_call(call_id)
+            raise
+
+        finally:
+            self._pending_calls.pop(call_id, None)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
@@ -179,3 +353,15 @@ class ControlClient:
 
         # if success, add the handler to the dict
         self.handlers[proto] = handler_cb
+
+
+class P2PHandlerError(Exception):
+    """
+    Raised if remote handled a request with an exception
+    """
+
+
+class P2PDaemonError(Exception):
+    """
+    Raised if daemon failed to handle request
+    """

+ 9 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -74,6 +74,12 @@ class PeerID:
         else:
             return False
 
+    def __lt__(self, other: object) -> bool:
+        if not isinstance(other, PeerID):
+            raise TypeError(f"'<' not supported between instances of 'PeerID' and '{type(other)}'")
+
+        return self.to_base58() < other.to_base58()
+
     def __hash__(self) -> int:
         return hash(self._bytes)
 
@@ -125,6 +131,9 @@ class PeerInfo:
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
+    def __repr__(self):
+        return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})"
+
 
 class InvalidAddrError(ValueError):
     pass

+ 24 - 3
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,16 +10,31 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 from multiaddr import Multiaddr
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler, TUnaryHandler
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
 class Client:
     control: ControlClient
 
-    def __init__(self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> None:
+    def __init__(self, *, _initialized_with_create=False) -> None:
+        assert _initialized_with_create, "Please use Client.create coroutine to spawn new client instances"
+        self.control = None
+
+    @classmethod
+    async def create(cls, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> "Client":
+        client = cls(_initialized_with_create=True)
+
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        self.control = ControlClient(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+
+        return client
+
+    def close(self) -> None:
+        self.control.close()
+
+    def __del__(self):
+        self.close()
 
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["Client"]:
@@ -30,6 +45,12 @@ class Client:
         async with self.control.listen():
             yield self
 
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        await self.control.add_unary_handler(proto, handler)
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self.control.call_unary_handler(peer_id, proto, data)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         """
         Get current node peer id and list of addresses

+ 12 - 8
hivemind/p2p/servicer.py

@@ -125,14 +125,18 @@ class ServicerBase:
         self._collect_rpc_handlers()
 
         servicer = self if wrapper is None else wrapper
-        for handler in self._rpc_handlers:
-            print("handler:", self._get_handle_name(namespace, handler.method_name))
-            await p2p.add_protobuf_handler(
-                self._get_handle_name(namespace, handler.method_name),
-                getattr(servicer, handler.method_name),
-                handler.request_type,
-                stream_input=handler.stream_input,
-            )
+        await asyncio.gather(
+            *[
+                p2p.add_protobuf_handler(
+                    self._get_handle_name(namespace, handler.method_name),
+                    getattr(servicer, handler.method_name),
+                    handler.request_type,
+                    stream_input=handler.stream_input,
+                    stream_output=handler.stream_output,
+                )
+                for handler in self._rpc_handlers
+            ]
+        )
 
     @classmethod
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:

+ 59 - 10
hivemind/proto/p2pd.proto

@@ -8,15 +8,17 @@ package p2pclient.p2pd.pb;
 
 message Request {
   enum Type {
-    IDENTIFY       = 0;
-    CONNECT        = 1;
-    STREAM_OPEN    = 2;
-    STREAM_HANDLER = 3;
-    DHT            = 4;
-    LIST_PEERS     = 5;
-    CONNMANAGER    = 6;
-    DISCONNECT     = 7;
-    PUBSUB         = 8;
+    IDENTIFY                 = 0;
+    CONNECT                  = 1;
+    STREAM_OPEN              = 2;
+    STREAM_HANDLER           = 3;
+    DHT                      = 4;
+    LIST_PEERS               = 5;
+    CONNMANAGER              = 6;
+    DISCONNECT               = 7;      
+    PUBSUB                   = 8;
+
+    PERSISTENT_CONN_UPGRADE  = 9;
   }
 
   required Type type = 1;
@@ -45,6 +47,29 @@ message Response {
   optional PSResponse pubsub = 7;
 }
 
+message PersistentConnectionRequest {
+  required bytes callId = 1;
+
+  oneof message {
+    AddUnaryHandlerRequest addUnaryHandler = 2;
+    CallUnaryRequest  callUnary = 3;
+    CallUnaryResponse unaryResponse = 4;
+    Cancel cancel = 5;
+  }
+}
+
+message PersistentConnectionResponse {
+  required bytes callId = 1;
+
+  oneof message {
+    CallUnaryResponse callUnaryResponse = 2;
+    CallUnaryRequest requestHandling = 3;
+    DaemonError daemonError = 4;
+    Cancel cancel = 5;
+  }
+}
+
+
 message IdentifyResponse {
   required bytes id = 1;
   repeated bytes addrs = 2;
@@ -148,7 +173,7 @@ message PSRequest {
 }
 
 message PSMessage {
-  optional bytes from_id = 1;
+  optional bytes from = 1;
   optional bytes data = 2;
   optional bytes seqno = 3;
   repeated string topicIDs = 4;
@@ -161,6 +186,30 @@ message PSResponse {
   repeated bytes peerIDs = 2;
 }
 
+message CallUnaryRequest {
+  required bytes peer = 1;
+  required string proto = 2;
+  required bytes data = 3;
+}
+
+message CallUnaryResponse {
+  oneof result {
+    bytes response = 1;
+    bytes error = 2;
+  }
+}
+
+message AddUnaryHandlerRequest {
+  required string proto = 1;
+}
+
+message DaemonError {
+  optional string message = 1;
+}
+
+message Cancel {
+}
+
 message RPCError {
   optional string message = 1;
 }

+ 24 - 2
hivemind/utils/asyncio.py

@@ -1,4 +1,5 @@
 import asyncio
+import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 
@@ -59,7 +60,7 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]
 
 
 async def asingle(aiter: AsyncIterable[T]) -> T:
-    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`."""
+    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises ``ValueError``."""
     count = 0
     async for item in aiter:
         count += 1
@@ -70,16 +71,37 @@ async def asingle(aiter: AsyncIterable[T]) -> T:
     return item
 
 
+async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
+    """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
+    async for item in aiter:
+        return item
+    return default
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable
         return False
-    except asyncio.CancelledError:
+    except (asyncio.CancelledError, concurrent.futures.CancelledError):
+        # In Python 3.7, awaiting a cancelled asyncio.Future raises concurrent.futures.CancelledError
+        # instead of asyncio.CancelledError
         return True
     except BaseException:
+        logger.exception(f"Exception in {awaitable}:")
         return False
 
 
+async def cancel_and_wait(awaitable: Awaitable) -> bool:
+    """
+    Cancels ``awaitable`` and waits for its cancellation.
+    In case of ``asyncio.Task``, helps to avoid ``Task was destroyed but it is pending!`` errors.
+    In case of ``asyncio.Future``, equal to ``future.cancel()``.
+    """
+
+    awaitable.cancel()
+    return await await_cancelled(awaitable)
+
+
 async def amap_in_executor(
     func: Callable[..., T],
     *iterables: AsyncIterable,

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -53,7 +53,7 @@ class SharedBytes:
         """Create another shared byte value, represented as a scalar uint8 tensor"""
         with cls._lock:
             if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
-                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16))
                 cls._pid = os.getpid()
                 cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
                 cls._index = 0

+ 6 - 1
hivemind/utils/networking.py

@@ -31,7 +31,12 @@ def strip_port(endpoint: Endpoint) -> Hostname:
 
 
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
+    """
+    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
+
+    :note: Using this function is discouraged since it often leads to a race condition
+           with the "Address is already in use" error if the code is run in parallel.
+    """
     try:
         with closing(socket.socket(*params)) as sock:
             sock.bind(("", 0))

+ 2 - 2
setup.py

@@ -14,8 +14,8 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.1"
-P2PD_CHECKSUM = "15292b880c6b31f5b3c36084b3acc17f"
+P2PD_VERSION = "v0.3.4"
+P2PD_CHECKSUM = "194dca06116fdd36bc4b681d18f3b9cb"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 
 here = os.path.abspath(os.path.dirname(__file__))

+ 5 - 2
tests/test_averaging.py

@@ -10,6 +10,7 @@ import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.partition import AllreduceException
 from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 
@@ -363,9 +364,11 @@ def test_too_few_peers():
         )
         for i, dht in enumerate(dht_instances)
     ]
-    step_futures = [averager.step(wait=False) for averager in averagers]
+    step_futures = [averager.step(wait=False, timeout=2) for averager in averagers]
+
     for future in step_futures:
-        assert len(future.result()) == 2
+        with pytest.raises(AllreduceException):
+            future.result()
 
     for process in averagers + dht_instances:
         process.shutdown()

+ 3 - 3
tests/test_dht.py

@@ -14,7 +14,7 @@ from test_utils.dht_swarms import launch_dht_instances
 
 @pytest.mark.asyncio
 async def test_startup_error():
-    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         hivemind.DHT(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
             start=True,
@@ -22,7 +22,7 @@ async def test_startup_error():
 
     dht = hivemind.DHT(start=True, await_ready=False)
     with pytest.raises(concurrent.futures.TimeoutError):
-        dht.wait_until_ready(timeout=0.1)
+        dht.wait_until_ready(timeout=0.01)
     dht.shutdown()
 
 
@@ -118,7 +118,7 @@ async def test_dht_get_visible_maddrs():
 
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
-    dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
+    dht = hivemind.DHT(start=True, p2p=p2p)
 
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()

+ 34 - 214
tests/test_dht_node.py

@@ -1,200 +1,27 @@
 import asyncio
 import heapq
-import multiprocessing as mp
 import random
-import signal
 from itertools import product
-from typing import List, Sequence, Tuple
 
 import numpy as np
 import pytest
-from multiaddr import Multiaddr
 
 import hivemind
 from hivemind import get_dht_time
 from hivemind.dht.node import DHTID, DHTNode
-from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.storage import DictionaryDHTValue
-from hivemind.p2p import P2P, PeerID
 from hivemind.utils.logging import get_logger
 
 from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
 
 logger = get_logger(__name__)
 
-
-def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
-    return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
-
-
-def run_protocol_listener(
-    dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
-) -> None:
-    loop = asyncio.get_event_loop()
-
-    p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
-    visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
-
-    protocol = loop.run_until_complete(
-        DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
-    )
-
-    logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
-
-    for peer_id in maddrs_to_peer_ids(initial_peers):
-        loop.run_until_complete(protocol.call_ping(peer_id))
-
-    maddr_conn.send((p2p.peer_id, visible_maddrs))
-
-    async def shutdown():
-        await p2p.shutdown()
-        logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
-        loop.stop()
-
-    loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
-    loop.run_forever()
-
-
-def launch_protocol_listener(
-    initial_peers: Sequence[Multiaddr] = (),
-) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
-    remote_conn, local_conn = mp.Pipe()
-    dht_id = DHTID.generate()
-    process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
-    process.start()
-    peer_id, visible_maddrs = local_conn.recv()
-
-    return dht_id, process, peer_id, visible_maddrs
-
-
 # note: we run network-related tests in a separate process to re-initialize all global states from scratch
 # this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
 
 
 @pytest.mark.forked
-def test_dht_protocol():
-    peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
-    peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
-
-    loop = asyncio.get_event_loop()
-    for client_mode in [True, False]:  # note: order matters, this test assumes that first run uses client mode
-        peer_id = DHTID.generate()
-        p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
-        protocol = loop.run_until_complete(
-            DHTProtocol.create(
-                p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
-            )
-        )
-        logger.info(f"Self id={protocol.node_id}")
-
-        assert loop.run_until_complete(protocol.call_ping(peer1_id)) == peer1_node_id
-
-        key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
-        store_ok = loop.run_until_complete(
-            protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )
-        assert all(store_ok), "DHT rejected a trivial store"
-
-        # peer 1 must know about peer 2
-        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [key])
-        )[key]
-        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-        (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
-        assert (
-            recv_id == peer2_node_id and recv_peer_id == peer2_id
-        ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
-
-        assert recv_value == value and recv_expiration == expiration, (
-            f"call_find_value expected {value} (expires by {expiration}) "
-            f"but got {recv_value} (expires by {recv_expiration})"
-        )
-
-        # peer 2 must know about peer 1, but not have a *random* nonexistent value
-        dummy_key = DHTID.generate()
-        empty_item, nodes_found_2 = loop.run_until_complete(protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
-        assert empty_item is None, "Non-existent keys shouldn't have values"
-        (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
-        assert (
-            recv_id == peer1_node_id and recv_peer_id == peer1_id
-        ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
-
-        # cause a non-response by querying a nonexistent peer
-        assert loop.run_until_complete(protocol.call_find(PeerID.from_base58("fakeid"), [key])) is None
-
-        # store/get a dictionary with sub-keys
-        nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
-        value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
-        assert loop.run_until_complete(
-            protocol.call_store(
-                peer1_id,
-                keys=[nested_key],
-                values=[hivemind.MSGPackSerializer.dumps(value1)],
-                expiration_time=[expiration],
-                subkeys=[subkey1],
-            )
-        )
-        assert loop.run_until_complete(
-            protocol.call_store(
-                peer1_id,
-                keys=[nested_key],
-                values=[hivemind.MSGPackSerializer.dumps(value2)],
-                expiration_time=[expiration + 5],
-                subkeys=[subkey2],
-            )
-        )
-        (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [nested_key])
-        )[nested_key]
-        assert isinstance(recv_dict, DictionaryDHTValue)
-        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
-        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
-        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
-
-        if not client_mode:
-            loop.run_until_complete(p2p.shutdown())
-
-    peer1_proc.terminate()
-    peer2_proc.terminate()
-
-
-@pytest.mark.forked
-def test_empty_table():
-    """Test RPC methods with empty routing table"""
-    peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
-
-    loop = asyncio.get_event_loop()
-    p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
-    protocol = loop.run_until_complete(
-        DHTProtocol.create(
-            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
-        )
-    )
-
-    key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
-
-    empty_item, nodes_found = loop.run_until_complete(protocol.call_find(peer_peer_id, [key]))[key]
-    assert empty_item is None and len(nodes_found) == 0
-    assert all(
-        loop.run_until_complete(
-            protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )
-    ), "peer rejected store"
-
-    (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-        protocol.call_find(peer_peer_id, [key])
-    )[key]
-    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-    assert len(nodes_found) == 0
-    assert recv_value == value and recv_expiration == expiration
-
-    assert loop.run_until_complete(protocol.call_ping(peer_peer_id)) == peer_id
-    assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58("fakeid"))) is None
-    peer_proc.terminate()
-
-
-@pytest.mark.forked
-def test_dht_node(
+@pytest.mark.asyncio
+async def test_dht_node(
     n_peers: int = 20, n_sequential_peers: int = 5, parallel_rpc: int = 10, bucket_size: int = 5, num_replicas: int = 3
 ):
     # step A: create a swarm of 50 dht nodes in separate processes
@@ -205,26 +32,23 @@ def test_dht_node(
     )
 
     # step B: run 51-st node in this process
-    loop = asyncio.get_event_loop()
     initial_peers = random.choice(swarm_maddrs)
-    me = loop.run_until_complete(
-        DHTNode.create(
-            initial_peers=initial_peers,
-            parallel_rpc=parallel_rpc,
-            bucket_size=bucket_size,
-            num_replicas=num_replicas,
-            cache_refresh_before_expiry=False,
-        )
+    me = await DHTNode.create(
+        initial_peers=initial_peers,
+        parallel_rpc=parallel_rpc,
+        bucket_size=bucket_size,
+        num_replicas=num_replicas,
+        cache_refresh_before_expiry=False,
     )
 
     # test 1: find self
-    nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
+    nearest = (await me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
     assert len(nearest) == 1 and nearest[me.node_id] == me.peer_id
 
     # test 2: find others
     for _ in range(10):
         ref_peer_id, query_id = random.choice(list(dht.items()))
-        nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
+        nearest = (await me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
         assert len(nearest) == 1
         found_node_id, found_peer_id = next(iter(nearest.items()))
         assert found_node_id == query_id and found_peer_id == ref_peer_id
@@ -238,10 +62,8 @@ def test_dht_node(
         query_id = DHTID.generate()
         k_nearest = random.randint(1, 10)
         exclude_self = random.random() > 0.5
-        nearest = loop.run_until_complete(
-            me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
-        )[query_id]
-        nearest_nodes = list(nearest)  # keys from ordered dict
+        find_result = await me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
+        nearest_nodes = list(find_result[query_id])  # keys from ordered dict
 
         assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
         assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
@@ -268,65 +90,63 @@ def test_dht_node(
 
     # test 4: find all nodes
     dummy = DHTID.generate()
-    nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
+    nearest = (await me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
     assert len(nearest) == len(dht) + 1
     assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
 
     # test 5: node without peers
-    detached_node = loop.run_until_complete(DHTNode.create())
-    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
+    detached_node = await DHTNode.create()
+    nearest = (await detached_node.find_nearest_nodes([dummy]))[dummy]
     assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.peer_id
-    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
+    nearest = (await detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
     assert len(nearest) == 0
 
     # test 6: store and get value
     true_time = get_dht_time() + 1200
-    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
+    assert await me.store("mykey", ["Value", 10], true_time)
 
     initial_peers = random.choice(swarm_maddrs)
-    that_guy = loop.run_until_complete(
-        DHTNode.create(
-            initial_peers=initial_peers,
-            parallel_rpc=parallel_rpc,
-            cache_refresh_before_expiry=False,
-            cache_locally=False,
-        )
+    that_guy = await DHTNode.create(
+        initial_peers=initial_peers,
+        parallel_rpc=parallel_rpc,
+        cache_refresh_before_expiry=False,
+        cache_locally=False,
     )
 
     for node in [me, that_guy]:
-        val, expiration_time = loop.run_until_complete(node.get("mykey"))
+        val, expiration_time = await node.get("mykey")
         assert val == ["Value", 10], "Wrong value"
         assert expiration_time == true_time, f"Wrong time"
 
-    assert loop.run_until_complete(detached_node.get("mykey")) is None
+    assert not await detached_node.get("mykey")
 
     # test 7: bulk store and bulk get
     keys = "foo", "bar", "baz", "zzz"
     values = 3, 2, "batman", [1, 2, 3]
-    store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
+    store_ok = await me.store_many(keys, values, expiration_time=get_dht_time() + 999)
     assert all(store_ok.values()), "failed to store one or more keys"
-    response = loop.run_until_complete(me.get_many(keys[::-1]))
+    response = await me.get_many(keys[::-1])
     for key, value in zip(keys, values):
         assert key in response and response[key][0] == value
 
     # test 8: store dictionaries as values (with sub-keys)
     upper_key, subkey1, subkey2, subkey3 = "ololo", "k1", "k2", "k3"
     now = get_dht_time()
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
+    assert await me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10)
+    assert await me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20)
     for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key))
+        value, time = await node.get(upper_key)
         assert isinstance(value, dict) and time == now + 20
         assert value[subkey1] == (123, now + 10)
         assert value[subkey2] == (456, now + 20)
         assert len(value) == 2
 
-    assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
+    assert not await me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10)
+    assert await me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30)
+    assert await me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50)
 
     for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key, latest=True))
+        value, time = await node.get(upper_key, latest=True)
         assert isinstance(value, dict) and time == now + 50, (value, time)
         assert value[subkey1] == (123, now + 10)
         assert value[subkey2] == (567, now + 30)
@@ -336,7 +156,7 @@ def test_dht_node(
     for proc in processes:
         proc.terminate()
     # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
-    loop.run_until_complete(asyncio.wait([node.shutdown() for node in [me, detached_node, that_guy]]))
+    await asyncio.gather(me.shutdown(), that_guy.shutdown(), detached_node.shutdown())
 
 
 @pytest.mark.forked

+ 163 - 0
tests/test_dht_protocol.py

@@ -0,0 +1,163 @@
+import asyncio
+import multiprocessing as mp
+import random
+import signal
+from typing import List, Sequence, Tuple
+
+import pytest
+from multiaddr import Multiaddr
+
+import hivemind
+from hivemind import P2P, PeerID, get_dht_time, get_logger
+from hivemind.dht import DHTID
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.storage import DictionaryDHTValue
+
+logger = get_logger(__name__)
+
+
+def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
+    return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
+
+
+def run_protocol_listener(
+    dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
+) -> None:
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+
+    p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
+    visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
+
+    protocol = loop.run_until_complete(
+        DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
+    )
+
+    logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
+
+    for peer_id in maddrs_to_peer_ids(initial_peers):
+        loop.run_until_complete(protocol.call_ping(peer_id))
+
+    maddr_conn.send((p2p.peer_id, visible_maddrs))
+
+    async def shutdown():
+        await p2p.shutdown()
+        logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
+        loop.stop()
+
+    loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
+    loop.run_forever()
+
+
+def launch_protocol_listener(
+    initial_peers: Sequence[Multiaddr] = (),
+) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
+    remote_conn, local_conn = mp.Pipe()
+    dht_id = DHTID.generate()
+    process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
+    process.start()
+    peer_id, visible_maddrs = local_conn.recv()
+
+    return dht_id, process, peer_id, visible_maddrs
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dht_protocol():
+    peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
+    peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
+
+    for client_mode in [True, False]:  # note: order matters, this test assumes that first run uses client mode
+        peer_id = DHTID.generate()
+        p2p = await P2P.create(initial_peers=peer1_maddrs)
+        protocol = await DHTProtocol.create(
+            p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
+        )
+        logger.info(f"Self id={protocol.node_id}")
+
+        assert peer1_node_id == await protocol.call_ping(peer1_id)
+
+        key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
+        store_ok = await protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        assert all(store_ok), "DHT rejected a trivial store"
+
+        # peer 1 must know about peer 2
+        (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [key]))[key]
+        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+        (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
+        assert (
+            recv_id == peer2_node_id and recv_peer_id == peer2_id
+        ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
+
+        assert recv_value == value and recv_expiration == expiration, (
+            f"call_find_value expected {value} (expires by {expiration}) "
+            f"but got {recv_value} (expires by {recv_expiration})"
+        )
+
+        # peer 2 must know about peer 1, but not have a *random* nonexistent value
+        dummy_key = DHTID.generate()
+        empty_item, nodes_found_2 = (await protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
+        assert empty_item is None, "Non-existent keys shouldn't have values"
+        (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
+        assert (
+            recv_id == peer1_node_id and recv_peer_id == peer1_id
+        ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
+
+        # cause a non-response by querying a nonexistent peer
+        assert not await protocol.call_find(PeerID.from_base58("fakeid"), [key])
+
+        # store/get a dictionary with sub-keys
+        nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
+        value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
+        assert await protocol.call_store(
+            peer1_id,
+            keys=[nested_key],
+            values=[hivemind.MSGPackSerializer.dumps(value1)],
+            expiration_time=[expiration],
+            subkeys=[subkey1],
+        )
+        assert await protocol.call_store(
+            peer1_id,
+            keys=[nested_key],
+            values=[hivemind.MSGPackSerializer.dumps(value2)],
+            expiration_time=[expiration + 5],
+            subkeys=[subkey2],
+        )
+        (recv_dict, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [nested_key]))[nested_key]
+        assert isinstance(recv_dict, DictionaryDHTValue)
+        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
+        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
+        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
+
+        if not client_mode:
+            await p2p.shutdown()
+
+    peer1_proc.terminate()
+    peer2_proc.terminate()
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_empty_table():
+    """Test RPC methods with empty routing table"""
+    peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
+
+    p2p = await P2P.create(initial_peers=peer_maddrs)
+    protocol = await DHTProtocol.create(
+        p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
+    )
+
+    key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
+
+    empty_item, nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
+    assert empty_item is None and len(nodes_found) == 0
+    assert all(await protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration))
+
+    (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
+    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+    assert len(nodes_found) == 0
+    assert recv_value == value and recv_expiration == expiration
+
+    assert peer_id == await protocol.call_ping(peer_peer_id)
+    assert not await protocol.call_ping(PeerID.from_base58("fakeid"))
+    peer_proc.terminate()

+ 30 - 5
tests/test_p2p_daemon.py

@@ -10,7 +10,7 @@ import pytest
 from multiaddr import Multiaddr
 
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
-from hivemind.proto import dht_pb2
+from hivemind.proto import dht_pb2, test_pb2
 from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 
@@ -36,13 +36,13 @@ async def test_daemon_killed_on_del():
 
 @pytest.mark.asyncio
 async def test_startup_error_message():
-    with pytest.raises(P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         await P2P.create(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
         )
 
     with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
-        await P2P.create(startup_timeout=0.1)  # Test that startup_timeout works
+        await P2P.create(startup_timeout=0.01)  # Test that startup_timeout works
 
 
 @pytest.mark.parametrize(
@@ -63,9 +63,9 @@ async def test_transports(host_maddrs: List[Multiaddr]):
     await client.wait_for_at_least_n_peers(1)
 
     peers = await client.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
     peers = await server.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
 
 
 @pytest.mark.asyncio
@@ -83,6 +83,31 @@ async def test_daemon_replica_does_not_affect_primary():
     assert not is_process_running(child_pid)
 
 
+@pytest.mark.asyncio
+async def test_unary_handler_edge_cases():
+    p2p = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p.daemon_listen_maddr)
+
+    async def square_handler(data: test_pb2.TestRequest, context):
+        return test_pb2.TestResponse(number=data.number ** 2)
+
+    await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler
+    with pytest.raises(P2PDaemonError):
+        await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler from replicated p2p
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try dialing yourself
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.call_protobuf_handler(
+            p2p.peer_id, "square", test_pb2.TestRequest(number=41), test_pb2.TestResponse
+        )
+
+
 @pytest.mark.parametrize(
     "should_cancel,replicate",
     [

+ 21 - 13
tests/test_p2p_daemon_bindings.py

@@ -18,7 +18,7 @@ from hivemind.p2p.p2p_daemon_bindings.utils import (
 )
 from hivemind.proto import p2pd_pb2 as p2pd_pb
 
-from test_utils.p2p_daemon import connect_safe, make_p2pd_pair_ip4
+from test_utils.p2p_daemon import connect_safe, make_p2pd_pair_unix
 
 
 def test_raise_if_failed_raises():
@@ -61,7 +61,7 @@ ENABLE_CONTROL = True
 ENABLE_CONNMGR = False
 ENABLE_DHT = False
 ENABLE_PUBSUB = False
-FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_ip4
+FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_unix
 
 
 class MockReader(io.BytesIO):
@@ -144,6 +144,12 @@ def test_peer_id():
     peer_id_3 = PeerID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
     assert PEER_ID != peer_id_3
 
+    a = PeerID.from_base58("bob")
+    b = PeerID.from_base58("eve")
+    assert a < b and b > a and not (b < a) and not (a > b)
+    with pytest.raises(TypeError):
+        assert a < object()
+
 
 def test_stream_info():
     proto = "123"
@@ -193,24 +199,31 @@ def test_parse_conn_protocol_invalid(maddr_str):
 
 
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_client_ctor_control_maddr(control_maddr_str):
+@pytest.mark.asyncio
+async def test_client_create_control_maddr(control_maddr_str):
     c = DaemonConnector(Multiaddr(control_maddr_str))
     assert c.control_maddr == Multiaddr(control_maddr_str)
 
 
-def test_client_ctor_default_control_maddr():
+def test_client_create_default_control_maddr():
     c = DaemonConnector()
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
 
 
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_control_client_ctor_listen_maddr(listen_maddr_str):
-    c = ControlClient(daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str))
+@pytest.mark.asyncio
+async def test_control_client_create_listen_maddr(listen_maddr_str):
+    c = await ControlClient.create(
+        daemon_connector=DaemonConnector(),
+        listen_maddr=Multiaddr(listen_maddr_str),
+        use_persistent_conn=False,
+    )
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
 
 
-def test_control_client_ctor_default_listen_maddr():
-    c = ControlClient(daemon_connector=DaemonConnector())
+@pytest.mark.asyncio
+async def test_control_client_create_default_listen_maddr():
+    c = await ControlClient.create(daemon_connector=DaemonConnector(), use_persistent_conn=False)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
 
 
@@ -376,11 +389,6 @@ async def p2pcs():
         yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
 
 
-@pytest.mark.asyncio
-async def test_client_identify_unix_socket(p2pcs):
-    await p2pcs[0].identify()
-
-
 @pytest.mark.asyncio
 async def test_client_identify(p2pcs):
     await p2pcs[0].identify()

+ 3 - 2
tests/test_p2p_servicer.py

@@ -5,6 +5,7 @@ import pytest
 
 from hivemind.p2p import P2P, P2PContext, ServicerBase
 from hivemind.proto import test_pb2
+from hivemind.utils.asyncio import anext
 
 
 @pytest.fixture
@@ -139,9 +140,9 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
         writer.close()
     elif cancel_reason == "close_generator":
         stub = ExampleServicer.get_stub(client, server.peer_id)
-        iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
+        iter = stub.rpc_wait(test_pb2.TestRequest(number=10))
 
-        assert await iter.__anext__() == test_pb2.TestResponse(number=11)
+        assert await anext(iter) == test_pb2.TestResponse(number=11)
         await asyncio.sleep(0.25)
 
         await iter.aclose()

+ 54 - 1
tests/test_util_modules.py

@@ -13,7 +13,17 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
+from hivemind.utils.asyncio import (
+    achain,
+    aenumerate,
+    afirst,
+    aiter,
+    amap_in_executor,
+    anext,
+    asingle,
+    azip,
+    cancel_and_wait,
+)
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 
@@ -498,3 +508,46 @@ async def test_asyncio_utils():
         await anext(iterator)
 
     assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
+
+    assert await asingle(aiter(1)) == 1
+    with pytest.raises(ValueError):
+        await asingle(aiter())
+    with pytest.raises(ValueError):
+        await asingle(aiter(1, 2, 3))
+
+    assert await afirst(aiter(1)) == 1
+    assert await afirst(aiter()) is None
+    assert await afirst(aiter(), -1) == -1
+    assert await afirst(aiter(1, 2, 3)) == 1
+
+
+@pytest.mark.asyncio
+async def test_cancel_and_wait():
+    finished_gracefully = False
+
+    async def coro_with_finalizer():
+        nonlocal finished_gracefully
+
+        try:
+            await asyncio.Event().wait()
+        except asyncio.CancelledError:
+            await asyncio.sleep(0.05)
+            finished_gracefully = True
+            raise
+
+    task = asyncio.create_task(coro_with_finalizer())
+    await asyncio.sleep(0.05)
+    assert await cancel_and_wait(task)
+    assert finished_gracefully
+
+    async def coro_with_result():
+        return 777
+
+    async def coro_with_error():
+        raise ValueError("error")
+
+    task_with_result = asyncio.create_task(coro_with_result())
+    task_with_error = asyncio.create_task(coro_with_error())
+    await asyncio.sleep(0.05)
+    assert not await cancel_and_wait(task_with_result)
+    assert not await cancel_and_wait(task_with_error)

+ 17 - 20
tests/test_utils/p2p_daemon.py

@@ -4,7 +4,7 @@ import os
 import subprocess
 import time
 import uuid
-from contextlib import asynccontextmanager
+from contextlib import asynccontextmanager, suppress
 from typing import NamedTuple
 
 from multiaddr import Multiaddr, protocols
@@ -57,7 +57,7 @@ class Daemon:
 
     def _run(self):
         cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
-        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{get_free_port()}"]
+        cmd_list += ["-hostAddrs=/ip4/127.0.0.1/tcp/0"]
         if self.enable_connmgr:
             cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
         if self.enable_dht:
@@ -107,24 +107,21 @@ async def make_p2pd_pair_unix(enable_control, enable_connmgr, enable_dht, enable
     name = str(uuid.uuid4())[:8]
     control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
     listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
-    # Remove the existing unix socket files if they are existing
     try:
-        os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
-    except FileNotFoundError:
-        pass
-    try:
-        os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
-    except FileNotFoundError:
-        pass
-    async with _make_p2pd_pair(
-        control_maddr=control_maddr,
-        listen_maddr=listen_maddr,
-        enable_control=enable_control,
-        enable_connmgr=enable_connmgr,
-        enable_dht=enable_dht,
-        enable_pubsub=enable_pubsub,
-    ) as pair:
-        yield pair
+        async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+        ) as pair:
+            yield pair
+    finally:
+        with suppress(FileNotFoundError):
+            os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
+        with suppress(FileNotFoundError):
+            os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
 
 
 @asynccontextmanager
@@ -160,7 +157,7 @@ async def _make_p2pd_pair(
     )
     # wait for daemon ready
     await p2pd.wait_until_ready()
-    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    client = await Client.create(control_maddr=control_maddr, listen_maddr=listen_maddr)
     try:
         async with client.listen():
             yield DaemonTuple(daemon=p2pd, client=client)