Browse Source

Refactor naming and serialization for PeerIDs (#339)

This PR follows #323 and does the remaining mass refactors:

1. Rename `Endpoint` to `PeerID` in averager (+ related variable names)
2. Rename the `P2P.id` field to `P2P.peer_id` (because the local peer ID is stored in the `.peer_id` fields in all other classes)
3. Serialize `PeerID`s as `bytes` instead of Base58 string
4. Remove `JoinRequest.peer_id` and `AveragingData.peer_id` fields (they duplicate `context.remote_id`)
5. Remove the `DecentralizedAveraging` gRPC interface (not used anymore)
Alexander Borzunov 4 years ago
parent
commit
0774937a93

+ 1 - 1
benchmarks/benchmark_averaging.py

@@ -70,7 +70,7 @@ def benchmark_averaging(
         processes.update({dht, averager})
 
         logger.info(
-            f"Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}"
+            f"Averager {index}: started with peer id {averager.peer_id}, group_bits: {averager.get_group_bits()}"
         )
         for step in range(num_rounds):
             try:

+ 41 - 42
hivemind/averaging/allreduce.py

@@ -5,7 +5,7 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 import torch
 
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
-from hivemind.p2p import P2P, P2PContext, PeerID as Endpoint, ServicerBase, StubBase
+from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.utils import get_logger
 from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor, asingle
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
@@ -38,11 +38,11 @@ class AllReduceRunner(ServicerBase):
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
-    :param endpoint: your endpoint, must be included in ordered_group_endpoints
-    :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    :param peer_id: your peer_id, must be included in ordered_peer_ids
+    :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
-    :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
+    :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
@@ -56,16 +56,16 @@ class AllReduceRunner(ServicerBase):
         prefix: Optional[str],
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
-        ordered_group_endpoints: Sequence[Endpoint],
+        ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
-        gathered: Optional[Dict[Endpoint, Any]] = None,
+        gathered: Optional[Dict[PeerID, Any]] = None,
         **kwargs,
     ):
         self._p2p = p2p
-        self.endpoint = p2p.id
-        assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+        self.peer_id = p2p.peer_id
+        assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
 
         if not issubclass(servicer_type, ServicerBase):
             raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
@@ -74,60 +74,60 @@ class AllReduceRunner(ServicerBase):
 
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
+        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
         for mode, frac, weight in zip(modes, peer_fractions, weights):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
-        self.group_id, self.ordered_group_endpoints = group_id, ordered_group_endpoints
+        self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
         self._future = asyncio.Future()
 
-        self.sender_endpoints, self.sender_weights = [], []
-        for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
+        self.sender_peer_ids, self.sender_weights = [], []
+        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
             if mode != AveragingMode.AUX:
-                self.sender_endpoints.append(endpoint)
+                self.sender_peer_ids.append(peer_id)
                 self.sender_weights.append(weight)
 
-        endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
+        peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
-        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
+        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
-            len(self.sender_endpoints),
+            len(self.sender_peer_ids),
             self.sender_weights,
         )
 
     def __repr__(self):
-        return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
+        return f"{self.__class__.__name__}({self.peer_id}, group_size={self.group_size})"
 
     def __aiter__(self):
         return self.run()
 
-    def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.ordered_group_endpoints
+    def __contains__(self, peer_id: PeerID):
+        return peer_id in self.ordered_peer_ids
 
     @property
     def group_size(self):
-        return len(self.ordered_group_endpoints)
+        return len(self.ordered_peer_ids)
 
-    def _get_peer_stub(self, peer: Endpoint) -> StubBase:
+    def _get_peer_stub(self, peer: PeerID) -> StubBase:
         return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
         try:
-            if len(self.sender_endpoints) == 0:
+            if len(self.sender_peer_ids) == 0:
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
                 self.finalize()
 
-            elif self.endpoint in self.sender_endpoints:
-                for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
+            elif self.peer_id in self.sender_peer_ids:
+                for peer_id, parts in zip(self.ordered_peer_ids, self.tensor_part_container.num_parts_by_peer):
                     if parts != 0:
-                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
+                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(peer_id)))
 
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
@@ -143,11 +143,11 @@ class AllReduceRunner(ServicerBase):
                 task.cancel()
             raise
 
-    async def _communicate_with_peer(self, peer_endpoint: Endpoint):
+    async def _communicate_with_peer(self, peer_id: PeerID):
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
-        peer_index = self.ordered_group_endpoints.index(peer_endpoint)
-        if peer_endpoint == self.endpoint:
-            sender_index = self.sender_endpoints.index(peer_endpoint)
+        peer_index = self.ordered_peer_ids.index(peer_id)
+        if peer_id == self.peer_id:
+            sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
                 averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
