浏览代码

Merge remote-tracking branch 'origin/master' into TPU

Michael Diskin 4 年之前
父节点
当前提交
a2347ece86

+ 0 - 2
.github/workflows/run-tests.yml

@@ -34,7 +34,6 @@ jobs:
         run: |
         run: |
           cd tests
           cd tests
           pytest --durations=0 --durations-min=1.0 -v
           pytest --durations=0 --durations-min=1.0 -v
-
   build_and_test_p2pd:
   build_and_test_p2pd:
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
     timeout-minutes: 10
     timeout-minutes: 10
@@ -61,7 +60,6 @@ jobs:
         run: |
         run: |
           cd tests
           cd tests
           pytest -k "p2p" -v
           pytest -k "p2p" -v
-
   codecov_in_develop_mode:
   codecov_in_develop_mode:
 
 
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest

+ 1 - 0
.gitignore

@@ -54,6 +54,7 @@ coverage.xml
 .project
 .project
 .pydevproject
 .pydevproject
 .idea
 .idea
+.vscode
 .ipynb_checkpoints
 .ipynb_checkpoints
 
 
 # Rope
 # Rope

+ 7 - 0
examples/albert/arguments.py

@@ -34,6 +34,13 @@ class BaseTrainingArguments:
         default_factory=list,
         default_factory=list,
         metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
         metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
     )
     )
+    identity_path: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
+            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+        },
+    )
 
 
 
 
 @dataclass
 @dataclass

+ 1 - 0
examples/albert/run_trainer.py

@@ -247,6 +247,7 @@ def main():
         use_ipfs=collaboration_args.use_ipfs,
         use_ipfs=collaboration_args.use_ipfs,
         host_maddrs=collaboration_args.host_maddrs,
         host_maddrs=collaboration_args.host_maddrs,
         announce_maddrs=collaboration_args.announce_maddrs,
         announce_maddrs=collaboration_args.announce_maddrs,
+        identity_path=collaboration_args.identity_path,
     )
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
 
 

+ 1 - 0
examples/albert/run_training_monitor.py

@@ -168,6 +168,7 @@ if __name__ == "__main__":
         use_ipfs=monitor_args.use_ipfs,
         use_ipfs=monitor_args.use_ipfs,
         host_maddrs=monitor_args.host_maddrs,
         host_maddrs=monitor_args.host_maddrs,
         announce_maddrs=monitor_args.announce_maddrs,
         announce_maddrs=monitor_args.announce_maddrs,
+        identity_path=monitor_args.identity_path,
     )
     )
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
     utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
 
 

+ 1 - 1
hivemind/__init__.py

@@ -19,4 +19,4 @@ from hivemind.optim import (
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = "1.0.0.dev0"
+__version__ = "1.0.0dev0"

+ 1 - 1
hivemind/averaging/allreduce.py

@@ -8,7 +8,7 @@ from hivemind.averaging.partition import AllreduceException, TensorPartContainer
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
+from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 
 # flavour types
 # flavour types

+ 8 - 3
hivemind/averaging/averager.py

@@ -96,7 +96,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         prefix: str,
         prefix: str,
         target_group_size: int,
         target_group_size: int,
         min_group_size: int = 2,
         min_group_size: int = 2,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
         averaging_expiration: float = 15,
         request_timeout: float = 3,
         request_timeout: float = 3,
         averaging_alpha: float = 1.0,
         averaging_alpha: float = 1.0,
@@ -117,7 +117,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         ), "bandwidth must be a non-negative float32"
         ), "bandwidth must be a non-negative float32"
         if not is_power_of_two(target_group_size):
         if not is_power_of_two(target_group_size):
             logger.warning("It is recommended to set target_group_size to a power of 2.")
             logger.warning("It is recommended to set target_group_size to a power of 2.")
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
+        assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
 
         super().__init__()
         super().__init__()
@@ -241,7 +241,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 self._ready.set_result(None)
                 self._ready.set_result(None)
 
 
                 while True:
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._matchmaking.request_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                     if method == "_shutdown":
                         await task
                         await task

+ 12 - 93
hivemind/averaging/key_manager.py

@@ -1,4 +1,3 @@
-import asyncio
 import random
 import random
 import re
 import re
 from typing import List, Optional, Tuple
 from typing import List, Optional, Tuple
@@ -25,31 +24,17 @@ class GroupKeyManager:
     Utility class that declares and fetches averaging-related keys using a DHT
     Utility class that declares and fetches averaging-related keys using a DHT
     """
     """
 
 
-    RESERVED_KEY_FOR_NBITS = "::NBITS"
-
     def __init__(
     def __init__(
         self,
         self,
         dht: DHT,
         dht: DHT,
         prefix: str,
         prefix: str,
-        initial_group_bits: Optional[str],
+        initial_group_bits: str,
         target_group_size: int,
         target_group_size: int,
-        insufficient_size: Optional[int] = None,
-        excessive_size: Optional[int] = None,
-        nbits_expiration: float = 60,
-        nbits_rewrite_grace_period: float = 15,
     ):
     ):
-        assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
-        if initial_group_bits is None:
-            search_result = dht.get(f"{prefix}.0b", latest=True)
-            initial_group_nbits = self.get_suggested_nbits(search_result) or 0
-            initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
+        assert all(bit in "01" for bit in initial_group_bits)
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
-        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
-        self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
-        self.excessive_size = excessive_size or target_group_size * 3
-        self.nbits_expiration, self.nbits_grace_period = nbits_expiration, nbits_rewrite_grace_period
-        self.suggested_nbits: Optional[int] = None
+        self.peer_id = dht.peer_id
 
 
     @property
     @property
     def current_key(self) -> GroupKey:
     def current_key(self) -> GroupKey:
@@ -93,51 +78,16 @@ class GroupKeyManager:
         if result is None or not isinstance(result.value, dict):
         if result is None or not isinstance(result.value, dict):
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
             return []
-        averagers = [
-            (PeerID(key), looking_for_group.expiration_time)
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
-        ]
-        num_active_averagers = sum(
-            1
-            for key, looking_for_group in result.value.items()
-            if key != self.RESERVED_KEY_FOR_NBITS and looking_for_group.value
-        )
-
-        suggested_nbits = self.get_suggested_nbits(result)
-        if (
-            suggested_nbits is not None
-            and suggested_nbits != len(self.group_bits)
-            and suggested_nbits != self.suggested_nbits
-        ):
-            self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
-        elif num_active_averagers >= self.excessive_size:
-            self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+        averagers = []
+        for key, looking_for_group in result.value.items():
+            try:
+                if only_active and not looking_for_group.value:
+                    continue
+                averagers.append((PeerID(key), looking_for_group.expiration_time))
+            except Exception as e:
+                logger.warning(f"Could not parse group key {key} ({looking_for_group}, exc={e})")
         return averagers
         return averagers
 
 
-    async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
-        """notify other peers that they can run averaging at this depth"""
-        return await self.dht.store(
-            key=group_key,
-            subkey=self.RESERVED_KEY_FOR_NBITS,
-            value=nbits,
-            expiration_time=expiration_time,
-            return_future=True,
-        )
-
-    @classmethod
-    def get_suggested_nbits(cls, search_result: Optional[ValueWithExpiration]) -> Optional[int]:
-        if (
-            isinstance(search_result, ValueWithExpiration)
-            and cls.RESERVED_KEY_FOR_NBITS in search_result.value
-            and isinstance(search_result.value[cls.RESERVED_KEY_FOR_NBITS].value, int)
-        ):
-            return search_result.value[cls.RESERVED_KEY_FOR_NBITS].value
-        else:
-            return None
-
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
         rng = random.Random(group_info.group_id)
@@ -148,37 +98,6 @@ class GroupKeyManager:
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
 
-        if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
-            asyncio.create_task(self.notify_stragglers())
-        if self.suggested_nbits is not None and self.suggested_nbits != len(self.group_bits):
-            num_extra_bits = max(0, self.suggested_nbits - len(self.group_bits))
-            self.group_bits = "".join((random.choice("01") for _ in range(num_extra_bits))) + self.group_bits
-            self.group_bits = self.group_bits[-self.suggested_nbits :]
-        self.suggested_nbits = None
-
     async def update_key_on_not_enough_peers(self):
     async def update_key_on_not_enough_peers(self):
         """this function is triggered whenever averager fails to assemble group within timeout"""
         """this function is triggered whenever averager fails to assemble group within timeout"""
-        new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
-        prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
-        if self.group_bits != prev_nbits:
-            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
-        self.suggested_nbits = None
-
-    async def notify_stragglers(self):
-        """Find averagers that have fewer nbits and redirect them to your current nbits"""
-        for nbits in reversed(range(1, len(self.group_bits) - 1)):
-            preceding_key = f"{self.prefix}.0b{self.group_bits[-nbits:] if nbits else ''}"
-            preceding_data, _ = await self.dht.get(preceding_key, latest=False, return_future=True) or ({}, None)
-
-            if len(preceding_data) > 0 and self.RESERVED_KEY_FOR_NBITS not in preceding_data:
-                await self.declare_nbits(preceding_key, len(self.group_bits), get_dht_time() + self.nbits_expiration)
-                break
-
-        root_data, _ = await self.dht.get(f"{self.prefix}.0b", latest=False, return_future=True) or ({}, None)
-        if (
-            isinstance(root_data, dict)
-            and root_data.get(self.RESERVED_KEY_FOR_NBITS, (None, -float("inf")))[1]
-            > get_dht_time() + self.nbits_grace_period
-        ):
-            return
-        await self.declare_nbits(f"{self.prefix}.0b", len(self.group_bits), get_dht_time() + self.nbits_expiration)
+        pass  # to be implemented in subclasses

+ 30 - 42
hivemind/averaging/matchmaking.py

@@ -15,7 +15,7 @@ from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
-from hivemind.utils.asyncio import anext
+from hivemind.utils.asyncio import anext, cancel_and_wait
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -45,7 +45,7 @@ class Matchmaking:
         min_group_size: int,
         min_group_size: int,
         request_timeout: float,
         request_timeout: float,
         client_mode: bool,
         client_mode: bool,
-        initial_group_bits: Optional[str] = None,
+        initial_group_bits: str = "",
         averaging_expiration: float = 15,
         averaging_expiration: float = 15,
     ):
     ):
         assert "." not in prefix, "group prefix must be a string without ."
         assert "." not in prefix, "group prefix must be a string without ."
@@ -127,10 +127,9 @@ class Matchmaking:
                 raise
                 raise
 
 
             finally:
             finally:
-                if not request_leaders_task.done():
-                    request_leaders_task.cancel()
-                if not self.assembled_group.done():
-                    self.assembled_group.cancel()
+                await cancel_and_wait(request_leaders_task)
+                self.assembled_group.cancel()
+
                 while len(self.current_followers) > 0:
                 while len(self.current_followers) > 0:
                     await self.follower_was_discarded.wait()
                     await self.follower_was_discarded.wait()
                     self.follower_was_discarded.clear()
                     self.follower_was_discarded.clear()
@@ -229,7 +228,7 @@ class Matchmaking:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
             return None
         except (P2PHandlerError, StopAsyncIteration) as e:
         except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+            logger.exception(f"{self} - failed to request potential leader {leader}:")
             return None
             return None
 
 
         finally:
         finally:
@@ -413,10 +412,9 @@ class PotentialLeaders:
             try:
             try:
                 yield self
                 yield self
             finally:
             finally:
-                if not update_queue_task.done():
-                    update_queue_task.cancel()
-                if declare and not declare_averager_task.done():
-                    declare_averager_task.cancel()
+                await cancel_and_wait(update_queue_task)
+                if declare:
+                    await cancel_and_wait(declare_averager_task)
 
 
                 for field in (
                 for field in (
                     self.past_attempts,
                     self.past_attempts,
@@ -477,37 +475,31 @@ class PotentialLeaders:
         else:
         else:
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
 
 
-    async def _update_queue_periodically(self, key_manager: GroupKeyManager):
-        try:
-            DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
-            while get_dht_time() < self.search_end_time:
-                new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
-                self.max_assured_time = max(
-                    self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
-                )
+    async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
+        DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+        while get_dht_time() < self.search_end_time:
+            new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
+            self.max_assured_time = max(
+                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+            )
 
 
-                self.leader_queue.clear()
-                for peer, peer_expiration_time in new_peers:
-                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
-                        continue
-                    self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
-                    self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+            self.leader_queue.clear()
+            for peer, peer_expiration_time in new_peers:
+                if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
+                    continue
+                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
 
-                self.update_finished.set()
+            self.update_finished.set()
 
 
-                await asyncio.wait(
-                    {self.running.wait(), self.update_triggered.wait()},
-                    return_when=asyncio.ALL_COMPLETED,
-                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
-                )
-                self.update_triggered.clear()
-        except (concurrent.futures.CancelledError, asyncio.CancelledError):
-            return  # note: this is a compatibility layer for python3.7
-        except Exception as e:
-            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
-            raise
+            await asyncio.wait(
+                {self.running.wait(), self.update_triggered.wait()},
+                return_when=asyncio.ALL_COMPLETED,
+                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
+            )
+            self.update_triggered.clear()
 
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
+    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
         async with self.lock_declare:
             try:
             try:
                 while True:
                 while True:
@@ -521,10 +513,6 @@ class PotentialLeaders:
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
                         await key_manager.update_key_on_not_enough_peers()
-            except (concurrent.futures.CancelledError, asyncio.CancelledError):
-                pass  # note: this is a compatibility layer for python3.7
-            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
             finally:
                 if self.declared_group_key is not None:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time

+ 19 - 12
hivemind/dht/__init__.py

@@ -27,7 +27,7 @@ from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.p2p import P2P, PeerID
 from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
+from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_logger, switch_to_uvloop
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -61,6 +61,7 @@ class DHT(mp.Process):
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         *,
         *,
         start: bool,
         start: bool,
+        p2p: Optional[P2P] = None,
         daemon: bool = True,
         daemon: bool = True,
         num_workers: int = DEFAULT_NUM_WORKERS,
         num_workers: int = DEFAULT_NUM_WORKERS,
         record_validators: Iterable[RecordValidatorBase] = (),
         record_validators: Iterable[RecordValidatorBase] = (),
@@ -94,6 +95,8 @@ class DHT(mp.Process):
         self._client_mode = None
         self._client_mode = None
         self._p2p_replica = None
         self._p2p_replica = None
 
 
+        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
+
         if start:
         if start:
             self.run_in_background(await_ready=await_ready)
             self.run_in_background(await_ready=await_ready)
 
 
@@ -105,10 +108,16 @@ class DHT(mp.Process):
 
 
             async def _run():
             async def _run():
                 try:
                 try:
+                    if self._daemon_listen_maddr is not None:
+                        replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                    else:
+                        replicated_p2p = None
+
                     self._node = await DHTNode.create(
                     self._node = await DHTNode.create(
                         initial_peers=self.initial_peers,
                         initial_peers=self.initial_peers,
                         num_workers=self.num_workers,
                         num_workers=self.num_workers,
                         record_validator=self._record_validator,
                         record_validator=self._record_validator,
+                        p2p=replicated_p2p,
                         **self.kwargs,
                         **self.kwargs,
                     )
                     )
                 except Exception as e:
                 except Exception as e:
@@ -119,7 +128,12 @@ class DHT(mp.Process):
                 self._ready.set_result(None)
                 self._ready.set_result(None)
 
 
                 while True:
                 while True:
-                    method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    try:
+                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                    except (OSError, ConnectionError) as e:
+                        logger.exception(e)
+                        await asyncio.sleep(self._node.protocol.wait_timeout)
+                        continue
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                     if method == "_shutdown":
                     if method == "_shutdown":
                         await task
                         await task
@@ -247,18 +261,11 @@ class DHT(mp.Process):
     async def _run_coroutine(
     async def _run_coroutine(
         self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
         self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
     ):
     ):
-        main_task = asyncio.create_task(coro(self, self._node))
-        cancel_task = asyncio.create_task(await_cancelled(future))
         try:
         try:
-            await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
-            if future.cancelled():
-                main_task.cancel()
-            else:
-                future.set_result(await main_task)
+            future.set_result(await coro(self, self._node))
         except BaseException as e:
         except BaseException as e:
-            logger.exception(f"Caught an exception when running a coroutine: {e}")
-            if not future.done():
-                future.set_exception(e)
+            logger.exception("Caught an exception when running a coroutine:")
+            future.set_exception(e)
 
 
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
         if not self._ready.done():
         if not self._ready.done():

+ 1 - 1
hivemind/dht/protocol.py

@@ -81,7 +81,7 @@ class DHTProtocol(ServicerBase):
 
 
     def __init__(self, *, _initialized_with_create=False):
     def __init__(self, *, _initialized_with_create=False):
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
-        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
+        assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
         super().__init__()
         super().__init__()
 
 
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:

+ 99 - 45
hivemind/p2p/p2p_daemon.py

@@ -5,12 +5,14 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
 from importlib.resources import path
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
 
+from google.protobuf.message import Message
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind.hivemind_cli as cli
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import aiter, asingle
 from hivemind.utils.asyncio import aiter, asingle
@@ -27,7 +29,6 @@ class P2PContext(object):
     handle_name: str
     handle_name: str
     local_id: PeerID
     local_id: PeerID
     remote_id: PeerID = None
     remote_id: PeerID = None
-    remote_maddr: Multiaddr = None
 
 
 
 
 class P2P:
 class P2P:
@@ -65,6 +66,7 @@ class P2P:
 
 
     def __init__(self):
     def __init__(self):
         self.peer_id = None
         self.peer_id = None
+        self._client = None
         self._child = None
         self._child = None
         self._alive = False
         self._alive = False
         self._reader_task = None
         self._reader_task = None
@@ -74,43 +76,50 @@ class P2P:
     async def create(
     async def create(
         cls,
         cls,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
         initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        use_ipfs: bool = False,
-        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        *,
         announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
         announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        quic: bool = False,
-        tls: bool = True,
+        auto_nat: bool = True,
         conn_manager: bool = True,
         conn_manager: bool = True,
         dht_mode: str = "dht_server",
         dht_mode: str = "dht_server",
         force_reachability: Optional[str] = None,
         force_reachability: Optional[str] = None,
+        host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
+        identity_path: Optional[str] = None,
+        idle_timeout: float = 30,
         nat_port_map: bool = True,
         nat_port_map: bool = True,
-        auto_nat: bool = True,
+        quic: bool = False,
+        relay_hop_limit: int = 0,
+        startup_timeout: float = 15,
+        tls: bool = True,
+        use_auto_relay: bool = False,
+        use_ipfs: bool = False,
         use_relay: bool = True,
         use_relay: bool = True,
         use_relay_hop: bool = False,
         use_relay_hop: bool = False,
         use_relay_discovery: bool = False,
         use_relay_discovery: bool = False,
-        use_auto_relay: bool = False,
-        relay_hop_limit: int = 0,
-        startup_timeout: float = 15,
     ) -> "P2P":
     ) -> "P2P":
         """
         """
         Start a new p2pd process and connect to it.
         Start a new p2pd process and connect to it.
         :param initial_peers: List of bootstrap peers
         :param initial_peers: List of bootstrap peers
-        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
-        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param auto_nat: Enables the AutoNAT service
         :param announce_maddrs: Visible multiaddrs that the peer will announce
         :param announce_maddrs: Visible multiaddrs that the peer will announce
-          for external connections from other p2p instances
-        :param quic: Enables the QUIC transport
-        :param tls: Enables TLS1.3 channel security protocol
+                                for external connections from other p2p instances
         :param conn_manager: Enables the Connection Manager
         :param conn_manager: Enables the Connection Manager
         :param dht_mode: DHT mode (dht_client/dht_server/dht)
         :param dht_mode: DHT mode (dht_client/dht_server/dht)
         :param force_reachability: Force reachability mode (public/private)
         :param force_reachability: Force reachability mode (public/private)
+        :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
+        :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
+                              May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``.
+        :param idle_timeout: kill daemon if client has been idle for a given number of
+                             seconds before opening persistent streams
         :param nat_port_map: Enables NAT port mapping
         :param nat_port_map: Enables NAT port mapping
-        :param auto_nat: Enables the AutoNAT service
+        :param quic: Enables the QUIC transport
+        :param relay_hop_limit: sets the hop limit for hop relays
+        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
+        :param tls: Enables TLS1.3 channel security protocol
+        :param use_auto_relay: enables autorelay
+        :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
         :param use_relay: enables circuit relay
         :param use_relay: enables circuit relay
         :param use_relay_hop: enables hop for relay
         :param use_relay_hop: enables hop for relay
         :param use_relay_discovery: enables passive discovery for relay
         :param use_relay_discovery: enables passive discovery for relay
-        :param use_auto_relay: enables autorelay
-        :param relay_hop_limit: sets the hop limit for hop relays
-        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :return: a wrapper for the p2p daemon
         :return: a wrapper for the p2p daemon
         """
         """
 
 
@@ -136,21 +145,24 @@ class P2P:
         ]:
         ]:
             if value:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
                 process_kwargs[param] = self._maddrs_to_str(value)
+        if identity_path is not None:
+            process_kwargs["id"] = identity_path
 
 
         proc_args = self._make_process_args(
         proc_args = self._make_process_args(
             str(p2pd_path),
             str(p2pd_path),
-            listen=self._daemon_listen_maddr,
-            quic=quic,
-            tls=tls,
+            autoRelay=use_auto_relay,
+            autonat=auto_nat,
+            b=need_bootstrap,
             connManager=conn_manager,
             connManager=conn_manager,
+            idleTimeout=f"{idle_timeout}s",
+            listen=self._daemon_listen_maddr,
             natPortMap=nat_port_map,
             natPortMap=nat_port_map,
-            autonat=auto_nat,
+            quic=quic,
             relay=use_relay,
             relay=use_relay,
-            relayHop=use_relay_hop,
             relayDiscovery=use_relay_discovery,
             relayDiscovery=use_relay_discovery,
-            autoRelay=use_auto_relay,
+            relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
             relayHopLimit=relay_hop_limit,
-            b=need_bootstrap,
+            tls=tls,
             **process_kwargs,
             **process_kwargs,
         )
         )
 
 
@@ -167,7 +179,7 @@ class P2P:
             await self.shutdown()
             await self.shutdown()
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
         await self._ping_daemon()
         await self._ping_daemon()
         return self
         return self
 
 
@@ -189,7 +201,7 @@ class P2P:
         self._daemon_listen_maddr = daemon_listen_maddr
         self._daemon_listen_maddr = daemon_listen_maddr
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
 
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
 
 
         await self._ping_daemon()
         await self._ping_daemon()
         return self
         return self
@@ -258,7 +270,7 @@ class P2P:
 
 
     @staticmethod
     @staticmethod
     async def receive_protobuf(
     async def receive_protobuf(
-        input_protobuf_type: type, reader: asyncio.StreamReader
+        input_protobuf_type: Type[Message], reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
         if msg_type == P2P.MESSAGE_MARKER:
@@ -279,7 +291,7 @@ class P2P:
         self,
         self,
         name: str,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         max_prefetch: int = 5,
         max_prefetch: int = 5,
     ) -> None:
     ) -> None:
         """
         """
@@ -297,7 +309,6 @@ class P2P:
                 handle_name=name,
                 handle_name=name,
                 local_id=self.peer_id,
                 local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
                 remote_id=stream_info.peer_id,
-                remote_maddr=stream_info.addr,
             )
             )
             requests = asyncio.Queue(max_prefetch)
             requests = asyncio.Queue(max_prefetch)
 
 
@@ -349,7 +360,7 @@ class P2P:
         await self.add_binary_stream_handler(name, _handle_stream)
         await self.add_binary_stream_handler(name, _handle_stream)
 
 
     async def _iterate_protobuf_stream_handler(
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
     ) -> TOutputStream:
     ) -> TOutputStream:
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
         _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
 
 
@@ -381,15 +392,22 @@ class P2P:
         handler: Callable[
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
         ],
-        input_protobuf_type: type,
+        input_protobuf_type: Type[Message],
         *,
         *,
         stream_input: bool = False,
         stream_input: bool = False,
+        stream_output: bool = False,
     ) -> None:
     ) -> None:
         """
         """
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
         :param stream_input: If True, assume ``handler`` to take ``TInputStream``
                              (not just ``TInputProtobuf``) as input.
                              (not just ``TInputProtobuf``) as input.
+        :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
+                              (not ``Awaitable[TOutputProtobuf]``).
         """
         """
 
 
+        if not stream_input and not stream_output:
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            return
+
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
         async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
             input = requests if stream_input else await asingle(requests)
             input = requests if stream_input else await asingle(requests)
             output = handler(input, context)
             output = handler(input, context)
@@ -402,23 +420,65 @@ class P2P:
 
 
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
 
 
+    async def _add_protobuf_unary_handler(
+        self,
+        handle_name: str,
+        handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
+        input_protobuf_type: Type[Message],
+    ) -> None:
+        """
+        Register a request-response (unary) handler. Unary requests and responses
+        are sent through persistent multiplexed connections to the daemon for the
+        sake of reducing the number of open files.
+        :param handle_name: name of the handler (protocol id)
+        :param handler: function handling the unary requests
+        :param input_protobuf_type: protobuf type of the request
+        """
+
+        async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
+            input_serialized = input_protobuf_type.FromString(request)
+            context = P2PContext(
+                handle_name=handle_name,
+                local_id=self.peer_id,
+                remote_id=remote_id,
+            )
+
+            response = await handler(input_serialized, context)
+            return response.SerializeToString()
+
+        await self._client.add_unary_handler(handle_name, _unary_handler)
+
     async def call_protobuf_handler(
     async def call_protobuf_handler(
         self,
         self,
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> Awaitable[TOutputProtobuf]:
     ) -> Awaitable[TOutputProtobuf]:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+
+        if not isinstance(input, AsyncIterableABC):
+            return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
+
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
         return await asingle(responses)
         return await asingle(responses)
 
 
+    async def _call_unary_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        handle_name: str,
+        input: TInputProtobuf,
+        output_protobuf_type: Type[Message],
+    ) -> Awaitable[TOutputProtobuf]:
+        serialized_input = input.SerializeToString()
+        response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
+        return output_protobuf_type.FromString(response)
+
     def iterate_protobuf_handler(
     def iterate_protobuf_handler(
         self,
         self,
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Type[Message],
     ) -> TOutputStream:
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
@@ -453,6 +513,8 @@ class P2P:
             await self._child.wait()
             await self._child.wait()
 
 
     def _terminate(self) -> None:
     def _terminate(self) -> None:
+        if self._client is not None:
+            self._client.close()
         if self._listen_task is not None:
         if self._listen_task is not None:
             self._listen_task.cancel()
             self._listen_task.cancel()
         if self._reader_task is not None:
         if self._reader_task is not None:
@@ -501,11 +563,3 @@ class P2P:
 
 
         if not ready.done():
         if not ready.done():
             ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
             ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
-
-
-class P2PDaemonError(RuntimeError):
-    pass
-
-
-class P2PHandlerError(Exception):
-    pass

+ 189 - 3
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -5,8 +5,9 @@ Author: Kevin Mai-Husan Chia
 """
 """
 
 
 import asyncio
 import asyncio
-from contextlib import asynccontextmanager
-from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
+from contextlib import asynccontextmanager, closing
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
+from uuid import UUID, uuid4
 
 
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
 
 
@@ -54,17 +55,75 @@ class DaemonConnector:
         else:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
 
+    async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        """
+        Open connection to daemon and upgrade it to a persistent one
+        """
+        reader, writer = await self.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
+        await write_pbmsg(writer, req)
+
+        response = p2pd_pb.Response()
+        await read_pbmsg_safe(reader, response)
+
+        if response.type == "ERROR":
+            raise P2PDaemonError(response.error.msg)
+
+        return reader, writer
+
+
+TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
+CallID = UUID
+
 
 
 class ControlClient:
 class ControlClient:
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
 
 
     def __init__(
     def __init__(
-        self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+        self,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        *,
+        _initialized_with_create=False,
     ) -> None:
     ) -> None:
+        assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
+
         self.listen_maddr = listen_maddr
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
         self.handlers: Dict[str, StreamHandler] = {}
 
 
+        self.unary_handlers: Dict[str, TUnaryHandler] = {}
+
+        self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
+        self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self._handler_tasks: Dict[CallID, asyncio.Task] = {}
+
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
+
+    @classmethod
+    async def create(
+        cls,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        use_persistent_conn: bool = True,
+    ) -> "ControlClient":
+        control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
+
+        if use_persistent_conn:
+            await control._ensure_persistent_conn()
+
+        return control
+
+    def close(self) -> None:
+        if self._read_task is not None:
+            self._read_task.cancel()
+        if self._write_task is not None:
+            self._write_task.cancel()
+
+    def __del__(self):
+        self.close()
+
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
         await read_pbmsg_safe(reader, pb_stream_info)
@@ -93,6 +152,121 @@ class ControlClient:
         async with server:
         async with server:
             yield self
             yield self
 
 
+    async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
+        while True:
+            resp = p2pd_pb.PersistentConnectionResponse()
+            try:
+                await read_pbmsg_safe(reader, resp)
+            except asyncio.IncompleteReadError:
+                break
+
+            call_id = UUID(bytes=resp.callId)
+
+            if resp.HasField("callUnaryResponse"):
+                if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
+                    self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
+                elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
+                    remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
+                    self._pending_calls[call_id].set_exception(remote_exc)
+                else:
+                    logger.debug(f"Received unexpected unary call: {resp}")
+
+            elif resp.HasField("requestHandling"):
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self._handler_tasks[call_id] = handler_task
+
+            elif call_id in self._handler_tasks and resp.HasField("cancel"):
+                self._handler_tasks[call_id].cancel()
+
+            elif call_id in self._pending_calls and resp.HasField("daemonError"):
+                daemon_exc = P2PDaemonError(resp.daemonError.message)
+                self._pending_calls[call_id].set_exception(daemon_exc)
+
+            elif call_id in self._pending_calls:
+                self._pending_calls[call_id].set_result(None)
+
+            else:
+                logger.debug(f"Received unexpected response from daemon: {resp}")
+
+    async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        with closing(writer):
+            while True:
+                msg = await self._pending_messages.get()
+                await write_pbmsg(writer, msg)
+
+    async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest):
+        if request.proto not in self.unary_handlers:
+            logger.warning(f"Protocol {request.proto} not supported")
+            return
+
+        try:
+            remote_id = PeerID(request.peer)
+            response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
+            response = p2pd_pb.CallUnaryResponse(response=response_payload)
+
+        except Exception as e:
+            response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
+
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                unaryResponse=response,
+            )
+        )
+        self._handler_tasks.pop(call_id)
+
+    async def _cancel_unary_call(self, call_id: UUID):
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                cancel=p2pd_pb.Cancel(),
+            ),
+        )
+
+    async def _ensure_persistent_conn(self):
+        reader, writer = await self.daemon_connector.open_persistent_connection()
+
+        self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+        self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        call_id = uuid4()
+
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+
+        if self.unary_handlers.get(proto):
+            raise P2PDaemonError(f"Handler for protocol {proto} already registered")
+        self.unary_handlers[proto] = handler
+
+        self._pending_calls[call_id] = asyncio.Future()
+        await self._pending_messages.put(req)
+        await self._pending_calls[call_id]
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        call_id = uuid4()
+        call_unary_req = p2pd_pb.CallUnaryRequest(
+            peer=peer_id.to_bytes(),
+            proto=proto,
+            data=data,
+        )
+        req = p2pd_pb.PersistentConnectionRequest(
+            callId=call_id.bytes,
+            callUnary=call_unary_req,
+        )
+
+        try:
+            self._pending_calls[call_id] = asyncio.Future()
+            await self._pending_messages.put(req)
+            return await self._pending_calls[call_id]
+
+        except asyncio.CancelledError:
+            await self._cancel_unary_call(call_id)
+            raise
+
+        finally:
+            self._pending_calls.pop(call_id, None)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         reader, writer = await self.daemon_connector.open_connection()
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
@@ -179,3 +353,15 @@ class ControlClient:
 
 
         # if success, add the handler to the dict
         # if success, add the handler to the dict
         self.handlers[proto] = handler_cb
         self.handlers[proto] = handler_cb
+
+
+class P2PHandlerError(Exception):
+    """
+    Raised if remote handled a request with an exception
+    """
+
+
+class P2PDaemonError(Exception):
+    """
+    Raised if daemon failed to handle request
+    """

+ 3 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -131,6 +131,9 @@ class PeerInfo:
     def __str__(self):
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
 
+    def __repr__(self):
+        return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})"
+
 
 
 class InvalidAddrError(ValueError):
 class InvalidAddrError(ValueError):
     pass
     pass

+ 24 - 3
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,16 +10,31 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler, TUnaryHandler
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
 
 
 class Client:
 class Client:
     control: ControlClient
     control: ControlClient
 
 
-    def __init__(self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> None:
+    def __init__(self, *, _initialized_with_create=False) -> None:
+        assert _initialized_with_create, "Please use Client.create coroutine to spawn new client instances"
+        self.control = None
+
+    @classmethod
+    async def create(cls, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> "Client":
+        client = cls(_initialized_with_create=True)
+
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        self.control = ControlClient(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+
+        return client
+
+    def close(self) -> None:
+        self.control.close()
+
+    def __del__(self):
+        self.close()
 
 
     @asynccontextmanager
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["Client"]:
     async def listen(self) -> AsyncIterator["Client"]:
@@ -30,6 +45,12 @@ class Client:
         async with self.control.listen():
         async with self.control.listen():
             yield self
             yield self
 
 
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        await self.control.add_unary_handler(proto, handler)
+
+    async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self.control.call_unary_handler(peer_id, proto, data)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         """
         """
         Get current node peer id and list of addresses
         Get current node peer id and list of addresses

+ 13 - 7
hivemind/p2p/servicer.py

@@ -125,13 +125,19 @@ class ServicerBase:
         self._collect_rpc_handlers()
         self._collect_rpc_handlers()
 
 
         servicer = self if wrapper is None else wrapper
         servicer = self if wrapper is None else wrapper
-        for handler in self._rpc_handlers:
-            await p2p.add_protobuf_handler(
-                self._get_handle_name(namespace, handler.method_name),
-                getattr(servicer, handler.method_name),
-                handler.request_type,
-                stream_input=handler.stream_input,
-            )
+
+        await asyncio.gather(
+            *[
+                p2p.add_protobuf_handler(
+                    self._get_handle_name(namespace, handler.method_name),
+                    getattr(servicer, handler.method_name),
+                    handler.request_type,
+                    stream_input=handler.stream_input,
+                    stream_output=handler.stream_output,
+                )
+                for handler in self._rpc_handlers
+            ]
+        )
 
 
     @classmethod
     @classmethod
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
     def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:

+ 59 - 10
hivemind/proto/p2pd.proto

@@ -8,15 +8,17 @@ package p2pclient.p2pd.pb;
 
 
 message Request {
 message Request {
   enum Type {
   enum Type {
-    IDENTIFY       = 0;
-    CONNECT        = 1;
-    STREAM_OPEN    = 2;
-    STREAM_HANDLER = 3;
-    DHT            = 4;
-    LIST_PEERS     = 5;
-    CONNMANAGER    = 6;
-    DISCONNECT     = 7;
-    PUBSUB         = 8;
+    IDENTIFY                 = 0;
+    CONNECT                  = 1;
+    STREAM_OPEN              = 2;
+    STREAM_HANDLER           = 3;
+    DHT                      = 4;
+    LIST_PEERS               = 5;
+    CONNMANAGER              = 6;
+    DISCONNECT               = 7;      
+    PUBSUB                   = 8;
+
+    PERSISTENT_CONN_UPGRADE  = 9;
   }
   }
 
 
   required Type type = 1;
   required Type type = 1;
@@ -45,6 +47,29 @@ message Response {
   optional PSResponse pubsub = 7;
   optional PSResponse pubsub = 7;
 }
 }
 
 
+message PersistentConnectionRequest {
+  required bytes callId = 1;
+
+  oneof message {
+    AddUnaryHandlerRequest addUnaryHandler = 2;
+    CallUnaryRequest  callUnary = 3;
+    CallUnaryResponse unaryResponse = 4;
+    Cancel cancel = 5;
+  }
+}
+
+message PersistentConnectionResponse {
+  required bytes callId = 1;
+
+  oneof message {
+    CallUnaryResponse callUnaryResponse = 2;
+    CallUnaryRequest requestHandling = 3;
+    DaemonError daemonError = 4;
+    Cancel cancel = 5;
+  }
+}
+
+
 message IdentifyResponse {
 message IdentifyResponse {
   required bytes id = 1;
   required bytes id = 1;
   repeated bytes addrs = 2;
   repeated bytes addrs = 2;
@@ -148,7 +173,7 @@ message PSRequest {
 }
 }
 
 
 message PSMessage {
 message PSMessage {
-  optional bytes from_id = 1;
+  optional bytes from = 1;
   optional bytes data = 2;
   optional bytes data = 2;
   optional bytes seqno = 3;
   optional bytes seqno = 3;
   repeated string topicIDs = 4;
   repeated string topicIDs = 4;
@@ -161,6 +186,30 @@ message PSResponse {
   repeated bytes peerIDs = 2;
   repeated bytes peerIDs = 2;
 }
 }
 
 
+message CallUnaryRequest {
+  required bytes peer = 1;
+  required string proto = 2;
+  required bytes data = 3;
+}
+
+message CallUnaryResponse {
+  oneof result {
+    bytes response = 1;
+    bytes error = 2;
+  }
+}
+
+message AddUnaryHandlerRequest {
+  required string proto = 1;
+}
+
+message DaemonError {
+  optional string message = 1;
+}
+
+message Cancel {
+}
+
 message RPCError {
 message RPCError {
   optional string message = 1;
   optional string message = 1;
 }
 }

+ 16 - 1
hivemind/utils/asyncio.py

@@ -1,4 +1,5 @@
 import asyncio
 import asyncio
+import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 
 
@@ -81,12 +82,26 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
     try:
         await awaitable
         await awaitable
         return False
         return False
-    except asyncio.CancelledError:
+    except (asyncio.CancelledError, concurrent.futures.CancelledError):
+        # In Python 3.7, awaiting a cancelled asyncio.Future raises concurrent.futures.CancelledError
+        # instead of asyncio.CancelledError
         return True
         return True
     except BaseException:
     except BaseException:
+        logger.exception(f"Exception in {awaitable}:")
         return False
         return False
 
 
 
 
+async def cancel_and_wait(awaitable: Awaitable) -> bool:
+    """
+    Cancels ``awaitable`` and waits for its cancellation.
+    In case of ``asyncio.Task``, helps to avoid ``Task was destroyed but it is pending!`` errors.
+    In case of ``asyncio.Future``, equal to ``future.cancel()``.
+    """
+
+    awaitable.cancel()
+    return await await_cancelled(awaitable)
+
+
 async def amap_in_executor(
 async def amap_in_executor(
     func: Callable[..., T],
     func: Callable[..., T],
     *iterables: AsyncIterable,
     *iterables: AsyncIterable,

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -53,7 +53,7 @@ class SharedBytes:
         """Create another shared byte value, represented as a scalar uint8 tensor"""
         """Create another shared byte value, represented as a scalar uint8 tensor"""
         with cls._lock:
         with cls._lock:
             if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
             if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
-                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16))
                 cls._pid = os.getpid()
                 cls._pid = os.getpid()
                 cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
                 cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
                 cls._index = 0
                 cls._index = 0

+ 2 - 2
setup.py

@@ -14,8 +14,8 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 from setuptools.command.develop import develop
 
 
-P2PD_VERSION = "v0.3.1"
-P2PD_CHECKSUM = "15292b880c6b31f5b3c36084b3acc17f"
+P2PD_VERSION = "v0.3.4"
+P2PD_CHECKSUM = "194dca06116fdd36bc4b681d18f3b9cb"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 
 
 here = os.path.abspath(os.path.dirname(__file__))
 here = os.path.abspath(os.path.dirname(__file__))

+ 5 - 2
tests/test_averaging.py

@@ -10,6 +10,7 @@ import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.partition import AllreduceException
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
@@ -363,9 +364,11 @@ def test_too_few_peers():
         )
         )
         for i, dht in enumerate(dht_instances)
         for i, dht in enumerate(dht_instances)
     ]
     ]
-    step_futures = [averager.step(wait=False) for averager in averagers]
+    step_futures = [averager.step(wait=False, timeout=2) for averager in averagers]
+
     for future in step_futures:
     for future in step_futures:
-        assert len(future.result()) == 2
+        with pytest.raises(AllreduceException):
+            future.result()
 
 
     for process in averagers + dht_instances:
     for process in averagers + dht_instances:
         process.shutdown()
         process.shutdown()

+ 2 - 2
tests/test_dht.py

@@ -14,7 +14,7 @@ from test_utils.dht_swarms import launch_dht_instances
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_startup_error():
 async def test_startup_error():
-    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         hivemind.DHT(
         hivemind.DHT(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
             start=True,
             start=True,
@@ -118,7 +118,7 @@ async def test_dht_get_visible_maddrs():
 
 
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337")
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
-    dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
+    dht = hivemind.DHT(start=True, p2p=p2p)
 
 
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()
     dht.shutdown()

+ 29 - 4
tests/test_p2p_daemon.py

@@ -10,7 +10,7 @@ import pytest
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
-from hivemind.proto import dht_pb2
+from hivemind.proto import dht_pb2, test_pb2
 from hivemind.utils.networking import get_free_port
 from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
 
 
@@ -36,7 +36,7 @@ async def test_daemon_killed_on_del():
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_startup_error_message():
 async def test_startup_error_message():
-    with pytest.raises(P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+    with pytest.raises(P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"):
         await P2P.create(
         await P2P.create(
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
             initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
         )
         )
@@ -63,9 +63,9 @@ async def test_transports(host_maddrs: List[Multiaddr]):
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
     peers = await client.list_peers()
     peers = await client.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
     peers = await server.list_peers()
     peers = await server.list_peers()
-    assert len(peers) == 1
+    assert len({p.peer_id for p in peers}) == 1
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
@@ -83,6 +83,31 @@ async def test_daemon_replica_does_not_affect_primary():
     assert not is_process_running(child_pid)
     assert not is_process_running(child_pid)
 
 
 
 
+@pytest.mark.asyncio
+async def test_unary_handler_edge_cases():
+    p2p = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p.daemon_listen_maddr)
+
+    async def square_handler(data: test_pb2.TestRequest, context):
+        return test_pb2.TestResponse(number=data.number ** 2)
+
+    await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler
+    with pytest.raises(P2PDaemonError):
+        await p2p.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try adding a duplicate handler from replicated p2p
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.add_protobuf_handler("square", square_handler, test_pb2.TestRequest)
+
+    # try dialing yourself
+    with pytest.raises(P2PDaemonError):
+        await p2p_replica.call_protobuf_handler(
+            p2p.peer_id, "square", test_pb2.TestRequest(number=41), test_pb2.TestResponse
+        )
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "should_cancel,replicate",
     "should_cancel,replicate",
     [
     [

+ 13 - 6
tests/test_p2p_daemon_bindings.py

@@ -199,24 +199,31 @@ def test_parse_conn_protocol_invalid(maddr_str):
 
 
 
 
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_client_ctor_control_maddr(control_maddr_str):
+@pytest.mark.asyncio
+async def test_client_create_control_maddr(control_maddr_str):
     c = DaemonConnector(Multiaddr(control_maddr_str))
     c = DaemonConnector(Multiaddr(control_maddr_str))
     assert c.control_maddr == Multiaddr(control_maddr_str)
     assert c.control_maddr == Multiaddr(control_maddr_str)
 
 
 
 
-def test_client_ctor_default_control_maddr():
+def test_client_create_default_control_maddr():
     c = DaemonConnector()
     c = DaemonConnector()
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
     assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
 
 
 
 
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_control_client_ctor_listen_maddr(listen_maddr_str):
-    c = ControlClient(daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str))
+@pytest.mark.asyncio
+async def test_control_client_create_listen_maddr(listen_maddr_str):
+    c = await ControlClient.create(
+        daemon_connector=DaemonConnector(),
+        listen_maddr=Multiaddr(listen_maddr_str),
+        use_persistent_conn=False,
+    )
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
 
 
 
 
-def test_control_client_ctor_default_listen_maddr():
-    c = ControlClient(daemon_connector=DaemonConnector())
+@pytest.mark.asyncio
+async def test_control_client_create_default_listen_maddr():
+    c = await ControlClient.create(daemon_connector=DaemonConnector(), use_persistent_conn=False)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
 
 
 
 

+ 43 - 1
tests/test_util_modules.py

@@ -13,7 +13,17 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
-from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext, asingle, azip
+from hivemind.utils.asyncio import (
+    achain,
+    aenumerate,
+    afirst,
+    aiter,
+    amap_in_executor,
+    anext,
+    asingle,
+    azip,
+    cancel_and_wait,
+)
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 from hivemind.utils.mpfuture import InvalidStateError
 
 
@@ -509,3 +519,35 @@ async def test_asyncio_utils():
     assert await afirst(aiter()) is None
     assert await afirst(aiter()) is None
     assert await afirst(aiter(), -1) == -1
     assert await afirst(aiter(), -1) == -1
     assert await afirst(aiter(1, 2, 3)) == 1
     assert await afirst(aiter(1, 2, 3)) == 1
+
+
+@pytest.mark.asyncio
+async def test_cancel_and_wait():
+    finished_gracefully = False
+
+    async def coro_with_finalizer():
+        nonlocal finished_gracefully
+
+        try:
+            await asyncio.Event().wait()
+        except asyncio.CancelledError:
+            await asyncio.sleep(0.05)
+            finished_gracefully = True
+            raise
+
+    task = asyncio.create_task(coro_with_finalizer())
+    await asyncio.sleep(0.05)
+    assert await cancel_and_wait(task)
+    assert finished_gracefully
+
+    async def coro_with_result():
+        return 777
+
+    async def coro_with_error():
+        raise ValueError("error")
+
+    task_with_result = asyncio.create_task(coro_with_result())
+    task_with_error = asyncio.create_task(coro_with_error())
+    await asyncio.sleep(0.05)
+    assert not await cancel_and_wait(task_with_result)
+    assert not await cancel_and_wait(task_with_error)

+ 1 - 1
tests/test_utils/p2p_daemon.py

@@ -157,7 +157,7 @@ async def _make_p2pd_pair(
     )
     )
     # wait for daemon ready
     # wait for daemon ready
     await p2pd.wait_until_ready()
     await p2pd.wait_until_ready()
-    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    client = await Client.create(control_maddr=control_maddr, listen_maddr=listen_maddr)
     try:
     try:
         async with client.listen():
         async with client.listen():
             yield DaemonTuple(daemon=p2pd, client=client)
             yield DaemonTuple(daemon=p2pd, client=client)