Просмотр исходного кода

Convert AllReduceRunner, Matchmaking, and GroupKeyManager to libp2p backend

Aleksandr Borzunov 4 лет назад
Родитель
Сommit
a8fcb0a609

+ 41 - 47
hivemind/averaging/allreduce.py

@@ -2,14 +2,14 @@ import asyncio
 from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
 from enum import Enum
 
-import grpc
 import torch
 
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
-from hivemind.utils import Endpoint, get_logger, ChannelCache
+from hivemind.p2p import P2P, P2PContext, PeerID as Endpoint, ServicerBase, StubBase
+from hivemind.utils import get_logger
 from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.proto import averaging_pb2_grpc, averaging_pb2
+from hivemind.proto import averaging_pb2
 
 # flavour types
 GroupID = bytes
@@ -22,7 +22,7 @@ class AveragingMode(Enum):
     AUX = 2
 
 
-class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
+class AllReduceRunner(ServicerBase):
     """
     An internal class that runs butterfly AllReduce in a predefined group of averagers
 
@@ -43,9 +43,9 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     def __init__(
         self,
         *,
+        p2p: P2P,
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
-        endpoint: Endpoint,
         ordered_group_endpoints: Sequence[Endpoint],
         peer_fractions: Tuple[float, ...],
         weights: Optional[Sequence[float]] = None,
@@ -53,7 +53,10 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
         gathered: Optional[Dict[Endpoint, Any]] = None,
         **kwargs,
     ):
-        assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+        self._p2p = p2p
+        self.endpoint = p2p.id
+        assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
         assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
@@ -62,7 +65,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
-        self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
+        self.group_id, self.ordered_group_endpoints = group_id, ordered_group_endpoints
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
         self._future = asyncio.Future()
@@ -95,8 +98,10 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
     def group_size(self):
         return len(self.ordered_group_endpoints)
 
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
+    def _get_stub(self, peer: Endpoint) -> StubBase:
+        from hivemind.averaging.averager import DecentralizedAverager
+
+        return DecentralizedAverager.get_stub(self._p2p, peer)
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
@@ -136,46 +141,35 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
         else:
             loop = asyncio.get_event_loop()
-            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-            write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))
-
-            try:
-                code = None
-                async for part_index, msg in aenumerate(stream):
-                    if code is None:
-                        code = msg.code
-                    averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
-                    self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
-                await write_task
-
-                if code != averaging_pb2.AVERAGED_PART:
-                    raise AllreduceException(
-                        f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
-                        f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                        f", allreduce failed"
-                    )
-            finally:
-                if not write_task.done():
-                    write_task.cancel()
-
-    async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
+            code = None
+            stream = self._get_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
+            async for part_index, msg in aenumerate(stream):
+                if code is None:
+                    code = msg.code
+                averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
+                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
+
+            if code != averaging_pb2.AVERAGED_PART:
+                raise AllreduceException(
+                    f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
+                    f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
+                    f", allreduce failed"
+                )
+
+    async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
         first_part = await anext(parts_aiter)
-        await stream.write(
-            averaging_pb2.AveragingData(
-                code=averaging_pb2.PART_FOR_AVERAGING,
-                group_id=self.group_id,
-                endpoint=self.endpoint,
-                tensor_part=first_part,
-            )
+        yield averaging_pb2.AveragingData(
+            code=averaging_pb2.PART_FOR_AVERAGING,
+            group_id=self.group_id,
+            endpoint=self.endpoint.to_base58(),
+            tensor_part=first_part,
         )
         async for part in parts_aiter:
-            await stream.write(averaging_pb2.AveragingData(tensor_part=part))
-
-        await stream.done_writing()
+            yield averaging_pb2.AveragingData(tensor_part=part)
 
     async def rpc_aggregate_part(
-        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], _: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         request: averaging_pb2.AveragingData = await anext(stream)
@@ -186,7 +180,7 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
-                sender_index = self.sender_endpoints.index(request.endpoint)
+                sender_index = self.sender_endpoints.index(Endpoint.from_base58(request.endpoint))
                 async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
                     yield msg
 
@@ -224,9 +218,9 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
     async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
-        await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
-        await stream.done_writing()
+        error = averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint.to_base58(), code=code)
+        async for _ in self._get_stub(peer_endpoint).rpc_aggregate_part(aiter(error)):
+            pass
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 6 - 6
hivemind/averaging/averager.py

@@ -188,7 +188,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     @property
     def endpoint(self) -> Endpoint:
-        return self.p2p.id
+        return self._p2p.id
 
     def run(self):
         """
