Browse Source

Merge branch 'master' of github.com:learning-at-home/hivemind into server-p2p

Denis Mazur 4 years ago
parent
commit
4739d33cfc

+ 0 - 2
.github/workflows/run-tests.yml

@@ -34,7 +34,6 @@ jobs:
         run: |
         run: |
           cd tests
           cd tests
           pytest --durations=0 --durations-min=1.0 -v
           pytest --durations=0 --durations-min=1.0 -v
-
   build_and_test_p2pd:
   build_and_test_p2pd:
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
     timeout-minutes: 10
     timeout-minutes: 10
@@ -61,7 +60,6 @@ jobs:
         run: |
         run: |
           cd tests
           cd tests
           pytest -k "p2p" -v
           pytest -k "p2p" -v
-
   codecov_in_develop_mode:
   codecov_in_develop_mode:
 
 
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest

+ 1 - 0
.gitignore

@@ -54,6 +54,7 @@ coverage.xml
 .project
 .project
 .pydevproject
 .pydevproject
 .idea
 .idea
+.vscode
 .ipynb_checkpoints
 .ipynb_checkpoints
 
 
 # Rope
 # Rope

+ 10 - 4
examples/albert/arguments.py

@@ -13,7 +13,7 @@ class BaseTrainingArguments:
         default_factory=list,
         default_factory=list,
         metadata={
         metadata={
             "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
             "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
-            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
+            "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY"
         },
         },
     )
     )
     use_ipfs: bool = field(
     use_ipfs: bool = field(
@@ -24,17 +24,23 @@ class BaseTrainingArguments:
         },
         },
     )
     )
     host_maddrs: List[str] = field(
     host_maddrs: List[str] = field(
-        default_factory=lambda: ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
+        default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"],
         metadata={
         metadata={
             "help": "Multiaddrs to listen for external connections from other p2p instances. "
             "help": "Multiaddrs to listen for external connections from other p2p instances. "
-            "Defaults to all IPv4 interfaces with TCP and QUIC (over UDP) protocols: "
-            "/ip4/0.0.0.0/tcp/0 /ip4/0.0.0.0/udp/0/quic"
+            "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0"
         },
         },
     )
     )
     announce_maddrs: List[str] = field(
     announce_maddrs: List[str] = field(
         default_factory=list,
         default_factory=list,
         metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
         metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
     )
     )
+    identity_path: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
+            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+        },
+    )
 
 
 
 
 @dataclass
 @dataclass

+ 1 - 0
examples/albert/run_trainer.py

@@ -248,6 +248,7 @@ def main():
         use_ipfs=collaboration_args.use_ipfs,
         use_ipfs=collaboration_args.use_ipfs,
         host_maddrs=collaboration_args.host_maddrs,
         host_maddrs=collaboration_args.host_maddrs,
         announce_maddrs=collaboration_args.announce_maddrs,
         announce_maddrs=collaboration_args.announce_maddrs,
+        identity_path=collaboration_args.identity_path,
     )
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
 
 

+ 2 - 1
examples/albert/run_training_monitor.py

@@ -156,7 +156,7 @@ if __name__ == "__main__":
         address = request.text
         address = request.text
         logger.info(f"Received public IP address of this machine: {address}")
         logger.info(f"Received public IP address of this machine: {address}")
         version = ip_address(address).version
         version = ip_address(address).version
-        monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0", f"/ip{version}/{address}/udp/0/quic"]
+        monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]
 
 
     experiment_prefix = monitor_args.experiment_prefix
     experiment_prefix = monitor_args.experiment_prefix
     validators, local_public_key = utils.make_validators(experiment_prefix)
     validators, local_public_key = utils.make_validators(experiment_prefix)
@@ -168,6 +168,7 @@ if __name__ == "__main__":
         use_ipfs=monitor_args.use_ipfs,
         use_ipfs=monitor_args.use_ipfs,
         host_maddrs=monitor_args.host_maddrs,
         host_maddrs=monitor_args.host_maddrs,
         announce_maddrs=monitor_args.announce_maddrs,
         announce_maddrs=monitor_args.announce_maddrs,
+        identity_path=monitor_args.identity_path,
     )
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
 
 

+ 1 - 1
hivemind/__init__.py

@@ -19,4 +19,4 @@ from hivemind.optim import (
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = "1.0.0.dev0"
+__version__ = "1.0.0dev0"

+ 3 - 3
hivemind/averaging/allreduce.py

@@ -8,7 +8,7 @@ from hivemind.averaging.partition import AllreduceException, TensorPartContainer
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
+from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 
 # flavour types
 # flavour types
@@ -231,8 +231,8 @@ class AllReduceRunner(ServicerBase):
 
 
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        # In case of reporting the error, we expect the response stream to contain exactly one item
-        await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
+        # Coroutines are lazy, so we take the first item to start the couroutine's execution
+        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
 
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 9 - 4
hivemind/averaging/averager.py

@@ -96,7 +96,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         prefix: str,
         prefix: str,
         target_group_size: int,
         target_group_size: int,
         min_group_size: int = 2,
         min_group_size: int = 2,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
         averaging_expiration: float = 15,
         request_timeout: float = 3,
         request_timeout: float = 3,
         averaging_alpha: float = 1.0,
         averaging_alpha: float = 1.0,
@@ -117,7 +117,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         ), "bandwidth must be a non-negative float32"
         ), "bandwidth must be a non-negative float32"
         if not is_power_of_two(target_group_size):
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
             logger.warning("It is recommended to set target_group_size to a power of 2.")
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
+        assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
 
         super().__init__()
         super().__init__()
@@ -241,7 +241,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 self._ready.set_result(None)
                 self._ready.set_result(None)
 
 
                 while True:
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._matchmaking.request_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                     if method == "_shutdown":
                         await task
                         await task
@@ -593,7 +598,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         future.set_result((metadata, tensors))
                         future.set_result((metadata, tensors))
                         self.last_updated = get_dht_time()
                         self.last_updated = get_dht_time()
                         return
                         return
-                    except BaseException as e:
+                    except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
 
 
         finally:
         finally:

+ 12 - 93
hivemind/averaging/key_manager.py

@@ -1,4 +1,3 @@
-import asyncio
 import random
 import random
 import re
 import re
 from typing import List, Optional, Tuple
 from typing import List, Optional, Tuple
@@ -25,31 +24,17 @@ class GroupKeyManager:
     Utility class that declares and fetches averaging-related keys using a DHT
     Utility class that declares and fetches averaging-related keys using a DHT
     """
     """
 
 
-    RESERVED_KEY_FOR_NBITS = "::NBITS"
-
     def __init__(
     def __init__(
         self,
         self,
         dht: DHT,
         dht: DHT,
         prefix: str,
         prefix: str,
-        initial_group_bits: Optional[str],
+        initial_group_bits: str,
         target_group_size: int,
         target_group_size: int,
-        insufficient_size: Optional[int] = None,
-        excessive_size: Optional[int] = None,
-        nbits_expiration: float = 60,
-        nbits_rewrite_grace_period: float = 15,
     ):
     ):
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
-        if initial_group_bits is None:
-            search_result = dht.get(f"{prefix}.0b", latest=True)
-            initial_group_nbits = self.get_suggested_nbits(search_result) or 0
-            initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
+        assert all(bit in "01" for bit in initial_group_bits)
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
-        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
-        self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
-        self.excessive_size = excessive_size or target_group_size * 3
-        self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
-        self.suggested_nbits: Optional[int] = None
+        self.peer_id = dht.peer_id
 
 
     @property
     @property
     def current_key(self) -> GroupKey:
     def current_key(self) -> GroupKey:
@@ -93,51 +78,16 @@ class GroupKeyManager:
         if result is None or not isinstance(result.value, dict):
         if result is None or not isinstance(result.value, dict):
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
             return []
-        averagers = [
-            (PeerID(key), looking_for_group.expiration_time)
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
-        ]
-        num_active_averagers = sum(
-            1
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and looking_for_group.value
-        )
-
-        suggested_nbits = self.get_suggested_nbits(result)
-        if (
-            suggested_nbits is not None
-            and suggested_nbits != len(self.group_bits)
-            and suggested_nbits != self.suggested_nbits
-        ):
-            self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
-        elif num_active_averagers >= self.excessive_size:
-            self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+        averagers = []
+        for key, looking_for_group in result.value.items():
+            try:
+                if only_active and not looking_for_group.value:
+                    continue
+                averagers.append((PeerID(key), looking_for_group.expiration_time))
+            except Exception as e:
+                logger.warning(f"Could not parse group key {key} ({looking_for_group}, exc={e})")
         return averagers
         return averagers
 
 
-    async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
-        """notify other peers that they can run averaging at this depth"""
-        return await self.dht.store(
-            key=group_key,
-            subkey=self.RESERVED_KEY_FOR_NBITS,
-            value=nbits,
-            expiration_time=expiration_time,
-            return_future=True,
-        )
-
-    @classmethod
-    def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
-        if (
-            isinstance(search_result, ValueWithExpiration)
-            and cls.RESERVED_KEY_FOR_NBITS in search_result.value
-            and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int)
-        ):
-            return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
-        else:
-            return None
-
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
         rng = random.Random(group_info.group_id)
@@ -148,37 +98,6 @@ class GroupKeyManager:
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
 
-        if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
-            asyncio.create_task(self.notify_stragglers())
-        if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
-            num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
-            self.group_bits = "".join((random.choice("01") for _ in range(num_extra_bits))) + self.group_bits
-            self.group_bits = self.group_bits[-self.suggested_nbits :]
-        self.suggested_nbits = None
-
     async def update_key_on_not_enough_peers(self):
     async def update_key_on_not_enough_peers(self):
         """this function is triggered whenever averager fails to assemble group within timeout"""
         """this function is triggered whenever averager fails to assemble group within timeout"""
-        new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
-        if self.group_bits != prev_nbits:
-            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
-        self.suggested_nbits = None
-
-    async def notify_stragglers(self):
-        """Find averagers that have fewer nbits and redirect them to your current nbits"""
-        for nbits in reversed(range(1, len(self.group_bits) - 1)):
-            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
-            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
-
-            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
-                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
-                break
-
-        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
-        if (
-            isinstance(root_data, dict)
-            and root_data.get(self.RESERVED_KEY_FOR_NBITS, (None, -float("inf")))[1]
-            > get_dht_time() + self.nbits_grace_period
-        ):
-            return
-        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
+        pass  # to be implemented in subclasses

+ 31 - 43
hivemind/averaging/matchmaking.py

@@ -15,7 +15,7 @@ from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
-from hivemind.utils.asyncio import anext
+from hivemind.utils.asyncio import anext, cancel_and_wait
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -45,7 +45,7 @@ class Matchmaking:
         min_group_size: int,
         min_group_size: int,
         request_timeout: float,
         request_timeout: float,
         client_mode: bool,
         client_mode: bool,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
         averaging_expiration: float = 15,
     ):
     ):
         assert "." not in prefix, "group prefix must be a string without ."
         assert "." not in prefix, "group prefix must be a string without ."
@@ -127,10 +127,9 @@ class Matchmaking:
                 raise
                 raise
 
 
             finally:
             finally:
-                if not request_leaders_task.done():
-                    request_leaders_task.cancel()
-                if not self.assembled_group.done():
-                    self.assembled_group.cancel()
+                await cancel_and_wait(request_leaders_task)
+                self.assembled_group.cancel()
+
                 while len(self.current_followers) > 0:
                 while len(self.current_followers) > 0:
                     await self.follower_was_discarded.wait()
                     await self.follower_was_discarded.wait()
                     self.follower_was_discarded.clear()
                     self.follower_was_discarded.clear()
@@ -189,7 +188,7 @@ class Matchmaking:
                         gather=self.data_for_gather,
                         gather=self.data_for_gather,
                         group_key=self.group_key_manager.current_key,
                         group_key=self.group_key_manager.current_key,
                     )
                     )
-                ).__aiter__()
+                )
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
 
                 if message.code == averaging_pb2.ACCEPTED:
                 if message.code == averaging_pb2.ACCEPTED:
@@ -229,7 +228,7 @@ class Matchmaking:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
             return None
         except (P2PHandlerError, StopAsyncIteration) as e:
         except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+            logger.exception(f"{self} - failed to request potential leader {leader}:")
             return None
             return None
 
 
         finally:
         finally:
@@ -413,10 +412,9 @@ class PotentialLeaders:
             try:
             try:
                 yield self
                 yield self
             finally:
             finally:
-                if not update_queue_task.done():
-                    update_queue_task.cancel()
-                if declare and not declare_averager_task.done():
-                    declare_averager_task.cancel()
+                await cancel_and_wait(update_queue_task)
+                if declare:
+                    await cancel_and_wait(declare_averager_task)
 
 
                 for field in (
                 for field in (
                     self.past_attempts,
                     self.past_attempts,
@@ -477,37 +475,31 @@ class PotentialLeaders:
         else:
         else:
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
 
 
-    async def _update_queue_periodically(self, key_manager: GroupKeyManager):
-        try:
-            DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
-            while get_dht_time() < self.search_end_time:
-                new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
-                self.max_assured_time = max(
-                    self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
-                )
+    async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
+        DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+        while get_dht_time() < self.search_end_time:
+            new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
+            self.max_assured_time = max(
+                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+            )
 
 
-                self.leader_queue.clear()
-                for peer, peer_expiration_time in new_peers:
-                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
-                        continue
-                    self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
-                    self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+            self.leader_queue.clear()
+            for peer, peer_expiration_time in new_peers:
+                if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
+                    continue
+                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
 
-                self.update_finished.set()
+            self.update_finished.set()
 
 
-                await asyncio.wait(
-                    {self.running.wait(), self.update_triggered.wait()},
-                    return_when=asyncio.ALL_COMPLETED,
-                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
-                )
-                self.update_triggered.clear()
-        except (concurrent.futures.CancelledError, asyncio.CancelledError):
-            return  # note: this is a compatibility layer for python3.7
-        except Exception as e:
-            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
-            raise
+            await asyncio.wait(
+                {self.running.wait(), self.update_triggered.wait()},
+                return_when=asyncio.ALL_COMPLETED,
+                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
+            )
+            self.update_triggered.clear()
 
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
+    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
         async with self.lock_declare:
             try:
             try:
                 while True:
                 while True:
@@ -521,10 +513,6 @@ class PotentialLeaders:
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
                         await key_manager.update_key_on_not_enough_peers()
-            except (concurrent.futures.CancelledError, asyncio.CancelledError):
-                pass  # note: this is a compatibility layer for python3.7
-            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
             finally:
                 if self.declared_group_key is not None:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time

+ 19 - 12
hivemind/dht/__init__.py

@@ -27,7 +27,7 @@ from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.p2p import P2P, PeerID
 from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
+from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_logger, switch_to_uvloop
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -61,6 +61,7 @@ class DHT(mp.Process):
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         *,
         *,
         start: bool,
         start: bool,
+        p2p: Optional[P2P] = None,
         daemon: bool = True,
         daemon: bool = True,
         num_workers: int = DEFAULT_NUM_WORKERS,
         num_workers: int = DEFAULT_NUM_WORKERS,
         record_validators: Iterable[RecordValidatorBase] = (),
         record_validators: Iterable[RecordValidatorBase] = (),
@@ -94,6 +95,8 @@ class DHT(mp.Process):
         self._client_mode = None
         self._client_mode = None
         self._p2p_replica = None
         self._p2p_replica = None
 
 
+        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
+
         if start:
         if start:
             self.run_in_background(await_ready=await_ready)
             self.run_in_background(await_ready=await_ready)
 
 
@@ -105,10 +108,16 @@ class DHT(mp.Process):
 
 
             async def _run():
             async def _run():
                 try:
                 try:
+                    if self._daemon_listen_maddr is not None:
+                        replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                    else:
+                        replicated_p2p = None
+
                     self._node = await DHTNode.create(
                     self._node = await DHTNode.create(
                         initial_peers=self.initial_peers,
                         initial_peers=self.initial_peers,
                         num_workers=self.num_workers,
                         num_workers=self.num_workers,
                         record_validator=self._record_validator,
                         record_validator=self._record_validator,
+                        p2p=replicated_p2p,
                         **self.kwargs,
                         **self.kwargs,
                     )
                     )
                 except Exception as e:
                 except Exception as e:
@@ -119,7 +128,12 @@ class DHT(mp.Process):
                 self._ready.set_result(None)
                 self._ready.set_result(None)
 
 
                 while True:
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._node.protocol.wait_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                     if method == "_shutdown":
                         await task
                         await task
@@ -247,18 +261,11 @@ class DHT(mp.Process):
     async def _run_coroutine(
     async def _run_coroutine(
         self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
         self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
     ):
     ):
-        main_task = asyncio.create_task(coro(self, self._node))
-        cancel_task = asyncio.create_task(await_cancelled(future))
         try:
         try:
-            await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
-            if future.cancelled():
-                main_task.cancel()
-            else:
-                future.set_result(await main_task)
+            future.set_result(await coro(self, self._node))
         except BaseException as e:
         except BaseException as e:
-            logger.exception(f"Caught an exception when running a coroutine: {e}")
-            if not future.done():
-                future.set_exception(e)
+            logger.exception("Caught an exception when running a coroutine:")
+            future.set_exception(e)
 
 
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
         if not self._ready.done():
         if not self._ready.done():

+ 9 - 4
hivemind/dht/node.py

@@ -121,7 +121,7 @@ class DHTNode:
         client_mode: bool = False,
         client_mode: bool = False,
         record_validator: Optional[RecordValidatorBase] = None,
         record_validator: Optional[RecordValidatorBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
         authorizer: Optional[AuthorizerBase] = None,
-        validate: bool = True,
+        ensure_bootstrap_success: bool = True,
         strict: bool = True,
         strict: bool = True,
         **kwargs,
         **kwargs,
     ) -> DHTNode:
     ) -> DHTNode:
@@ -156,7 +156,8 @@ class DHTNode:
         :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
         :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
         :param blacklist_time: excludes non-responsive peers from search for this many seconds (set 0 to disable)
         :param blacklist_time: excludes non-responsive peers from search for this many seconds (set 0 to disable)
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
         :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
-        :param validate: if True, use initial peers to validate that this node is accessible and synchronized
+        :param ensure_bootstrap_success: raise an error if node could not connect to initial peers (or vice versa)
+           If False, print a warning instead. It is recommended to keep this flag unless you know what you're doing.
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
         :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
         :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citizen"
         :param client_mode: if False (default), this node will accept incoming requests as a full DHT "citizen"
           if True, this node will refuse any incoming requests, effectively being only a client
           if True, this node will refuse any incoming requests, effectively being only a client
@@ -220,7 +221,7 @@ class DHTNode:
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             start_time = get_dht_time()
             start_time = get_dht_time()
             ping_tasks = set(
             ping_tasks = set(
-                asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
+                asyncio.create_task(self.protocol.call_ping(peer, validate=ensure_bootstrap_success, strict=strict))
                 for peer in initial_peers
                 for peer in initial_peers
             )
             )
             finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
             finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
@@ -235,7 +236,11 @@ class DHTNode:
                 finished_pings |= finished_in_time
                 finished_pings |= finished_in_time
 
 
             if not finished_pings or all(ping.result() is None for ping in finished_pings):
             if not finished_pings or all(ping.result() is None for ping in finished_pings):
-                logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
+                message = "DHTNode bootstrap failed: none of the initial_peers responded to a ping."
+                if ensure_bootstrap_success:
+                    raise RuntimeError(f"{message} (set ensure_bootstrap_success=False to ignore)")
+                else:
+                    logger.warning(message)
 
 
             if strict:
             if strict:
                 for task in asyncio.as_completed(finished_pings):
                 for task in asyncio.as_completed(finished_pings):

+ 1 - 1
hivemind/dht/protocol.py

@@ -81,7 +81,7 @@ class DHTProtocol(ServicerBase):
 
 
     def __init__(self, *, _initialized_with_create=False):
     def __init__(self, *, _initialized_with_create=False):
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
-        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
+        assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
         super().__init__()
         super().__init__()
 
 
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:

+ 108 - 48
hivemind/p2p/p2p_daemon.py

@@ -5,12 +5,14 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
 from importlib.resources import path
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
 
+from google.protobuf.message import Message
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind.hivemind_cli as cli
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import aiter, asingle
 from hivemind.utils.asyncio import aiter, asingle
@@ -27,7 +29,6 @@ class P2PContext(object):
     handle_name: str
     handle_name: str
     local_id: PeerID
     local_id: PeerID
     remote_id: PeerID = None
     remote_id: PeerID = None
-    remote_maddr: Multiaddr = None
 
 
 
 
 class P2P:
 class P2P:
@@ -65,6 +66,7 @@ class P2P:
 
 
     def __init__(self):
     def __init__(self):
         self.peer_id = None
         self.peer_id = None
+        self._client = None
         self._child = None
         self._child = None
         self._alive = False
         self._alive = False
         self._reader_task = None
         self._reader_task = None
@@ -74,43 +76,50 @@ class P2P:
     async def create(
     async def create(
         cls,
         cls,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        use_ipfs: bool = False,
-        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        *,
         announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
         announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        quic: bool = True,
-        tls: bool = True,
+        auto_nat: bool = True,
         conn_manager: bool = True,
         conn_manager: bool = True,
         dht_mode: str = "dht_server",
         dht_mode: str = "dht_server",
         force_reachability: Optional[str] = None,
         force_reachability: Optional[str] = None,
+        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        identity_path: Optional[str] = None,
+        idle_timeout: float = 30,
         nat_port_map: bool = True,
         nat_port_map: bool = True,
-        auto_nat: bool = True,
+        quic: bool = False,
+        relay_hop_limit: int = 0,
+        startup_timeout: float = 15,
+        tls: bool = True,
+        use_auto_relay: bool = False,
+        use_ipfs: bool = False,
         use_relay: bool = True,
         use_relay: bool = True,
         use_relay_hop: bool = False,
         use_relay_hop: bool = False,
         use_relay_discovery: bool = False,
         use_relay_discovery: bool = False,
-        use_auto_relay: bool = False,
-        relay_hop_limit: int = 0,
-        startup_timeout: float = 15,
     ) -> "P2P":
     ) -> "P2P":
         """
         """
         Start a new p2pd process and connect to it.
         Start a new p2pd process and connect to it.
         :param initial_peers: List of bootstrap peers
         :param initial_peers: List of bootstrap peers
-        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
-        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param auto_nat: Enables the AutoNAT service
         :param announce_maddrs: Visible multiaddrs that the peer will announce
         :param announce_maddrs: Visible multiaddrs that the peer will announce
-          for external connections from other p2p instances
-        :param quic: Enables the QUIC transport
-        :param tls: Enables TLS1.3 channel security protocol
+                                for external connections from other p2p instances
         :param conn_manager: Enables the Connection Manager
         :param conn_manager: Enables the Connection Manager
         :param dht_mode: DHT mode (dht_client/dht_server/dht)
         :param dht_mode: DHT mode (dht_client/dht_server/dht)
         :param force_reachability: Force reachability mode (public/private)
         :param force_reachability: Force reachability mode (public/private)
+        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
+                              May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``.
+        :param idle_timeout: kill daemon if client has been idle for a given number of
+                             seconds before opening persistent streams
         :param nat_port_map: Enables NAT port mapping
         :param nat_port_map: Enables NAT port mapping
-        :param auto_nat: Enables the AutoNAT service
+        :param quic: Enables the QUIC transport
+        :param relay_hop_limit: sets the hop limit for hop relays
+        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
+        :param tls: Enables TLS1.3 channel security protocol
+        :param use_auto_relay: enables autorelay
+        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
         :param use_relay: enables circuit relay
         :param use_relay: enables circuit relay
         :param use_relay_hop: enables hop for relay
         :param use_relay_hop: enables hop for relay
         :param use_relay_discovery: enables passive discovery for relay
         :param use_relay_discovery: enables passive discovery for relay
-        :param use_auto_relay: enables autorelay
-        :param relay_hop_limit: sets the hop limit for hop relays
-        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :return: a wrapper for the p2p daemon
         :return: a wrapper for the p2p daemon
         """
         """
 
 
@@ -136,21 +145,24 @@ class P2P:
         ]:
         ]:
             if value:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
                 process_kwargs[param] = self._maddrs_to_str(value)
+        if identity_path is not None:
+            process_kwargs["id"] = identity_path
 
 
         proc_args = self._make_process_args(
         proc_args = self._make_process_args(
             str(p2pd_path),
             str(p2pd_path),
-            listen=self._daemon_listen_maddr,
-            quic=quic,
-            tls=tls,
+            autoRelay=use_auto_relay,
+            autonat=auto_nat,
+            b=need_bootstrap,
             connManager=conn_manager,
             connManager=conn_manager,
+            idleTimeout=f"{idle_timeout}s",
+            listen=self._daemon_listen_maddr,
             natPortMap=nat_port_map,
             natPortMap=nat_port_map,
-            autonat=auto_nat,
+            quic=quic,
             relay=use_relay,
             relay=use_relay,
-            relayHop=use_relay_hop,
             relayDiscovery=use_relay_discovery,
             relayDiscovery=use_relay_discovery,
-            autoRelay=use_auto_relay,
+            relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
             relayHopLimit=relay_hop_limit,
-            b=need_bootstrap,
+            tls=tls,
             **process_kwargs,
             **process_kwargs,
         )
         )
 
 
@@ -167,7 +179,7 @@ class P2P:
             await self.shutdown()
             await self.shutdown()
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
         await self._ping_daemon()
         await self._ping_daemon()
         return self
         return self
 
 
@@ -189,7 +201,7 @@ class P2P:
         self._daemon_listen_maddr = daemon_listen_maddr
         self._daemon_listen_maddr = daemon_listen_maddr
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
 
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
 
 
         await self._ping_daemon()
         await self._ping_daemon()
         return self
         return self
@@ -258,7 +270,7 @@ class P2P:
 
 
     @staticmethod
     @staticmethod
     async def receive_protobuf(
     async def receive_protobuf(
-        input_protobuf_type: type, reader: asyncio.StreamReader
+        input_protobuf_type: Type[Message], reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
         if msg_type == P2P.MESSAGE_MARKER:
@@ -279,7 +291,7 @@ class P2P:
         self,
         self,
         name: str,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
         max_prefetch: int = 5,
     ) -> None:
     ) -> None:
         """
         """
@@ -297,7 +309,6 @@ class P2P:
                 handle_name=name,
                 handle_name=name,
                 local_id=self.peer_id,
                 local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
                 remote_id=stream_info.peer_id,
-                remote_maddr=stream_info.addr,
             )
             )
             requests = asyncio.Queue(max_prefetch)
             requests = asyncio.Queue(max_prefetch)
 
 
@@ -311,11 +322,17 @@ class P2P:
             async def _process_stream() -> None:
             async def _process_stream() -> None:
                 try:
                 try:
                     async for response in handler(_read_stream(), context):
                     async for response in handler(_read_stream(), context):
-                        await P2P.send_protobuf(response, writer)
+                        try:
+                            await P2P.send_protobuf(response, writer)
+                        except Exception:
+                            # The connection is unexpectedly closed by the caller or broken.
+                            # The loglevel is DEBUG since the actual error will be reported on the caller
+                            logger.debug("Exception while sending response:", exc_info=True)
+                            break
                 except Exception as e:
                 except Exception as e:
-                    logger.warning("Exception while processing stream and sending responses:", exc_info=True)
-                    # Sometimes `e` is a connection error, so we won't be able to report the error to the caller
+                    logger.warning("Handler failed with the exception:", exc_info=True)
                     with suppress(Exception):
                     with suppress(Exception):
+                        # Sometimes `e` is a connection error, so it is okay if we fail to report `e` to the caller
                         await P2P.send_protobuf(RPCError(message=str(e)), writer)
                         await P2P.send_protobuf(RPCError(message=str(e)), writer)
 
 
             with closing(writer):
             with closing(writer):
@@ -343,7 +360,7 @@ class P2P:
         await self.add_binary_stream_handler(name, _handle_stream)
         await self.add_binary_stream_handler(name, _handle_stream)
 
 
     async def _iterate_protobuf_stream_handler(
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
     ) -> TOutputStream:
     ) -> TOutputStream:
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
 
@@ -375,15 +392,22 @@ class P2P:
         handler: Callable[
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
         ],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         *,
         *,
         stream_input: bool = False,
         stream_input: bool = False,
+        stream_output: bool = False,
     ) -> None:
     ) -> None:
         """
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
                              (not just ``TInputProtobuf``) as input.
                              (not just ``TInputProtobuf``) as input.
+        :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
+                              (not ``Awaitable[TOutputProtobuf]``).
         """
         """
 
 
+        if not stream_input and not stream_output:
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            return
+
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
             input = requests if stream_input else await asingle(requests)
             input = requests if stream_input else await asingle(requests)
             output = handler(input, context)
             output = handler(input, context)
@@ -396,23 +420,65 @@ class P2P:
 
 
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
 
 
+    async def _add_protobuf_unary_handler(
+        self,
+        handle_name: str,
+        handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
+        input_protobuf_type: Type[Message],
+    ) -> None:
+        """
+        Register a request-response (unary) handler. Unary requests and responses
+        are sent through persistent multiplexed connections to the daemon for the
+        sake of reducing the number of open files.
+        :param handle_name: name of the handler (protocol id)
+        :param handler: function handling the unary requests
+        :param input_protobuf_type: protobuf type of the request
+        """
+
+        async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
+            input_serialized = input_protobuf_type.FromString(request)
+            context = P2PContext(
+                handle_name=handle_name,
+                local_id=self.peer_id,
+                remote_id=remote_id,
+            )
+
+            response = await handler(input_serialized, context)
+            return response.SerializeToString()
+
+        await self._client.add_unary_handler(handle_name, _unary_handler)
+
     async def call_protobuf_handler(
     async def call_protobuf_handler(
         self,
         self,
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
     ) -> Awaitable[TOutputProtobuf]:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+
+        if not isinstance(input, AsyncIterableABC):
+            return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
+
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
         return await asingle(responses)
 
 
+    async def _call_unary_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        handle_name: str,
+        input: TInputProtobuf,
+        output_protobuf_type: Type[Message],
+    ) -> Awaitable[TOutputProtobuf]:
+        serialized_input = input.SerializeToString()
+        response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
+        return output_protobuf_type.FromString(response)
+
     def iterate_protobuf_handler(
     def iterate_protobuf_handler(
         self,
         self,
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> TOutputStream:
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
@@ -447,6 +513,8 @@ class P2P:
             await self._child.wait()
             await self._child.wait()
 
 
     def _terminate(self) -> None:
     def _terminate(self) -> None:
+        if self._client is not None:
+            self._client.close()
         if self._listen_task is not None:
         if self._listen_task is not None:
             self._listen_task.cancel()
             self._listen_task.cancel()
         if self._reader_task is not None:
         if self._reader_task is not None:
@@ -495,11 +563,3 @@ class P2P:
 
 
         if not ready.done():
         if not ready.done():
             ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
             ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
-
-
-class P2PDaemonError(RuntimeError):
-    pass
-
-
-class P2PHandlerError(Exception):
-    pass

+ 189 - 3
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -5,8 +5,9 @@ Author: Kevin Mai-Husan Chia
 """
 """
 
 
 import asyncio
 import asyncio
-from contextlib import asynccontextmanager
-from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
+from contextlib import asynccontextmanager, closing
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
+from uuid import UUID, uuid4
 
 
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
 
 
@@ -54,17 +55,75 @@ class DaemonConnector:
         else:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
 
+    async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        """
+        Open connection to daemon and upgrade it to a persistent one
+        """
+        reader, writer = await self.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
+        await write_pbmsg(writer, req)
+
+        response = p2pd_pb.Response()
+        await read_pbmsg_safe(reader, response)
+
+        if response.type == "ERROR":
+            raise P2PDaemonError(response.error.msg)
+
+        return reader, writer
+
+
+TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
+CallID = UUID
+
 
 
 class ControlClient:
 class ControlClient:
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
 
 
     def __init__(
     def __init__(
-        self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+        self,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        *,
+        _initialized_with_create=False,
     ) -> None:
     ) -> None:
+        assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
+
         self.listen_maddr = listen_maddr
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
         self.handlers: Dict[str, StreamHandler] = {}
 
 
+        self.unary_handlers: Dict[str, TUnaryHandler] = {}
+
+        self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
+        self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self._handler_tasks: Dict[CallID, asyncio.Task] = {}
+
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
+
+    @classmethod
+    async def create(
+        cls,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        use_persistent_conn: bool = True,
+    ) -> "ControlClient":
+        control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
+
+        if use_persistent_conn:
+            await control._ensure_persistent_conn()
+
+        return control
+
+    def close(self) -> None:
+        if self._read_task is not None:
+            self._read_task.cancel()
+        if self._write_task is not None:
+            self._write_task.cancel()
+
+    def __del__(self):
+        self.close()
+
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
         await read_pbmsg_safe(reader, pb_stream_info)
@@ -93,6 +152,121 @@ class ControlClient:
         async with server:
         async with server:
             yield self
             yield self
 
 
+    async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
+        while True:
+            resp = p2pd_pb.PersistentConnectionResponse()
+            try:
+                await read_pbmsg_safe(reader, resp)
+            except asyncio.IncompleteReadError:
+                break
+
+            call_id = UUID(bytes=resp.callId)
+
+            if resp.HasField("callUnaryResponse"):
+                if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
+                    self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
+                elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
+                    remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
+                    self._pending_calls[call_id].set_exception(remote_exc)
+                else:
+                    logger.debug(f"Received unexpected unary call: {resp}")
+
+            elif resp.HasField("requestHandling"):
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self._handler_tasks[call_id] = handler_task
+
+            elif call_id in self._handler_tasks and resp.HasField("cancel"):
+                self._handler_tasks[call_id].cancel()
+
+            elif call_id in self._pending_calls and resp.HasField("daemonError"):
+                daemon_exc = P2PDaemonError(resp.daemonError.message)
+                self._pending_calls[call_id].set_exception(daemon_exc)
+
+            elif call_id in self._pending_calls:
+                self._pending_calls[call_id].set_result(None)
+
+            else:
+                logger.debug(f"Received unexpected response from daemon: {resp}")
+
+    async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        with closing(writer):
+            while True:
+                msg = await self._pending_messages.get()
+                await write_pbmsg(writer, msg)
+
+    async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest):
+        if request.proto not in self.unary_handlers:
+            logger.warning(f"Protocol {request.proto} not supported")
+            return
+
+        try:
+            remote_id = PeerID(request.peer)
+            response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
+            response = p2pd_pb.CallUnaryResponse(response=response_payload)
+
+        except Exception as e:
+            response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
+
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                unaryResponse=response,
+            )
+        )
+        self._handler_tasks.pop(call_id)
+
+    async def _cancel_unary_call(self, call_id: UUID):
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                cancel=p2pd_pb.Cancel(),
+            ),
+        )
+
+    async def _ensure_persistent_conn(self):
+        reader, writer = await self.daemon_connector.open_persistent_connection()
+
+        self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+        self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        call_id = uuid4()
+
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+
+        if self.unary_handlers.get(proto):
+            raise P2PDaemonError(f"Handler for protocol {proto} already registered")
+        self.unary_handlers[proto] = handler
+
+        self._pending_calls[call_id] = asyncio.Future()
+        await self._pending_messages.put(req)
+        await self._pending_calls[call_id]
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        call_id = uuid4()
+        call_unary_req = p2pd_pb.CallUnaryRequest(
+            peer=peer_id.to_bytes(),
+            proto=proto,
+            data=data,
+        )
+        req = p2pd_pb.PersistentConnectionRequest(
+            callId=call_id.bytes,
+            callUnary=call_unary_req,
+        )
+
+        try:
+            self._pending_calls[call_id] = asyncio.Future()
+            await self._pending_messages.put(req)
+            return await self._pending_calls[call_id]
+
+        except asyncio.CancelledError:
+            await self._cancel_unary_call(call_id)
+            raise
+
+        finally:
+            self._pending_calls.pop(call_id, None)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         reader, writer = await self.daemon_connector.open_connection()
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
@@ -179,3 +353,15 @@ class ControlClient:
 
 
         # if success, add the handler to the dict
         # if success, add the handler to the dict
         self.handlers[proto] = handler_cb
         self.handlers[proto] = handler_cb
+
+
+class P2PHandlerError(Exception):
+    """
+    Raised if remote handled a request with an exception
+    """
+
+
+class P2PDaemonError(Exception):
+    """
+    Raised if daemon failed to handle request
+    """

+ 9 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -74,6 +74,12 @@ class PeerID:
         else:
         else:
             return False
             return False
 
 
+    def __lt__(self, other: object) -> bool:
+        if not isinstance(other, PeerID):
+            raise TypeError(f"'<' not supported between instances of 'PeerID' and '{type(other)}'")
+
+        return self.to_base58() < other.to_base58()
+
     def __hash__(self) -> int:
     def __hash__(self) -> int:
         return hash(self._bytes)
         return hash(self._bytes)
 
 
@@ -125,6 +131,9 @@ class PeerInfo:
     def __str__(self):
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
 
+    def __repr__(self):
+        return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})"
+
 
 
 class InvalidAddrError(ValueError):
 class InvalidAddrError(ValueError):
     pass
     pass

+ 24 - 3
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,16 +10,31 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler, TUnaryHandler
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
 
 
 class Client:
 class Client:
     control: ControlClient
     control: ControlClient
 
 
-    def __init__(self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> None:
+    def __init__(self, *, _initialized_with_create=False) -> None:
+        assert _initialized_with_create, "Please use Client.create coroutine to spawn new client instances"
+        self.control = None
+
+    @classmethod
+    async def create(cls, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> "Client":
+        client = cls(_initialized_with_create=True)
+
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        self.control = ControlClient(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+
+        return client
+
+    def close(self) -> None:
+        self.control.close()
+
+    def __del__(self):
+        self.close()
 
 
     @asynccontextmanager
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["Client"]:
     async def listen(self) -> AsyncIterator["Client"]:
@@ -30,6 +45,12 @@ class Client:
         async with self.control.listen():
         async with self.control.listen():
             yield self
             yield self
 
 
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        await self.control.add_unary_handler(proto, handler)
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self.control.call_unary_handler(peer_id, proto, data)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         """
         """
         Get current node peer id and list of addresses
         Get current node peer id and list of addresses

+ 12 - 8
hivemind/p2p/servicer.py

@@ -125,14 +125,18 @@ class ServicerBase:
         self._collect_rpc_handlers()
         self._collect_rpc_handlers()
 
 
         servicer = self if wrapper is None else wrapper
         servicer = self if wrapper is None else wrapper
-        for handler in self._rpc_handlers:
-            print("handler:", self._get_handle_name(namespace, handler.method_name))
-            await p2p.add_protobuf_handler(
-                self._get_handle_name(namespace, handler.method_name),
-                getattr(servicer, handler.method_name),
-                handler.request_type,
-                stream_input=handler.stream_input,
-            )
+        await asyncio.gather(
+            *[
+                p2p.add_protobuf_handler(
+                    self._get_handle_name(namespace, handler.method_name),
+                    getattr(servicer, handler.method_name),
+                    handler.request_type,
+                    stream_input=handler.stream_input,
+                    stream_output=handler.stream_output,
+                )
+                for handler in self._rpc_handlers
+            ]
+        )
 
 
     @classmethod
     @classmethod
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:

+ 59 - 10
hivemind/proto/p2pd.proto

@@ -8,15 +8,17 @@ package p2pclient.p2pd.pb;
 
 
 message Request {
 message Request {
   enum Type {
   enum Type {
-    IDENTIFY       = 0;
-    CONNECT        = 1;
-    STREAM_OPEN    = 2;
-    STREAM_HANDLER = 3;
-    DHT            = 4;
-    LIST_PEERS     = 5;
-    CONNMANAGER    = 6;
-    DISCONNECT     = 7;
-    PUBSUB         = 8;
+    IDENTIFY                 = 0;
+    CONNECT                  = 1;
+    STREAM_OPEN              = 2;
+    STREAM_HANDLER           = 3;
+    DHT                      = 4;
+    LIST_PEERS               = 5;
+    CONNMANAGER              = 6;
+    DISCONNECT               = 7;      
+    PUBSUB                   = 8;
+
+    PERSISTENT_CONN_UPGRADE  = 9;
   }
   }
 
 
   required Type type = 1;
   required Type type = 1;
@@ -45,6 +47,29 @@ message Response {
   optional PSResponse pubsub = 7;
   optional PSResponse pubsub = 7;
 }
 }
 
 
+message PersistentConnectionRequest {
+  required bytes callId = 1;
+
+  oneof message {
+    AddUnaryHandlerRequest addUnaryHandler = 2;
+    CallUnaryRequest  callUnary = 3;
+    CallUnaryResponse unaryResponse = 4;
+    Cancel cancel = 5;
+  }
+}
+
+message PersistentConnectionResponse {
+  required bytes callId = 1;
+
+  oneof message {
+    CallUnaryResponse callUnaryResponse = 2;
+    CallUnaryRequest requestHandling = 3;
+    DaemonError daemonError = 4;
+    Cancel cancel = 5;
+  }
+}
+
+
 message IdentifyResponse {
 message IdentifyResponse {
   required bytes id = 1;
   required bytes id = 1;
   repeated bytes addrs = 2;
   repeated bytes addrs = 2;
@@ -148,7 +173,7 @@ message PSRequest {
 }
 }
 
 
 message PSMessage {
 message PSMessage {
-  optional bytes from_id = 1;
+  optional bytes from = 1;
   optional bytes data = 2;
   optional bytes data = 2;
   optional bytes seqno = 3;
   optional bytes seqno = 3;
   repeated string topicIDs = 4;
   repeated string topicIDs = 4;
@@ -161,6 +186,30 @@ message PSResponse {
   repeated bytes peerIDs = 2;
   repeated bytes peerIDs = 2;
 }
 }
 
 
+message CallUnaryRequest {
+  required bytes peer = 1;
+  required string proto = 2;
+  required bytes data = 3;
+}
+
+message CallUnaryResponse {
+  oneof result {
+    bytes response = 1;
+    bytes error = 2;
+  }
+}
+
+message AddUnaryHandlerRequest {
+  required string proto = 1;
+}
+
+message DaemonError {
+  optional string message = 1;
+}
+
+message Cancel {
+}
+
 message RPCError {
 message RPCError {
   optional string message = 1;
   optional string message = 1;
 }
 }

+ 24 - 2
hivemind/utils/asyncio.py

@@ -1,4 +1,5 @@
 import asyncio
 import asyncio
+import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 
 
@@ -59,7 +60,7 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]
 
 
 
 
 async def asingle(aiter: AsyncIterable[T]) -> T:
 async def asingle(aiter: AsyncIterable[T]) -> T:
-    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`."""
+    """If ``aiter`` has exactly one item, returns this item. Otherwise, raises ``ValueError``."""
     count = 0
     count = 0
     async for item in aiter:
     async for item in aiter:
         count += 1
         count += 1
@@ -70,16 +71,37 @@ async def asingle(aiter: AsyncIterable[T]) -> T:
     return item
     return item
 
 
 
 
+async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
+    """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
+    async for item in aiter:
+        return item
+    return default
+
+
 async def await_cancelled(awaitable: Awaitable) -> bool:
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
     try:
         await awaitable
         await awaitable
         return False
         return False
-    except asyncio.CancelledError:
+    except (asyncio.CancelledError, concurrent.futures.CancelledError):
+        # In Python 3.7, awaiting a cancelled asyncio.Future raises concurrent.futures.CancelledError
+        # instead of asyncio.CancelledError
         return True
         return True
     except BaseException:
     except BaseException:
+        logger.exception(f"Exception in {awaitable}:")
         return False
         return False
 
 
 
 
+async def cancel_and_wait(awaitable: Awaitable) -> bool:
+    """
+    Cancels ``awaitable`` and waits for its cancellation.
+    In case of ``asyncio.Task``, helps to avoid ``Task was destroyed but it is pending!`` errors.
+    In case of ``asyncio.Future``, equal to ``future.cancel()``.
+    """
+
+    awaitable.cancel()
+    return await await_cancelled(awaitable)
+
+
 async def amap_in_executor(
 async def amap_in_executor(
     func: Callable[..., T],
     func: Callable[..., T],
     *iterables: AsyncIterable,
     *iterables: AsyncIterable,

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -53,7 +53,7 @@ class SharedBytes:
         """Create another shared byte value, represented as a scalar uint8 tensor"""
         """Create another shared byte value, represented as a scalar uint8 tensor"""
         with cls._lock:
         with cls._lock:
             if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
             if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
-                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16))
                 cls._pid = os.getpid()
                 cls._pid = os.getpid()
                 cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
                 cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
                 cls._index = 0
                 cls._index = 0

+ 6 - 1
hivemind/utils/networking.py

@@ -31,7 +31,12 @@ def strip_port(endpoint: Endpoint) -> Hostname:
 
 
 
 
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
+    """
+    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
+
+    :note: Using this function is discouraged since it often leads to a race condition
+           with the "Address is already in use" error if the code is run in parallel.
+    """
     try:
     try:
         with closing(socket.socket(*params)) as sock:
         with closing(socket.socket(*params)) as sock:
             sock.bind(("", 0))
             sock.bind(("", 0))

+ 2 - 2
setup.py

@@ -14,8 +14,8 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 from setuptools.command.develop import develop
 
 
-P2PD_VERSION = "v0.3.1"
-P2PD_CHECKSUM = "15292b880c6b31f5b3c36084b3acc17f"
+P2PD_VERSION = "v0.3.4"
+P2PD_CHECKSUM = "194dca06116fdd36bc4b681d18f3b9cb"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 
 
 here = os.path.abspath(os.path.dirname(__file__))
 here = os.path.abspath(os.path.dirname(__file__))

+ 5 - 2
tests/test_averaging.py

@@ -10,6 +10,7 @@ import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.partition import AllreduceException
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
@@ -363,9 +364,11 @@ def test_too_few_peers():
         )
         )
         for i, dht in enumerate(dht_instances)
         for i, dht in enumerate(dht_instances)
     ]
     ]
-    step_futures = [averager.step(wait=False) for averager in averagers]
+    step_futures = [averager.step(wait=False, timeout=2) for averager in averagers]
+
     for future in step_futures:
     for future in step_futures:
-        assert len(future.result()) == 2
+        with pytest.raises(AllreduceException):
+            future.result()
 
 
     for process in averagers + dht_instances:
     for process in averagers + dht_instances:
         process.shutdown()
         process.shutdown()

+ 3 - 3
tests/test_dht.py

@@ -14,7 +14,7 @@ from test_utils.dht_swarms import launch_dht_instances
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_startup_error():
 async def test_startup_error():
-    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         hivemind.DHT(
         hivemind.DHT(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
             start=True,
             start=True,
@@ -22,7 +22,7 @@ async def test_startup_error():
 
 
     dht = hivemind.DHT(start=True, await_ready=False)
     dht = hivemind.DHT(start=True, await_ready=False)
     with pytest.raises(concurrent.futures.TimeoutError):
     with pytest.raises(concurrent.futures.TimeoutError):
-        dht.wait_until_ready(timeout=0.1)
+        dht.wait_until_ready(timeout=0.01)
     dht.shutdown()
     dht.shutdown()
 
 
 
 
@@ -118,7 +118,7 @@ async def test_dht_get_visible_maddrs():
 
 
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
-    dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
+    dht = hivemind.DHT(start=True, p2p=p2p)
 
 
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()
     dht.shutdown()

+ 34 - 214
tests/test_dht_node.py

@@ -1,200 +1,27 @@
 import asyncio
 import asyncio
 import heapq
 import heapq
-import multiprocessing as mp
 import random
 import random
-import signal
 from itertools import product
 from itertools import product
-from typing import List, Sequence, Tuple
 
 
 import numpy as np
 import numpy as np
 import pytest
 import pytest
-from multiaddr import Multiaddr
 
 
 import hivemind
 import hivemind
 from hivemind import get_dht_time
 from hivemind import get_dht_time
 from hivemind.dht.node import DHTID, DHTNode
 from hivemind.dht.node import DHTID, DHTNode
-from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.storage import DictionaryDHTValue
-from hivemind.p2p import P2P, PeerID
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
 from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
-
-def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
-    return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
-
-
-def run_protocol_listener(
-    dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
-) -> None:
-    loop = asyncio.get_event_loop()
-
-    p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
-    visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
-
-    protocol = loop.run_until_complete(
-        DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
-    )
-
-    logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
-
-    for peer_id in maddrs_to_peer_ids(initial_peers):
-        loop.run_until_complete(protocol.call_ping(peer_id))
-
-    maddr_conn.send((p2p.peer_id, visible_maddrs))
-
-    async def shutdown():
-        await p2p.shutdown()
-        logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
-        loop.stop()
-
-    loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
-    loop.run_forever()
-
-
-def launch_protocol_listener(
-    initial_peers: Sequence[Multiaddr] = (),
-) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
-    remote_conn, local_conn = mp.Pipe()
-    dht_id = DHTID.generate()
-    process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
-    process.start()
-    peer_id, visible_maddrs = local_conn.recv()
-
-    return dht_id, process, peer_id, visible_maddrs
-
-
 # note: we run network-related tests in a separate process to re-initialize all global states from scratch
 # note: we run network-related tests in a separate process to re-initialize all global states from scratch
 # this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
 # this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_dht_protocol():
-    peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
-    peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
-
-    loop = asyncio.get_event_loop()
-    for client_mode in [True, False]:  # note: order matters, this test assumes that first run uses client mode
-        peer_id = DHTID.generate()
-        p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
-        protocol = loop.run_until_complete(
-            DHTProtocol.create(
-                p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
-            )
-        )
-        logger.info(f"Self id={protocol.node_id}")
-
-        assert loop.run_until_complete(protocol.call_ping(peer1_id)) == peer1_node_id
-
-        key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
-        store_ok = loop.run_until_complete(
-            protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )
-        assert all(store_ok), "DHT rejected a trivial store"
-
-        # peer 1 must know about peer 2
-        (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [key])
-        )[key]
-        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-        (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
-        assert (
-            recv_id == peer2_node_id and recv_peer_id == peer2_id
-        ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
-
-        assert recv_value == value and recv_expiration == expiration, (
-            f"call_find_value expected {value} (expires by {expiration}) "
-            f"but got {recv_value} (expires by {recv_expiration})"
-        )
-
-        # peer 2 must know about peer 1, but not have a *random* nonexistent value
-        dummy_key = DHTID.generate()
-        empty_item, nodes_found_2 = loop.run_until_complete(protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
-        assert empty_item is None, "Non-existent keys shouldn't have values"
-        (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
-        assert (
-            recv_id == peer1_node_id and recv_peer_id == peer1_id
-        ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
-
-        # cause a non-response by querying a nonexistent peer
-        assert loop.run_until_complete(protocol.call_find(PeerID.from_base58("fakeid"), [key])) is None
-
-        # store/get a dictionary with sub-keys
-        nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
-        value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
-        assert loop.run_until_complete(
-            protocol.call_store(
-                peer1_id,
-                keys=[nested_key],
-                values=[hivemind.MSGPackSerializer.dumps(value1)],
-                expiration_time=[expiration],
-                subkeys=[subkey1],
-            )
-        )
-        assert loop.run_until_complete(
-            protocol.call_store(
-                peer1_id,
-                keys=[nested_key],
-                values=[hivemind.MSGPackSerializer.dumps(value2)],
-                expiration_time=[expiration + 5],
-                subkeys=[subkey2],
-            )
-        )
-        (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_id, [nested_key])
-        )[nested_key]
-        assert isinstance(recv_dict, DictionaryDHTValue)
-        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
-        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
-        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
-
-        if not client_mode:
-            loop.run_until_complete(p2p.shutdown())
-
-    peer1_proc.terminate()
-    peer2_proc.terminate()
-
-
-@pytest.mark.forked
-def test_empty_table():
-    """Test RPC methods with empty routing table"""
-    peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
-
-    loop = asyncio.get_event_loop()
-    p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
-    protocol = loop.run_until_complete(
-        DHTProtocol.create(
-            p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
-        )
-    )
-
-    key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
-
-    empty_item, nodes_found = loop.run_until_complete(protocol.call_find(peer_peer_id, [key]))[key]
-    assert empty_item is None and len(nodes_found) == 0
-    assert all(
-        loop.run_until_complete(
-            protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
-        )
-    ), "peer rejected store"
-
-    (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-        protocol.call_find(peer_peer_id, [key])
-    )[key]
-    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-    assert len(nodes_found) == 0
-    assert recv_value == value and recv_expiration == expiration
-
-    assert loop.run_until_complete(protocol.call_ping(peer_peer_id)) == peer_id
-    assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58("fakeid"))) is None
-    peer_proc.terminate()
-
-
-@pytest.mark.forked
-def test_dht_node(
+@pytest.mark.asyncio
+async def test_dht_node(
     n_peers: int = 20, n_sequential_peers: int = 5, parallel_rpc: int = 10, bucket_size: int = 5, num_replicas: int = 3
     n_peers: int = 20, n_sequential_peers: int = 5, parallel_rpc: int = 10, bucket_size: int = 5, num_replicas: int = 3
 ):
 ):
     # step A: create a swarm of 50 dht nodes in separate processes
     # step A: create a swarm of 50 dht nodes in separate processes
@@ -205,26 +32,23 @@ def test_dht_node(
     )
     )
 
 
     # step B: run 51-st node in this process
     # step B: run 51-st node in this process
-    loop = asyncio.get_event_loop()
     initial_peers = random.choice(swarm_maddrs)
     initial_peers = random.choice(swarm_maddrs)
-    me = loop.run_until_complete(
-        DHTNode.create(
-            initial_peers=initial_peers,
-            parallel_rpc=parallel_rpc,
-            bucket_size=bucket_size,
-            num_replicas=num_replicas,
-            cache_refresh_before_expiry=False,
-        )
+    me = await DHTNode.create(
+        initial_peers=initial_peers,
+        parallel_rpc=parallel_rpc,
+        bucket_size=bucket_size,
+        num_replicas=num_replicas,
+        cache_refresh_before_expiry=False,
     )
     )
 
 
     # test 1: find self
     # test 1: find self
-    nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
+    nearest = (await me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
     assert len(nearest) == 1 and nearest[me.node_id] == me.peer_id
     assert len(nearest) == 1 and nearest[me.node_id] == me.peer_id
 
 
     # test 2: find others
     # test 2: find others
     for _ in range(10):
     for _ in range(10):
         ref_peer_id, query_id = random.choice(list(dht.items()))
         ref_peer_id, query_id = random.choice(list(dht.items()))
-        nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
+        nearest = (await me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
         assert len(nearest) == 1
         assert len(nearest) == 1
         found_node_id, found_peer_id = next(iter(nearest.items()))
         found_node_id, found_peer_id = next(iter(nearest.items()))
         assert found_node_id == query_id and found_peer_id == ref_peer_id
         assert found_node_id == query_id and found_peer_id == ref_peer_id
@@ -238,10 +62,8 @@ def test_dht_node(
         query_id = DHTID.generate()
         query_id = DHTID.generate()
         k_nearest = random.randint(1, 10)
         k_nearest = random.randint(1, 10)
         exclude_self = random.random() > 0.5
         exclude_self = random.random() > 0.5
-        nearest = loop.run_until_complete(
-            me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
-        )[query_id]
-        nearest_nodes = list(nearest)  # keys from ordered dict
+        find_result = await me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
+        nearest_nodes = list(find_result[query_id])  # keys from ordered dict
 
 
         assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
         assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
         assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
         assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
@@ -268,65 +90,63 @@ def test_dht_node(
 
 
     # test 4: find all nodes
     # test 4: find all nodes
     dummy = DHTID.generate()
     dummy = DHTID.generate()
-    nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
+    nearest = (await me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
     assert len(nearest) == len(dht) + 1
     assert len(nearest) == len(dht) + 1
     assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
     assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
 
 
     # test 5: node without peers
     # test 5: node without peers
-    detached_node = loop.run_until_complete(DHTNode.create())
-    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
+    detached_node = await DHTNode.create()
+    nearest = (await detached_node.find_nearest_nodes([dummy]))[dummy]
     assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.peer_id
     assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.peer_id
-    nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
+    nearest = (await detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
     assert len(nearest) == 0
     assert len(nearest) == 0
 
 
     # test 6: store and get value
     # test 6: store and get value
     true_time = get_dht_time() + 1200
     true_time = get_dht_time() + 1200
-    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
+    assert await me.store("mykey", ["Value", 10], true_time)
 
 
     initial_peers = random.choice(swarm_maddrs)
     initial_peers = random.choice(swarm_maddrs)
-    that_guy = loop.run_until_complete(
-        DHTNode.create(
-            initial_peers=initial_peers,
-            parallel_rpc=parallel_rpc,
-            cache_refresh_before_expiry=False,
-            cache_locally=False,
-        )
+    that_guy = await DHTNode.create(
+        initial_peers=initial_peers,
+        parallel_rpc=parallel_rpc,
+        cache_refresh_before_expiry=False,
+        cache_locally=False,
     )
     )
 
 
     for node in [me, that_guy]:
     for node in [me, that_guy]:
-        val, expiration_time = loop.run_until_complete(node.get("mykey"))
+        val, expiration_time = await node.get("mykey")
         assert val == ["Value", 10], "Wrong value"
         assert val == ["Value", 10], "Wrong value"
         assert expiration_time == true_time, f"Wrong time"
         assert expiration_time == true_time, f"Wrong time"
 
 
-    assert loop.run_until_complete(detached_node.get("mykey")) is None
+    assert not await detached_node.get("mykey")
 
 
     # test 7: bulk store and bulk get
     # test 7: bulk store and bulk get
     keys = "foo", "bar", "baz", "zzz"
     keys = "foo", "bar", "baz", "zzz"
     values = 3, 2, "batman", [1, 2, 3]
     values = 3, 2, "batman", [1, 2, 3]
-    store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
+    store_ok = await me.store_many(keys, values, expiration_time=get_dht_time() + 999)
     assert all(store_ok.values()), "failed to store one or more keys"
     assert all(store_ok.values()), "failed to store one or more keys"
-    response = loop.run_until_complete(me.get_many(keys[::-1]))
+    response = await me.get_many(keys[::-1])
     for key, value in zip(keys, values):
     for key, value in zip(keys, values):
         assert key in response and response[key][0] == value
         assert key in response and response[key][0] == value
 
 
     # test 8: store dictionaries as values (with sub-keys)
     # test 8: store dictionaries as values (with sub-keys)
     upper_key, subkey1, subkey2, subkey3 = "ololo", "k1", "k2", "k3"
     upper_key, subkey1, subkey2, subkey3 = "ololo", "k1", "k2", "k3"
     now = get_dht_time()
     now = get_dht_time()
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
+    assert await me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10)
+    assert await me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20)
     for node in [that_guy, me]:
     for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key))
+        value, time = await node.get(upper_key)
         assert isinstance(value, dict) and time == now + 20
         assert isinstance(value, dict) and time == now + 20
         assert value[subkey1] == (123, now + 10)
         assert value[subkey1] == (123, now + 10)
         assert value[subkey2] == (456, now + 20)
         assert value[subkey2] == (456, now + 20)
         assert len(value) == 2
         assert len(value) == 2
 
 
-    assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
-    assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
+    assert not await me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10)
+    assert await me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30)
+    assert await me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50)
 
 
     for node in [that_guy, me]:
     for node in [that_guy, me]:
-        value, time = loop.run_until_complete(node.get(upper_key, latest=True))
+        value, time = await node.get(upper_key, latest=True)
         assert isinstance(value, dict) and time == now + 50, (value, time)
         assert isinstance(value, dict) and time == now + 50, (value, time)
         assert value[subkey1] == (123, now + 10)
         assert value[subkey1] == (123, now + 10)
         assert value[subkey2] == (567, now + 30)
         assert value[subkey2] == (567, now + 30)
@@ -336,7 +156,7 @@ def test_dht_node(
     for proc in processes:
     for proc in processes:
         proc.terminate()
         proc.terminate()
     # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
     # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
-    loop.run_until_complete(asyncio.wait([node.shutdown() for node in [me, detached_node, that_guy]]))
+    await asyncio.gather(me.shutdown(), that_guy.shutdown(), detached_node.shutdown())
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 163 - 0
tests/test_dht_protocol.py

@@ -0,0 +1,163 @@
+import asyncio
+import multiprocessing as mp
+import random
+import signal
+from typing import List, Sequence, Tuple
+
+import pytest
+from multiaddr import Multiaddr
+
+import hivemind
+from hivemind import P2P, PeerID, get_dht_time, get_logger
+from hivemind.dht import DHTID
+from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.storage import DictionaryDHTValue
+
+logger = get_logger(__name__)
+
+
+def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
+    return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
+
+
+def run_protocol_listener(
+    dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
+) -> None:
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+
+    p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
+    visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
+
+    protocol = loop.run_until_complete(
+        DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
+    )
+
+    logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
+
+    for peer_id in maddrs_to_peer_ids(initial_peers):
+        loop.run_until_complete(protocol.call_ping(peer_id))
+
+    maddr_conn.send((p2p.peer_id, visible_maddrs))
+
+    async def shutdown():
+        await p2p.shutdown()
+        logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
+        loop.stop()
+
+    loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
+    loop.run_forever()
+
+
+def launch_protocol_listener(
+    initial_peers: Sequence[Multiaddr] = (),
+) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
+    remote_conn, local_conn = mp.Pipe()
+    dht_id = DHTID.generate()
+    process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
+    process.start()
+    peer_id, visible_maddrs = local_conn.recv()
+
+    return dht_id, process, peer_id, visible_maddrs
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dht_protocol():
+    peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
+    peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
+
+    for client_mode in [True, False]:  # note: order matters, this test assumes that first run uses client mode
+        peer_id = DHTID.generate()
+        p2p = await P2P.create(initial_peers=peer1_maddrs)
+        protocol = await DHTProtocol.create(
+            p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
+        )
+        logger.info(f"Self id={protocol.node_id}")
+
+        assert peer1_node_id == await protocol.call_ping(peer1_id)
+
+        key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
+        store_ok = await protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        assert all(store_ok), "DHT rejected a trivial store"
+
+        # peer 1 must know about peer 2
+        (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [key]))[key]
+        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+        (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
+        assert (
+            recv_id == peer2_node_id and recv_peer_id == peer2_id
+        ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
+
+        assert recv_value == value and recv_expiration == expiration, (
+            f"call_find_value expected {value} (expires by {expiration}) "
+            f"but got {recv_value} (expires by {recv_expiration})"
+        )
+
+        # peer 2 must know about peer 1, but not have a *random* nonexistent value
+        dummy_key = DHTID.generate()
+        empty_item, nodes_found_2 = (await protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
+        assert empty_item is None, "Non-existent keys shouldn't have values"
+        (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
+        assert (
+            recv_id == peer1_node_id and recv_peer_id == peer1_id
+        ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
+
+        # cause a non-response by querying a nonexistent peer
+        assert not await protocol.call_find(PeerID.from_base58("fakeid"), [key])
+
+        # store/get a dictionary with sub-keys
+        nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
+        value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
+        assert await protocol.call_store(
+            peer1_id,
+            keys=[nested_key],
+            values=[hivemind.MSGPackSerializer.dumps(value1)],
+            expiration_time=[expiration],
+            subkeys=[subkey1],
+        )
+        assert await protocol.call_store(
+            peer1_id,
+            keys=[nested_key],
+            values=[hivemind.MSGPackSerializer.dumps(value2)],
+            expiration_time=[expiration + 5],
+            subkeys=[subkey2],
+        )
+        (recv_dict, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [nested_key]))[nested_key]
+        assert isinstance(recv_dict, DictionaryDHTValue)
+        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
+        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
+        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
+
+        if not client_mode:
+            await p2p.shutdown()
+
+    peer1_proc.terminate()
+    peer2_proc.terminate()
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_empty_table():
+    """Test RPC methods with empty routing table"""
+    peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
+
+    p2p = await P2P.create(initial_peers=peer_maddrs)
+    protocol = await DHTProtocol.create(
+        p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
+    )
+
+    key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
+
+    empty_item, nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
+    assert empty_item is None and len(nodes_found) == 0
+    assert all(await protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration))
+
+    (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
+    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+    assert len(nodes_found) == 0
+    assert recv_value == value and recv_expiration == expiration
+
+    assert peer_id == await protocol.call_ping(peer_peer_id)
+    assert not await protocol.call_ping(PeerID.from_base58("fakeid"))
+    peer_proc.terminate()

+ 30 - 5
tests/test_p2p_daemon.py

@@ -10,7 +10,7 @@ import pytest
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
-from hivemind.proto import dht_pb2
+from hivemind.proto import dht_pb2, test_pb2
 from hivemind.utils.networking import get_free_port
 from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
 
 
@@ -36,13 +36,13 @@ async def test_daemon_killed_on_del():
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_startup_error_message():
 async def test_startup_error_message():
-    with pytest.raises(P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         await P2P.create(
         await P2P.create(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
         )
         )
 
 
     with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
     with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
-        await P2P.create(startup_timeout=0.1)  # Test that startup_timeout works
+        await P2P.create(startup_timeout=0.01)  # Test that startup_timeout works
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -63,9 +63,9 @@ async def test_transports(host_maddrs: List[Multiaddr]):
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
     peers = await client.list_peers()
     peers = await client.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
     peers = await server.list_peers()
     peers = await server.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -83,6 +83,31 @@ async def test_daemon_replica_does_not_affect_primary():
     assert not is_process_running(child_pid)
     assert not is_process_running(child_pid)
 
 
 
 
+@pytest.mark.asyncio
+async def test_unary_handler_edge_cases():
+    p2p = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p.daemon_listen_maddr)
+
+    async def square_handler(data: test_pb2.TestRequest, context):
+        return test_pb2.TestResponse(number=data.number ** 2)
+
+    await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler
+    with pytest.raises(P2PDaemonError):
+        await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler from replicated p2p
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try dialing yourself
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.call_protobuf_handler(
+            p2p.peer_id, "square", test_pb2.TestRequest(number=41), test_pb2.TestResponse
+        )
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "should_cancel,replicate",
     "should_cancel,replicate",
     [
     [

+ 21 - 13
tests/test_p2p_daemon_bindings.py

@@ -18,7 +18,7 @@ from hivemind.p2p.p2p_daemon_bindings.utils import (
 )
 )
 from hivemind.proto import p2pd_pb2 as p2pd_pb
 from hivemind.proto import p2pd_pb2 as p2pd_pb
 
 
-from test_utils.p2p_daemon import connect_safe, make_p2pd_pair_ip4
+from test_utils.p2p_daemon import connect_safe, make_p2pd_pair_unix
 
 
 
 
 def test_raise_if_failed_raises():
 def test_raise_if_failed_raises():
@@ -61,7 +61,7 @@ ENABLE_CONTROL = True
 ENABLE_CONNMGR = False
 ENABLE_CONNMGR = False
 ENABLE_DHT = False
 ENABLE_DHT = False
 ENABLE_PUBSUB = False
 ENABLE_PUBSUB = False
-FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_ip4
+FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_unix
 
 
 
 
 class MockReader(io.BytesIO):
 class MockReader(io.BytesIO):
@@ -144,6 +144,12 @@ def test_peer_id():
     peer_id_3 = PeerID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
     peer_id_3 = PeerID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
     assert PEER_ID != peer_id_3
     assert PEER_ID != peer_id_3
 
 
+    a = PeerID.from_base58("bob")
+    b = PeerID.from_base58("eve")
+    assert a < b and b > a and not (b < a) and not (a > b)
+    with pytest.raises(TypeError):
+        assert a < object()
+
 
 
 def test_stream_info():
 def test_stream_info():
     proto = "123"
     proto = "123"
@@ -193,24 +199,31 @@ def test_parse_conn_protocol_invalid(maddr_str):
 
 
 
 
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_client_ctor_control_maddr(control_maddr_str):
+@pytest.mark.asyncio
+async def test_client_create_control_maddr(control_maddr_str):
     c = DaemonConnector(Multiaddr(control_maddr_str))
     c = DaemonConnector(Multiaddr(control_maddr_str))
     assert c.control_maddr == Multiaddr(control_maddr_str)
     assert c.control_maddr == Multiaddr(control_maddr_str)
 
 
 
 
-def test_client_ctor_default_control_maddr():
+def test_client_create_default_control_maddr():
     c = DaemonConnector()
     c = DaemonConnector()
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
 
 
 
 
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_control_client_ctor_listen_maddr(listen_maddr_str):
-    c = ControlClient(daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str))
+@pytest.mark.asyncio
+async def test_control_client_create_listen_maddr(listen_maddr_str):
+    c = await ControlClient.create(
+        daemon_connector=DaemonConnector(),
+        listen_maddr=Multiaddr(listen_maddr_str),
+        use_persistent_conn=False,
+    )
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
 
 
 
 
-def test_control_client_ctor_default_listen_maddr():
-    c = ControlClient(daemon_connector=DaemonConnector())
+@pytest.mark.asyncio
+async def test_control_client_create_default_listen_maddr():
+    c = await ControlClient.create(daemon_connector=DaemonConnector(), use_persistent_conn=False)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
 
 
 
 
@@ -376,11 +389,6 @@ async def p2pcs():
         yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
         yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
 
 
 
 
-@pytest.mark.asyncio
-async def test_client_identify_unix_socket(p2pcs):
-    await p2pcs[0].identify()
-
-
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_client_identify(p2pcs):
 async def test_client_identify(p2pcs):
     await p2pcs[0].identify()
     await p2pcs[0].identify()

+ 3 - 2
tests/test_p2p_servicer.py

@@ -5,6 +5,7 @@ import pytest
 
 
 from hivemind.p2p import P2P, P2PContext, ServicerBase
 from hivemind.p2p import P2P, P2PContext, ServicerBase
 from hivemind.proto import test_pb2
 from hivemind.proto import test_pb2
+from hivemind.utils.asyncio import anext
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -139,9 +140,9 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
         writer.close()
         writer.close()
     elif cancel_reason == "close_generator":
     elif cancel_reason == "close_generator":
         stub = ExampleServicer.get_stub(client, server.peer_id)
         stub = ExampleServicer.get_stub(client, server.peer_id)
-        iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
+        iter = stub.rpc_wait(test_pb2.TestRequest(number=10))
 
 
-        assert await iter.__anext__() == test_pb2.TestResponse(number=11)
+        assert await anext(iter) == test_pb2.TestResponse(number=11)
         await asyncio.sleep(0.25)
         await asyncio.sleep(0.25)
 
 
         await iter.aclose()
         await iter.aclose()

+ 54 - 1
tests/test_util_modules.py

@@ -13,7 +13,17 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
+from hivemind.utils.asyncio import (
+    achain,
+    aenumerate,
+    afirst,
+    aiter,
+    amap_in_executor,
+    anext,
+    asingle,
+    azip,
+    cancel_and_wait,
+)
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 from hivemind.utils.mpfuture import InvalidStateError
 
 
@@ -498,3 +508,46 @@ async def test_asyncio_utils():
         await anext(iterator)
         await anext(iterator)
 
 
     assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
     assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
+
+    assert await asingle(aiter(1)) == 1
+    with pytest.raises(ValueError):
+        await asingle(aiter())
+    with pytest.raises(ValueError):
+        await asingle(aiter(1, 2, 3))
+
+    assert await afirst(aiter(1)) == 1
+    assert await afirst(aiter()) is None
+    assert await afirst(aiter(), -1) == -1
+    assert await afirst(aiter(1, 2, 3)) == 1
+
+
+@pytest.mark.asyncio
+async def test_cancel_and_wait():
+    finished_gracefully = False
+
+    async def coro_with_finalizer():
+        nonlocal finished_gracefully
+
+        try:
+            await asyncio.Event().wait()
+        except asyncio.CancelledError:
+            await asyncio.sleep(0.05)
+            finished_gracefully = True
+            raise
+
+    task = asyncio.create_task(coro_with_finalizer())
+    await asyncio.sleep(0.05)
+    assert await cancel_and_wait(task)
+    assert finished_gracefully
+
+    async def coro_with_result():
+        return 777
+
+    async def coro_with_error():
+        raise ValueError("error")
+
+    task_with_result = asyncio.create_task(coro_with_result())
+    task_with_error = asyncio.create_task(coro_with_error())
+    await asyncio.sleep(0.05)
+    assert not await cancel_and_wait(task_with_result)
+    assert not await cancel_and_wait(task_with_error)

+ 17 - 20
tests/test_utils/p2p_daemon.py

@@ -4,7 +4,7 @@ import os
 import subprocess
 import subprocess
 import time
 import time
 import uuid
 import uuid
-from contextlib import asynccontextmanager
+from contextlib import asynccontextmanager, suppress
 from typing import NamedTuple
 from typing import NamedTuple
 
 
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
@@ -57,7 +57,7 @@ class Daemon:
 
 
     def _run(self):
     def _run(self):
         cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
         cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
-        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{get_free_port()}"]
+        cmd_list += ["-hostAddrs=/ip4/127.0.0.1/tcp/0"]
         if self.enable_connmgr:
         if self.enable_connmgr:
             cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
             cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
         if self.enable_dht:
         if self.enable_dht:
@@ -107,24 +107,21 @@ async def make_p2pd_pair_unix(enable_control, enable_connmgr, enable_dht, enable
     name = str(uuid.uuid4())[:8]
     name = str(uuid.uuid4())[:8]
     control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
     control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
     listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
     listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
-    # Remove the existing unix socket files if they are existing
     try:
     try:
-        os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
-    except FileNotFoundError:
-        pass
-    try:
-        os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
-    except FileNotFoundError:
-        pass
-    async with _make_p2pd_pair(
-        control_maddr=control_maddr,
-        listen_maddr=listen_maddr,
-        enable_control=enable_control,
-        enable_connmgr=enable_connmgr,
-        enable_dht=enable_dht,
-        enable_pubsub=enable_pubsub,
-    ) as pair:
-        yield pair
+        async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+        ) as pair:
+            yield pair
+    finally:
+        with suppress(FileNotFoundError):
+            os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
+        with suppress(FileNotFoundError):
+            os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
 
 
 
 
 @asynccontextmanager
 @asynccontextmanager
@@ -160,7 +157,7 @@ async def _make_p2pd_pair(
     )
     )
     # wait for daemon ready
     # wait for daemon ready
     await p2pd.wait_until_ready()
     await p2pd.wait_until_ready()
-    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    client = await Client.create(control_maddr=control_maddr, listen_maddr=listen_maddr)
     try:
     try:
         async with client.listen():
         async with client.listen():
             yield DaemonTuple(daemon=p2pd, client=client)
             yield DaemonTuple(daemon=p2pd, client=client)