@@ -155,7 +155,7 @@ class AllReduceRunner(ServicerBase):
         else:
             loop = asyncio.get_event_loop()
             code = None
-            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
+            stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
             async for part_index, msg in aenumerate(stream):
                 if code is None:
                     code = msg.code
@@ -164,7 +164,7 @@ class AllReduceRunner(ServicerBase):
 
             if code != averaging_pb2.AVERAGED_PART:
                 raise AllreduceException(
-                    f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
+                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
                     f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
                     f", allreduce failed"
                 )
@@ -175,14 +175,13 @@ class AllReduceRunner(ServicerBase):
         yield averaging_pb2.AveragingData(
             code=averaging_pb2.PART_FOR_AVERAGING,
             group_id=self.group_id,
-            endpoint=self.endpoint.to_base58(),
             tensor_part=first_part,
         )
         async for part in parts_aiter:
             yield averaging_pb2.AveragingData(tensor_part=part)
 
     async def rpc_aggregate_part(
-        self, stream: AsyncIterator[averaging_pb2.AveragingData], _context: P2PContext
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         request: averaging_pb2.AveragingData = await anext(stream)
@@ -193,7 +192,7 @@ class AllReduceRunner(ServicerBase):
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
-                sender_index = self.sender_endpoints.index(Endpoint.from_base58(request.endpoint))
+                sender_index = self.sender_peer_ids.index(context.remote_id)
                 async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
                     yield msg
 
@@ -202,8 +201,8 @@ class AllReduceRunner(ServicerBase):
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
-            logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
-            self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
+            self.finalize(exception=AllreduceException(f"peer {context.remote_id} sent {error_code}."))
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
     def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
@@ -230,10 +229,10 @@ class AllReduceRunner(ServicerBase):
             )
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
-    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint.to_base58(), code=code)
+    async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
+        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
         # In case of reporting the error, we expect the response stream to contain exactly one item
-        await asingle(self._get_peer_stub(peer_endpoint).rpc_aggregate_part(aiter(error)))
+        await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
@@ -246,9 +245,9 @@ class AllReduceRunner(ServicerBase):
             else:
                 code = averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
-                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
-                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_endpoint, code)))
+            for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
+                if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
+                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
 
         if not self._future.done():
             if cancel:

+ 10 - 10
hivemind/averaging/averager.py

@@ -22,7 +22,7 @@ from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase
+from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2, runtime_pb2
 from hivemind.utils import MPFuture, get_logger, TensorDescriptor
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
@@ -194,7 +194,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             self._allow_state_sharing.value = value
 
     @property
-    def endpoint(self) -> Endpoint:
+    def peer_id(self) -> PeerID:
         return self.dht.peer_id
 
     def run(self):
@@ -281,7 +281,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         timeout: Optional[float] = None,
         allow_retries: bool = True,
         wait: bool = True,
-    ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
@@ -375,7 +375,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
             weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
 
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
@@ -393,7 +393,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     prefix=self.prefix,
                     group_id=group_info.group_id,
                     tensors=local_tensors,
-                    ordered_group_endpoints=group_info.endpoints,
+                    ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     weights=weights,
                     gathered=user_gathered,
@@ -406,7 +406,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     # actually run all-reduce
                     averaging_outputs = [output async for output in allreduce]
 
-                    if modes[group_info.endpoints.index(self.endpoint)] != AveragingMode.AUX:
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                         assert len(local_tensors) == len(self._averaged_tensors)
                         for tensor, update in zip(local_tensors, averaging_outputs):
                             tensor.add_(update, alpha=self._averaging_alpha)
@@ -481,7 +481,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.wait_for(
                         self.dht.store(
                             download_key,
-                            subkey=self.endpoint.to_base58(),
+                            subkey=self.peer_id.to_bytes(),
                             value=self.last_updated,
                             expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
                             return_future=True,
@@ -547,8 +547,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority = {
-                Endpoint.from_base58(peer): float(info.value)
-                for peer, info in peer_priority.items()
+                PeerID(peer_id): float(info.value)
+                for peer_id, info in peer_priority.items()
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
             }
 
@@ -559,7 +559,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             metadata = None
             for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
-                if peer != self.endpoint:
+                if peer != self.peer_id:
                     logger.info(f"Downloading parameters from peer {peer}")
                     try:
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)

+ 6 - 6
hivemind/averaging/group_info.py

@@ -1,7 +1,7 @@
 from dataclasses import dataclass
 from typing import Tuple
 
-from hivemind.utils import Endpoint
+from hivemind.p2p import PeerID
 
 
 @dataclass(frozen=True)
@@ -9,12 +9,12 @@ class GroupInfo:
     """A group of peers assembled through decentralized matchmaking"""
 
     group_id: bytes  # random unique bytestring that describes the current group, generated by group leader
-    endpoints: Tuple[Endpoint, ...]  # an ordered sequence of endpoints of each groupmate
-    gathered: Tuple[bytes, ...]  # binary metadata gathered from all peers by leader, same order as endpoints
+    peer_ids: Tuple[PeerID, ...]  # an ordered sequence of peer_ids of each groupmate
+    gathered: Tuple[bytes, ...]  # binary metadata gathered from all peers by leader, same order as peer_ids
 
     @property
     def group_size(self):
-        return len(self.endpoints)
+        return len(self.peer_ids)
 
-    def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.endpoints
+    def __contains__(self, peer_id: PeerID):
+        return peer_id in self.peer_ids

+ 13 - 13
hivemind/averaging/key_manager.py

@@ -7,7 +7,7 @@ import numpy as np
 
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.dht import DHT
-from hivemind.p2p import PeerID as Endpoint
+from hivemind.p2p import PeerID
 from hivemind.utils import get_logger, DHTExpiration, get_dht_time, ValueWithExpiration
 
 GroupKey = str
@@ -44,7 +44,7 @@ class GroupKeyManager:
             initial_group_nbits = self.get_suggested_nbits(search_result) or 0
             initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
-        self.endpoint = dht.peer_id
+        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.excessive_size = excessive_size or target_group_size * 3
@@ -56,13 +56,13 @@ class GroupKeyManager:
         return f"{self.prefix}.0b{self.group_bits}"
 
     async def declare_averager(
-        self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float, looking_for_group: bool = True
+        self, group_key: GroupKey, peer_id: PeerID, expiration_time: float, looking_for_group: bool = True
     ) -> bool:
         """
         Add (or remove) the averager to a given allreduce bucket
 
         :param group_key: allreduce group key, e.g. my_averager.0b011011101
-        :param endpoint: averager public endpoint for incoming requests
+        :param peer_id: averager public peer_id for incoming requests
         :param expiration_time: intent to run allreduce before this timestamp
         :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
           If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
@@ -73,20 +73,20 @@ class GroupKeyManager:
         expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
         return await self.dht.store(
             key=group_key,
-            subkey=endpoint.to_base58(),
+            subkey=peer_id.to_bytes(),
             value=looking_for_group,
             expiration_time=expiration_time,
             return_future=True,
         )
 
-    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[Endpoint, DHTExpiration]]:
+    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[PeerID, DHTExpiration]]:
         """
         Find and return averagers that were declared with a given all-reduce key
 
         :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
         :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
             if False, return all averagers under a given group_key regardless of value
-        :return: endpoints and expirations of every matching averager
+        :return: peer_ids and expirations of every matching averager
         """
         assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
         result = await self.dht.get(group_key, latest=True, return_future=True)
@@ -94,7 +94,7 @@ class GroupKeyManager:
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
         averagers = [
-            (Endpoint.from_base58(key), looking_for_group.expiration_time)
+            (PeerID(key), looking_for_group.expiration_time)
             for key, looking_for_group in result.value.items()
             if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
         ]
@@ -111,10 +111,10 @@ class GroupKeyManager:
             and suggested_nbits != self.suggested_nbits
         ):
             self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
+            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
         elif num_active_averagers >= self.excessive_size:
             self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
         return averagers
 
     async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
@@ -141,12 +141,12 @@ class GroupKeyManager:
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
-        index = group_info.endpoints.index(self.endpoint)
+        index = group_info.peer_ids.index(self.peer_id)
         generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
         nbits = int(np.ceil(np.log2(self.target_group_size)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
-        logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
+        logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
         if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
             asyncio.create_task(self.notify_stragglers())
@@ -161,7 +161,7 @@ class GroupKeyManager:
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
         prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
         if self.group_bits != prev_nbits:
-            logger.warning(f"{self.endpoint} - switching to {len(self.group_bits)}-bit keys")
+            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
         self.suggested_nbits = None
 
     async def notify_stragglers(self):

+ 49 - 57
hivemind/averaging/matchmaking.py

@@ -12,7 +12,7 @@ from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.dht import DHT, DHTID, DHTExpiration
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
 from hivemind.utils.asyncio import anext
 from hivemind.proto import averaging_pb2
@@ -63,7 +63,7 @@ class Matchmaking:
         self._servicer_type = servicer_type
         self._prefix = prefix
 
-        self.endpoint = p2p.id
+        self.peer_id = p2p.peer_id
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
@@ -76,9 +76,9 @@ class Matchmaking:
         self.was_accepted_to_group = asyncio.Event()
         self.assembled_group = asyncio.Future()
 
-        self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
-        self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.endpoint, averaging_expiration, target_group_size)
+        self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
+        self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
+        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
         self.data_for_gather: Optional[bytes] = None
 
     @property
@@ -94,7 +94,7 @@ class Matchmaking:
                 lfg_status += f" leading {len(self.current_followers)} followers,"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         return (
-            f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}"
+            f"{self.__class__.__name__}(peer_id={self.peer_id}, schema={schema_hash_repr}, {lfg_status}"
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
 
@@ -167,7 +167,7 @@ class Matchmaking:
                         self.assembled_group.set_exception(e)
                     raise e
 
-    async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
         """
         :param leader: request this peer to be your leader for allreduce
         :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
@@ -183,7 +183,6 @@ class Matchmaking:
 
                 stream = leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
-                        endpoint=self.endpoint.to_base58(),
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,
                         client_mode=self.client_mode,
@@ -194,7 +193,7 @@ class Matchmaking:
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
-                    logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
+                    logger.debug(f"{self.peer_id} - joining the group of {leader}; waiting for peers")
                     self.current_leader = leader
                     self.was_accepted_to_group.set()
                     if len(self.current_followers) > 0:
@@ -202,7 +201,7 @@ class Matchmaking:
 
             if message.code != averaging_pb2.ACCEPTED:
                 code = averaging_pb2.MessageCode.Name(message.code)
-                logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
+                logger.debug(f"{self.peer_id} - requested {leader} to be my leader, but got rejected with {code}")
                 return None
 
             async with self.potential_leaders.pause_search():
@@ -215,8 +214,8 @@ class Matchmaking:
 
             if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
                 if message.suggested_leader:
-                    suggested_leader = Endpoint.from_base58(message.suggested_leader)
-                    if suggested_leader != self.endpoint:
+                    suggested_leader = PeerID(message.suggested_leader)
+                    if suggested_leader != self.peer_id:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
                         await stream.aclose()
@@ -240,19 +239,17 @@ class Matchmaking:
                 await stream.aclose()
 
     async def rpc_join_group(
-        self, request: averaging_pb2.JoinRequest, _context: P2PContext
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
-        request_endpoint = None
         try:
             async with self.lock_request_join_group:
-                reason_to_reject = self._check_reasons_to_reject(request)
+                reason_to_reject = self._check_reasons_to_reject(request, context)
                 if reason_to_reject is not None:
                     yield reason_to_reject
                     return
 
-                request_endpoint = Endpoint.from_base58(request.endpoint)
-                self.current_followers[request_endpoint] = request
+                self.current_followers[context.remote_id] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
@@ -280,12 +277,12 @@ class Matchmaking:
                 self.was_accepted_to_group.is_set()
                 or not self.assembled_group.done()
                 or self.assembled_group.cancelled()
-                or request_endpoint not in self.assembled_group.result()
+                or context.remote_id not in self.assembled_group.result()
             ):
                 if self.current_leader is not None:
                     # outcome 3: found by a leader with higher priority, send our followers to him
                     yield averaging_pb2.MessageFromLeader(
-                        code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader.to_base58()
+                        code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader.to_bytes()
                     )
                     return
                 else:
@@ -296,7 +293,7 @@ class Matchmaking:
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE,
                 group_id=group_info.group_id,
-                ordered_group_endpoints=[item.to_base58() for item in group_info.endpoints],
+                ordered_peer_ids=[item.to_bytes() for item in group_info.peer_ids],
                 gathered=group_info.gathered,
             )
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
@@ -306,27 +303,22 @@ class Matchmaking:
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
-            self.current_followers.pop(request_endpoint, None)
+            self.current_followers.pop(context.remote_id, None)
             self.follower_was_discarded.set()
 
     def _check_reasons_to_reject(
-        self, request: averaging_pb2.JoinRequest
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> Optional[averaging_pb2.MessageFromLeader]:
         """:returns: if accepted, return None, otherwise return a reason for rejection"""
         if not self.is_looking_for_group or self.assembled_group.done():
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
 
-        try:
-            request_endpoint = Endpoint.from_base58(request.endpoint)
-        except (ValueError, TypeError):
-            request_endpoint = None
         if (
             request.ListFields() == 3
             and not isinstance(request.schema_hash, bytes)
             or len(request.schema_hash) == 0
             or not isinstance(request.expiration, DHTExpiration)
             or not isfinite(request.expiration)
-            or request_endpoint is None
             or self.client_mode
             or not isinstance(request.group_key, GroupKey)
         ):
@@ -342,10 +334,10 @@ class Matchmaking:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
         elif self.current_leader is not None:
             return averaging_pb2.MessageFromLeader(
-                code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader.to_base58()
+                code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader.to_bytes()
             )
-        elif request_endpoint == self.endpoint or request_endpoint in self.current_followers:
-            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
+        elif context.remote_id == self.peer_id or context.remote_id in self.current_followers:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_PEER_ID)
         elif len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
         else:
@@ -355,35 +347,35 @@ class Matchmaking:
         """Form up all current followers into a group and gather metadata"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
         assert not self.assembled_group.done()
-        group_id = DHTID.generate().to_bytes()  # note: both groupd_id and the order of endpoints must be random
-        ordered_group_endpoints = list(self.current_followers)
-        ordered_group_endpoints.append(self.endpoint)
-        random.shuffle(ordered_group_endpoints)
+        group_id = DHTID.generate().to_bytes()  # note: both groupd_id and the order of peer_ids must be random
+        ordered_peer_ids = list(self.current_followers)
+        ordered_peer_ids.append(self.peer_id)
+        random.shuffle(ordered_peer_ids)
 
         gathered = tuple(
-            self.data_for_gather if endpoint == self.endpoint else self.current_followers[endpoint].gather
-            for endpoint in ordered_group_endpoints
+            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            for peer_id in ordered_peer_ids
         )
 
-        logger.debug(f"{self.endpoint} - assembled group of {len(ordered_group_endpoints)} peers.")
-        group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), gathered)
+        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers.")
+        group_info = GroupInfo(group_id, tuple(ordered_peer_ids), gathered)
         await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
         self.assembled_group.set_result(group_info)
         return group_info
 
-    async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
+    async def follower_assemble_group(self, leader: PeerID, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
         """Form a group from using peers and metadata provided by our leader"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert not self.assembled_group.done()
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
 
         group_id = msg.group_id
-        ordered_group_endpoints = [Endpoint.from_base58(item) for item in msg.ordered_group_endpoints]
-        assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
-        assert len(ordered_group_endpoints) == len(msg.gathered)
+        ordered_peer_ids = [PeerID(item) for item in msg.ordered_peer_ids]
+        assert self.peer_id in ordered_peer_ids, "Leader sent us group_peer_ids that does not contain us!"
+        assert len(ordered_peer_ids) == len(msg.gathered)
 
-        logger.debug(f"{self.endpoint} - follower assembled group with leader {leader}.")
-        group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), tuple(msg.gathered))
+        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}.")
+        group_info = GroupInfo(group_id, tuple(ordered_peer_ids), tuple(msg.gathered))
         await self.group_key_manager.update_key_on_group_assembled(group_info)
         self.assembled_group.set_result(group_info)
         return group_info
@@ -397,13 +389,13 @@ class Matchmaking:
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
 
-    def __init__(self, endpoint: Endpoint, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.endpoint, self.averaging_expiration = endpoint, averaging_expiration
+    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
-        self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
-        self.past_attempts: Set[Tuple[Endpoint, DHTExpiration]] = set()
+        self.leader_queue = TimedStorage[PeerID, DHTExpiration]()
+        self.past_attempts: Set[Tuple[PeerID, DHTExpiration]] = set()
         self.declared_expiration_time = float("inf")
         self.declared_group_key: Optional[GroupKey] = None
         self.max_assured_time = float("-inf")
@@ -450,7 +442,7 @@ class PotentialLeaders:
             else:
                 self.running.clear()
 
-    async def pop_next_leader(self) -> Endpoint:
+    async def pop_next_leader(self) -> PeerID:
         """Remove and return the next most suitable leader or throw an exception if reached timeout"""
         assert self.running.is_set(), "Not running search at the moment"
         while True:
@@ -459,9 +451,9 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
 
-            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader.to_base58()) > (
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader.to_bytes()) > (
                 self.declared_expiration_time,
-                self.endpoint.to_base58(),
+                self.peer_id.to_bytes(),
             ):
                 await asyncio.wait(
                     {self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
@@ -496,7 +488,7 @@ class PotentialLeaders:
 
                 self.leader_queue.clear()
                 for peer, peer_expiration_time in new_peers:
-                    if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
+                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
                         continue
                     self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
                     self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
@@ -512,7 +504,7 @@ class PotentialLeaders:
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
             return  # note: this is a compatibility layer for python3.7
         except Exception as e:
-            logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             raise
 
     async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
@@ -525,21 +517,21 @@ class PotentialLeaders:
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()
-                    await key_manager.declare_averager(group_key, self.endpoint, expiration_time=new_expiration_time)
+                    await key_manager.declare_averager(group_key, self.peer_id, expiration_time=new_expiration_time)
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
             except (concurrent.futures.CancelledError, asyncio.CancelledError):
                 pass  # note: this is a compatibility layer for python3.7
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     self.declared_group_key, self.declared_expiration_time = None, float("inf")
-                    self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float("-inf")
+                    self.leader_queue, self.max_assured_time = TimedStorage[PeerID, DHTExpiration](), float("-inf")
                     await key_manager.declare_averager(
-                        prev_declared_key, self.endpoint, prev_expiration_time, looking_for_group=False
+                        prev_declared_key, self.peer_id, prev_expiration_time, looking_for_group=False
                     )
 
 

+ 1 - 1
hivemind/dht/node.py

@@ -207,7 +207,7 @@ class DHTNode:
             record_validator,
             authorizer,
         )
-        self.peer_id = p2p.id
+        self.peer_id = p2p.peer_id
 
         if initial_peers:
             initial_peers = {PeerID.from_base58(Multiaddr(item)["p2p"]) for item in initial_peers}

+ 2 - 2
hivemind/dht/protocol.py

@@ -296,7 +296,7 @@ class DHTProtocol(ServicerBase):
                 nearest = dict(
                     zip(
                         map(DHTID.from_bytes, result.nearest_node_ids),
-                        map(PeerID.from_base58, result.nearest_peer_ids),
+                        map(PeerID, result.nearest_peer_ids),
                     )
                 )
 
@@ -359,7 +359,7 @@ class DHTProtocol(ServicerBase):
                 key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)
             ):
                 item.nearest_node_ids.append(node_id.to_bytes())