@@ -207,14 +207,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
             async def _run():
-                self.p2p = await self.dht.replicate_p2p()
+                self._p2p = await self.dht.replicate_p2p()
                 if not self.client_mode:
-                    await self.add_p2p_handlers(self.p2p)
+                    await self.add_p2p_handlers(self._p2p)
                 else:
                     logger.debug(f"The averager is running in client mode.")
 
                 self._matchmaking = Matchmaking(
-                    self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
+                    self._p2p, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
                 )
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
@@ -379,9 +379,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             async with self.get_tensors_async() as local_tensors:
                 allreduce = AllReduceRunner(
+                    p2p=self._p2p,
                     group_id=group_info.group_id,
                     tensors=local_tensors,
-                    endpoint=self.endpoint,
                     ordered_group_endpoints=group_info.endpoints,
                     peer_fractions=peer_fractions,
                     weights=weights,
@@ -551,7 +551,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if peer != self.endpoint:
                     logger.info(f"Downloading parameters from peer {peer}")
                     try:
-                        stub = self.get_stub(self.p2p, peer)
+                        stub = self.get_stub(self._p2p, peer)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         async for message in stream:

+ 6 - 5
hivemind/averaging/key_manager.py

@@ -5,9 +5,10 @@ from typing import Optional, List, Tuple
 
 import numpy as np
 
-from hivemind.dht import DHT
 from hivemind.averaging.group_info import GroupInfo
-from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
+from hivemind.dht import DHT
+from hivemind.p2p import PeerID as Endpoint
+from hivemind.utils import get_logger, DHTExpiration, get_dht_time, ValueWithExpiration
 
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
@@ -72,7 +73,7 @@ class GroupKeyManager:
         expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
         return await self.dht.store(
             key=group_key,
-            subkey=endpoint,
+            subkey=endpoint.to_base58(),
             value=looking_for_group,
             expiration_time=expiration_time,
             return_future=True,
@@ -93,11 +94,11 @@ class GroupKeyManager:
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
         averagers = [
-            (key, entry.expiration_time)
+            (Endpoint.from_base58(key), entry.expiration_time)
             for key, entry in result.value.items()
             if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)
         ]
-        num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
+        num_active_averagers = sum(1 for entry in result.value.values() if entry.value is True)
 
         suggested_nbits = self.get_suggested_nbits(result)
         if (

+ 31 - 28
hivemind/averaging/matchmaking.py

@@ -3,26 +3,25 @@
 from __future__ import annotations
 
 import contextlib
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 import random
 from math import isfinite
 from typing import Optional, AsyncIterator, Set, Tuple, Dict
 import concurrent.futures
 import asyncio
 
-import grpc
-import grpc._cython.cygrpc
-
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.dht import DHT, DHTID, DHTExpiration
-from hivemind.utils import get_logger, Endpoint, timed_storage, TimedStorage, get_dht_time
-from hivemind.proto import averaging_pb2, averaging_pb2_grpc
-from hivemind.utils.grpc import ChannelCache
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint
+from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
+from hivemind.utils.asyncio import anext
+from hivemind.proto import averaging_pb2
 
 logger = get_logger(__name__)
 
 
-class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
+class Matchmaking:
     f"""
     An internal class that is used to form groups of averages for running allreduce
     See DecentralizedAverager docstring for the detailed description of all parameters
@@ -37,7 +36,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
     def __init__(
         self,
-        endpoint: Endpoint,
+        p2p: P2P,
         schema_hash: bytes,
         dht: DHT,
         *,
@@ -57,8 +56,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             )
 
         super().__init__()
-        self.endpoint, self.schema_hash = endpoint, schema_hash
-        self.group_key_manager = GroupKeyManager(dht, endpoint, prefix, initial_group_bits, target_group_size)
+        self._p2p = p2p
+        self.endpoint = p2p.id
+        self.schema_hash = schema_hash
+        self.group_key_manager = GroupKeyManager(dht, self.endpoint, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.client_mode = client_mode
@@ -71,7 +72,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
         self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(endpoint, averaging_expiration, target_group_size)
+        self.potential_leaders = PotentialLeaders(self.endpoint, averaging_expiration, target_group_size)
         self.data_for_gather: Optional[bytes] = None
 
     @property
@@ -169,20 +170,23 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
           The originally specified leader can disband group and redirect us to a different leader
         """
         assert self.is_looking_for_group and self.current_leader is None
-        call: Optional[grpc.aio.UnaryStreamCall] = None
+        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
         try:
             async with self.lock_request_join_group:
-                leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
-                call = leader_stub.rpc_join_group(
+                from hivemind.averaging.averager import DecentralizedAverager
+
+                leader_stub = DecentralizedAverager.get_stub(self._p2p, leader)
+
+                stream = leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
-                        endpoint=self.endpoint,
+                        endpoint=self.endpoint.to_base58(),
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,
                         client_mode=self.client_mode,
                         gather=self.data_for_gather,
                     )
-                )
-                message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
+                ).__aiter__()
+                message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
                     logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
@@ -198,7 +202,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 
             async with self.potential_leaders.pause_search():
                 time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
-                message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)
+                message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                     async with self.lock_request_join_group:
@@ -208,7 +212,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 if message.suggested_leader and message.suggested_leader != self.endpoint:
                     logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
                     self.current_leader = None
-                    call.cancel()
+                    await stream.aclose()
                     return await self.request_join_group(message.suggested_leader, expiration_time)
                 else:
                     logger.debug(f"{self} - leader disbanded group")
@@ -218,23 +222,22 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             return None
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
-            if call is not None:
-                call.cancel()
             return None
-        except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
+        except (P2PHandlerError, StopAsyncIteration) as e:
             logger.error(f"{self} - failed to request potential leader {leader}: {e}")
             return None
 
         finally:
             self.was_accepted_to_group.clear()
             self.current_leader = None
-            if call is not None:
-                await call.code()
+            if stream is not None:
+                await stream.aclose()
 
     async def rpc_join_group(
-        self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
+        self, request: averaging_pb2.JoinRequest, _: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
+        request_endpoint = PeerID.from_base58(request.endpoint)
         try:
             async with self.lock_request_join_group:
                 reason_to_reject = self._check_reasons_to_reject(request)
@@ -242,7 +245,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                     yield reason_to_reject
                     return
 
-                self.current_followers[request.endpoint] = request
+                self.current_followers[request_endpoint] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
@@ -270,7 +273,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
                 self.was_accepted_to_group.is_set()
                 or not self.assembled_group.done()
                 or self.assembled_group.cancelled()
-                or request.endpoint not in self.assembled_group.result()
+                or request_endpoint not in self.assembled_group.result()
             ):
                 if self.current_leader is not None:
                     # outcome 3: found by a leader with higher priority, send our followers to him
@@ -296,7 +299,7 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
-            self.current_followers.pop(request.endpoint, None)
+            self.current_followers.pop(request_endpoint, None)
             self.follower_was_discarded.set()
 
     def _check_reasons_to_reject(

+ 1 - 1
hivemind/averaging/partition.py

@@ -32,7 +32,7 @@ class TensorPartContainer:
         self,
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
-        compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
+        compression_type: Union['CompressionType', Sequence['CompressionType']] = CompressionType.NONE,
         part_size_bytes: int = 2 ** 20,
         prefetch: int = 1,
     ):

+ 4 - 2
hivemind/p2p/servicer.py

@@ -75,8 +75,10 @@ class ServicerBase:
 
                 spec = inspect.getfullargspec(method)
                 if len(spec.args) < 3:
-                    raise ValueError(f"{handle_name} is expected to at least three positional arguments "
-                                     f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)")
+                    raise ValueError(
+                        f"{handle_name} is expected to at least three positional arguments "
+                        f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)"
+                    )
                 request_arg = spec.args[1]
                 hints = get_type_hints(method)
                 try:

+ 0 - 109
tests/log

@@ -1,109 +0,0 @@
-[2021/07/10 11:53:40.062][DEBUG][utils.grpc.ChannelCache:49] Eviction period = 600.0s, max channels = 4096
-[2021/07/10 11:53:40.247][DEBUG][asyncio.__init__:59] Using selector: EpollSelector
-[2021/07/10 11:53:40.654][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe5d8f481f0> opens connection to /unix/tmp/hivemind-p2pd-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.672][DEBUG][p2p.p2p_daemon._ping_daemon:193] Launched p2pd with id = QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak, host multiaddrs = (<Multiaddr /ip4/127.0.0.1/tcp/38633>,)
-[2021/07/10 11:53:40.673][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe5d8f481f0> opens connection to /unix/tmp/hivemind-p2pd-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.674][DEBUG][p2p.p2p_daemon_bindings.control.listen:108] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.ControlClient object at 0x7fe5d8f48250> starts listening to /unix/tmp/hivemind-p2pclient-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.674][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe5d8f481f0> opens connection to /unix/tmp/hivemind-p2pd-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.675][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe5d8f481f0> opens connection to /unix/tmp/hivemind-p2pd-B4tnhAq6U1U.sock
-[2021/07/10 11:53:40.676][INFO][root.run_protocol_listener:40] Started peer id=DHTID(0x3db9c894718f96bdd909afcef13a545941e7e1e4) visible_maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/38633/p2p/QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak>]
-[2021/07/10 11:53:40.683][DEBUG][asyncio.__init__:59] Using selector: EpollSelector
-[2021/07/10 11:53:41.101][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.103][DEBUG][p2p.p2p_daemon._ping_daemon:193] Launched p2pd with id = Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm, host multiaddrs = (<Multiaddr /ip4/127.0.0.1/tcp/41695>,)
-[2021/07/10 11:53:41.105][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.105][DEBUG][p2p.p2p_daemon_bindings.control.listen:108] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.ControlClient object at 0x7fe6623e2dc0> starts listening to /unix/tmp/hivemind-p2pclient-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.106][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.106][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.107][INFO][root.run_protocol_listener:40] Started peer id=DHTID(0xa7087d197c0f2758c6bff7256e3cb78c5152c137) visible_maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/41695/p2p/Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm>]
-[2021/07/10 11:53:41.108][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e2d30> opens connection to /unix/tmp/hivemind-p2pd-vXE2xWbi9bg.sock
-[2021/07/10 11:53:41.109][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm addr=/ip4/127.0.0.1/tcp/41695 proto=DHTProtocol.rpc_ping>
-[2021/07/10 11:53:41.111][DEBUG][asyncio.__init__:59] Using selector: EpollSelector
-[2021/07/10 11:53:41.528][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.530][DEBUG][p2p.p2p_daemon._ping_daemon:193] Launched p2pd with id = QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q, host multiaddrs = (<Multiaddr /ip4/127.0.0.1/tcp/34419>,)
-[2021/07/10 11:53:41.531][INFO][root.test_dht_protocol:81] Self id=DHTID(0xc147c315f7abd7fff33ccb0bd7b35cb0215ca616)
-[2021/07/10 11:53:41.531][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.532][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_ping>
-[2021/07/10 11:53:41.534][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.535][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.536][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.536][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.538][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.539][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.541][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.542][ERROR][dht.protocol.call_find:278] DHTProtocol failed to find at fakeid
-Traceback (most recent call last):
-  File "/home/borzunov/hivemind/hivemind/dht/protocol.py", line 245, in call_find
-    response = await self.get_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
-  File "/home/borzunov/hivemind/hivemind/utils/auth.py", line 199, in wrapped_rpc
-    response = await method(request, *args, **kwargs)
-  File "/home/borzunov/hivemind/hivemind/p2p/servicer.py", line 72, in caller
-    return await asyncio.wait_for(
-  File "/home/borzunov/anaconda3/lib/python3.8/asyncio/tasks.py", line 483, in wait_for
-    return fut.result()
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon.py", line 374, in call_protobuf_handler
-    stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/p2pclient.py", line 76, in stream_open
-    return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/control.py", line 185, in stream_open
-    raise_if_failed(resp)
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/utils.py", line 60, in raise_if_failed
-    raise ControlFailure(f"Connect failed. msg={response.error.msg}")
-hivemind.p2p.p2p_daemon_bindings.utils.ControlFailure: Connect failed. msg=length greater than remaining number of bytes in buffer
-[2021/07/10 11:53:41.544][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.545][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.546][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.547][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.547][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe6623e5d90> opens connection to /unix/tmp/hivemind-p2pd-VRUjGSy2BP8.sock
-[2021/07/10 11:53:41.548][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q addr=/ip4/127.0.0.1/tcp/34419 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.958][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.960][DEBUG][p2p.p2p_daemon._ping_daemon:193] Launched p2pd with id = QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE, host multiaddrs = (<Multiaddr /ip4/127.0.0.1/tcp/35705>,)
-[2021/07/10 11:53:41.961][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.962][DEBUG][p2p.p2p_daemon_bindings.control.listen:108] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.ControlClient object at 0x7fe66237f220> starts listening to /unix/tmp/hivemind-p2pclient-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.963][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.964][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.965][INFO][root.test_dht_protocol:81] Self id=DHTID(0x493ce245bb79ab0900ad291134df41125dfcf217)
-[2021/07/10 11:53:41.966][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.967][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_ping>
-[2021/07/10 11:53:41.968][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.969][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.970][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.971][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.973][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.974][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_find>
-[2021/07/10 11:53:41.975][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.975][ERROR][dht.protocol.call_find:278] DHTProtocol failed to find at fakeid
-Traceback (most recent call last):
-  File "/home/borzunov/hivemind/hivemind/dht/protocol.py", line 245, in call_find
-    response = await self.get_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
-  File "/home/borzunov/hivemind/hivemind/utils/auth.py", line 199, in wrapped_rpc
-    response = await method(request, *args, **kwargs)
-  File "/home/borzunov/hivemind/hivemind/p2p/servicer.py", line 72, in caller
-    return await asyncio.wait_for(
-  File "/home/borzunov/anaconda3/lib/python3.8/asyncio/tasks.py", line 483, in wait_for
-    return fut.result()
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon.py", line 374, in call_protobuf_handler
-    stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/p2pclient.py", line 76, in stream_open
-    return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/control.py", line 185, in stream_open
-    raise_if_failed(resp)
-  File "/home/borzunov/hivemind/hivemind/p2p/p2p_daemon_bindings/utils.py", line 60, in raise_if_failed
-    raise ControlFailure(f"Connect failed. msg={response.error.msg}")
-hivemind.p2p.p2p_daemon_bindings.utils.ControlFailure: Connect failed. msg=length greater than remaining number of bytes in buffer
-[2021/07/10 11:53:41.976][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.977][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.978][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.978][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_store>
-[2021/07/10 11:53:41.979][DEBUG][p2p.p2p_daemon_bindings.control.open_connection:57] DaemonConnector <hivemind.p2p.p2p_daemon_bindings.control.DaemonConnector object at 0x7fe66237f0a0> opens connection to /unix/tmp/hivemind-p2pd-Pswf-De9YNA.sock
-[2021/07/10 11:53:41.980][DEBUG][p2p.p2p_daemon_bindings.control._handler:83] New incoming stream: <StreamInfo peer_id=QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE addr=/ip4/127.0.0.1/tcp/35705 proto=DHTProtocol.rpc_find>
-2021-07-10T11:53:41.981+0300	ERROR	p2pd	error accepting connection	{"error": "accept unix /tmp/hivemind-p2pd-Pswf-De9YNA.sock: use of closed network connection"}
-[2021/07/10 11:53:41.985][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = QmfJuUCnS8PFAazbfgiExgpYGDCGnq5pBCQBSFZjVB71mE
-2021-07-10T11:53:41.985+0300	ERROR	p2pd	error accepting connection	{"error": "accept unix /tmp/hivemind-p2pd-B4tnhAq6U1U.sock: use of closed network connection"}
-[2021/07/10 11:53:41.988][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm
-[2021/07/10 11:53:41.988][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm
-[2021/07/10 11:53:41.989][INFO][root.shutdown:49] Finished peer id=DHTID(0xa7087d197c0f2758c6bff7256e3cb78c5152c137) maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/41695/p2p/Qma6rFAHFdwecmQKqWtU59LRqGAeXDam1rLc1u5ueNbwXm>]
-[2021/07/10 11:53:41.989][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak
-[2021/07/10 11:53:41.989][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak
-[2021/07/10 11:53:41.990][INFO][root.shutdown:49] Finished peer id=DHTID(0x3db9c894718f96bdd909afcef13a545941e7e1e4) maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/38633/p2p/QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak>]
-[2021/07/10 11:53:41.990][INFO][root.shutdown:49] Finished peer id=DHTID(0x3db9c894718f96bdd909afcef13a545941e7e1e4) maddrs=[<Multiaddr /ip4/127.0.0.1/tcp/38633/p2p/QmTqNtzhfJrChohykAkiZQ8xBJW8995EfQynqAf37BNCak>]
-[2021/07/10 11:53:42.033][DEBUG][p2p.p2p_daemon._terminate:401] Terminated p2pd with id = QmYqP7ECXjtxaTi54eznbk2tgq9D43xpyL8hRYsksRMQ8Q

+ 19 - 31
tests/test_allreduce.py

@@ -3,16 +3,15 @@ import random
 import time
 from typing import Sequence
 
-import grpc
 import pytest
 import torch
 
-from hivemind import aenumerate, Endpoint
+from hivemind import aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
-from hivemind.proto import averaging_pb2_grpc
+from hivemind.p2p import P2P, StubBase
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import deserialize_torch_tensor, ChannelCache
+from hivemind.utils import deserialize_torch_tensor
 
 
 @pytest.mark.forked
@@ -152,19 +151,6 @@ async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float
             assert torch.allclose(averaging_result, reference_tensor, rtol=1e-3, atol=1e-5)
 
 
-class AllreduceRunnerForTesting(AllReduceRunner):
-    """a version of AllReduceRunner that was monkey-patched to accept custom endpoint names"""
-
-    def __init__(self, *args, peer_endpoints, **kwargs):
-        self.__peer_endpoints = peer_endpoints
-        super().__init__(*args, **kwargs)
-
-    def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
-        return ChannelCache.get_stub(
-            self.__peer_endpoints[peer], averaging_pb2_grpc.DecentralizedAveragingStub, aio=True
-        )
-
-
 NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 
 
@@ -190,8 +176,18 @@ NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
     """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""
 
-    peers = "alice", "bob", "carol", "colab"
+    class AllreduceRunnerForTesting(AllReduceRunner):
+        def _get_stub(self, peer: str) -> StubBase:
+            return AllReduceRunner.get_stub(self._p2p, peer)
+
+    p2ps = []
+    initial_peers = []
+    for _ in range(4):
+        instance = await P2P.create(initial_peers=initial_peers)
+        p2ps.append(instance)
+        initial_peers.extend(await instance.get_visible_maddrs())
 
+    peers = [instance.id for instance in p2ps]
     tensors_by_peer = {
         peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
         for i, peer in enumerate(peers)
@@ -199,28 +195,20 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
 
     group_id = random.getrandbits(160).to_bytes(length=20, byteorder="big")
 
-    servers = []
     allreduce_protocols = []
-    peer_endpoints = {}
-
-    for peer in peers:
-        server = grpc.aio.server()
+    for p2p, peer in zip(p2ps, peers):
         allreduce_protocol = AllreduceRunnerForTesting(
+            p2p=p2p,
             group_id=group_id,
-            endpoint=peer,
             tensors=[x.clone() for x in tensors_by_peer[peer]],
             ordered_group_endpoints=peers,
             peer_fractions=peer_fractions,
             modes=peer_modes,
             weights=averaging_weights,
-            peer_endpoints=peer_endpoints,
             part_size_bytes=part_size_bytes,
         )
-        averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(allreduce_protocol, server)
-        peer_endpoints[peer] = f"127.0.0.1:{server.add_insecure_port('127.0.0.1:*')}"
+        await allreduce_protocol.add_p2p_handlers(p2p)
         allreduce_protocols.append(allreduce_protocol)
-        servers.append(server)
-        await server.start()
 
     async def _run_allreduce_inplace(allreduce: AllReduceRunner):
         async for tensor_index, tensor_delta in aenumerate(allreduce):
@@ -244,5 +232,5 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
         assert len(output_tensors) == len(targets_for_peer)
         assert all(torch.allclose(our, ref, atol=1e-6, rtol=0) for our, ref in zip(output_tensors, targets_for_peer))
 
-    for server in servers:
-        await server.stop(grace=1)
+    for instance in p2ps:
+        await instance.shutdown()

+ 18 - 10
tests/test_averaging.py

@@ -9,15 +9,19 @@ import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_key_manager():
+    localhvost = PeerID(b"localhvost")
+    localhvost2 = PeerID(b"localhvost2")
+
     key_manager = GroupKeyManager(
         hivemind.DHT(start=True),
-        endpoint="localhvost",
+        endpoint=localhvost,
         prefix="test_averaging",
         initial_group_bits="10110",
         target_group_size=2,
@@ -25,24 +29,24 @@ async def test_key_manager():
 
     t = hivemind.get_dht_time()
     key = key_manager.current_key
-    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 60)
-    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61)
+    await key_manager.declare_averager(key, localhvost, expiration_time=t + 60)
+    await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61)
 
     q1 = await key_manager.get_averagers(key, only_active=True)
 
-    await key_manager.declare_averager(key, "localhvost", expiration_time=t + 66)
+    await key_manager.declare_averager(key, localhvost, expiration_time=t + 66)
     q2 = await key_manager.get_averagers(key, only_active=True)
 
-    await key_manager.declare_averager(key, "localhvost2", expiration_time=t + 61, looking_for_group=False)
+    await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61, looking_for_group=False)
     q3 = await key_manager.get_averagers(key, only_active=True)
     q4 = await key_manager.get_averagers(key, only_active=False)
 
     q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
 
-    assert len(q1) == 2 and ("localhvost", t + 60) in q1 and ("localhvost2", t + 61) in q1
-    assert len(q2) == 2 and ("localhvost", t + 66) in q2 and ("localhvost2", t + 61) in q2
-    assert len(q3) == 1 and ("localhvost", t + 66) in q3
-    assert len(q4) == 2 and ("localhvost", t + 66) in q4 and ("localhvost2", t + 61) in q2
+    assert len(q1) == 2 and (localhvost, t + 60) in q1 and (localhvost2, t + 61) in q1
+    assert len(q2) == 2 and (localhvost, t + 66) in q2 and (localhvost2, t + 61) in q2
+    assert len(q3) == 1 and (localhvost, t + 66) in q3
+    assert len(q4) == 2 and (localhvost, t + 66) in q4 and (localhvost2, t + 61) in q2
     assert len(q5) == 0
 
 
@@ -459,7 +463,11 @@ def test_load_state_from_peers():
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
     averager = hivemind.averaging.DecentralizedAverager(
-        [torch.randn(3)], dht=dht, start=True, prefix="test_prefix", target_group_size=2,
+        [torch.randn(3)],
+        dht=dht,
+        start=True,
+        prefix="test_prefix",
+        target_group_size=2,
     )
     averager.set_group_bits("00101011101010")
     assert averager.get_group_bits() == "00101011101010"