Browse Source

Use PeerID exclusively to address MoE experts (#479)

Changed declare_experts / RemoteExpert to use only p2p peer ID, not the whole multiaddress.
This slightly reduces the code complexity and gives you an easier time sharing experts with dynamic IP.

It also fixes one DHT edge case i've discovered when working on it.

Minor changes:
- fixed an edge case: previously, DHT would **freeze** if accessing DHT.peer_id or otherwise calling .run_coroutine from inside another run_coroutine
- merged RemoteExpertInfo and UidEndpoint into one structure (ExpertInfo), now in expert_uid.py
- extracted expert_uid.py from hivemind.moe.server to hivemind.moe in order to avoid circular imports
- renamed get_expert_stub into get_server_stub since it is not expert-specific


Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Pavel Samygin <samygin@phystech.edu>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 năm trước cách đây
mục cha
commit
25366a1436

+ 4 - 4
benchmarks/benchmark_throughput.py

@@ -7,11 +7,12 @@ import time
 import torch
 
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo
+from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server.layers import name_to_block
-from hivemind.p2p import P2P, PeerInfo
+from hivemind.p2p import P2P
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
@@ -48,9 +49,8 @@ def client_process(
     can_start.wait()
 
     p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
-    peer_info = PeerInfo(server_peer_id, server_maddrs)
     experts = [
-        RemoteExpert(expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=peer_info), p2p=p2p)
+        RemoteExpert(expert_info=ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_id), p2p=p2p)
         for i in range(num_experts)
     ]
 

+ 8 - 1
hivemind/dht/dht.py

@@ -169,6 +169,7 @@ class DHT(mp.Process):
         :param kwargs: parameters forwarded to DHTNode.get_many_by_id
         :returns: (value, expiration time); if value was not found, returns None
         """
+        assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
         future = MPFuture()
         self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
         return future if return_future else future.result()
@@ -202,6 +203,7 @@ class DHT(mp.Process):
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
+        assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
         future = MPFuture()
         self._outer_pipe.send(
             (
@@ -246,6 +248,7 @@ class DHT(mp.Process):
           or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
         :note: when run_coroutine is called with return_future=False, MPFuture can be cancelled to interrupt the task.
         """
+        assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
         future = MPFuture()
         self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
         return future if return_future else future.result()
@@ -275,7 +278,11 @@ class DHT(mp.Process):
     @property
     def peer_id(self) -> PeerID:
         if self._peer_id is None:
-            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+            if os.getpid() == self.pid:
+                self._peer_id = self._node.peer_id
+            else:
+                # note: we cannot run_coroutine from the same pid because it would deadlock the event loop
+                self._peer_id = self.run_coroutine(DHT._get_peer_id)
         return self._peer_id
 
     @staticmethod

+ 51 - 60
hivemind/moe/client/beam_search.py

@@ -5,25 +5,21 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import (
-    RemoteExpert,
-    RemoteExpertInfo,
-    batch_create_remote_experts,
-    create_remote_experts,
-)
-from hivemind.moe.server.expert_uid import (
+from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
+from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
     PREFIX_PATTERN,
     UID_DELIMITER,
     Coordinate,
+    ExpertInfo,
     ExpertPrefix,
     ExpertUID,
     Score,
-    UidEndpoint,
     is_valid_prefix,
+    is_valid_uid,
 )
-from hivemind.p2p import PeerInfo
-from hivemind.utils import MPFuture, get_dht_time, get_logger
+from hivemind.p2p import PeerID
+from hivemind.utils import MPFuture, ValueWithExpiration, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 
@@ -100,7 +96,7 @@ class MoEBeamSearcher:
 
     def get_initial_beam(
         self, scores: Sequence[float], beam_size: int, return_future: bool = False
-    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
         """
         :param scores: prefer suffix coordinates that have highest scores
         :param beam_size: select this many active suffixes with highest scores
@@ -130,9 +126,9 @@ class MoEBeamSearcher:
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
-    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
         num_workers = num_workers or dht.num_workers or beam_size
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = []
         unattempted_indices: List[Coordinate] = sorted(
             range(len(scores)), key=scores.__getitem__
         )  # from worst to best
@@ -150,13 +146,7 @@ class MoEBeamSearcher:
             try:
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
-                    successors = {
-                        coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
-                        for coord, match in maybe_prefix_data.value.items()
-                        if isinstance(coord, Coordinate)
-                        and isinstance(getattr(match, "value", None), list)
-                        and len(match.value) == 2
-                    }
+                    successors = MoEBeamSearcher._select_valid_entries(maybe_prefix_data)
                     if successors:
                         beam.append((scores[pending_best_index], pending_best_prefix, successors))
                 elif maybe_prefix_data is None and negative_caching:
@@ -178,7 +168,7 @@ class MoEBeamSearcher:
 
     def get_active_successors(
         self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
-    ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+    ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
         """
         :param prefixes: a list of prefix for which to find active successor uids
         :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
@@ -201,6 +191,22 @@ class MoEBeamSearcher:
             return_future=return_future,
         )
 
+    @staticmethod
+    def _select_valid_entries(entry: ValueWithExpiration, grid_size: Optional[int] = None):
+        if not isinstance(entry, ValueWithExpiration) or not isinstance(entry.value, dict):
+            return {}
+        return {
+            coord: ExpertInfo(uid=match.value[0], peer_id=PeerID.from_base58(match.value[1]))
+            for coord, match in entry.value.items()
+            if isinstance(coord, Coordinate)
+            and (grid_size is None or 0 <= coord < grid_size)
+            and isinstance(match, ValueWithExpiration)
+            and isinstance(match.value, tuple)
+            and len(match.value) == 2
+            and is_valid_uid(match.value[0])
+            and isinstance(match.value[1], str)
+        }
+
     @staticmethod
     async def _get_active_successors(
         dht: DHT,
@@ -210,28 +216,18 @@ class MoEBeamSearcher:
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
-    ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+    ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
         grid_size = grid_size or float("inf")
         num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
-        successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
+        successors: Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]] = {}
         for prefix, found in dht_responses.items():
-            if found and isinstance(found.value, dict):
-                successors[prefix] = {
-                    coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
-                    for coord, match in found.value.items()
-                    if isinstance(coord, Coordinate)
-                    and 0 <= coord < grid_size
-                    and isinstance(getattr(match, "value", None), list)
-                    and len(match.value) == 2
-                }
-            else:
-                successors[prefix] = {}
-                if found is None and negative_caching:
-                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
-                    asyncio.create_task(
-                        node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
-                    )
+            successors[prefix] = MoEBeamSearcher._select_valid_entries(found, grid_size)
+            if not successors[prefix] and negative_caching:
+                logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
+                asyncio.create_task(
+                    node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
+                )
         return successors
 
     def find_best_experts(
@@ -246,7 +242,6 @@ class MoEBeamSearcher:
          After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
          Please note that any queries that fall outside the budget will still be performed in background and cached
          for subsequent iterations as long as DHTNode.cache_locally is True
-        :param num_workers: use up to this many concurrent workers to search DHT
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
@@ -263,7 +258,6 @@ class MoEBeamSearcher:
             ),
             return_future,
         )
-
         return create_remote_experts(result, self.dht, return_future)
 
     @classmethod
@@ -277,23 +271,23 @@ class MoEBeamSearcher:
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
-    ) -> List[RemoteExpertInfo]:
+    ) -> List[ExpertInfo]:
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = await cls._get_initial_beam(
             dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
         )
 
-        best_experts_heap: List[Tuple[Score, UidEndpoint]] = []  # max-heap of expert uids/endpoints ordered by scores
+        best_experts_heap: List[Tuple[Score, ExpertInfo]] = []  # max-heap of expert infos ordered by scores
         unique_experts: Set[ExpertUID] = set()
 
         for dim_index in range(1, len(grid_scores) - 1):
-            for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
-                if uid_endpoint.uid not in unique_experts:
+            for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
+                if expert_info.uid not in unique_experts:
                     push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                    push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                    unique_experts.add(uid_endpoint.uid)
+                    push_and_maybe_pop(best_experts_heap, (score, expert_info))
+                    unique_experts.add(expert_info.uid)
 
             # form new beam using successors from the current beam
             dim_scores = grid_scores[dim_index]
@@ -306,6 +300,7 @@ class MoEBeamSearcher:
                     if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
                 ),
             )
+
             _, best_uid_prefixes = zip(*best_active_pairs)
 
             # search DHT for next step suffixes
@@ -324,22 +319,18 @@ class MoEBeamSearcher:
                 break
 
         # add best experts from the final beam
-        for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
-            if uid_endpoint.uid not in unique_experts:
+        for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
+            if expert_info.uid not in unique_experts:
                 push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                unique_experts.add(uid_endpoint.uid)
+                push_and_maybe_pop(best_experts_heap, (score, expert_info))
+                unique_experts.add(expert_info.uid)
 
-        best_experts = [
-            RemoteExpertInfo(uid_endpoint.uid, uid_endpoint.peer_info)
-            for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
-        ]
-        return best_experts
+        return [expert_info for _, expert_info in sorted(best_experts_heap, reverse=True)]
 
     @staticmethod
     def _iterate_matching_experts(
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]], grid_scores: Sequence[Sequence[float]]
-    ) -> Iterator[Tuple[Score, UidEndpoint]]:
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]], grid_scores: Sequence[Sequence[float]]
+    ) -> Iterator[Tuple[Score, ExpertInfo]]:
         """iterate over all exemplar experts attached to current beam"""
         for score, prefix, suffixes in beam:
             for next_coord, match in suffixes.items():
@@ -399,7 +390,7 @@ class MoEBeamSearcher:
         beam_size: int,
         negative_caching: bool,
         num_workers: Optional[int],
-    ) -> Sequence[Sequence[RemoteExpertInfo]]:
+    ) -> Sequence[Sequence[ExpertInfo]]:
         batch_grid_scores = [
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
         ]

+ 13 - 20
hivemind/moe/client/expert.py

@@ -1,7 +1,6 @@
 from __future__ import annotations
 
 from concurrent.futures import Future
-from dataclasses import dataclass
 from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import torch
@@ -12,7 +11,8 @@ from hivemind import moe
 from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2P, PeerInfo, StubBase
+from hivemind.moe.expert_uid import ExpertInfo
+from hivemind.p2p import P2P, PeerID, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
@@ -24,16 +24,9 @@ from hivemind.utils.streaming import split_for_streaming
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
-def get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandlerStub":
-    return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
-
-
-@dataclass(frozen=True)
-class RemoteExpertInfo:
-    """A simple data class containing uid of expert and server PeerInfo"""
-
-    uid: str
-    peer_info: PeerInfo
+def get_server_stub(p2p: P2P, server_peer_id: PeerID) -> "ConnectionHandlerStub":
+    """Create an RPC stub that can send requests to any expert on the specified remote server"""
+    return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_id)
 
 
 class RemoteExpert(nn.Module):
@@ -47,7 +40,7 @@ class RemoteExpert(nn.Module):
     :param p2p: P2P instance connected to the running p2pd
     """
 
-    def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
+    def __init__(self, expert_info: ExpertInfo, p2p: P2P):
         super().__init__()
         self._info, self.p2p = expert_info, p2p
         self._rpc_info = None
@@ -57,12 +50,12 @@ class RemoteExpert(nn.Module):
         return self._info.uid
 
     @property
-    def server_peer_info(self):
-        return self._info.peer_info
+    def peer_id(self) -> PeerID:
+        return self._info.peer_id
 
     @property
     def stub(self) -> StubBase:
-        return get_expert_stub(self.p2p, self.server_peer_info)
+        return get_server_stub(self.p2p, self.peer_id)
 
     def forward(self, *args, **kwargs):
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
@@ -89,10 +82,10 @@ class RemoteExpert(nn.Module):
         return self._rpc_info
 
     def extra_repr(self):
-        return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
+        return f"uid={self.uid}, server_peer_id={self.peer_id}"
 
 
-def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
+def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
     experts: List[Optional[RemoteExpert]] = []
     for info in infos:
         if info is not None:
@@ -103,7 +96,7 @@ def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P
 
 
 def create_remote_experts(
-    infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+    infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
 ) -> Union[List[Optional[RemoteExpert]], Future]:
     if return_future:
 
@@ -118,7 +111,7 @@ def create_remote_experts(
 
 
 def batch_create_remote_experts(
-    infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
+    infos: Union[Sequence[Sequence[Optional[ExpertInfo]]], MPFuture],
     dht: DHT,
     return_future: bool = False,
 ) -> Union[List[List[Optional[RemoteExpert]]], Future]:

+ 4 - 4
hivemind/moe/client/moe.py

@@ -12,9 +12,9 @@ from torch.autograd.function import once_differentiable
 from hivemind.compression import serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import DUMMY, RemoteExpert, expert_backward, expert_forward, get_expert_stub
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, expert_backward, expert_forward, get_server_stub
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.moe.expert_uid import UID_DELIMITER
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils.logging import get_logger
@@ -227,7 +227,7 @@ class _RemoteCallMany(torch.autograd.Function):
         pending_tasks: Dict[Future, Tuple[int, int]] = {}
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
-                stub = get_expert_stub(expert.p2p, expert.server_peer_info)
+                stub = get_server_stub(expert.p2p, expert.peer_id)
                 serialized_tensors = (
                     serialize_torch_tensor(tensor, proto.compression)
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
@@ -321,7 +321,7 @@ class _RemoteCallMany(torch.autograd.Function):
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
         ):
             expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
-            stub = get_expert_stub(expert.p2p, expert.server_peer_info)
+            stub = get_server_stub(expert.p2p, expert.peer_id)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             serialized_tensors = (
                 serialize_torch_tensor(tensor, proto.compression)

+ 1 - 1
hivemind/moe/client/switch_moe.py

@@ -6,7 +6,7 @@ import torch
 
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
-from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.moe.expert_uid import UID_DELIMITER
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger

+ 4 - 2
hivemind/moe/server/expert_uid.py → hivemind/moe/expert_uid.py

@@ -1,10 +1,12 @@
+from __future__ import annotations
+
 import re
 from typing import NamedTuple, Tuple, Union
 
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerInfo
+from hivemind.p2p import PeerID
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("peer_info", PeerInfo)])
+ExpertInfo = NamedTuple("ExpertInfo", [("uid", ExpertUID), ("peer_id", PeerID)])
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims

+ 17 - 16
hivemind/moe/server/dht_handler.py

@@ -3,18 +3,19 @@ from functools import partial
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
-from hivemind.moe.server.expert_uid import (
+from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
+from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
     UID_DELIMITER,
     UID_PATTERN,
     Coordinate,
+    ExpertInfo,
     ExpertPrefix,
     ExpertUID,
     is_valid_uid,
     split_uid,
 )
-from hivemind.p2p import PeerID, PeerInfo
+from hivemind.p2p import PeerID
 from hivemind.utils import MPFuture, get_dht_time
 
 
@@ -44,27 +45,27 @@ def declare_experts(
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    if not isinstance(uids, list):
+        uids = list(uids)
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
-    addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in dht.get_visible_maddrs())
-    return dht.run_coroutine(
-        partial(_declare_experts, uids=list(uids), peer_id=dht.peer_id, addrs=addrs, expiration=expiration),
-        return_future=not wait,
-    )
+    return dht.run_coroutine(partial(_declare_experts, uids=uids, expiration=expiration), return_future=not wait)
 
 
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, addrs: Tuple[str], expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
+    peer_id_base58 = dht.peer_id.to_base58()
+
     for uid in uids:
-        data_to_store[uid, None] = (peer_id.to_base58(), addrs)
+        data_to_store[uid, None] = peer_id_base58
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         for i in range(prefix.count(UID_DELIMITER) - 1):
             prefix, last_coord = split_uid(prefix)
-            data_to_store[prefix, last_coord] = [uid, (peer_id.to_base58(), addrs)]
+            data_to_store[prefix, last_coord] = (uid, peer_id_base58)
 
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
@@ -87,15 +88,15 @@ def get_experts(
 
 async def _get_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[RemoteExpertInfo]]:
+) -> List[Optional[ExpertInfo]]:
     if expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
-    experts: List[Optional[RemoteExpert]] = [None] * len(uids)
+    experts: List[Optional[ExpertInfo]] = [None] * len(uids)
     for i, uid in enumerate(uids):
-        expert_info_for_uid = found[uid]
-        if expert_info_for_uid is not None and isinstance(expert_info_for_uid.value, tuple):
-            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(expert_info_for_uid.value))
+        server_peer_id = found[uid]
+        if server_peer_id is not None and isinstance(server_peer_id.value, str):
+            experts[i] = ExpertInfo(uid, PeerID.from_base58(server_peer_id.value))
     return experts

+ 2 - 3
hivemind/moe/server/server.py

@@ -6,17 +6,16 @@ import threading
 from contextlib import contextmanager
 from functools import partial
 from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
 
 import torch
-from multiaddr import Multiaddr
 
 from hivemind.dht import DHT
+from hivemind.moe.expert_uid import UID_DELIMITER
 from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.layers import (
     add_custom_models_from_file,
     name_to_block,

+ 0 - 6
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -128,12 +128,6 @@ class PeerInfo:
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         return PeerInfo(peer_id, addrs)
 
-    @classmethod
-    def from_tuple(cls, value: Tuple[str, Sequence[str]]) -> "PeerInfo":
-        peer_id = PeerID.from_base58(value[0])
-        addrs = [Multiaddr(addr) for addr in value[1]]
-        return PeerInfo(peer_id, addrs)
-
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 

+ 6 - 5
tests/test_custom_experts.py

@@ -4,7 +4,8 @@ import pytest
 import torch
 
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
+from hivemind.moe.client.expert import create_remote_experts
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import background_server
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@@ -23,8 +24,8 @@ def test_custom_expert(hid_dim=16):
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         expert0, expert1 = create_remote_experts(
             [
-                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
-                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
             ],
             dht=dht,
         )
@@ -54,8 +55,8 @@ def test_multihead_expert(hid_dim=16):
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         expert0, expert1 = create_remote_experts(
             [
-                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
-                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
             ],
             dht=dht,
         )

+ 7 - 10
tests/test_dht_experts.py

@@ -8,9 +8,8 @@ import pytest
 import hivemind
 from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
+from hivemind.moe.expert_uid import ExpertInfo, is_valid_prefix, is_valid_uid, split_uid
 from hivemind.moe.server import declare_experts, get_experts
-from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
-from hivemind.p2p import PeerInfo
 
 
 @pytest.mark.forked
@@ -35,7 +34,7 @@ def test_store_get_experts(n_peers=10):
     declare_experts(other_peer, [other_expert])
     first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
-    assert first_found.server_peer_info.peer_id == other_peer.peer_id
+    assert first_found.peer_id == other_peer.peer_id
     assert first_notfound is None
 
     # test graceful shutdown
@@ -45,7 +44,7 @@ def test_store_get_experts(n_peers=10):
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
     assert all(declare_experts(remaining_peer1, ["new_expert.1"]))
-    assert get_experts(remaining_peer2, ["new_expert.1"])[0].server_peer_info.peer_id == remaining_peer1.peer_id
+    assert get_experts(remaining_peer2, ["new_expert.1"])[0].peer_id == remaining_peer1.peer_id
 
 
 @pytest.mark.forked
@@ -96,7 +95,7 @@ def test_dht_single_node():
     assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"])) == 7
 
     for expert in get_experts(node, ["expert.3", "expert.2"]):
-        assert expert.server_peer_info.peer_id == node.peer_id
+        assert expert.peer_id == node.peer_id
 
     assert all(declare_experts(node, ["expert.5", "expert.2"]).values())
     found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
@@ -105,11 +104,9 @@ def test_dht_single_node():
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     assert len(successors["e.1.2."]) == 2
 
-    peer_info = PeerInfo(node.peer_id, [a.decapsulate("/p2p/" + a.get("p2p")) for a in node.get_visible_maddrs()])
-
-    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", peer_info)
-    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", peer_info)
-    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", peer_info)
+    assert successors["e.1.2."][3] == ExpertInfo("e.1.2.3", node.peer_id)
+    assert successors["e.1.2."][5] == ExpertInfo("e.1.2.5", node.peer_id)
+    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == ExpertInfo("e.2.0", node.peer_id)
     assert successors["e.4.5."] == {}
 
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)

+ 9 - 8
tests/test_moe.py

@@ -3,9 +3,10 @@ import pytest
 import torch
 
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
+from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
@@ -77,10 +78,10 @@ def test_call_many(hidden_dim=16):
 
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         e0, e1, e2, e3, e4 = create_remote_experts(
-            [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
+            [ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_info.peer_id) for i in range(5)],
             dht,
         )
-        e5 = RemoteExpert(RemoteExpertInfo(f"thisshouldnotexist", server_peer_info), None)
+        e5 = RemoteExpert(ExpertInfo(f"thisshouldnotexist", server_peer_info), None)
 
         mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
@@ -137,8 +138,8 @@ def test_remote_module_call(hidden_dim=16):
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         real_expert, fake_expert = create_remote_experts(
             [
-                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
-                RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="oiasfjiasjf", peer_id=server_peer_info.peer_id),
             ],
             dht=dht,
         )
@@ -181,7 +182,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[RemoteExpert(RemoteExpertInfo(uid, None), None) for uid in all_expert_uids]],
+            [[RemoteExpert(ExpertInfo(uid, None), None) for uid in all_expert_uids]],
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
@@ -205,7 +206,7 @@ def test_determinism(hidden_dim=16):
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         expert = create_remote_experts(
-            [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
+            [ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id)],
             dht=dht,
         )[0]
 
@@ -231,7 +232,7 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
             [
-                RemoteExpert(RemoteExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
+                RemoteExpert(ExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
                 for expert_i in range(len(ii[batch_i]))
             ]
             for batch_i in range(len(ii))

+ 4 - 3
tests/test_training.py

@@ -9,7 +9,8 @@ from sklearn.datasets import load_digits
 
 from hivemind import DHT
 from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
+from hivemind.moe.client.expert import create_remote_experts
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import background_server
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
@@ -26,8 +27,8 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         expert1, expert2 = create_remote_experts(
             [
-                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
-                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
             ],
             dht=dht,
         )