-                item.nearest_peer_ids.append(peer_id.to_base58())
+                item.nearest_peer_ids.append(peer_id.to_bytes())
             response.results.append(item)
         return response
 

+ 2 - 2
hivemind/optim/collaborative.py

@@ -42,7 +42,7 @@ class CollaborationState:
 
 
 class TrainingState(BaseModel):
-    peer_id: str
+    peer_id: bytes
     step: conint(ge=0, strict=True)
     samples_accumulated: conint(ge=0, strict=True)
     samples_per_second: confloat(ge=0.0, strict=True)
@@ -354,7 +354,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             with self.lock_local_progress:
                 current_time = get_dht_time()
                 local_state_info = TrainingState(
-                    peer_id=self.averager.endpoint.to_base58(),
+                    peer_id=self.averager.peer_id.to_bytes(),
                     step=self.local_step,
                     samples_accumulated=self.local_samples_accumulated,
                     samples_per_second=self.performance_ema.samples_per_second,

+ 8 - 8
hivemind/p2p/p2p_daemon.py

@@ -44,7 +44,7 @@ class P2P:
       - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
 
     To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
-    using the recipient's unique `P2P.id` and the name of the corresponding handler.
+    using the recipient's unique `P2P.peer_id` and the name of the corresponding handler.
     """
 
     HEADER_LEN = 8
@@ -65,7 +65,7 @@ class P2P:
     _UNIX_SOCKET_PREFIX = "/unix/tmp/hivemind-"
 
     def __init__(self):
-        self.id = None
+        self.peer_id = None
         self._child = None
         self._alive = False
         self._listen_task = None
@@ -212,8 +212,8 @@ class P2P:
         return self
 
     async def _ping_daemon(self) -> None:
-        self.id, self._visible_maddrs = await self._client.identify()
-        logger.debug(f"Launched p2pd with id = {self.id}, host multiaddrs = {self._visible_maddrs}")
+        self.peer_id, self._visible_maddrs = await self._client.identify()
+        logger.debug(f"Launched p2pd with peer id = {self.peer_id}, host multiaddrs = {self._visible_maddrs}")
 
     async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
@@ -226,9 +226,9 @@ class P2P:
             _, self._visible_maddrs = await self._client.identify()
 
         if not self._visible_maddrs:
-            raise ValueError(f"No multiaddrs found for peer {self.id}")
+            raise ValueError(f"No multiaddrs found for peer {self.peer_id}")
 
-        p2p_maddr = Multiaddr(f"/p2p/{self.id.to_base58()}")
+        p2p_maddr = Multiaddr(f"/p2p/{self.peer_id.to_base58()}")
         return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
 
     async def list_peers(self) -> List[PeerInfo]:
@@ -312,7 +312,7 @@ class P2P:
         ) -> None:
             context = P2PContext(
                 handle_name=name,
-                local_id=self.id,
+                local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
                 remote_maddr=stream_info.addr,
             )
@@ -475,7 +475,7 @@ class P2P:
         if self._child is not None and self._child.poll() is None:
             self._child.terminate()
             self._child.wait()
-            logger.debug(f"Terminated p2pd with id = {self.id}")
+            logger.debug(f"Terminated p2pd with id = {self.peer_id}")
 
             with suppress(FileNotFoundError):
                 os.remove(self._daemon_listen_maddr["unix"])

+ 6 - 14
hivemind/proto/averaging.proto

@@ -2,13 +2,6 @@ syntax = "proto3";
 import "runtime.proto";
 
 
-// Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
-service DecentralizedAveraging {
-  rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
-  rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData);  // send local part => get average part
-  rpc rpc_download_state(DownloadRequest) returns (stream DownloadData);
-}
-
 enum MessageCode {
   NO_CODE = 0;               // Default value that should not be used explicitly
   REQUEST_JOIN = 1;          // "Dear maybe leader, will you have me in your group as a follower?"
@@ -21,7 +14,7 @@ enum MessageCode {
   BAD_EXPIRATION_TIME = 8;   // "I will not accept you. I cannot guarantee that we begin before you expire."
   BAD_SCHEMA_HASH = 9;       // "I will not accept you. I am not averaging the samy type of tensors as you."
   BAD_GROUP_ID = 10;         // "I will not accept your request, your group id does not match with any groups i'm in."
-  DUPLICATE_ENDPOINT = 11;   // "I will not accept you, i already have exactly the same endpoint in my current group."
+  DUPLICATE_PEER_ID = 11;    // "I will not accept you, i already have exactly the same peer id in my current group."
   GROUP_IS_FULL = 12;        // "I will not accept you, my group already contains too many peers."
   NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost."
   PROTOCOL_VIOLATION = 14;   // "You did something so unspeakable that i don't have a special code for that."
@@ -32,7 +25,6 @@ enum MessageCode {
 }
 
 message JoinRequest {
-  string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
@@ -42,16 +34,16 @@ message JoinRequest {
 
 message MessageFromLeader {
   MessageCode code = 1;
-  bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
-  string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
-  repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
-  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endpoints
+  bytes group_id = 2;           // a unique identifier of this group, only valid until allreduce is finished/failed
+  bytes suggested_leader = 3;   // if peer is already in a group, it'll provide us with a peer id of its leader
+  repeated bytes ordered_peer_ids = 4;  // a sequence of peers, each responsible for one shard during averaging
+  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their peer ids
 }
 
 message AveragingData {
   MessageCode code = 1;     // in case of a protocol violation, this will be the error message
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
-  string endpoint = 3;      // sender's rpc endpoint, used for coordination
+  bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
   bytes metadata = 5;       // reserved user-extendable metadata
 }

+ 2 - 2
hivemind/proto/dht.proto

@@ -65,8 +65,8 @@ message FindResult {
   double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
 
   // two aligned arrays: DHTIDs and PeerIDs for nearest peers (sorted by XOR distance)
-  repeated bytes nearest_node_ids = 4;      // DHTIDs of the nearest peers serialized with node_id.to_bytes()
-  repeated string nearest_peer_ids = 5;     // Base58-serialized libp2p PeerIDs of the nearest peers
+  repeated bytes nearest_node_ids = 4;  // DHTIDs of the nearest peers serialized with node_id.to_bytes()
+  repeated bytes nearest_peer_ids = 5;  // libp2p PeerIDs of the nearest peers
 }
 
 message FindResponse {

+ 3 - 3
tests/test_allreduce.py

@@ -178,7 +178,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     visible_maddrs = await p2ps[0].get_visible_maddrs()
     p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
 
-    peers = [instance.id for instance in p2ps]
+    peers = [instance.peer_id for instance in p2ps]
     tensors_by_peer = {
         peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
         for i, peer in enumerate(peers)
@@ -193,8 +193,8 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             servicer_type=AllReduceRunner,
             prefix=None,
             group_id=group_id,
-            tensors=[x.clone() for x in tensors_by_peer[p2p.id]],
-            ordered_group_endpoints=peers,
+            tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
+            ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             modes=peer_modes,
             weights=averaging_weights,

+ 4 - 4
tests/test_averaging.py

@@ -95,7 +95,7 @@ def _test_allreduce_once(n_clients, n_aux):
     for future in futures:
         result = future.result()
         for averager in averagers:
-            assert averager.endpoint in result
+            assert averager.peer_id in result
 
     for averager in averagers:
         if averager.mode != AveragingMode.AUX:
@@ -291,13 +291,13 @@ def test_allgather(n_averagers=8, target_group_size=4):
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
 
     reference_metadata = {
-        averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
+        averager.peer_id: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
     }
     for future in futures:
         gathered = future.result()
         assert len(gathered) == target_group_size
-        for endpoint in gathered:
-            assert gathered[endpoint] == reference_metadata[endpoint]
+        for peer_id in gathered:
+            assert gathered[peer_id] == reference_metadata[peer_id]
 
     for process in averagers + dht_instances:
         process.shutdown()

+ 1 - 1
tests/test_dht.py

@@ -103,5 +103,5 @@ async def test_dht_get_visible_maddrs():
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
 
-    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.id}")]
+    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()

+ 1 - 1
tests/test_dht_node.py

@@ -44,7 +44,7 @@ def run_protocol_listener(
     for peer_id in maddrs_to_peer_ids(initial_peers):
         loop.run_until_complete(protocol.call_ping(peer_id))
 
-    maddr_conn.send((p2p.id, visible_maddrs))
+    maddr_conn.send((p2p.peer_id, visible_maddrs))
 
     async def shutdown():
         await p2p.shutdown()

+ 12 - 12
tests/test_p2p_daemon.py

@@ -92,7 +92,7 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
         except asyncio.CancelledError:
             nonlocal handler_cancelled
             handler_cancelled = True
-        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
+        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
 
     server_pid = server_primary._child.pid
     await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
@@ -104,12 +104,12 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
 
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
-    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
+    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
 
     if should_cancel:
         call_task = asyncio.create_task(
-            client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+            client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
         )
         await asyncio.sleep(0.25)
 
@@ -119,7 +119,7 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
         assert handler_cancelled
     else:
         actual_response = await client.call_protobuf_handler(
-            server.id, handle_name, ping_request, dht_pb2.PingResponse
+            server.peer_id, handle_name, ping_request, dht_pb2.PingResponse
         )
         assert actual_response == expected_response
         assert not handler_cancelled
@@ -147,10 +147,10 @@ async def test_call_protobuf_handler_error(handle_name="handle"):
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
 
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
 
     with pytest.raises(P2PHandlerError) as excinfo:
-        await client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        await client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
     assert "boom" in str(excinfo.value)
 
     await server.shutdown()
@@ -196,7 +196,7 @@ async def test_call_peer_single_process():
 
     await client.wait_for_at_least_n_peers(1)
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     await validate_square_stream(reader, writer)
 
     await server.shutdown()
@@ -213,7 +213,7 @@ async def run_server(handler_name, server_side, response_received):
 
     await server.add_binary_stream_handler(handler_name, handle_square_stream)
 
-    server_side.send(server.id)
+    server_side.send(server.peer_id)
     server_side.send(await server.get_visible_maddrs())
     while response_received.value == 0:
         await asyncio.sleep(0.5)
@@ -281,7 +281,7 @@ async def test_error_closes_connection():
 
     await client.wait_for_at_least_n_peers(1)
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     with closing(writer):
         await P2P.send_raw_data(b"raise_error", writer)
         with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
@@ -290,7 +290,7 @@ async def test_error_closes_connection():
     # Despite the handler raised an exception, the server did not crash and ready for next requests
     assert is_process_running(server_pid)
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     with closing(writer):
         await P2P.send_raw_data(b"behave_normally", writer)
         assert await P2P.receive_raw_data(reader) == b"okay"
@@ -309,7 +309,7 @@ async def test_handlers_on_different_replicas():
             await P2P.send_raw_data(key, writer)
 
     server_primary = await P2P.create()
-    server_id = server_primary.id
+    server_id = server_primary.peer_id
     await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
 
     server_replica1 = await replicate_if_needed(server_primary, True)

+ 6 - 6
tests/test_p2p_servicer.py

@@ -25,7 +25,7 @@ async def test_unary_unary(server_client):
     server, client = server_client
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
     assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
 
@@ -44,7 +44,7 @@ async def test_stream_unary(server_client):
     server, client = server_client
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
@@ -65,7 +65,7 @@ async def test_unary_stream(server_client):
     server, client = server_client
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
     i = 0
     async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
@@ -87,7 +87,7 @@ async def test_stream_stream(server_client):
     server, client = server_client
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
@@ -128,7 +128,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
     await servicer.add_p2p_handlers(server)
 
     if cancel_reason == "close_connection":
-        _, reader, writer = await client.call_binary_stream_handler(server.id, "ExampleServicer.rpc_wait")
+        _, reader, writer = await client.call_binary_stream_handler(server.peer_id, "ExampleServicer.rpc_wait")
         await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
         await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
 
@@ -138,7 +138,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
 
         writer.close()
     elif cancel_reason == "close_generator":
-        stub = ExampleServicer.get_stub(client, server.id)
+        stub = ExampleServicer.get_stub(client, server.peer_id)
         iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
 
         assert await iter.__anext__() == test_pb2.TestResponse(number=11)

+ 4 - 4
tests/test_utils/dht_swarms.py

@@ -52,15 +52,15 @@ def launch_swarm_in_separate_processes(
         proc.start()
         processes.append(proc)
 
-        node_id, peer_endpoint, peer_maddrs = info_queue.get()
-        dht[peer_endpoint] = node_id
+        node_id, peer_id, peer_maddrs = info_queue.get()
+        dht[peer_id] = node_id
         swarm_maddrs.append(peer_maddrs)
 
     def collect_info():
         while True:
-            node_id, peer_endpoint, peer_maddrs = info_queue.get()
+            node_id, peer_id, peer_maddrs = info_queue.get()
             with info_lock:
-                dht[peer_endpoint] = node_id
+                dht[peer_id] = node_id
                 swarm_maddrs.append(peer_maddrs)
 
                 if len(dht) == n_peers: