Browse Source

Convert AllReduceRunner, Matchmaking, and GroupKeyManager to libp2p backend

Aleksandr Borzunov 4 years ago
parent
commit
a8fcb0a609

+ 41 - 47
hivemind/averaging/allreduce.py

@@ -2,14 +2,14 @@ import asyncio
 from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
 from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
 from enum import Enum
 from enum import Enum
 
 
-import grpc
 import torch
 import torch
 
 
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
-from hivemind.utils import Endpoint, get_logger, ChannelCache
+from hivemind.p2p import P2P, P2PContext, PeerID as Endpoint, ServicerBase, StubBase
+from hivemind.utils import get_logger
 from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
 from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.proto import averaging_pb2_grpc, averaging_pb2
+from hivemind.proto import averaging_pb2
 
 
 # flavour types
 # flavour types
 GroupID = bytes
 GroupID = bytes
@@ -22,7 +22,7 @@ class AveragingMode(Enum):
     AUX = 2
     AUX = 2
 
 
 
 
-class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
+class AllReduceRunner(ServicerBase):
     """
     """
     An internal class that runs butterfly AllReduce in a predefined group of averagers
     An internal class that runs butterfly AllReduce in a predefined group of averagers
 
 
@@ -43,9 +43,9 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     def __init__(
     def __init__(
         self,
         self,
         *,
         *,
+        p2p: P2P,
         group_id: GroupID,
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
-        endpoint: Endpoint,
         ordered_group_endpoints: Sequence[Endpoint],
         ordered_group_endpoints: Sequence[Endpoint],
         peer_fractions: Tuple[float, ...],
         peer_fractions: Tuple[float, ...],
         weights: Optional[Sequence[float]] = None,
         weights: Optional[Sequence[float]] = None,
@@ -53,7 +53,10 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
         gathered: Optional[Dict[Endpoint, Any]] = None,
         gathered: Optional[Dict[Endpoint, Any]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
-        assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+        self._p2p = p2p
+        self.endpoint = p2p.id
+        assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
         assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
         assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
@@ -62,7 +65,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
             assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
 
-        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.group_id, self.ordered_group_endpoints = group_id, ordered_group_endpoints
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
 
         self._future = asyncio.Future()
         self._future = asyncio.Future()
@@ -95,8 +98,10 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     def group_size(self):
     def group_size(self):
         return len(self.ordered_group_endpoints)
         return len(self.ordered_group_endpoints)
 
 
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+    def _get_stub(self, peer: Endpoint) -> StubBase:
+        from hivemind.averaging.averager import DecentralizedAverager
+
+        return DecentralizedAverager.get_stub(self._p2p, peer)
 
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
@@ -136,46 +141,35 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
         else:
         else:
             loop = asyncio.get_event_loop()
             loop = asyncio.get_event_loop()
-            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-            write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))
-
-            try:
-                code = None
-                async for part_index, msg in aenumerate(stream):
-                    if code is None:
-                        code = msg.code
-                    averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
-                    self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
-                await write_task
-
-                if code != averaging_pb2.AVERAGED_PART:
-                    raise AllreduceException(
-                        f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
-                        f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                        f", allreduce failed"
-                    )
-            finally:
-                if not write_task.done():
-                    write_task.cancel()
-
-    async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
+            code = None
+            stream = self._get_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
+            async for part_index, msg in aenumerate(stream):
+                if code is None:
+                    code = msg.code
+                averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
+                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
+
+            if code != averaging_pb2.AVERAGED_PART:
+                raise AllreduceException(
+                    f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
+                    f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
+                    f", allreduce failed"
+                )
+
+    async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
         first_part = await anext(parts_aiter)
         first_part = await anext(parts_aiter)
-        await stream.write(
-            averaging_pb2.AveragingData(
-                code=averaging_pb2.PART_FOR_AVERAGING,
-                group_id=self.group_id,
-                endpoint=self.endpoint,
-                tensor_part=first_part,
-            )
+        yield averaging_pb2.AveragingData(
+            code=averaging_pb2.PART_FOR_AVERAGING,
+            group_id=self.group_id,
+            endpoint=self.endpoint.to_base58(),
+            tensor_part=first_part,
         )
         )
         async for part in parts_aiter:
         async for part in parts_aiter:
-            await stream.write(averaging_pb2.AveragingData(tensor_part=part))
-
-        await stream.done_writing()
+            yield averaging_pb2.AveragingData(tensor_part=part)
 
 
     async def rpc_aggregate_part(
     async def rpc_aggregate_part(
-        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], _: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         request: averaging_pb2.AveragingData = await anext(stream)
         request: averaging_pb2.AveragingData = await anext(stream)
@@ -186,7 +180,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
             try:
-                sender_index = self.sender_endpoints.index(request.endpoint)
+                sender_index = self.sender_endpoints.index(Endpoint.from_base58(request.endpoint))
                 async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
                 async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
                     yield msg
                     yield msg
 
 
@@ -224,9 +218,9 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
 
     async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
     async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
-        await stream.done_writing()
+        error = averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint.to_base58(), code=code)
+        async for _ in self._get_stub(peer_endpoint).rpc_aggregate_part(aiter(error)):
+            pass
 
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 6 - 6
hivemind/averaging/averager.py

@@ -188,7 +188,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
     @property
     @property
     def endpoint(self) -> Endpoint:
     def endpoint(self) -> Endpoint:
-        return self.p2p.id
+        return self._p2p.id
 
 
     def run(self):
     def run(self):
         """
         """
@@ -207,14 +207,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
 
             async def _run():
             async def _run():
-                self.p2p = await self.dht.replicate_p2p()
+                self._p2p = await self.dht.replicate_p2p()
                 if not self.client_mode:
                 if not self.client_mode:
-                    await self.add_p2p_handlers(self.p2p)
+                    await self.add_p2p_handlers(self._p2p)
                 else:
                 else:
                     logger.debug(f"The averager is running in client mode.")
                     logger.debug(f"The averager is running in client mode.")
 
 
                 self._matchmaking = Matchmaking(
                 self._matchmaking = Matchmaking(
-                    self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
+                    self._p2p, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
                 )
                 )
                 if not self.client_mode:
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
                     asyncio.create_task(self._declare_for_download_periodically())
@@ -379,9 +379,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
             async with self.get_tensors_async() as local_tensors:
             async with self.get_tensors_async() as local_tensors:
                 allreduce = AllReduceRunner(
                 allreduce = AllReduceRunner(
+                    p2p=self._p2p,
                     group_id=group_info.group_id,
                     group_id=group_info.group_id,
                     tensors=local_tensors,
                     tensors=local_tensors,
-                    endpoint=self.endpoint,
                     ordered_group_endpoints=group_info.endpoints,
                     ordered_group_endpoints=group_info.endpoints,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
                     weights=weights,
                     weights=weights,
@@ -551,7 +551,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if peer != self.endpoint:
                 if peer != self.endpoint:
                     logger.info(f"Downloading parameters from peer {peer}")
                     logger.info(f"Downloading parameters from peer {peer}")
                     try:
                     try:
-                        stub = self.get_stub(self.p2p, peer)
+                        stub = self.get_stub(self._p2p, peer)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         current_tensor_parts, tensors = [], []
                         async for message in stream:
                         async for message in stream:

+ 6 - 5
hivemind/averaging/key_manager.py

@@ -5,9 +5,10 @@ from typing import Optional, List, Tuple
 
 
 import numpy as np
 import numpy as np
 
 
-from hivemind.dht import DHT
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
-from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
+from hivemind.dht import DHT
+from hivemind.p2p import PeerID as Endpoint
+from hivemind.utils import get_logger, DHTExpiration, get_dht_time, ValueWithExpiration
 
 
 GroupKey = str
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
@@ -72,7 +73,7 @@ class GroupKeyManager:
         expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
         expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
         return await self.dht.store(
         return await self.dht.store(
             key=group_key,
             key=group_key,
-            subkey=endpoint,
+            subkey=endpoint.to_base58(),
             value=looking_for_group,
             value=looking_for_group,
             expiration_time=expiration_time,
             expiration_time=expiration_time,
             return_future=True,
             return_future=True,
@@ -93,11 +94,11 @@ class GroupKeyManager:
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
             return []
         averagers = [
         averagers = [
-            (key, entry.expiration_time)
+            (Endpoint.from_base58(key), entry.expiration_time)
             for key, entry in result.value.items()
             for key, entry in result.value.items()
             if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)
             if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)
         ]
         ]
-        num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
+        num_active_averagers = sum(1 for entry in result.value.values() if entry.value is True)
 
 
         suggested_nbits = self.get_suggested_nbits(result)
         suggested_nbits = self.get_suggested_nbits(result)
         if (
         if (

+ 31 - 28
hivemind/averaging/matchmaking.py

@@ -3,26 +3,25 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import contextlib
 import contextlib
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 import random
 import random
 from math import isfinite
 from math import isfinite
 from typing import Optional, AsyncIterator, Set, Tuple, Dict
 from typing import Optional, AsyncIterator, Set, Tuple, Dict
 import concurrent.futures
 import concurrent.futures
 import asyncio
 import asyncio
 
 
-import grpc
-import grpc._cython.cygrpc
-
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.dht import DHT, DHTID, DHTExpiration
-from hivemind.utils import get_logger, Endpoint, timed_storage, TimedStorage, get_dht_time
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc
-from hivemind.utils.grpc import ChannelCache
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint
+from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
+from hivemind.utils.asyncio import anext
+from hivemind.proto import averaging_pb2
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
+class Matchmaking:
     f"""
     f"""
     An internal class that is used to form groups of averages for running allreduce
     An internal class that is used to form groups of averages for running allreduce
     See DecentralizedAverager docstring for the detailed description of all parameters
     See DecentralizedAverager docstring for the detailed description of all parameters
@@ -37,7 +36,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        endpoint: Endpoint,
+        p2p: P2P,
         schema_hash: bytes,
         schema_hash: bytes,
         dht: DHT,
         dht: DHT,
         *,
         *,
@@ -57,8 +56,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             )
             )
 
 
         super().__init__()
         super().__init__()
-        self.endpoint, self.schema_hash = endpoint, schema_hash
-        self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
+        self._p2p = p2p
+        self.endpoint = p2p.id
+        self.schema_hash = schema_hash
+        self.group_key_manager = GroupKeyManager(dht, self.endpoint, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.client_mode = client_mode
         self.client_mode = client_mode
@@ -71,7 +72,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
         self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(endpoint, averaging_expiration, target_group_size)
+        self.potential_leaders = PotentialLeaders(self.endpoint, averaging_expiration, target_group_size)
         self.data_for_gather: Optional[bytes] = None
         self.data_for_gather: Optional[bytes] = None
 
 
     @property
     @property
@@ -169,20 +170,23 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
           The originally specified leader can disband group and redirect us to a different leader
           The originally specified leader can disband group and redirect us to a different leader
         """
         """
         assert self.is_looking_for_group and self.current_leader is None
         assert self.is_looking_for_group and self.current_leader is None
-        call: Optional[grpc.aio.UnaryStreamCall] = None
+        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
-                leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
-                call = leader_stub.rpc_join_group(
+                from hivemind.averaging.averager import DecentralizedAverager
+
+                leader_stub = DecentralizedAverager.get_stub(self._p2p, leader)
+
+                stream = leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                     averaging_pb2.JoinRequest(
-                        endpoint=self.endpoint,
+                        endpoint=self.endpoint.to_base58(),
                         schema_hash=self.schema_hash,
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,
                         expiration=expiration_time,
                         client_mode=self.client_mode,
                         client_mode=self.client_mode,
                         gather=self.data_for_gather,
                         gather=self.data_for_gather,
                     )
                     )
-                )
-                message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
+                ).__aiter__()
+                message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
 
                 if message.code == averaging_pb2.ACCEPTED:
                 if message.code == averaging_pb2.ACCEPTED:
                     logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
                     logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
@@ -198,7 +202,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
 
             async with self.potential_leaders.pause_search():
             async with self.potential_leaders.pause_search():
                 time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
                 time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
-                message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)
+                message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                     async with self.lock_request_join_group:
                     async with self.lock_request_join_group:
@@ -208,7 +212,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 if message.suggested_leader and message.suggested_leader != self.endpoint:
                 if message.suggested_leader and message.suggested_leader != self.endpoint:
                     logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
                     logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
                     self.current_leader = None
                     self.current_leader = None
-                    call.cancel()
+                    await stream.aclose()
                     return await self.request_join_group(message.suggested_leader, expiration_time)
                     return await self.request_join_group(message.suggested_leader, expiration_time)
                 else:
                 else:
                     logger.debug(f"{self} - leader disbanded group")
                     logger.debug(f"{self} - leader disbanded group")
@@ -218,23 +222,22 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             return None
             return None
         except asyncio.TimeoutError:
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
-            if call is not None:
-                call.cancel()
             return None
             return None
-        except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
+        except (P2PHandlerError, StopAsyncIteration) as e:
             logger.error(f"{self} - failed to request potential leader {leader}: {e}")
             logger.error(f"{self} - failed to request potential leader {leader}: {e}")
             return None
             return None
 
 
         finally:
         finally:
             self.was_accepted_to_group.clear()
             self.was_accepted_to_group.clear()
             self.current_leader = None
             self.current_leader = None
-            if call is not None:
-                await call.code()
+            if stream is not None:
+                await stream.aclose()
 
 
     async def rpc_join_group(
     async def rpc_join_group(
-        self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+        self, request: averaging_pb2.JoinRequest, _: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
+        request_endpoint = PeerID.from_base58(request.endpoint)
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
                 reason_to_reject = self._check_reasons_to_reject(request)
                 reason_to_reject = self._check_reasons_to_reject(request)
@@ -242,7 +245,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                     yield reason_to_reject
                     yield reason_to_reject
                     return
                     return
 
 
-                self.current_followers[request.endpoint] = request
+                self.current_followers[request_endpoint] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
 
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
@@ -270,7 +273,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 self.was_accepted_to_group.is_set()
                 self.was_accepted_to_group.is_set()
                 or not self.assembled_group.done()
                 or not self.assembled_group.done()
                 or self.assembled_group.cancelled()
                 or self.assembled_group.cancelled()
-                or request.endpoint not in self.assembled_group.result()
+                or request_endpoint not in self.assembled_group.result()
             ):
             ):
                 if self.current_leader is not None:
                 if self.current_leader is not None:
                     # outcome 3: found by a leader with higher priority, send our followers to him
                     # outcome 3: found by a leader with higher priority, send our followers to him
@@ -296,7 +299,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
 
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
-            self.current_followers.pop(request.endpoint, None)
+            self.current_followers.pop(request_endpoint, None)
             self.follower_was_discarded.set()
             self.follower_was_discarded.set()
 
 
     def _check_reasons_to_reject(
     def _check_reasons_to_reject(

+ 1 - 1
hivemind/averaging/partition.py

@@ -32,7 +32,7 @@ class TensorPartContainer:
         self,
         self,
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
         peer_fractions: Sequence[float],
-        compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
+        compression_type: Union['CompressionType', Sequence['CompressionType']] = CompressionType.NONE,
         part_size_bytes: int = 2 ** 20,
         part_size_bytes: int = 2 ** 20,
         prefetch: int = 1,
         prefetch: int = 1,
     ):
     ):

+ 4 - 2
hivemind/p2p/servicer.py

@@ -75,8 +75,10 @@ class ServicerBase:
 
 
                 spec = inspect.getfullargspec(method)
                 spec = inspect.getfullargspec(method)
                 if len(spec.args) < 3:
                 if len(spec.args) < 3:
-                    raise ValueError(f"{handle_name} is expected to at least three positional arguments "
-                                     f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)")
+                    raise ValueError(
+                        f"{handle_name} is expected to at least three positional arguments "
+                        f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)"
+                    )
                 request_arg = spec.args[1]
                 request_arg = spec.args[1]
                 hints = get_type_hints(method)
                 hints = get_type_hints(method)
                 try:
                 try:

+ 0 - 109
tests/log

@@ -1,109 +0,0 @@
-[2021/07/10 11:53:40.062][DEBUG][utils.grpc.ChannelCache:49] Eviction period = 600.0s, max channels = 4096
-[2021/07/10 11:53:40.247][DEBUG][asyncio.__init__:59] Using selector: EpollSelector
-[2021/07/10 11:53:40.654][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe5d8f481f0> opens connection to /unix/tmp/hivemind-p2pd-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.672][DEBUG][p2p.p2p_daemon._ping_daemon:193] Launched p2pd with id = QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak, host multiaddrs = (<Multiaddr /ip4/127.0.0.1/tcp/38633>,)
-[2021/07/10 11:53:40.673][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe5d8f481f0> opens connection to /unix/tmp/hivemind-p2pd-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.674][DEBUG][p2p.p2p_daemon_bindings.control.listen:108] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.ControlClient object at 0x7fe5d8f48250> starts listening to /unix/tmp/hivemind-p2pclient-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.674][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe5d8f481f0> opens connection to /unix/tmp/hivemind-p2pd-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.675][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe5d8f481f0> opens connection to /unix/tmp/hivemind-p2pd-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.676][INFO][root.run_protocol_listener:40] Started peer id=DHTID(0x3db9c894718f96bdd909afcef13a545941e7e1e4) visible_maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/38633/p2p/QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak>]
-[2021/07/10 11:53:40.683][DEBUG][asyncio.__init__:59] Using selector: EpollSelector
-[2021/07/10 11:53:41.101][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.103][DEBUG][p2p.p2p_daemon._ping_daemon:193] Launched p2pd with id = Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm, host multiaddrs = (<Multiaddr /ip4/127.0.0.1/tcp/41695>,)
-[2021/07/10 11:53:41.105][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.105][DEBUG][p2p.p2p_daemon_bindings.control.listen:108] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.ControlClient object at 0x7fe6623e2dc0> starts listening to /unix/tmp/hivemind-p2pclient-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.106][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.106][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.107][INFO][root.run_protocol_listener:40] Started peer id=DHTID(0xa7087d197c0f2758c6bff7256e3cb78c5152c137) visible_maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/41695/p2p/Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm>]
-[2021/07/10 11:53:41.108][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.109][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm addr=/ip4/127.0.0.1/tcp/41695 proto=DHTProtocol.rpc_ping>
-[2021/07/10 11:53:41.111][DEBUG][asyncio.__init__:59] Using selector: EpollSelector
-[2021/07/10 11:53:41.528][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.530][DEBUG][p2p.p2p_daemon._ping_daemon:193] Launched p2pd with id = QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q, host multiaddrs = (<Multiaddr /ip4/127.0.0.1/tcp/34419>,)
-[2021/07/10 11:53:41.531][INFO][root.test_dht_protocol:81] Self id=DHTID(0xc147c315f7abd7fff33ccb0bd7b35cb0215ca616)
-[2021/07/10 11:53:41.531][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.532][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_ping>
-[2021/07/10 11:53:41.534][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.535][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.536][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.536][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.538][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.539][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.541][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.542][ERROR][dht.protocol.call_find:278] DHTProtocol failed to find at fakeid
-Traceback (most recent call last):
-  File "/home/borzunov/hivemind/hivemind/dht/protocol.py", line 245, in call_find
-    response = await self.get_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
-  File "/home/borzunov/hivemind/hivemind/utils/auth.py", line 199, in wrapped_rpc
-    response = await method(request, *args, **kwargs)
-  File "/home/borzunov/hivemind/hivemind/p2p/servicer.py", line 72, in caller
-    return await asyncio.wait_for(
-  File "/home/borzunov/anaconda3/lib/python3.8/asyncio/tasks.py", line 483, in wait_for
-    return fut.result()
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon.py", line 374, in call_protobuf_handler
-    stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/p2pclient.py", line 76, in stream_open
-    return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/control.py", line 185, in stream_open
-    raise_if_failed(resp)
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/utils.py", line 60, in raise_if_failed
-    raise ControlFailure(f"Connect failed. msg={response.error.msg}")
-hivemind.p2p.p2p_daemon_bindings.utils.ControlFailure: Connect failed. msg=length greater than remaining number of bytes in buffer
-[2021/07/10 11:53:41.544][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.545][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.546][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.547][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.547][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.548][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.958][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.960][DEBUG][p2p.p2p_daemon._ping_daemon:193] Launched p2pd with id = QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE, host multiaddrs = (<Multiaddr /ip4/127.0.0.1/tcp/35705>,)
-[2021/07/10 11:53:41.961][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.962][DEBUG][p2p.p2p_daemon_bindings.control.listen:108] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.ControlClient object at 0x7fe66237f220> starts listening to /unix/tmp/hivemind-p2pclient-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.963][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.964][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.965][INFO][root.test_dht_protocol:81] Self id=DHTID(0x493ce245bb79ab0900ad291134df41125dfcf217)
-[2021/07/10 11:53:41.966][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.967][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_ping>
-[2021/07/10 11:53:41.968][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.969][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.970][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.971][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.973][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.974][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.975][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.975][ERROR][dht.protocol.call_find:278] DHTProtocol failed to find at fakeid
-Traceback (most recent call last):
-  File "/home/borzunov/hivemind/hivemind/dht/protocol.py", line 245, in call_find
-    response = await self.get_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
-  File "/home/borzunov/hivemind/hivemind/utils/auth.py", line 199, in wrapped_rpc
-    response = await method(request, *args, **kwargs)
-  File "/home/borzunov/hivemind/hivemind/p2p/servicer.py", line 72, in caller
-    return await asyncio.wait_for(
-  File "/home/borzunov/anaconda3/lib/python3.8/asyncio/tasks.py", line 483, in wait_for
-    return fut.result()
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon.py", line 374, in call_protobuf_handler
-    stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/p2pclient.py", line 76, in stream_open
-    return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/control.py", line 185, in stream_open
-    raise_if_failed(resp)
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/utils.py", line 60, in raise_if_failed
-    raise ControlFailure(f"Connect failed. msg={response.error.msg}")
-hivemind.p2p.p2p_daemon_bindings.utils.ControlFailure: Connect failed. msg=length greater than remaining number of bytes in buffer
-[2021/07/10 11:53:41.976][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.977][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.978][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.978][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.979][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.980][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_find>
-2021-07-10T11:53:41.981+0300	ERROR	p2pd	error accepting connection	{"error": "accept unix /tmp/hivemind-p2pd-Pswf-De9YNA.sock: use of closed network connection"}
-[2021/07/10 11:53:41.985][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE
-2021-07-10T11:53:41.985+0300	ERROR	p2pd	error accepting connection	{"error": "accept unix /tmp/hivemind-p2pd-B4tnhAq6U1U.sock: use of closed network connection"}
-[2021/07/10 11:53:41.988][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm
-[2021/07/10 11:53:41.988][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm
-[2021/07/10 11:53:41.989][INFO][root.shutdown:49] Finished peer id=DHTID(0xa7087d197c0f2758c6bff7256e3cb78c5152c137) maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/41695/p2p/Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm>]
-[2021/07/10 11:53:41.989][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak
-[2021/07/10 11:53:41.989][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak
-[2021/07/10 11:53:41.990][INFO][root.shutdown:49] Finished peer id=DHTID(0x3db9c894718f96bdd909afcef13a545941e7e1e4) maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/38633/p2p/QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak>]
-[2021/07/10 11:53:41.990][INFO][root.shutdown:49] Finished peer id=DHTID(0x3db9c894718f96bdd909afcef13a545941e7e1e4) maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/38633/p2p/QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak>]
-[2021/07/10 11:53:42.033][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q

+ 19 - 31
tests/test_allreduce.py

@@ -3,16 +3,15 @@ import random
 import time
 import time
 from typing import Sequence
 from typing import Sequence
 
 
-import grpc
 import pytest
 import pytest
 import torch
 import torch
 
 
-from hivemind import aenumerate, Endpoint
+from hivemind import aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
-from hivemind.proto import averaging_pb2_grpc
+from hivemind.p2p import P2P, StubBase
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import deserialize_torch_tensor, ChannelCache
+from hivemind.utils import deserialize_torch_tensor
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -152,19 +151,6 @@ async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float
             assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
             assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
 
 
 
 
-class AllreduceRunnerForTesting(AllReduceRunner):
-    """a version of AllReduceRunner that was monkey-patched to accept custom endpoint names"""
-
-    def __init__(self, *args, peer_endpoints, **kwargs):
-        self.__peer_endpoints = peer_endpoints
-        super().__init__(*args, **kwargs)
-
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(
-            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True
-        )
-
-
 NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 
 
 
 
@@ -190,8 +176,18 @@ NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
     """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
     """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
 
 
-    peers = "alice", "bob", "carol", "colab"
+    class AllreduceRunnerForTesting(AllReduceRunner):
+        def _get_stub(self, peer: str) -> StubBase:
+            return AllReduceRunner.get_stub(self._p2p, peer)
+
+    p2ps = []
+    initial_peers = []
+    for _ in range(4):
+        instance = await P2P.create(initial_peers=initial_peers)
+        p2ps.append(instance)
+        initial_peers.extend(await instance.get_visible_maddrs())
 
 
+    peers = [instance.id for instance in p2ps]
     tensors_by_peer = {
     tensors_by_peer = {
         peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
         peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
         for i, peer in enumerate(peers)
         for i, peer in enumerate(peers)
@@ -199,28 +195,20 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
 
 
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
 
-    servers = []
     allreduce_protocols = []
     allreduce_protocols = []
-    peer_endpoints = {}
-
-    for peer in peers:
-        server = grpc.aio.server()
+    for p2p, peer in zip(p2ps, peers):
         allreduce_protocol = AllreduceRunnerForTesting(
         allreduce_protocol = AllreduceRunnerForTesting(
+            p2p=p2p,
             group_id=group_id,
             group_id=group_id,
-            endpoint=peer,
             tensors=[x.clone() for x in tensors_by_peer[peer]],
             tensors=[x.clone() for x in tensors_by_peer[peer]],
             ordered_group_endpoints=peers,
             ordered_group_endpoints=peers,
             peer_fractions=peer_fractions,
             peer_fractions=peer_fractions,
             modes=peer_modes,
             modes=peer_modes,
             weights=averaging_weights,
             weights=averaging_weights,
-            peer_endpoints=peer_endpoints,
             part_size_bytes=part_size_bytes,
             part_size_bytes=part_size_bytes,
         )
         )
-        averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(allreduce_protocol, server)
-        peer_endpoints[peer] = f"127.0.0.1:{server.add_insecure_port('127.0.0.1:*')}"
+        await allreduce_protocol.add_p2p_handlers(p2p)
         allreduce_protocols.append(allreduce_protocol)
         allreduce_protocols.append(allreduce_protocol)
-        servers.append(server)
-        await server.start()
 
 
     async def _run_allreduce_inplace(allreduce: AllReduceRunner):
     async def _run_allreduce_inplace(allreduce: AllReduceRunner):
         async for tensor_index, tensor_delta in aenumerate(allreduce):
         async for tensor_index, tensor_delta in aenumerate(allreduce):
@@ -244,5 +232,5 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
         assert len(output_tensors) == len(targets_for_peer)
         assert len(output_tensors) == len(targets_for_peer)
         assert all(torch.allclose(our, ref, atol=1e-6, rtol=0) for our, ref in zip(output_tensors, targets_for_peer))
         assert all(torch.allclose(our, ref, atol=1e-6, rtol=0) for our, ref in zip(output_tensors, targets_for_peer))
 
 
-    for server in servers:
-        await server.stop(grace=1)
+    for instance in p2ps:
+        await instance.shutdown()

+ 18 - 10
tests/test_averaging.py

@@ -9,15 +9,19 @@ import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_key_manager():
 async def test_key_manager():
+    localhvost = PeerID(b"localhvost")
+    localhvost2 = PeerID(b"localhvost2")
+
     key_manager = GroupKeyManager(
     key_manager = GroupKeyManager(
         hivemind.DHT(start=True),
         hivemind.DHT(start=True),
-        endpoint="localhvost",
+        endpoint=localhvost,
         prefix="test_averaging",
         prefix="test_averaging",
         initial_group_bits="10110",
         initial_group_bits="10110",
         target_group_size=2,
         target_group_size=2,
@@ -25,24 +29,24 @@ async def test_key_manager():
 
 
     t = hivemind.get_dht_time()
     t = hivemind.get_dht_time()
     key = key_manager.current_key
     key = key_manager.current_key
-    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 60)
-    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61)
+    await key_manager.declare_averager(key, localhvost, expiration_time=t + 60)
+    await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61)
 
 
     q1 = await key_manager.get_averagers(key, only_active=True)
     q1 = await key_manager.get_averagers(key, only_active=True)
 
 
-    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 66)
+    await key_manager.declare_averager(key, localhvost, expiration_time=t + 66)
     q2 = await key_manager.get_averagers(key, only_active=True)
     q2 = await key_manager.get_averagers(key, only_active=True)
 
 
-    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61, looking_for_group=False)
+    await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61, looking_for_group=False)
     q3 = await key_manager.get_averagers(key, only_active=True)
     q3 = await key_manager.get_averagers(key, only_active=True)
     q4 = await key_manager.get_averagers(key, only_active=False)
     q4 = await key_manager.get_averagers(key, only_active=False)
 
 
     q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
     q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
 
 
-    assert len(q1) == 2 and ("localhvost", t + 60) in q1 and ("localhvost2", t + 61) in q1
-    assert len(q2) == 2 and ("localhvost", t + 66) in q2 and ("localhvost2", t + 61) in q2
-    assert len(q3) == 1 and ("localhvost", t + 66) in q3
-    assert len(q4) == 2 and ("localhvost", t + 66) in q4 and ("localhvost2", t + 61) in q2
+    assert len(q1) == 2 and (localhvost, t + 60) in q1 and (localhvost2, t + 61) in q1
+    assert len(q2) == 2 and (localhvost, t + 66) in q2 and (localhvost2, t + 61) in q2
+    assert len(q3) == 1 and (localhvost, t + 66) in q3
+    assert len(q4) == 2 and (localhvost, t + 66) in q4 and (localhvost2, t + 61) in q2
     assert len(q5) == 0
     assert len(q5) == 0
 
 
 
 
@@ -459,7 +463,11 @@ def test_load_state_from_peers():
 def test_getset_bits():
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
     dht = hivemind.DHT(start=True)
     averager = hivemind.averaging.DecentralizedAverager(
     averager = hivemind.averaging.DecentralizedAverager(
-        [torch.randn(3)], dht=dht, start=True, prefix="test_prefix", target_group_size=2,
+        [torch.randn(3)],
+        dht=dht,
+        start=True,
+        prefix="test_prefix",
+        target_group_size=2,
     )
     )
     averager.set_group_bits("00101011101010")
     averager.set_group_bits("00101011101010")
     assert averager.get_group_bits() == "00101011101010"
     assert averager.get_group_bits() == "00101011101010"