소스 검색

Refactor naming and serialization for PeerIDs (#339)

This PR follows #323 and does the remaining mass refactors:

1. Rename `Endpoint` to `PeerID` in averager (+ related variable names)
2. Rename the `P2P.id` field to `P2P.peer_id` (because the local peer ID is stored in the `.peer_id` fields in all other classes)
3. Serialize `PeerID`s as `bytes` instead of Base58 string
4. Remove `JoinRequest.peer_id` and `AveragingData.peer_id` fields (they duplicate `context.remote_id`)
5. Remove the `DecentralizedAveraging` gRPC interface (not used anymore)
Alexander Borzunov 4 년 전
부모
커밋
0774937a93

+ 1 - 1
benchmarks/benchmark_averaging.py

@@ -70,7 +70,7 @@ def benchmark_averaging(
         processes.update({dht, averager})
         processes.update({dht, averager})
 
 
         logger.info(
         logger.info(
-            f"Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}"
+            f"Averager {index}: started with peer id {averager.peer_id}, group_bits: {averager.get_group_bits()}"
         )
         )
         for step in range(num_rounds):
         for step in range(num_rounds):
             try:
             try:

+ 41 - 42
hivemind/averaging/allreduce.py

@@ -5,7 +5,7 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 import torch
 import torch
 
 
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
-from hivemind.p2p import P2P, P2PContext, PeerID as Endpoint, ServicerBase, StubBase
+from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
 from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor, asingle
 from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor, asingle
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
@@ -38,11 +38,11 @@ class AllReduceRunner(ServicerBase):
     :param group_id: unique identifier of this specific all-reduce run
     :param group_id: unique identifier of this specific all-reduce run
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
     :param tensors: local tensors that should be averaged with groupmates
-    :param endpoint: your endpoint, must be included in ordered_group_endpoints
-    :param ordered_group_endpoints: group endpoints ordered s.t. i-th endpoint is responsible for averaging i-th part
+    :param peer_id: your peer_id, must be included in ordered_peer_ids
+    :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
-    :param modes: AveragingMode for each peer in ordered_group_endpoints (normal, client-only or auxiliary)
+    :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param weights: scaling coefficients for weighted averaging (default = equal weights for all non-aux peers)
     :param gathered: additional user-defined data collected from this group
     :param gathered: additional user-defined data collected from this group
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
     :param kwargs: additional paramters (e.g. part_size_bytes) will be passed to TensorPartContainer
@@ -56,16 +56,16 @@ class AllReduceRunner(ServicerBase):
         prefix: Optional[str],
         prefix: Optional[str],
         group_id: GroupID,
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
-        ordered_group_endpoints: Sequence[Endpoint],
+        ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         peer_fractions: Tuple[float, ...],
         weights: Optional[Sequence[float]] = None,
         weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
-        gathered: Optional[Dict[Endpoint, Any]] = None,
+        gathered: Optional[Dict[PeerID, Any]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
         self._p2p = p2p
         self._p2p = p2p
-        self.endpoint = p2p.id
-        assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
+        self.peer_id = p2p.peer_id
+        assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
 
 
         if not issubclass(servicer_type, ServicerBase):
         if not issubclass(servicer_type, ServicerBase):
             raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
             raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
@@ -74,60 +74,60 @@ class AllReduceRunner(ServicerBase):
 
 
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
         weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
-        assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
+        assert len(weights) == len(modes) == len(ordered_peer_ids), "lists have inconsistent length"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
         assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers"
         for mode, frac, weight in zip(modes, peer_fractions, weights):
         for mode, frac, weight in zip(modes, peer_fractions, weights):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
             assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"
 
 
-        self.group_id, self.ordered_group_endpoints = group_id, ordered_group_endpoints
+        self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
         self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
 
 
         self._future = asyncio.Future()
         self._future = asyncio.Future()
 
 
-        self.sender_endpoints, self.sender_weights = [], []
-        for endpoint, weight, mode in zip(self.ordered_group_endpoints, weights, modes):
+        self.sender_peer_ids, self.sender_weights = [], []
+        for peer_id, weight, mode in zip(self.ordered_peer_ids, weights, modes):
             if mode != AveragingMode.AUX:
             if mode != AveragingMode.AUX:
-                self.sender_endpoints.append(endpoint)
+                self.sender_peer_ids.append(peer_id)
                 self.sender_weights.append(weight)
                 self.sender_weights.append(weight)
 
 
-        endpoint_index = self.ordered_group_endpoints.index(self.endpoint)
+        peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
         self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
-        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(endpoint_index)
+        self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.tensor_part_reducer = TensorPartReducer(
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             tuple(part.shape for part in self.parts_for_local_averaging),
-            len(self.sender_endpoints),
+            len(self.sender_peer_ids),
             self.sender_weights,
             self.sender_weights,
         )
         )
 
 
     def __repr__(self):
     def __repr__(self):
-        return f"{self.__class__.__name__}({self.endpoint}, group_size={self.group_size})"
+        return f"{self.__class__.__name__}({self.peer_id}, group_size={self.group_size})"
 
 
     def __aiter__(self):
     def __aiter__(self):
         return self.run()
         return self.run()
 
 
-    def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.ordered_group_endpoints
+    def __contains__(self, peer_id: PeerID):
+        return peer_id in self.ordered_peer_ids
 
 
     @property
     @property
     def group_size(self):
     def group_size(self):
-        return len(self.ordered_group_endpoints)
+        return len(self.ordered_peer_ids)
 
 
-    def _get_peer_stub(self, peer: Endpoint) -> StubBase:
+    def _get_peer_stub(self, peer: PeerID) -> StubBase:
         return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
         return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix)
 
 
     async def run(self) -> AsyncIterator[torch.Tensor]:
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
         pending_tasks = set()
         try:
         try:
-            if len(self.sender_endpoints) == 0:
+            if len(self.sender_peer_ids) == 0:
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
                 self.finalize()
                 self.finalize()
 
 
-            elif self.endpoint in self.sender_endpoints:
-                for endpoint, parts in zip(self.ordered_group_endpoints, self.tensor_part_container.num_parts_by_peer):
+            elif self.peer_id in self.sender_peer_ids:
+                for peer_id, parts in zip(self.ordered_peer_ids, self.tensor_part_container.num_parts_by_peer):
                     if parts != 0:
                     if parts != 0:
-                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(endpoint)))
+                        pending_tasks.add(asyncio.create_task(self._communicate_with_peer(peer_id)))
 
 
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
@@ -143,11 +143,11 @@ class AllReduceRunner(ServicerBase):
                 task.cancel()
                 task.cancel()
             raise
             raise
 
 
-    async def _communicate_with_peer(self, peer_endpoint: Endpoint):
+    async def _communicate_with_peer(self, peer_id: PeerID):
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
-        peer_index = self.ordered_group_endpoints.index(peer_endpoint)
-        if peer_endpoint == self.endpoint:
-            sender_index = self.sender_endpoints.index(peer_endpoint)
+        peer_index = self.ordered_peer_ids.index(peer_id)
+        if peer_id == self.peer_id:
+            sender_index = self.sender_peer_ids.index(peer_id)
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
             for part_index, tensor_part in enumerate(self.parts_for_local_averaging):
                 averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
                 averaged_part = await self.tensor_part_reducer.accumulate_part(sender_index, part_index, tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
@@ -155,7 +155,7 @@ class AllReduceRunner(ServicerBase):
         else:
         else:
             loop = asyncio.get_event_loop()
             loop = asyncio.get_event_loop()
             code = None
             code = None
-            stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
+            stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
             async for part_index, msg in aenumerate(stream):
             async for part_index, msg in aenumerate(stream):
                 if code is None:
                 if code is None:
                     code = msg.code
                     code = msg.code
@@ -164,7 +164,7 @@ class AllReduceRunner(ServicerBase):
 
 
             if code != averaging_pb2.AVERAGED_PART:
             if code != averaging_pb2.AVERAGED_PART:
                 raise AllreduceException(
                 raise AllreduceException(
-                    f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
+                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
                     f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
                     f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
                     f", allreduce failed"
                     f", allreduce failed"
                 )
                 )
@@ -175,14 +175,13 @@ class AllReduceRunner(ServicerBase):
         yield averaging_pb2.AveragingData(
         yield averaging_pb2.AveragingData(
             code=averaging_pb2.PART_FOR_AVERAGING,
             code=averaging_pb2.PART_FOR_AVERAGING,
             group_id=self.group_id,
             group_id=self.group_id,
-            endpoint=self.endpoint.to_base58(),
             tensor_part=first_part,
             tensor_part=first_part,
         )
         )
         async for part in parts_aiter:
         async for part in parts_aiter:
             yield averaging_pb2.AveragingData(tensor_part=part)
             yield averaging_pb2.AveragingData(tensor_part=part)
 
 
     async def rpc_aggregate_part(
     async def rpc_aggregate_part(
-        self, stream: AsyncIterator[averaging_pb2.AveragingData], _context: P2PContext
+        self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         request: averaging_pb2.AveragingData = await anext(stream)
         request: averaging_pb2.AveragingData = await anext(stream)
@@ -193,7 +192,7 @@ class AllReduceRunner(ServicerBase):
 
 
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
             try:
-                sender_index = self.sender_endpoints.index(Endpoint.from_base58(request.endpoint))
+                sender_index = self.sender_peer_ids.index(context.remote_id)
                 async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
                 async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
                     yield msg
                     yield msg
 
 
@@ -202,8 +201,8 @@ class AllReduceRunner(ServicerBase):
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
         else:
         else:
             error_code = averaging_pb2.MessageCode.Name(request.code)
             error_code = averaging_pb2.MessageCode.Name(request.code)
-            logger.debug(f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue")
-            self.finalize(exception=AllreduceException(f"peer {request.endpoint} sent {error_code}."))
+            logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
+            self.finalize(exception=AllreduceException(f"peer {context.remote_id} sent {error_code}."))
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
 
     def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
     def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
@@ -230,10 +229,10 @@ class AllReduceRunner(ServicerBase):
             )
             )
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
 
-    async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint.to_base58(), code=code)
+    async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
+        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
         # In case of reporting the error, we expect the response stream to contain exactly one item
         # In case of reporting the error, we expect the response stream to contain exactly one item
-        await asingle(self._get_peer_stub(peer_endpoint).rpc_aggregate_part(aiter(error)))
+        await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
 
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
@@ -246,9 +245,9 @@ class AllReduceRunner(ServicerBase):
             else:
             else:
                 code = averaging_pb2.INTERNAL_ERROR
                 code = averaging_pb2.INTERNAL_ERROR
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
             logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_endpoint, mode in zip(self.ordered_group_endpoints, self.modes):
-                if peer_endpoint != self.endpoint and mode != AveragingMode.CLIENT:
-                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_endpoint, code)))
+            for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
+                if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
+                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
 
 
         if not self._future.done():
         if not self._future.done():
             if cancel:
             if cancel:

+ 10 - 10
hivemind/averaging/averager.py

@@ -22,7 +22,7 @@ from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase
+from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2, runtime_pb2
 from hivemind.proto import averaging_pb2, runtime_pb2
 from hivemind.utils import MPFuture, get_logger, TensorDescriptor
 from hivemind.utils import MPFuture, get_logger, TensorDescriptor
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
@@ -194,7 +194,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             self._allow_state_sharing.value = value
             self._allow_state_sharing.value = value
 
 
     @property
     @property
-    def endpoint(self) -> Endpoint:
+    def peer_id(self) -> PeerID:
         return self.dht.peer_id
         return self.dht.peer_id
 
 
     def run(self):
     def run(self):
@@ -281,7 +281,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         timeout: Optional[float] = None,
         timeout: Optional[float] = None,
         allow_retries: bool = True,
         allow_retries: bool = True,
         wait: bool = True,
         wait: bool = True,
-    ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
         """
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
 
@@ -375,7 +375,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
             weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
             weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
             modes = tuple(map(AveragingMode, mode_ids))
             modes = tuple(map(AveragingMode, mode_ids))
 
 
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
@@ -393,7 +393,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     prefix=self.prefix,
                     prefix=self.prefix,
                     group_id=group_info.group_id,
                     group_id=group_info.group_id,
                     tensors=local_tensors,
                     tensors=local_tensors,
-                    ordered_group_endpoints=group_info.endpoints,
+                    ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
                     weights=weights,
                     weights=weights,
                     gathered=user_gathered,
                     gathered=user_gathered,
@@ -406,7 +406,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     # actually run all-reduce
                     # actually run all-reduce
                     averaging_outputs = [output async for output in allreduce]
                     averaging_outputs = [output async for output in allreduce]
 
 
-                    if modes[group_info.endpoints.index(self.endpoint)] != AveragingMode.AUX:
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                         assert len(local_tensors) == len(self._averaged_tensors)
                         assert len(local_tensors) == len(self._averaged_tensors)
                         for tensor, update in zip(local_tensors, averaging_outputs):
                         for tensor, update in zip(local_tensors, averaging_outputs):
                             tensor.add_(update, alpha=self._averaging_alpha)
                             tensor.add_(update, alpha=self._averaging_alpha)
@@ -481,7 +481,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.wait_for(
                     asyncio.wait_for(
                         self.dht.store(
                         self.dht.store(
                             download_key,
                             download_key,
-                            subkey=self.endpoint.to_base58(),
+                            subkey=self.peer_id.to_bytes(),
                             value=self.last_updated,
                             value=self.last_updated,
                             expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
                             expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
                             return_future=True,
                             return_future=True,
@@ -547,8 +547,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             key_manager = self._matchmaking.group_key_manager
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority = {
             peer_priority = {
-                Endpoint.from_base58(peer): float(info.value)
-                for peer, info in peer_priority.items()
+                PeerID(peer_id): float(info.value)
+                for peer_id, info in peer_priority.items()
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
             }
             }
 
 
@@ -559,7 +559,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
             metadata = None
             metadata = None
             for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
             for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
-                if peer != self.endpoint:
+                if peer != self.peer_id:
                     logger.info(f"Downloading parameters from peer {peer}")
                     logger.info(f"Downloading parameters from peer {peer}")
                     try:
                     try:
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)

+ 6 - 6
hivemind/averaging/group_info.py

@@ -1,7 +1,7 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Tuple
 from typing import Tuple
 
 
-from hivemind.utils import Endpoint
+from hivemind.p2p import PeerID
 
 
 
 
 @dataclass(frozen=True)
 @dataclass(frozen=True)
@@ -9,12 +9,12 @@ class GroupInfo:
     """A group of peers assembled through decentralized matchmaking"""
     """A group of peers assembled through decentralized matchmaking"""
 
 
     group_id: bytes  # random unique bytestring that describes the current group, generated by group leader
     group_id: bytes  # random unique bytestring that describes the current group, generated by group leader
-    endpoints: Tuple[Endpoint, ...]  # an ordered sequence of endpoints of each groupmate
-    gathered: Tuple[bytes, ...]  # binary metadata gathered from all peers by leader, same order as endpoints
+    peer_ids: Tuple[PeerID, ...]  # an ordered sequence of peer_ids of each groupmate
+    gathered: Tuple[bytes, ...]  # binary metadata gathered from all peers by leader, same order as peer_ids
 
 
     @property
     @property
     def group_size(self):
     def group_size(self):
-        return len(self.endpoints)
+        return len(self.peer_ids)
 
 
-    def __contains__(self, endpoint: Endpoint):
-        return endpoint in self.endpoints
+    def __contains__(self, peer_id: PeerID):
+        return peer_id in self.peer_ids

+ 13 - 13
hivemind/averaging/key_manager.py

@@ -7,7 +7,7 @@ import numpy as np
 
 
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.p2p import PeerID as Endpoint
+from hivemind.p2p import PeerID
 from hivemind.utils import get_logger, DHTExpiration, get_dht_time, ValueWithExpiration
 from hivemind.utils import get_logger, DHTExpiration, get_dht_time, ValueWithExpiration
 
 
 GroupKey = str
 GroupKey = str
@@ -44,7 +44,7 @@ class GroupKeyManager:
             initial_group_nbits = self.get_suggested_nbits(search_result) or 0
             initial_group_nbits = self.get_suggested_nbits(search_result) or 0
             initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
             initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
-        self.endpoint = dht.peer_id
+        self.peer_id = dht.peer_id
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.excessive_size = excessive_size or target_group_size * 3
         self.excessive_size = excessive_size or target_group_size * 3
@@ -56,13 +56,13 @@ class GroupKeyManager:
         return f"{self.prefix}.0b{self.group_bits}"
         return f"{self.prefix}.0b{self.group_bits}"
 
 
     async def declare_averager(
     async def declare_averager(
-        self, group_key: GroupKey, endpoint: Endpoint, expiration_time: float, looking_for_group: bool = True
+        self, group_key: GroupKey, peer_id: PeerID, expiration_time: float, looking_for_group: bool = True
     ) -> bool:
     ) -> bool:
         """
         """
         Add (or remove) the averager to a given allreduce bucket
         Add (or remove) the averager to a given allreduce bucket
 
 
         :param group_key: allreduce group key, e.g. my_averager.0b011011101
         :param group_key: allreduce group key, e.g. my_averager.0b011011101
-        :param endpoint: averager public endpoint for incoming requests
+        :param peer_id: averager public peer_id for incoming requests
         :param expiration_time: intent to run allreduce before this timestamp
         :param expiration_time: intent to run allreduce before this timestamp
         :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
         :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
           If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
           If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
@@ -73,20 +73,20 @@ class GroupKeyManager:
         expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
         expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
         return await self.dht.store(
         return await self.dht.store(
             key=group_key,
             key=group_key,
-            subkey=endpoint.to_base58(),
+            subkey=peer_id.to_bytes(),
             value=looking_for_group,
             value=looking_for_group,
             expiration_time=expiration_time,
             expiration_time=expiration_time,
             return_future=True,
             return_future=True,
         )
         )
 
 
-    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[Endpoint, DHTExpiration]]:
+    async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tuple[PeerID, DHTExpiration]]:
         """
         """
         Find and return averagers that were declared with a given all-reduce key
         Find and return averagers that were declared with a given all-reduce key
 
 
         :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
         :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
         :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
         :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
             if False, return all averagers under a given group_key regardless of value
             if False, return all averagers under a given group_key regardless of value
-        :return: endpoints and expirations of every matching averager
+        :return: peer_ids and expirations of every matching averager
         """
         """
         assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
         assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
         result = await self.dht.get(group_key, latest=True, return_future=True)
         result = await self.dht.get(group_key, latest=True, return_future=True)
@@ -94,7 +94,7 @@ class GroupKeyManager:
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
             return []
             return []
         averagers = [
         averagers = [
-            (Endpoint.from_base58(key), looking_for_group.expiration_time)
+            (PeerID(key), looking_for_group.expiration_time)
             for key, looking_for_group in result.value.items()
             for key, looking_for_group in result.value.items()
             if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
             if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or looking_for_group.value)
         ]
         ]
@@ -111,10 +111,10 @@ class GroupKeyManager:
             and suggested_nbits != self.suggested_nbits
             and suggested_nbits != self.suggested_nbits
         ):
         ):
             self.suggested_nbits = suggested_nbits
             self.suggested_nbits = suggested_nbits
-            logger.warning(f"{self.endpoint} - another averager suggested {self.suggested_nbits}-bit keys")
+            logger.warning(f"{self.peer_id} - another averager suggested {self.suggested_nbits}-bit keys")
         elif num_active_averagers >= self.excessive_size:
         elif num_active_averagers >= self.excessive_size:
             self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
             self.suggested_nbits = max(suggested_nbits or 0, len(self.group_bits) + 1)
-            logger.warning(f"{self.endpoint} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
+            logger.warning(f"{self.peer_id} - too many peers in bucket, switching to {self.suggested_nbits}-bit keys")
         return averagers
         return averagers
 
 
     async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
     async def declare_nbits(self, group_key: GroupKey, nbits: int, expiration_time: DHTExpiration) -> bool:
@@ -141,12 +141,12 @@ class GroupKeyManager:
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
     async def update_key_on_group_assembled(self, group_info: GroupInfo, is_leader: bool = True):
         """this function is triggered every time an averager finds an allreduce group"""
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
         rng = random.Random(group_info.group_id)
-        index = group_info.endpoints.index(self.endpoint)
+        index = group_info.peer_ids.index(self.peer_id)
         generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
         generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
         nbits = int(np.ceil(np.log2(self.target_group_size)))
         nbits = int(np.ceil(np.log2(self.target_group_size)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
-        logger.debug(f"{self.endpoint} - updated group key to {self.group_bits}")
+        logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
 
 
         if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
         if is_leader and self.insufficient_size < group_info.group_size < self.excessive_size:
             asyncio.create_task(self.notify_stragglers())
             asyncio.create_task(self.notify_stragglers())
@@ -161,7 +161,7 @@ class GroupKeyManager:
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
         new_nbits = self.suggested_nbits if self.suggested_nbits is not None else len(self.group_bits) - 1
         prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
         prev_nbits, self.group_bits = self.group_bits, self.group_bits[-new_nbits:] if new_nbits else ""
         if self.group_bits != prev_nbits:
         if self.group_bits != prev_nbits:
-            logger.warning(f"{self.endpoint} - switching to {len(self.group_bits)}-bit keys")
+            logger.warning(f"{self.peer_id} - switching to {len(self.group_bits)}-bit keys")
         self.suggested_nbits = None
         self.suggested_nbits = None
 
 
     async def notify_stragglers(self):
     async def notify_stragglers(self):

+ 49 - 57
hivemind/averaging/matchmaking.py

@@ -12,7 +12,7 @@ from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.dht import DHT, DHTID, DHTExpiration
-from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
 from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
 from hivemind.utils.asyncio import anext
 from hivemind.utils.asyncio import anext
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
@@ -63,7 +63,7 @@ class Matchmaking:
         self._servicer_type = servicer_type
         self._servicer_type = servicer_type
         self._prefix = prefix
         self._prefix = prefix
 
 
-        self.endpoint = p2p.id
+        self.peer_id = p2p.peer_id
         self.schema_hash = schema_hash
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
@@ -76,9 +76,9 @@ class Matchmaking:
         self.was_accepted_to_group = asyncio.Event()
         self.was_accepted_to_group = asyncio.Event()
         self.assembled_group = asyncio.Future()
         self.assembled_group = asyncio.Future()
 
 
-        self.current_leader: Optional[Endpoint] = None  # iff i am a follower, this is a link to my current leader
-        self.current_followers: Dict[Endpoint, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.endpoint, averaging_expiration, target_group_size)
+        self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
+        self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
+        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
         self.data_for_gather: Optional[bytes] = None
         self.data_for_gather: Optional[bytes] = None
 
 
     @property
     @property
@@ -94,7 +94,7 @@ class Matchmaking:
                 lfg_status += f" leading {len(self.current_followers)} followers,"
                 lfg_status += f" leading {len(self.current_followers)} followers,"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         schema_hash_repr = f"{self.schema_hash[0]}...{self.schema_hash[-8:]}"
         return (
         return (
-            f"{self.__class__.__name__}(endpoint={self.endpoint}, schema={schema_hash_repr}, {lfg_status}"
+            f"{self.__class__.__name__}(peer_id={self.peer_id}, schema={schema_hash_repr}, {lfg_status}"
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
         )
 
 
@@ -167,7 +167,7 @@ class Matchmaking:
                         self.assembled_group.set_exception(e)
                         self.assembled_group.set_exception(e)
                     raise e
                     raise e
 
 
-    async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
         """
         """
         :param leader: request this peer to be your leader for allreduce
         :param leader: request this peer to be your leader for allreduce
         :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
         :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
@@ -183,7 +183,6 @@ class Matchmaking:
 
 
                 stream = leader_stub.rpc_join_group(
                 stream = leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                     averaging_pb2.JoinRequest(
-                        endpoint=self.endpoint.to_base58(),
                         schema_hash=self.schema_hash,
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,
                         expiration=expiration_time,
                         client_mode=self.client_mode,
                         client_mode=self.client_mode,
@@ -194,7 +193,7 @@ class Matchmaking:
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
 
                 if message.code == averaging_pb2.ACCEPTED:
                 if message.code == averaging_pb2.ACCEPTED:
-                    logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
+                    logger.debug(f"{self.peer_id} - joining the group of {leader}; waiting for peers")
                     self.current_leader = leader
                     self.current_leader = leader
                     self.was_accepted_to_group.set()
                     self.was_accepted_to_group.set()
                     if len(self.current_followers) > 0:
                     if len(self.current_followers) > 0:
@@ -202,7 +201,7 @@ class Matchmaking:
 
 
             if message.code != averaging_pb2.ACCEPTED:
             if message.code != averaging_pb2.ACCEPTED:
                 code = averaging_pb2.MessageCode.Name(message.code)
                 code = averaging_pb2.MessageCode.Name(message.code)
-                logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
+                logger.debug(f"{self.peer_id} - requested {leader} to be my leader, but got rejected with {code}")
                 return None
                 return None
 
 
             async with self.potential_leaders.pause_search():
             async with self.potential_leaders.pause_search():
@@ -215,8 +214,8 @@ class Matchmaking:
 
 
             if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
             if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
                 if message.suggested_leader:
                 if message.suggested_leader:
-                    suggested_leader = Endpoint.from_base58(message.suggested_leader)
-                    if suggested_leader != self.endpoint:
+                    suggested_leader = PeerID(message.suggested_leader)
+                    if suggested_leader != self.peer_id:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
                         self.current_leader = None
                         await stream.aclose()
                         await stream.aclose()
@@ -240,19 +239,17 @@ class Matchmaking:
                 await stream.aclose()
                 await stream.aclose()
 
 
     async def rpc_join_group(
     async def rpc_join_group(
-        self, request: averaging_pb2.JoinRequest, _context: P2PContext
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
-        request_endpoint = None
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
-                reason_to_reject = self._check_reasons_to_reject(request)
+                reason_to_reject = self._check_reasons_to_reject(request, context)
                 if reason_to_reject is not None:
                 if reason_to_reject is not None:
                     yield reason_to_reject
                     yield reason_to_reject
                     return
                     return
 
 
-                request_endpoint = Endpoint.from_base58(request.endpoint)
-                self.current_followers[request_endpoint] = request
+                self.current_followers[context.remote_id] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
 
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
                 if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
@@ -280,12 +277,12 @@ class Matchmaking:
                 self.was_accepted_to_group.is_set()
                 self.was_accepted_to_group.is_set()
                 or not self.assembled_group.done()
                 or not self.assembled_group.done()
                 or self.assembled_group.cancelled()
                 or self.assembled_group.cancelled()
-                or request_endpoint not in self.assembled_group.result()
+                or context.remote_id not in self.assembled_group.result()
             ):
             ):
                 if self.current_leader is not None:
                 if self.current_leader is not None:
                     # outcome 3: found by a leader with higher priority, send our followers to him
                     # outcome 3: found by a leader with higher priority, send our followers to him
                     yield averaging_pb2.MessageFromLeader(
                     yield averaging_pb2.MessageFromLeader(
-                        code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader.to_base58()
+                        code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader.to_bytes()
                     )
                     )
                     return
                     return
                 else:
                 else:
@@ -296,7 +293,7 @@ class Matchmaking:
             yield averaging_pb2.MessageFromLeader(
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE,
                 code=averaging_pb2.BEGIN_ALLREDUCE,
                 group_id=group_info.group_id,
                 group_id=group_info.group_id,
-                ordered_group_endpoints=[item.to_base58() for item in group_info.endpoints],
+                ordered_peer_ids=[item.to_bytes() for item in group_info.peer_ids],
                 gathered=group_info.gathered,
                 gathered=group_info.gathered,
             )
             )
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
@@ -306,27 +303,22 @@ class Matchmaking:
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
             yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)
 
 
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
         finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
-            self.current_followers.pop(request_endpoint, None)
+            self.current_followers.pop(context.remote_id, None)
             self.follower_was_discarded.set()
             self.follower_was_discarded.set()
 
 
     def _check_reasons_to_reject(
     def _check_reasons_to_reject(
-        self, request: averaging_pb2.JoinRequest
+        self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> Optional[averaging_pb2.MessageFromLeader]:
     ) -> Optional[averaging_pb2.MessageFromLeader]:
         """:returns: if accepted, return None, otherwise return a reason for rejection"""
         """:returns: if accepted, return None, otherwise return a reason for rejection"""
         if not self.is_looking_for_group or self.assembled_group.done():
         if not self.is_looking_for_group or self.assembled_group.done():
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
 
 
-        try:
-            request_endpoint = Endpoint.from_base58(request.endpoint)
-        except (ValueError, TypeError):
-            request_endpoint = None
         if (
         if (
             request.ListFields() == 3
             request.ListFields() == 3
             and not isinstance(request.schema_hash, bytes)
             and not isinstance(request.schema_hash, bytes)
             or len(request.schema_hash) == 0
             or len(request.schema_hash) == 0
             or not isinstance(request.expiration, DHTExpiration)
             or not isinstance(request.expiration, DHTExpiration)
             or not isfinite(request.expiration)
             or not isfinite(request.expiration)
-            or request_endpoint is None
             or self.client_mode
             or self.client_mode
             or not isinstance(request.group_key, GroupKey)
             or not isinstance(request.group_key, GroupKey)
         ):
         ):
@@ -342,10 +334,10 @@ class Matchmaking:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
         elif self.current_leader is not None:
         elif self.current_leader is not None:
             return averaging_pb2.MessageFromLeader(
             return averaging_pb2.MessageFromLeader(
-                code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader.to_base58()
+                code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader.to_bytes()
             )
             )
-        elif request_endpoint == self.endpoint or request_endpoint in self.current_followers:
-            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
+        elif context.remote_id == self.peer_id or context.remote_id in self.current_followers:
+            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_PEER_ID)
         elif len(self.current_followers) + 1 >= self.target_group_size:
         elif len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
         else:
         else:
@@ -355,35 +347,35 @@ class Matchmaking:
         """Form up all current followers into a group and gather metadata"""
         """Form up all current followers into a group and gather metadata"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() and not self.client_mode
         assert not self.assembled_group.done()
         assert not self.assembled_group.done()
-        group_id = DHTID.generate().to_bytes()  # note: both groupd_id and the order of endpoints must be random
-        ordered_group_endpoints = list(self.current_followers)
-        ordered_group_endpoints.append(self.endpoint)
-        random.shuffle(ordered_group_endpoints)
+        group_id = DHTID.generate().to_bytes()  # note: both groupd_id and the order of peer_ids must be random
+        ordered_peer_ids = list(self.current_followers)
+        ordered_peer_ids.append(self.peer_id)
+        random.shuffle(ordered_peer_ids)
 
 
         gathered = tuple(
         gathered = tuple(
-            self.data_for_gather if endpoint == self.endpoint else self.current_followers[endpoint].gather
-            for endpoint in ordered_group_endpoints
+            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            for peer_id in ordered_peer_ids
         )
         )
 
 
-        logger.debug(f"{self.endpoint} - assembled group of {len(ordered_group_endpoints)} peers.")
-        group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), gathered)
+        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers.")
+        group_info = GroupInfo(group_id, tuple(ordered_peer_ids), gathered)
         await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
         await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
         self.assembled_group.set_result(group_info)
         self.assembled_group.set_result(group_info)
         return group_info
         return group_info
 
 
-    async def follower_assemble_group(self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
+    async def follower_assemble_group(self, leader: PeerID, msg: averaging_pb2.MessageFromLeader) -> GroupInfo:
         """Form a group from using peers and metadata provided by our leader"""
         """Form a group from using peers and metadata provided by our leader"""
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked()
         assert not self.assembled_group.done()
         assert not self.assembled_group.done()
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
 
 
         group_id = msg.group_id
         group_id = msg.group_id
-        ordered_group_endpoints = [Endpoint.from_base58(item) for item in msg.ordered_group_endpoints]
-        assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
-        assert len(ordered_group_endpoints) == len(msg.gathered)
+        ordered_peer_ids = [PeerID(item) for item in msg.ordered_peer_ids]
+        assert self.peer_id in ordered_peer_ids, "Leader sent us group_peer_ids that does not contain us!"
+        assert len(ordered_peer_ids) == len(msg.gathered)
 
 
-        logger.debug(f"{self.endpoint} - follower assembled group with leader {leader}.")
-        group_info = GroupInfo(group_id, tuple(ordered_group_endpoints), tuple(msg.gathered))
+        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}.")
+        group_info = GroupInfo(group_id, tuple(ordered_peer_ids), tuple(msg.gathered))
         await self.group_key_manager.update_key_on_group_assembled(group_info)
         await self.group_key_manager.update_key_on_group_assembled(group_info)
         self.assembled_group.set_result(group_info)
         self.assembled_group.set_result(group_info)
         return group_info
         return group_info
@@ -397,13 +389,13 @@ class Matchmaking:
 class PotentialLeaders:
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
     """An utility class that searches for averagers that could become our leaders"""
 
 
-    def __init__(self, endpoint: Endpoint, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.endpoint, self.averaging_expiration = endpoint, averaging_expiration
+    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
-        self.leader_queue = TimedStorage[Endpoint, DHTExpiration]()
-        self.past_attempts: Set[Tuple[Endpoint, DHTExpiration]] = set()
+        self.leader_queue = TimedStorage[PeerID, DHTExpiration]()
+        self.past_attempts: Set[Tuple[PeerID, DHTExpiration]] = set()
         self.declared_expiration_time = float("inf")
         self.declared_expiration_time = float("inf")
         self.declared_group_key: Optional[GroupKey] = None
         self.declared_group_key: Optional[GroupKey] = None
         self.max_assured_time = float("-inf")
         self.max_assured_time = float("-inf")
@@ -450,7 +442,7 @@ class PotentialLeaders:
             else:
             else:
                 self.running.clear()
                 self.running.clear()
 
 
-    async def pop_next_leader(self) -> Endpoint:
+    async def pop_next_leader(self) -> PeerID:
         """Remove and return the next most suitable leader or throw an exception if reached timeout"""
         """Remove and return the next most suitable leader or throw an exception if reached timeout"""
         assert self.running.is_set(), "Not running search at the moment"
         assert self.running.is_set(), "Not running search at the moment"
         while True:
         while True:
@@ -459,9 +451,9 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
                 self.update_triggered.set()
 
 
-            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader.to_base58()) > (
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader.to_bytes()) > (
                 self.declared_expiration_time,
                 self.declared_expiration_time,
-                self.endpoint.to_base58(),
+                self.peer_id.to_bytes(),
             ):
             ):
                 await asyncio.wait(
                 await asyncio.wait(
                     {self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
                     {self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
@@ -496,7 +488,7 @@ class PotentialLeaders:
 
 
                 self.leader_queue.clear()
                 self.leader_queue.clear()
                 for peer, peer_expiration_time in new_peers:
                 for peer, peer_expiration_time in new_peers:
-                    if peer == self.endpoint or (peer, peer_expiration_time) in self.past_attempts:
+                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
                         continue
                         continue
                     self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
                     self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
                     self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
                     self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
@@ -512,7 +504,7 @@ class PotentialLeaders:
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
             return  # note: this is a compatibility layer for python3.7
             return  # note: this is a compatibility layer for python3.7
         except Exception as e:
         except Exception as e:
-            logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             raise
             raise
 
 
     async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
     async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
@@ -525,21 +517,21 @@ class PotentialLeaders:
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()
                     self.declared_expiration.set()
-                    await key_manager.declare_averager(group_key, self.endpoint, expiration_time=new_expiration_time)
+                    await key_manager.declare_averager(group_key, self.peer_id, expiration_time=new_expiration_time)
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
                         await key_manager.update_key_on_not_enough_peers()
             except (concurrent.futures.CancelledError, asyncio.CancelledError):
             except (concurrent.futures.CancelledError, asyncio.CancelledError):
                 pass  # note: this is a compatibility layer for python3.7
                 pass  # note: this is a compatibility layer for python3.7
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
             except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
+                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
             finally:
                 if self.declared_group_key is not None:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                     self.declared_group_key, self.declared_expiration_time = None, float("inf")
                     self.declared_group_key, self.declared_expiration_time = None, float("inf")
-                    self.leader_queue, self.max_assured_time = TimedStorage[Endpoint, DHTExpiration](), float("-inf")
+                    self.leader_queue, self.max_assured_time = TimedStorage[PeerID, DHTExpiration](), float("-inf")
                     await key_manager.declare_averager(
                     await key_manager.declare_averager(
-                        prev_declared_key, self.endpoint, prev_expiration_time, looking_for_group=False
+                        prev_declared_key, self.peer_id, prev_expiration_time, looking_for_group=False
                     )
                     )
 
 
 
 

+ 1 - 1
hivemind/dht/node.py

@@ -207,7 +207,7 @@ class DHTNode:
             record_validator,
             record_validator,
             authorizer,
             authorizer,
         )
         )
-        self.peer_id = p2p.id
+        self.peer_id = p2p.peer_id
 
 
         if initial_peers:
         if initial_peers:
             initial_peers = {PeerID.from_base58(Multiaddr(item)["p2p"]) for item in initial_peers}
             initial_peers = {PeerID.from_base58(Multiaddr(item)["p2p"]) for item in initial_peers}

+ 2 - 2
hivemind/dht/protocol.py

@@ -296,7 +296,7 @@ class DHTProtocol(ServicerBase):
                 nearest = dict(
                 nearest = dict(
                     zip(
                     zip(
                         map(DHTID.from_bytes, result.nearest_node_ids),
                         map(DHTID.from_bytes, result.nearest_node_ids),
-                        map(PeerID.from_base58, result.nearest_peer_ids),
+                        map(PeerID, result.nearest_peer_ids),
                     )
                     )
                 )
                 )
 
 
@@ -359,7 +359,7 @@ class DHTProtocol(ServicerBase):
                 key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)
                 key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)
             ):
             ):
                 item.nearest_node_ids.append(node_id.to_bytes())
                 item.nearest_node_ids.append(node_id.to_bytes())
-                item.nearest_peer_ids.append(peer_id.to_base58())
+                item.nearest_peer_ids.append(peer_id.to_bytes())
             response.results.append(item)
             response.results.append(item)
         return response
         return response
 
 

+ 2 - 2
hivemind/optim/collaborative.py

@@ -42,7 +42,7 @@ class CollaborationState:
 
 
 
 
 class TrainingState(BaseModel):
 class TrainingState(BaseModel):
-    peer_id: str
+    peer_id: bytes
     step: conint(ge=0, strict=True)
     step: conint(ge=0, strict=True)
     samples_accumulated: conint(ge=0, strict=True)
     samples_accumulated: conint(ge=0, strict=True)
     samples_per_second: confloat(ge=0.0, strict=True)
     samples_per_second: confloat(ge=0.0, strict=True)
@@ -354,7 +354,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             with self.lock_local_progress:
             with self.lock_local_progress:
                 current_time = get_dht_time()
                 current_time = get_dht_time()
                 local_state_info = TrainingState(
                 local_state_info = TrainingState(
-                    peer_id=self.averager.endpoint.to_base58(),
+                    peer_id=self.averager.peer_id.to_bytes(),
                     step=self.local_step,
                     step=self.local_step,
                     samples_accumulated=self.local_samples_accumulated,
                     samples_accumulated=self.local_samples_accumulated,
                     samples_per_second=self.performance_ema.samples_per_second,
                     samples_per_second=self.performance_ema.samples_per_second,

+ 8 - 8
hivemind/p2p/p2p_daemon.py

@@ -44,7 +44,7 @@ class P2P:
       - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
       - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
 
 
     To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
     To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
-    using the recipient's unique `P2P.id` and the name of the corresponding handler.
+    using the recipient's unique `P2P.peer_id` and the name of the corresponding handler.
     """
     """
 
 
     HEADER_LEN = 8
     HEADER_LEN = 8
@@ -65,7 +65,7 @@ class P2P:
     _UNIX_SOCKET_PREFIX = "/unix/tmp/hivemind-"
     _UNIX_SOCKET_PREFIX = "/unix/tmp/hivemind-"
 
 
     def __init__(self):
     def __init__(self):
-        self.id = None
+        self.peer_id = None
         self._child = None
         self._child = None
         self._alive = False
         self._alive = False
         self._listen_task = None
         self._listen_task = None
@@ -212,8 +212,8 @@ class P2P:
         return self
         return self
 
 
     async def _ping_daemon(self) -> None:
     async def _ping_daemon(self) -> None:
-        self.id, self._visible_maddrs = await self._client.identify()
-        logger.debug(f"Launched p2pd with id = {self.id}, host multiaddrs = {self._visible_maddrs}")
+        self.peer_id, self._visible_maddrs = await self._client.identify()
+        logger.debug(f"Launched p2pd with peer id = {self.peer_id}, host multiaddrs = {self._visible_maddrs}")
 
 
     async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
     async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
         """
@@ -226,9 +226,9 @@ class P2P:
             _, self._visible_maddrs = await self._client.identify()
             _, self._visible_maddrs = await self._client.identify()
 
 
         if not self._visible_maddrs:
         if not self._visible_maddrs:
-            raise ValueError(f"No multiaddrs found for peer {self.id}")
+            raise ValueError(f"No multiaddrs found for peer {self.peer_id}")
 
 
-        p2p_maddr = Multiaddr(f"/p2p/{self.id.to_base58()}")
+        p2p_maddr = Multiaddr(f"/p2p/{self.peer_id.to_base58()}")
         return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
         return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
 
 
     async def list_peers(self) -> List[PeerInfo]:
     async def list_peers(self) -> List[PeerInfo]:
@@ -312,7 +312,7 @@ class P2P:
         ) -> None:
         ) -> None:
             context = P2PContext(
             context = P2PContext(
                 handle_name=name,
                 handle_name=name,
-                local_id=self.id,
+                local_id=self.peer_id,
                 remote_id=stream_info.peer_id,
                 remote_id=stream_info.peer_id,
                 remote_maddr=stream_info.addr,
                 remote_maddr=stream_info.addr,
             )
             )
@@ -475,7 +475,7 @@ class P2P:
         if self._child is not None and self._child.poll() is None:
         if self._child is not None and self._child.poll() is None:
             self._child.terminate()
             self._child.terminate()
             self._child.wait()
             self._child.wait()
-            logger.debug(f"Terminated p2pd with id = {self.id}")
+            logger.debug(f"Terminated p2pd with id = {self.peer_id}")
 
 
             with suppress(FileNotFoundError):
             with suppress(FileNotFoundError):
                 os.remove(self._daemon_listen_maddr["unix"])
                 os.remove(self._daemon_listen_maddr["unix"])

+ 6 - 14
hivemind/proto/averaging.proto

@@ -2,13 +2,6 @@ syntax = "proto3";
 import "runtime.proto";
 import "runtime.proto";
 
 
 
 
-// Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
-service DecentralizedAveraging {
-  rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader);  // assemble a group for allreduce
-  rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData);  // send local part => get average part
-  rpc rpc_download_state(DownloadRequest) returns (stream DownloadData);
-}
-
 enum MessageCode {
 enum MessageCode {
   NO_CODE = 0;               // Default value that should not be used explicitly
   NO_CODE = 0;               // Default value that should not be used explicitly
   REQUEST_JOIN = 1;          // "Dear maybe leader, will you have me in your group as a follower?"
   REQUEST_JOIN = 1;          // "Dear maybe leader, will you have me in your group as a follower?"
@@ -21,7 +14,7 @@ enum MessageCode {
   BAD_EXPIRATION_TIME = 8;   // "I will not accept you. I cannot guarantee that we begin before you expire."
   BAD_EXPIRATION_TIME = 8;   // "I will not accept you. I cannot guarantee that we begin before you expire."
   BAD_SCHEMA_HASH = 9;       // "I will not accept you. I am not averaging the samy type of tensors as you."
   BAD_SCHEMA_HASH = 9;       // "I will not accept you. I am not averaging the samy type of tensors as you."
   BAD_GROUP_ID = 10;         // "I will not accept your request, your group id does not match with any groups i'm in."
   BAD_GROUP_ID = 10;         // "I will not accept your request, your group id does not match with any groups i'm in."
-  DUPLICATE_ENDPOINT = 11;   // "I will not accept you, i already have exactly the same endpoint in my current group."
+  DUPLICATE_PEER_ID = 11;    // "I will not accept you, i already have exactly the same peer id in my current group."
   GROUP_IS_FULL = 12;        // "I will not accept you, my group already contains too many peers."
   GROUP_IS_FULL = 12;        // "I will not accept you, my group already contains too many peers."
   NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost."
   NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost."
   PROTOCOL_VIOLATION = 14;   // "You did something so unspeakable that i don't have a special code for that."
   PROTOCOL_VIOLATION = 14;   // "You did something so unspeakable that i don't have a special code for that."
@@ -32,7 +25,6 @@ enum MessageCode {
 }
 }
 
 
 message JoinRequest {
 message JoinRequest {
-  string endpoint = 1;          // A follower accepts incoming allreduce requests at this address
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   bytes schema_hash = 2;        // A hash that describes follower's tensors (shapes, num tensors, etc)
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   double expiration = 3;        // Follower would like to **begin** all_reduce by this point in time
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
   bytes gather = 4;             // optional metadata that is gathered from all peers (e.g. batch size or current loss)
@@ -42,16 +34,16 @@ message JoinRequest {
 
 
 message MessageFromLeader {
 message MessageFromLeader {
   MessageCode code = 1;
   MessageCode code = 1;
-  bytes group_id = 2;        // a unique identifier of this group, only valid until allreduce is finished/failed
-  string suggested_leader = 3;  // if peer is already in a group, it'll provide us with an endpoint of its leader
-  repeated string ordered_group_endpoints = 4;  // a sequence of peers, each responsible for one shard during averaging
-  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their endpoints
+  bytes group_id = 2;           // a unique identifier of this group, only valid until allreduce is finished/failed
+  bytes suggested_leader = 3;   // if peer is already in a group, it'll provide us with a peer id of its leader
+  repeated bytes ordered_peer_ids = 4;  // a sequence of peers, each responsible for one shard during averaging
+  repeated bytes gathered = 5;  // metadata (gather) from all groupmates in the same order as their peer ids
 }
 }
 
 
 message AveragingData {
 message AveragingData {
   MessageCode code = 1;     // in case of a protocol violation, this will be the error message
   MessageCode code = 1;     // in case of a protocol violation, this will be the error message
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
   bytes group_id = 2;       // a unique group identifier, same as in MessageFromLeader
-  string endpoint = 3;      // sender's rpc endpoint, used for coordination
+  bytes peer_id = 3;        // sender's rpc peer_id, used for coordination
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
   Tensor tensor_part = 4;   // either peer's local tensor part (rpc input) or group average of this part (rpc output)
   bytes metadata = 5;       // reserved user-extendable metadata
   bytes metadata = 5;       // reserved user-extendable metadata
 }
 }

+ 2 - 2
hivemind/proto/dht.proto

@@ -65,8 +65,8 @@ message FindResult {
   double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
   double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
 
 
   // two aligned arrays: DHTIDs and PeerIDs for nearest peers (sorted by XOR distance)
   // two aligned arrays: DHTIDs and PeerIDs for nearest peers (sorted by XOR distance)
-  repeated bytes nearest_node_ids = 4;      // DHTIDs of the nearest peers serialized with node_id.to_bytes()
-  repeated string nearest_peer_ids = 5;     // Base58-serialized libp2p PeerIDs of the nearest peers
+  repeated bytes nearest_node_ids = 4;  // DHTIDs of the nearest peers serialized with node_id.to_bytes()
+  repeated bytes nearest_peer_ids = 5;  // libp2p PeerIDs of the nearest peers
 }
 }
 
 
 message FindResponse {
 message FindResponse {

+ 3 - 3
tests/test_allreduce.py

@@ -178,7 +178,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
     visible_maddrs = await p2ps[0].get_visible_maddrs()
     visible_maddrs = await p2ps[0].get_visible_maddrs()
     p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
     p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
 
 
-    peers = [instance.id for instance in p2ps]
+    peers = [instance.peer_id for instance in p2ps]
     tensors_by_peer = {
     tensors_by_peer = {
         peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
         peer: [torch.randn(3, 128), torch.rand(32), torch.tensor(i, dtype=torch.float32)]
         for i, peer in enumerate(peers)
         for i, peer in enumerate(peers)
@@ -193,8 +193,8 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             servicer_type=AllReduceRunner,
             servicer_type=AllReduceRunner,
             prefix=None,
             prefix=None,
             group_id=group_id,
             group_id=group_id,
-            tensors=[x.clone() for x in tensors_by_peer[p2p.id]],
-            ordered_group_endpoints=peers,
+            tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
+            ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             peer_fractions=peer_fractions,
             modes=peer_modes,
             modes=peer_modes,
             weights=averaging_weights,
             weights=averaging_weights,

+ 4 - 4
tests/test_averaging.py

@@ -95,7 +95,7 @@ def _test_allreduce_once(n_clients, n_aux):
     for future in futures:
     for future in futures:
         result = future.result()
         result = future.result()
         for averager in averagers:
         for averager in averagers:
-            assert averager.endpoint in result
+            assert averager.peer_id in result
 
 
     for averager in averagers:
     for averager in averagers:
         if averager.mode != AveragingMode.AUX:
         if averager.mode != AveragingMode.AUX:
@@ -291,13 +291,13 @@ def test_allgather(n_averagers=8, target_group_size=4):
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
 
 
     reference_metadata = {
     reference_metadata = {
-        averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
+        averager.peer_id: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
     }
     }
     for future in futures:
     for future in futures:
         gathered = future.result()
         gathered = future.result()
         assert len(gathered) == target_group_size
         assert len(gathered) == target_group_size
-        for endpoint in gathered:
-            assert gathered[endpoint] == reference_metadata[endpoint]
+        for peer_id in gathered:
+            assert gathered[peer_id] == reference_metadata[peer_id]
 
 
     for process in averagers + dht_instances:
     for process in averagers + dht_instances:
         process.shutdown()
         process.shutdown()

+ 1 - 1
tests/test_dht.py

@@ -103,5 +103,5 @@ async def test_dht_get_visible_maddrs():
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint])
     dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
     dht = hivemind.DHT(start=True, p2p=await p2p.replicate(p2p.daemon_listen_maddr))
 
 
-    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.id}")]
+    assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")]
     dht.shutdown()
     dht.shutdown()

+ 1 - 1
tests/test_dht_node.py

@@ -44,7 +44,7 @@ def run_protocol_listener(
     for peer_id in maddrs_to_peer_ids(initial_peers):
     for peer_id in maddrs_to_peer_ids(initial_peers):
         loop.run_until_complete(protocol.call_ping(peer_id))
         loop.run_until_complete(protocol.call_ping(peer_id))
 
 
-    maddr_conn.send((p2p.id, visible_maddrs))
+    maddr_conn.send((p2p.peer_id, visible_maddrs))
 
 
     async def shutdown():
     async def shutdown():
         await p2p.shutdown()
         await p2p.shutdown()

+ 12 - 12
tests/test_p2p_daemon.py

@@ -92,7 +92,7 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
         except asyncio.CancelledError:
         except asyncio.CancelledError:
             nonlocal handler_cancelled
             nonlocal handler_cancelled
             handler_cancelled = True
             handler_cancelled = True
-        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
+        return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
 
 
     server_pid = server_primary._child.pid
     server_pid = server_primary._child.pid
     await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
     await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
@@ -104,12 +104,12 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
     assert is_process_running(client_pid)
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
-    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
+    expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.peer_id.to_bytes()), available=True)
 
 
     if should_cancel:
     if should_cancel:
         call_task = asyncio.create_task(
         call_task = asyncio.create_task(
-            client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+            client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
         )
         )
         await asyncio.sleep(0.25)
         await asyncio.sleep(0.25)
 
 
@@ -119,7 +119,7 @@ async def test_call_protobuf_handler(should_cancel, replicate, handle_name="hand
         assert handler_cancelled
         assert handler_cancelled
     else:
     else:
         actual_response = await client.call_protobuf_handler(
         actual_response = await client.call_protobuf_handler(
-            server.id, handle_name, ping_request, dht_pb2.PingResponse
+            server.peer_id, handle_name, ping_request, dht_pb2.PingResponse
         )
         )
         assert actual_response == expected_response
         assert actual_response == expected_response
         assert not handler_cancelled
         assert not handler_cancelled
@@ -147,10 +147,10 @@ async def test_call_protobuf_handler_error(handle_name="handle"):
     assert is_process_running(client_pid)
     assert is_process_running(client_pid)
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
+    ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.peer_id.to_bytes()), validate=True)
 
 
     with pytest.raises(P2PHandlerError) as excinfo:
     with pytest.raises(P2PHandlerError) as excinfo:
-        await client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        await client.call_protobuf_handler(server.peer_id, handle_name, ping_request, dht_pb2.PingResponse)
     assert "boom" in str(excinfo.value)
     assert "boom" in str(excinfo.value)
 
 
     await server.shutdown()
     await server.shutdown()
@@ -196,7 +196,7 @@ async def test_call_peer_single_process():
 
 
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     await validate_square_stream(reader, writer)
     await validate_square_stream(reader, writer)
 
 
     await server.shutdown()
     await server.shutdown()
@@ -213,7 +213,7 @@ async def run_server(handler_name, server_side, response_received):
 
 
     await server.add_binary_stream_handler(handler_name, handle_square_stream)
     await server.add_binary_stream_handler(handler_name, handle_square_stream)
 
 
-    server_side.send(server.id)
+    server_side.send(server.peer_id)
     server_side.send(await server.get_visible_maddrs())
     server_side.send(await server.get_visible_maddrs())
     while response_received.value == 0:
     while response_received.value == 0:
         await asyncio.sleep(0.5)
         await asyncio.sleep(0.5)
@@ -281,7 +281,7 @@ async def test_error_closes_connection():
 
 
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     with closing(writer):
     with closing(writer):
         await P2P.send_raw_data(b"raise_error", writer)
         await P2P.send_raw_data(b"raise_error", writer)
         with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
         with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
@@ -290,7 +290,7 @@ async def test_error_closes_connection():
     # Despite the handler raised an exception, the server did not crash and ready for next requests
     # Despite the handler raised an exception, the server did not crash and ready for next requests
     assert is_process_running(server_pid)
     assert is_process_running(server_pid)
 
 
-    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.peer_id, handler_name)
     with closing(writer):
     with closing(writer):
         await P2P.send_raw_data(b"behave_normally", writer)
         await P2P.send_raw_data(b"behave_normally", writer)
         assert await P2P.receive_raw_data(reader) == b"okay"
         assert await P2P.receive_raw_data(reader) == b"okay"
@@ -309,7 +309,7 @@ async def test_handlers_on_different_replicas():
             await P2P.send_raw_data(key, writer)
             await P2P.send_raw_data(key, writer)
 
 
     server_primary = await P2P.create()
     server_primary = await P2P.create()
-    server_id = server_primary.id
+    server_id = server_primary.peer_id
     await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
     await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
 
 
     server_replica1 = await replicate_if_needed(server_primary, True)
     server_replica1 = await replicate_if_needed(server_primary, True)

+ 6 - 6
tests/test_p2p_servicer.py

@@ -25,7 +25,7 @@ async def test_unary_unary(server_client):
     server, client = server_client
     server, client = server_client
     servicer = ExampleServicer()
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
 
     assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
     assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
 
 
@@ -44,7 +44,7 @@ async def test_stream_unary(server_client):
     server, client = server_client
     server, client = server_client
     servicer = ExampleServicer()
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
         for i in range(10):
@@ -65,7 +65,7 @@ async def test_unary_stream(server_client):
     server, client = server_client
     server, client = server_client
     servicer = ExampleServicer()
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
 
     i = 0
     i = 0
     async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
     async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
@@ -87,7 +87,7 @@ async def test_stream_stream(server_client):
     server, client = server_client
     server, client = server_client
     servicer = ExampleServicer()
     servicer = ExampleServicer()
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
-    stub = ExampleServicer.get_stub(client, server.id)
+    stub = ExampleServicer.get_stub(client, server.peer_id)
 
 
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
     async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
         for i in range(10):
         for i in range(10):
@@ -128,7 +128,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
     await servicer.add_p2p_handlers(server)
     await servicer.add_p2p_handlers(server)
 
 
     if cancel_reason == "close_connection":
     if cancel_reason == "close_connection":
-        _, reader, writer = await client.call_binary_stream_handler(server.id, "ExampleServicer.rpc_wait")
+        _, reader, writer = await client.call_binary_stream_handler(server.peer_id, "ExampleServicer.rpc_wait")
         await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
         await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
         await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
         await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
 
 
@@ -138,7 +138,7 @@ async def test_unary_stream_cancel(server_client, cancel_reason):
 
 
         writer.close()
         writer.close()
     elif cancel_reason == "close_generator":
     elif cancel_reason == "close_generator":
-        stub = ExampleServicer.get_stub(client, server.id)
+        stub = ExampleServicer.get_stub(client, server.peer_id)
         iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
         iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
 
 
         assert await iter.__anext__() == test_pb2.TestResponse(number=11)
         assert await iter.__anext__() == test_pb2.TestResponse(number=11)

+ 4 - 4
tests/test_utils/dht_swarms.py

@@ -52,15 +52,15 @@ def launch_swarm_in_separate_processes(
         proc.start()
         proc.start()
         processes.append(proc)
         processes.append(proc)
 
 
-        node_id, peer_endpoint, peer_maddrs = info_queue.get()
-        dht[peer_endpoint] = node_id
+        node_id, peer_id, peer_maddrs = info_queue.get()
+        dht[peer_id] = node_id
         swarm_maddrs.append(peer_maddrs)
         swarm_maddrs.append(peer_maddrs)
 
 
     def collect_info():
     def collect_info():
         while True:
         while True:
-            node_id, peer_endpoint, peer_maddrs = info_queue.get()
+            node_id, peer_id, peer_maddrs = info_queue.get()
             with info_lock:
             with info_lock:
-                dht[peer_endpoint] = node_id
+                dht[peer_id] = node_id
                 swarm_maddrs.append(peer_maddrs)
                 swarm_maddrs.append(peer_maddrs)
 
 
                 if len(dht) == n_peers:
                 if len(dht) == n_peers: