Browse Source

Use PeerID exclusively to address MoE experts (#479)

Changed declare_experts / RemoteExpert to use only p2p peer ID, not the whole multiaddress.
This slightly reduces the code complexity and gives you an easier time sharing experts with dynamic IP.

It also fixes one DHT edge case i've discovered when working on it.

Minor changes:
- fixed an edge case: previously, DHT would **freeze** if accessing DHT.peer_id or otherwise calling .run_coroutine from inside another run_coroutine
- merged RemoteExpertInfo and UidEndpoint into one structure (ExpertInfo), now in expert_uid.py
- extracted expert_uid.py from hivemind.moe.server to hivemind.moe in order to avoid circular imports
- renamed get_expert_stub into get_server_stub since it is not expert-specific


Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Pavel Samygin <samygin@phystech.edu>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 years ago
parent
commit
25366a1436

+ 4 - 4
benchmarks/benchmark_throughput.py

@@ -7,11 +7,12 @@ import time
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo
+from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server import ExpertBackend, Server
 from hivemind.moe.server.layers import name_to_block
 from hivemind.moe.server.layers import name_to_block
-from hivemind.p2p import P2P, PeerInfo
+from hivemind.p2p import P2P
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
@@ -48,9 +49,8 @@ def client_process(
     can_start.wait()
     can_start.wait()
 
 
     p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
     p2p = RemoteExpertWorker.run_coroutine(P2P.create(initial_peers=server_maddrs))
-    peer_info = PeerInfo(server_peer_id, server_maddrs)
     experts = [
     experts = [
-        RemoteExpert(expert_info=RemoteExpertInfo(uid=f"expert.{i}", peer_info=peer_info), p2p=p2p)
+        RemoteExpert(expert_info=ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_id), p2p=p2p)
         for i in range(num_experts)
         for i in range(num_experts)
     ]
     ]
 
 

+ 8 - 1
hivemind/dht/dht.py

@@ -169,6 +169,7 @@ class DHT(mp.Process):
         :param kwargs: parameters forwarded to DHTNode.get_many_by_id
         :param kwargs: parameters forwarded to DHTNode.get_many_by_id
         :returns: (value, expiration time); if value was not found, returns None
         :returns: (value, expiration time); if value was not found, returns None
         """
         """
+        assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
         future = MPFuture()
         future = MPFuture()
         self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
         self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
         return future if return_future else future.result()
         return future if return_future else future.result()
@@ -202,6 +203,7 @@ class DHT(mp.Process):
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         """
+        assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
         future = MPFuture()
         future = MPFuture()
         self._outer_pipe.send(
         self._outer_pipe.send(
             (
             (
@@ -246,6 +248,7 @@ class DHT(mp.Process):
           or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
           or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
         :note: when run_coroutine is called with return_future=False, MPFuture can be cancelled to interrupt the task.
         :note: when run_coroutine is called with return_future=False, MPFuture can be cancelled to interrupt the task.
         """
         """
+        assert os.getpid() != self.pid, "calling *external* DHT interface from inside DHT will result in a deadlock"
         future = MPFuture()
         future = MPFuture()
         self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
         self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
         return future if return_future else future.result()
         return future if return_future else future.result()
@@ -275,7 +278,11 @@ class DHT(mp.Process):
     @property
     @property
     def peer_id(self) -> PeerID:
     def peer_id(self) -> PeerID:
         if self._peer_id is None:
         if self._peer_id is None:
-            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+            if os.getpid() == self.pid:
+                self._peer_id = self._node.peer_id
+            else:
+                # note: we cannot run_coroutine from the same pid because it would deadlock the event loop
+                self._peer_id = self.run_coroutine(DHT._get_peer_id)
         return self._peer_id
         return self._peer_id
 
 
     @staticmethod
     @staticmethod

+ 51 - 60
hivemind/moe/client/beam_search.py

@@ -5,25 +5,21 @@ from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode
 from hivemind.dht import DHT, DHTExpiration, DHTNode
-from hivemind.moe.client.expert import (
-    RemoteExpert,
-    RemoteExpertInfo,
-    batch_create_remote_experts,
-    create_remote_experts,
-)
-from hivemind.moe.server.expert_uid import (
+from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
+from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     PREFIX_PATTERN,
     PREFIX_PATTERN,
     UID_DELIMITER,
     UID_DELIMITER,
     Coordinate,
     Coordinate,
+    ExpertInfo,
     ExpertPrefix,
     ExpertPrefix,
     ExpertUID,
     ExpertUID,
     Score,
     Score,
-    UidEndpoint,
     is_valid_prefix,
     is_valid_prefix,
+    is_valid_uid,
 )
 )
-from hivemind.p2p import PeerInfo
-from hivemind.utils import MPFuture, get_dht_time, get_logger
+from hivemind.p2p import PeerID
+from hivemind.utils import MPFuture, ValueWithExpiration, get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -100,7 +96,7 @@ class MoEBeamSearcher:
 
 
     def get_initial_beam(
     def get_initial_beam(
         self, scores: Sequence[float], beam_size: int, return_future: bool = False
         self, scores: Sequence[float], beam_size: int, return_future: bool = False
-    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
         """
         """
         :param scores: prefer suffix coordinates that have highest scores
         :param scores: prefer suffix coordinates that have highest scores
         :param beam_size: select this many active suffixes with highest scores
         :param beam_size: select this many active suffixes with highest scores
@@ -130,9 +126,9 @@ class MoEBeamSearcher:
         negative_caching: bool,
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
-    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
+    ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]]:
         num_workers = num_workers or dht.num_workers or beam_size
         num_workers = num_workers or dht.num_workers or beam_size
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = []
         unattempted_indices: List[Coordinate] = sorted(
         unattempted_indices: List[Coordinate] = sorted(
             range(len(scores)), key=scores.__getitem__
             range(len(scores)), key=scores.__getitem__
         )  # from worst to best
         )  # from worst to best
@@ -150,13 +146,7 @@ class MoEBeamSearcher:
             try:
             try:
                 maybe_prefix_data = await pending_task
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
-                    successors = {
-                        coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
-                        for coord, match in maybe_prefix_data.value.items()
-                        if isinstance(coord, Coordinate)
-                        and isinstance(getattr(match, "value", None), list)
-                        and len(match.value) == 2
-                    }
+                    successors = MoEBeamSearcher._select_valid_entries(maybe_prefix_data)
                     if successors:
                     if successors:
                         beam.append((scores[pending_best_index], pending_best_prefix, successors))
                         beam.append((scores[pending_best_index], pending_best_prefix, successors))
                 elif maybe_prefix_data is None and negative_caching:
                 elif maybe_prefix_data is None and negative_caching:
@@ -178,7 +168,7 @@ class MoEBeamSearcher:
 
 
     def get_active_successors(
     def get_active_successors(
         self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
         self, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, return_future: bool = False
-    ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+    ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
         """
         """
         :param prefixes: a list of prefix for which to find active successor uids
         :param prefixes: a list of prefix for which to find active successor uids
         :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
         :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
@@ -201,6 +191,22 @@ class MoEBeamSearcher:
             return_future=return_future,
             return_future=return_future,
         )
         )
 
 
+    @staticmethod
+    def _select_valid_entries(entry: ValueWithExpiration, grid_size: Optional[int] = None):
+        if not isinstance(entry, ValueWithExpiration) or not isinstance(entry.value, dict):
+            return {}
+        return {
+            coord: ExpertInfo(uid=match.value[0], peer_id=PeerID.from_base58(match.value[1]))
+            for coord, match in entry.value.items()
+            if isinstance(coord, Coordinate)
+            and (grid_size is None or 0 <= coord < grid_size)
+            and isinstance(match, ValueWithExpiration)
+            and isinstance(match.value, tuple)
+            and len(match.value) == 2
+            and is_valid_uid(match.value[0])
+            and isinstance(match.value[1], str)
+        }
+
     @staticmethod
     @staticmethod
     async def _get_active_successors(
     async def _get_active_successors(
         dht: DHT,
         dht: DHT,
@@ -210,28 +216,18 @@ class MoEBeamSearcher:
         negative_caching: bool,
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
-    ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
+    ) -> Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]]:
         grid_size = grid_size or float("inf")
         grid_size = grid_size or float("inf")
         num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
         num_workers = num_workers or min(len(prefixes), dht.num_workers or len(prefixes))
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
         dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
-        successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
+        successors: Dict[ExpertPrefix, Dict[Coordinate, ExpertInfo]] = {}
         for prefix, found in dht_responses.items():
         for prefix, found in dht_responses.items():
-            if found and isinstance(found.value, dict):
-                successors[prefix] = {
-                    coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
-                    for coord, match in found.value.items()
-                    if isinstance(coord, Coordinate)
-                    and 0 <= coord < grid_size
-                    and isinstance(getattr(match, "value", None), list)
-                    and len(match.value) == 2
-                }
-            else:
-                successors[prefix] = {}
-                if found is None and negative_caching:
-                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
-                    asyncio.create_task(
-                        node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
-                    )
+            successors[prefix] = MoEBeamSearcher._select_valid_entries(found, grid_size)
+            if not successors[prefix] and negative_caching:
+                logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
+                asyncio.create_task(
+                    node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + cache_expiration)
+                )
         return successors
         return successors
 
 
     def find_best_experts(
     def find_best_experts(
@@ -246,7 +242,6 @@ class MoEBeamSearcher:
          After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
          After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
          Please note that any queries that fall outside the budget will still be performed in background and cached
          Please note that any queries that fall outside the budget will still be performed in background and cached
          for subsequent iterations as long as DHTNode.cache_locally is True
          for subsequent iterations as long as DHTNode.cache_locally is True
-        :param num_workers: use up to this many concurrent workers to search DHT
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
         :returns: a list that contains *up to* k_best RemoteExpert instances
         :returns: a list that contains *up to* k_best RemoteExpert instances
         """
         """
@@ -263,7 +258,6 @@ class MoEBeamSearcher:
             ),
             ),
             return_future,
             return_future,
         )
         )
-
         return create_remote_experts(result, self.dht, return_future)
         return create_remote_experts(result, self.dht, return_future)
 
 
     @classmethod
     @classmethod
@@ -277,23 +271,23 @@ class MoEBeamSearcher:
         negative_caching: bool,
         negative_caching: bool,
         cache_expiration: DHTExpiration,
         cache_expiration: DHTExpiration,
         num_workers: Optional[int] = None,
         num_workers: Optional[int] = None,
-    ) -> List[RemoteExpertInfo]:
+    ) -> List[ExpertInfo]:
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
         num_workers = num_workers or min(beam_size, dht.num_workers or beam_size)
 
 
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
         # form initial beam from top-k active L1 prefixes, each row is (score, uid prefix, possible suffixes)
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = await cls._get_initial_beam(
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]] = await cls._get_initial_beam(
             dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
             dht, node, prefix, beam_size, grid_scores[0], negative_caching, min(beam_size, num_workers)
         )
         )
 
 
-        best_experts_heap: List[Tuple[Score, UidEndpoint]] = []  # max-heap of expert uids/endpoints ordered by scores
+        best_experts_heap: List[Tuple[Score, ExpertInfo]] = []  # max-heap of expert infos ordered by scores
         unique_experts: Set[ExpertUID] = set()
         unique_experts: Set[ExpertUID] = set()
 
 
         for dim_index in range(1, len(grid_scores) - 1):
         for dim_index in range(1, len(grid_scores) - 1):
-            for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
-                if uid_endpoint.uid not in unique_experts:
+            for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
+                if expert_info.uid not in unique_experts:
                     push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
                     push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                    push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                    unique_experts.add(uid_endpoint.uid)
+                    push_and_maybe_pop(best_experts_heap, (score, expert_info))
+                    unique_experts.add(expert_info.uid)
 
 
             # form new beam using successors from the current beam
             # form new beam using successors from the current beam
             dim_scores = grid_scores[dim_index]
             dim_scores = grid_scores[dim_index]
@@ -306,6 +300,7 @@ class MoEBeamSearcher:
                     if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
                     if isinstance(next_coord, int) and 0 <= next_coord < len(dim_scores)
                 ),
                 ),
             )
             )
+
             _, best_uid_prefixes = zip(*best_active_pairs)
             _, best_uid_prefixes = zip(*best_active_pairs)
 
 
             # search DHT for next step suffixes
             # search DHT for next step suffixes
@@ -324,22 +319,18 @@ class MoEBeamSearcher:
                 break
                 break
 
 
         # add best experts from the final beam
         # add best experts from the final beam
-        for score, uid_endpoint in cls._iterate_matching_experts(beam, grid_scores):
-            if uid_endpoint.uid not in unique_experts:
+        for score, expert_info in cls._iterate_matching_experts(beam, grid_scores):
+            if expert_info.uid not in unique_experts:
                 push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
                 push_and_maybe_pop = heapq.heappush if len(best_experts_heap) < beam_size else heapq.heappushpop
-                push_and_maybe_pop(best_experts_heap, (score, uid_endpoint))
-                unique_experts.add(uid_endpoint.uid)
+                push_and_maybe_pop(best_experts_heap, (score, expert_info))
+                unique_experts.add(expert_info.uid)
 
 
-        best_experts = [
-            RemoteExpertInfo(uid_endpoint.uid, uid_endpoint.peer_info)
-            for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
-        ]
-        return best_experts
+        return [expert_info for _, expert_info in sorted(best_experts_heap, reverse=True)]
 
 
     @staticmethod
     @staticmethod
     def _iterate_matching_experts(
     def _iterate_matching_experts(
-        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]], grid_scores: Sequence[Sequence[float]]
-    ) -> Iterator[Tuple[Score, UidEndpoint]]:
+        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, ExpertInfo]]], grid_scores: Sequence[Sequence[float]]
+    ) -> Iterator[Tuple[Score, ExpertInfo]]:
         """iterate over all exemplar experts attached to current beam"""
         """iterate over all exemplar experts attached to current beam"""
         for score, prefix, suffixes in beam:
         for score, prefix, suffixes in beam:
             for next_coord, match in suffixes.items():
             for next_coord, match in suffixes.items():
@@ -399,7 +390,7 @@ class MoEBeamSearcher:
         beam_size: int,
         beam_size: int,
         negative_caching: bool,
         negative_caching: bool,
         num_workers: Optional[int],
         num_workers: Optional[int],
-    ) -> Sequence[Sequence[RemoteExpertInfo]]:
+    ) -> Sequence[Sequence[ExpertInfo]]:
         batch_grid_scores = [
         batch_grid_scores = [
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
             [tuple(grid_score[i]) for grid_score in batch_grid_scores] for i in range(len(batch_grid_scores[0]))
         ]
         ]

+ 13 - 20
hivemind/moe/client/expert.py

@@ -1,7 +1,6 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 from concurrent.futures import Future
 from concurrent.futures import Future
-from dataclasses import dataclass
 from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
@@ -12,7 +11,8 @@ from hivemind import moe
 from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2P, PeerInfo, StubBase
+from hivemind.moe.expert_uid import ExpertInfo
+from hivemind.p2p import P2P, PeerID, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
 from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
@@ -24,16 +24,9 @@ from hivemind.utils.streaming import split_for_streaming
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
 
 
-def get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandlerStub":
-    return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
-
-
-@dataclass(frozen=True)
-class RemoteExpertInfo:
-    """A simple data class containing uid of expert and server PeerInfo"""
-
-    uid: str
-    peer_info: PeerInfo
+def get_server_stub(p2p: P2P, server_peer_id: PeerID) -> "ConnectionHandlerStub":
+    """Create an RPC stub that can send requests to any expert on the specified remote server"""
+    return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_id)
 
 
 
 
 class RemoteExpert(nn.Module):
 class RemoteExpert(nn.Module):
@@ -47,7 +40,7 @@ class RemoteExpert(nn.Module):
     :param p2p: P2P instance connected to the running p2pd
     :param p2p: P2P instance connected to the running p2pd
     """
     """
 
 
-    def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
+    def __init__(self, expert_info: ExpertInfo, p2p: P2P):
         super().__init__()
         super().__init__()
         self._info, self.p2p = expert_info, p2p
         self._info, self.p2p = expert_info, p2p
         self._rpc_info = None
         self._rpc_info = None
@@ -57,12 +50,12 @@ class RemoteExpert(nn.Module):
         return self._info.uid
         return self._info.uid
 
 
     @property
     @property
-    def server_peer_info(self):
-        return self._info.peer_info
+    def peer_id(self) -> PeerID:
+        return self._info.peer_id
 
 
     @property
     @property
     def stub(self) -> StubBase:
     def stub(self) -> StubBase:
-        return get_expert_stub(self.p2p, self.server_peer_info)
+        return get_server_stub(self.p2p, self.peer_id)
 
 
     def forward(self, *args, **kwargs):
     def forward(self, *args, **kwargs):
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
         """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
@@ -89,10 +82,10 @@ class RemoteExpert(nn.Module):
         return self._rpc_info
         return self._rpc_info
 
 
     def extra_repr(self):
     def extra_repr(self):
-        return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
+        return f"uid={self.uid}, server_peer_id={self.peer_id}"
 
 
 
 
-def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
+def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
     experts: List[Optional[RemoteExpert]] = []
     experts: List[Optional[RemoteExpert]] = []
     for info in infos:
     for info in infos:
         if info is not None:
         if info is not None:
@@ -103,7 +96,7 @@ def _create_remote_experts(infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P
 
 
 
 
 def create_remote_experts(
 def create_remote_experts(
-    infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
+    infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
 ) -> Union[List[Optional[RemoteExpert]], Future]:
 ) -> Union[List[Optional[RemoteExpert]], Future]:
     if return_future:
     if return_future:
 
 
@@ -118,7 +111,7 @@ def create_remote_experts(
 
 
 
 
 def batch_create_remote_experts(
 def batch_create_remote_experts(
-    infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
+    infos: Union[Sequence[Sequence[Optional[ExpertInfo]]], MPFuture],
     dht: DHT,
     dht: DHT,
     return_future: bool = False,
     return_future: bool = False,
 ) -> Union[List[List[Optional[RemoteExpert]]], Future]:
 ) -> Union[List[List[Optional[RemoteExpert]]], Future]:

+ 4 - 4
hivemind/moe/client/moe.py

@@ -12,9 +12,9 @@ from torch.autograd.function import once_differentiable
 from hivemind.compression import serialize_torch_tensor
 from hivemind.compression import serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import DUMMY, RemoteExpert, expert_backward, expert_forward, get_expert_stub
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, expert_backward, expert_forward, get_server_stub
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.moe.expert_uid import UID_DELIMITER
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
@@ -227,7 +227,7 @@ class _RemoteCallMany(torch.autograd.Function):
         pending_tasks: Dict[Future, Tuple[int, int]] = {}
         pending_tasks: Dict[Future, Tuple[int, int]] = {}
         for i in range(num_samples):
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
             for j, expert in enumerate(experts_per_sample[i]):
-                stub = get_expert_stub(expert.p2p, expert.server_peer_info)
+                stub = get_server_stub(expert.p2p, expert.peer_id)
                 serialized_tensors = (
                 serialized_tensors = (
                     serialize_torch_tensor(tensor, proto.compression)
                     serialize_torch_tensor(tensor, proto.compression)
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
                     for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
@@ -321,7 +321,7 @@ class _RemoteCallMany(torch.autograd.Function):
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
             alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
         ):
         ):
             expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
             expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
-            stub = get_expert_stub(expert.p2p, expert.server_peer_info)
+            stub = get_server_stub(expert.p2p, expert.peer_id)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             serialized_tensors = (
             serialized_tensors = (
                 serialize_torch_tensor(tensor, proto.compression)
                 serialize_torch_tensor(tensor, proto.compression)

+ 1 - 1
hivemind/moe/client/switch_moe.py

@@ -6,7 +6,7 @@ import torch
 
 
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
-from hivemind.moe.server.expert_uid import UID_DELIMITER
+from hivemind.moe.expert_uid import UID_DELIMITER
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger

+ 4 - 2
hivemind/moe/server/expert_uid.py → hivemind/moe/expert_uid.py

@@ -1,10 +1,12 @@
+from __future__ import annotations
+
 import re
 import re
 from typing import NamedTuple, Tuple, Union
 from typing import NamedTuple, Tuple, Union
 
 
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerInfo
+from hivemind.p2p import PeerID
 
 
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
-UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("peer_info", PeerInfo)])
+ExpertInfo = NamedTuple("ExpertInfo", [("uid", ExpertUID), ("peer_id", PeerID)])
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims

+ 17 - 16
hivemind/moe/server/dht_handler.py

@@ -3,18 +3,19 @@ from functools import partial
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
-from hivemind.moe.server.expert_uid import (
+from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
+from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
     UID_DELIMITER,
     UID_DELIMITER,
     UID_PATTERN,
     UID_PATTERN,
     Coordinate,
     Coordinate,
+    ExpertInfo,
     ExpertPrefix,
     ExpertPrefix,
     ExpertUID,
     ExpertUID,
     is_valid_uid,
     is_valid_uid,
     split_uid,
     split_uid,
 )
 )
-from hivemind.p2p import PeerID, PeerInfo
+from hivemind.p2p import PeerID
 from hivemind.utils import MPFuture, get_dht_time
 from hivemind.utils import MPFuture, get_dht_time
 
 
 
 
@@ -44,27 +45,27 @@ def declare_experts(
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+    if not isinstance(uids, list):
+        uids = list(uids)
     for uid in uids:
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
-    addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in dht.get_visible_maddrs())
-    return dht.run_coroutine(
-        partial(_declare_experts, uids=list(uids), peer_id=dht.peer_id, addrs=addrs, expiration=expiration),
-        return_future=not wait,
-    )
+    return dht.run_coroutine(partial(_declare_experts, uids=uids, expiration=expiration), return_future=not wait)
 
 
 
 
 async def _declare_experts(
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, addrs: Tuple[str], expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
+    peer_id_base58 = dht.peer_id.to_base58()
+
     for uid in uids:
     for uid in uids:
-        data_to_store[uid, None] = (peer_id.to_base58(), addrs)
+        data_to_store[uid, None] = peer_id_base58
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         for i in range(prefix.count(UID_DELIMITER) - 1):
         for i in range(prefix.count(UID_DELIMITER) - 1):
             prefix, last_coord = split_uid(prefix)
             prefix, last_coord = split_uid(prefix)
-            data_to_store[prefix, last_coord] = [uid, (peer_id.to_base58(), addrs)]
+            data_to_store[prefix, last_coord] = (uid, peer_id_base58)
 
 
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
@@ -87,15 +88,15 @@ def get_experts(
 
 
 async def _get_experts(
 async def _get_experts(
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
     dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[RemoteExpertInfo]]:
+) -> List[Optional[ExpertInfo]]:
     if expiration_time is None:
     if expiration_time is None:
         expiration_time = get_dht_time()
         expiration_time = get_dht_time()
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
     found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
 
-    experts: List[Optional[RemoteExpert]] = [None] * len(uids)
+    experts: List[Optional[ExpertInfo]] = [None] * len(uids)
     for i, uid in enumerate(uids):
     for i, uid in enumerate(uids):
-        expert_info_for_uid = found[uid]
-        if expert_info_for_uid is not None and isinstance(expert_info_for_uid.value, tuple):
-            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(expert_info_for_uid.value))
+        server_peer_id = found[uid]
+        if server_peer_id is not None and isinstance(server_peer_id.value, str):
+            experts[i] = ExpertInfo(uid, PeerID.from_base58(server_peer_id.value))
     return experts
     return experts

+ 2 - 3
hivemind/moe/server/server.py

@@ -6,17 +6,16 @@ import threading
 from contextlib import contextmanager
 from contextlib import contextmanager
 from functools import partial
 from functools import partial
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
 
 
 import torch
 import torch
-from multiaddr import Multiaddr
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
+from hivemind.moe.expert_uid import UID_DELIMITER
 from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
 from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.layers import (
 from hivemind.moe.server.layers import (
     add_custom_models_from_file,
     add_custom_models_from_file,
     name_to_block,
     name_to_block,

+ 0 - 6
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -128,12 +128,6 @@ class PeerInfo:
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         return PeerInfo(peer_id, addrs)
         return PeerInfo(peer_id, addrs)
 
 
-    @classmethod
-    def from_tuple(cls, value: Tuple[str, Sequence[str]]) -> "PeerInfo":
-        peer_id = PeerID.from_base58(value[0])
-        addrs = [Multiaddr(addr) for addr in value[1]]
-        return PeerInfo(peer_id, addrs)
-
     def __str__(self):
     def __str__(self):
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
 

+ 6 - 5
tests/test_custom_experts.py

@@ -4,7 +4,8 @@ import pytest
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
+from hivemind.moe.client.expert import create_remote_experts
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
 
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py")
@@ -23,8 +24,8 @@ def test_custom_expert(hid_dim=16):
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         expert0, expert1 = create_remote_experts(
         expert0, expert1 = create_remote_experts(
             [
             [
-                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
-                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
             ],
             ],
             dht=dht,
             dht=dht,
         )
         )
@@ -54,8 +55,8 @@ def test_multihead_expert(hid_dim=16):
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         expert0, expert1 = create_remote_experts(
         expert0, expert1 = create_remote_experts(
             [
             [
-                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
-                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
             ],
             ],
             dht=dht,
             dht=dht,
         )
         )

+ 7 - 10
tests/test_dht_experts.py

@@ -8,9 +8,8 @@ import pytest
 import hivemind
 import hivemind
 from hivemind.dht import DHTNode
 from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
+from hivemind.moe.expert_uid import ExpertInfo, is_valid_prefix, is_valid_uid, split_uid
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.moe.server import declare_experts, get_experts
-from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
-from hivemind.p2p import PeerInfo
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -35,7 +34,7 @@ def test_store_get_experts(n_peers=10):
     declare_experts(other_peer, [other_expert])
     declare_experts(other_peer, [other_expert])
     first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert isinstance(first_found, hivemind.RemoteExpert)
-    assert first_found.server_peer_info.peer_id == other_peer.peer_id
+    assert first_found.peer_id == other_peer.peer_id
     assert first_notfound is None
     assert first_notfound is None
 
 
     # test graceful shutdown
     # test graceful shutdown
@@ -45,7 +44,7 @@ def test_store_get_experts(n_peers=10):
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
     assert all(declare_experts(remaining_peer1, ["new_expert.1"]))
     assert all(declare_experts(remaining_peer1, ["new_expert.1"]))
-    assert get_experts(remaining_peer2, ["new_expert.1"])[0].server_peer_info.peer_id == remaining_peer1.peer_id
+    assert get_experts(remaining_peer2, ["new_expert.1"])[0].peer_id == remaining_peer1.peer_id
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -96,7 +95,7 @@ def test_dht_single_node():
     assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"])) == 7
     assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"])) == 7
 
 
     for expert in get_experts(node, ["expert.3", "expert.2"]):
     for expert in get_experts(node, ["expert.3", "expert.2"]):
-        assert expert.server_peer_info.peer_id == node.peer_id
+        assert expert.peer_id == node.peer_id
 
 
     assert all(declare_experts(node, ["expert.5", "expert.2"]).values())
     assert all(declare_experts(node, ["expert.5", "expert.2"]).values())
     found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
     found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2)
@@ -105,11 +104,9 @@ def test_dht_single_node():
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     successors = beam_search.get_active_successors(["e.1.2.", "e.2.", "e.4.5."])
     assert len(successors["e.1.2."]) == 2
     assert len(successors["e.1.2."]) == 2
 
 
-    peer_info = PeerInfo(node.peer_id, [a.decapsulate("/p2p/" + a.get("p2p")) for a in node.get_visible_maddrs()])
-
-    assert successors["e.1.2."][3] == UidEndpoint("e.1.2.3", peer_info)
-    assert successors["e.1.2."][5] == UidEndpoint("e.1.2.5", peer_info)
-    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == UidEndpoint("e.2.0", peer_info)
+    assert successors["e.1.2."][3] == ExpertInfo("e.1.2.3", node.peer_id)
+    assert successors["e.1.2."][5] == ExpertInfo("e.1.2.5", node.peer_id)
+    assert len(successors["e.2."]) == 1 and successors["e.2."][0] == ExpertInfo("e.2.0", node.peer_id)
     assert successors["e.4.5."] == {}
     assert successors["e.4.5."] == {}
 
 
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)
     initial_beam = beam_search.get_initial_beam((3, 2, 1, 0, -1, -2, -3), beam_size=3)

+ 9 - 8
tests/test_moe.py

@@ -3,9 +3,10 @@ import pytest
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, create_remote_experts
+from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
 from hivemind.moe.server.layers import name_to_block
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
@@ -77,10 +78,10 @@ def test_call_many(hidden_dim=16):
 
 
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         e0, e1, e2, e3, e4 = create_remote_experts(
         e0, e1, e2, e3, e4 = create_remote_experts(
-            [RemoteExpertInfo(uid=f"expert.{i}", peer_info=server_peer_info) for i in range(5)],
+            [ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_info.peer_id) for i in range(5)],
             dht,
             dht,
         )
         )
-        e5 = RemoteExpert(RemoteExpertInfo(f"thisshouldnotexist", server_peer_info), None)
+        e5 = RemoteExpert(ExpertInfo(f"thisshouldnotexist", server_peer_info), None)
 
 
         mask, expert_outputs = _RemoteCallMany.apply(
         mask, expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             DUMMY,
@@ -137,8 +138,8 @@ def test_remote_module_call(hidden_dim=16):
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         real_expert, fake_expert = create_remote_experts(
         real_expert, fake_expert = create_remote_experts(
             [
             [
-                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
-                RemoteExpertInfo(uid="oiasfjiasjf", peer_info=server_peer_info),
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="oiasfjiasjf", peer_id=server_peer_info.peer_id),
             ],
             ],
             dht=dht,
             dht=dht,
         )
         )
@@ -181,7 +182,7 @@ def test_beam_search_correctness():
         # reference: independently find :beam_size: best experts with exhaustive search
         # reference: independently find :beam_size: best experts with exhaustive search
         all_scores = dmoe.compute_expert_scores(
         all_scores = dmoe.compute_expert_scores(
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
             [dim_scores.unsqueeze(0) for dim_scores in grid_scores],
-            [[RemoteExpert(RemoteExpertInfo(uid, None), None) for uid in all_expert_uids]],
+            [[RemoteExpert(ExpertInfo(uid, None), None) for uid in all_expert_uids]],
         )[0]
         )[0]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
         true_best_scores = sorted(all_scores.cpu().detach().numpy(), reverse=True)[: len(chosen_experts)]
 
 
@@ -205,7 +206,7 @@ def test_determinism(hidden_dim=16):
     ) as server_peer_info:
     ) as server_peer_info:
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         expert = create_remote_experts(
         expert = create_remote_experts(
-            [RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info)],
+            [ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id)],
             dht=dht,
             dht=dht,
         )[0]
         )[0]
 
 
@@ -231,7 +232,7 @@ def test_compute_expert_scores():
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
         batch_experts = [
             [
             [
-                RemoteExpert(RemoteExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
+                RemoteExpert(ExpertInfo(f"expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}", None), None)
                 for expert_i in range(len(ii[batch_i]))
                 for expert_i in range(len(ii[batch_i]))
             ]
             ]
             for batch_i in range(len(ii))
             for batch_i in range(len(ii))

+ 4 - 3
tests/test_training.py

@@ -9,7 +9,8 @@ from sklearn.datasets import load_digits
 
 
 from hivemind import DHT
 from hivemind import DHT
 from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.client.expert import RemoteExpertInfo, create_remote_experts
+from hivemind.moe.client.expert import create_remote_experts
+from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
@@ -26,8 +27,8 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         dht = DHT(initial_peers=server_peer_info.addrs, start=True)
         expert1, expert2 = create_remote_experts(
         expert1, expert2 = create_remote_experts(
             [
             [
-                RemoteExpertInfo(uid="expert.0", peer_info=server_peer_info),
-                RemoteExpertInfo(uid="expert.1", peer_info=server_peer_info),
+                ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id),
+                ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id),
             ],
             ],
             dht=dht,
             dht=dht,
         )
         )