浏览代码

add benchmarks

Denis Mazur 4 年之前
父节点
当前提交
87012861ff

+ 34 - 19
benchmarks/benchmark_throughput-p2p.py → benchmarks/benchmark_throughput_p2p.py

@@ -7,7 +7,7 @@ import time
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-from hivemind import get_free_port
+from hivemind.dht import DHT
 from hivemind.moe.server import layers
 from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
@@ -28,11 +28,20 @@ def print_device_info(device=None):
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
 
 
 
 
-def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+def client_process(
+    can_start,
+    benchmarking_failed,
+    server_peer_info,
+    num_experts,
+    batch_size,
+    hid_dim,
+    num_batches,
+    backprop=True,
+) -> None:
     torch.set_num_threads(1)
     torch.set_num_threads(1)
     can_start.wait()
     can_start.wait()
     experts = [
     experts = [
-        hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
+        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info) for i in range(num_experts)
     ]
     ]
 
 
     try:
     try:
@@ -48,17 +57,16 @@ def client_process(can_start, benchmarking_failed, port, num_experts, batch_size
 
 
 
 
 def benchmark_throughput(
 def benchmark_throughput(
-    num_experts=16,
+    num_experts=1,
     num_handlers=None,
     num_handlers=None,
-    num_clients=128,
+    num_clients=1,
     num_batches_per_client=16,
     num_batches_per_client=16,
     expert_cls="ffn",
     expert_cls="ffn",
     hid_dim=1024,
     hid_dim=1024,
-    batch_size=2048,
+    batch_size=16,
     max_batch_size=None,
     max_batch_size=None,
     backprop=True,
     backprop=True,
     device=None,
     device=None,
-    port=None,
 ):
 ):
     assert (
     assert (
         not hasattr(torch.cuda, "is_initialized")
         not hasattr(torch.cuda, "is_initialized")
@@ -66,7 +74,6 @@ def benchmark_throughput(
         or torch.device(device) == torch.device("cpu")
         or torch.device(device) == torch.device("cpu")
     )
     )
     assert expert_cls in layers.name_to_block
     assert expert_cls in layers.name_to_block
-    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
     benchmarking_failed = mp.Event()
@@ -74,8 +81,12 @@ def benchmark_throughput(
     timestamps = dict(started=time.perf_counter())
     timestamps = dict(started=time.perf_counter())
 
 
     try:
     try:
-        # start clients and await server
-        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        server_dht = DHT(start=True)
+        server_dht_peer_info = hivemind.PeerInfo(
+            peer_id=server_dht.peer_id,
+            addrs=[addr.decapsulate("/p2p/" + addr.get("p2p")) for addr in server_dht.get_visible_maddrs()],
+        )
+
         clients = [
         clients = [
             mp.Process(
             mp.Process(
                 target=client_process,
                 target=client_process,
@@ -83,30 +94,29 @@ def benchmark_throughput(
                 args=(
                 args=(
                     can_start,
                     can_start,
                     benchmarking_failed,
                     benchmarking_failed,
-                    port,
+                    server_dht_peer_info,
                     num_experts,
                     num_experts,
                     batch_size,
                     batch_size,
                     hid_dim,
                     hid_dim,
                     num_batches_per_client,
                     num_batches_per_client,
                     backprop,
                     backprop,
                 ),
                 ),
+                daemon=True,
             )
             )
             for i in range(num_clients)
             for i in range(num_clients)
         ]
         ]
 
 
         for client in clients:
         for client in clients:
-            client.daemon = True
             client.start()
             client.start()
 
 
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
 
-        # start server
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         experts = {}
         for i in range(num_experts):
         for i in range(num_experts):
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = hivemind.ExpertBackend(
-                name=f"expert{i}",
+            experts[f"expert.{i}"] = hivemind.ExpertBackend(
+                name=f"expert.{i}",
                 expert=expert,
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
                 optimizer=torch.optim.Adam(expert.parameters()),
                 args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
                 args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
@@ -114,11 +124,11 @@ def benchmark_throughput(
                 max_batch_size=max_batch_size,
                 max_batch_size=max_batch_size,
             )
             )
         timestamps["created_experts"] = time.perf_counter()
         timestamps["created_experts"] = time.perf_counter()
+
         server = hivemind.moe.Server(
         server = hivemind.moe.Server(
-            None,
-            experts,
-            listen_on=f"{hivemind.LOCALHOST}:{port}",
-            num_connection_handlers=num_handlers,
+            dht=server_dht,
+            expert_backends=experts,
+            num_connection_handlers=1,  # TODO: support greater number
             device=device,
             device=device,
         )
         )
         server.start()
         server.start()
@@ -127,8 +137,13 @@ def benchmark_throughput(
         can_start.set()
         can_start.set()
 
 
         for client in clients:
         for client in clients:
+            print("joining clients")
+            print(client.is_alive())
             client.join()
             client.join()
+            print("i'm here")
+
         timestamps["clients_finished"] = time.perf_counter()
         timestamps["clients_finished"] = time.perf_counter()
+
     except BaseException as e:
     except BaseException as e:
         benchmarking_failed.set()
         benchmarking_failed.set()
         raise e
         raise e

+ 1 - 2
hivemind/hivemind_cli/run_server.py

@@ -17,8 +17,7 @@ def main():
     # fmt:off
     # fmt:off
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
-    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
-                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
+
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
                         help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will'
                         help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will'

+ 8 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,9 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class, ConnectionHandler
+from hivemind.moe.server import (
+    ConnectionHandler,
+    ExpertBackend,
+    Server,
+    declare_experts,
+    get_experts,
+    register_expert_class,
+)

+ 15 - 22
hivemind/moe/client/expert.py

@@ -1,32 +1,23 @@
-import asyncio
-from concurrent.futures import Future
-import multiprocessing as mp
 import pickle
 import pickle
-from typing import Any, Dict, Optional, Tuple, Awaitable
-from threading import Thread
+from concurrent.futures import Future
 from queue import Queue
 from queue import Queue
+from threading import Thread
+from typing import Any, Awaitable, Dict, Optional, Tuple
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-
-#from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+import hivemind
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
-
+from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
 
 
-import hivemind
-
-#ConnectionHandlerStub = hivemind.moe.server.connection_handler.ConnectionHandler._stub_type
-
-
-def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
+def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHandlerStub:
     return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
     return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
 
 
 
@@ -39,7 +30,7 @@ class RemoteExpert(nn.Module):
     :param uid: unique expert identifier
     :param uid: unique expert identifier
     """
     """
 
 
-    def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None):
+    def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None, connect: bool = True):
         super().__init__()
         super().__init__()
         self.uid, self.server_peer_info = uid, server_peer_info
         self.uid, self.server_peer_info = uid, server_peer_info
         self._info = None
         self._info = None
@@ -49,7 +40,8 @@ class RemoteExpert(nn.Module):
         else:
         else:
             self.p2p = p2p
             self.p2p = p2p
 
 
-        _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
+        if connect:
+            _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
 
 
     @property
     @property
     def stub(self) -> StubBase:
     def stub(self) -> StubBase:
@@ -66,6 +58,7 @@ class RemoteExpert(nn.Module):
 
 
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
 
 
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
@@ -100,7 +93,6 @@ class _RemoteModuleCall(torch.autograd.Function):
                 except Exception as e:
                 except Exception as e:
                     future.set_exception(e)
                     future.set_exception(e)
                     continue
                     continue
-
                 future.set_result(result)
                 future.set_result(result)
 
 
         loop.run_until_complete(receive_tasks())
         loop.run_until_complete(receive_tasks())
@@ -108,7 +100,7 @@ class _RemoteModuleCall(torch.autograd.Function):
     @classmethod
     @classmethod
     def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
     def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
         if cls._event_thread is None:
         if cls._event_thread is None:
-            cls._event_thread = Thread(target=cls._run)
+            cls._event_thread = Thread(target=cls._run, daemon=True)
             cls._event_thread.start()
             cls._event_thread.start()
 
 
         future = Future()
         future = Future()
@@ -117,7 +109,8 @@ class _RemoteModuleCall(torch.autograd.Function):
         if return_future:
         if return_future:
             return future
             return future
 
 
-        return future.result()
+        result = future.result()
+        return result
 
 
     @classmethod
     @classmethod
     def forward(
     def forward(
@@ -125,7 +118,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx,
         ctx,
         dummy: torch.Tensor,
         dummy: torch.Tensor,
         uid: str,
         uid: str,
-        stub,#: ConnectionHandlerStub,
+        stub,  #: ConnectionHandlerStub,
         info: Dict[str, Any],
         info: Dict[str, Any],
         *inputs: torch.Tensor,
         *inputs: torch.Tensor,
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:

+ 2 - 9
hivemind/moe/server/__init__.py

@@ -58,7 +58,6 @@ class Server(threading.Thread):
         self,
         self,
         dht: Optional[DHT],
         dht: Optional[DHT],
         expert_backends: Dict[str, ExpertBackend],
         expert_backends: Dict[str, ExpertBackend],
-        listen_on: Endpoint = "0.0.0.0:*",
         num_connection_handlers: int = 1,
         num_connection_handlers: int = 1,
         update_period: int = 30,
         update_period: int = 30,
         start=False,
         start=False,
@@ -67,9 +66,6 @@ class Server(threading.Thread):
     ):
     ):
         super().__init__()
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=get_free_port())
-        self.listen_on, self.port = listen_on, get_port(listen_on)
 
 
         self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(1)]
         self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(1)]
         if checkpoint_dir is not None:
         if checkpoint_dir is not None:
@@ -82,7 +78,7 @@ class Server(threading.Thread):
             self.dht_handler_thread = DHTHandlerThread(
             self.dht_handler_thread = DHTHandlerThread(
                 experts=self.experts,
                 experts=self.experts,
                 dht=self.dht,
                 dht=self.dht,
-                endpoint=self.listen_on,
+                peer_id=self.dht.peer_id,
                 update_period=self.update_period,
                 update_period=self.update_period,
                 daemon=True,
                 daemon=True,
             )
             )
@@ -93,7 +89,6 @@ class Server(threading.Thread):
     @classmethod
     @classmethod
     def create(
     def create(
         cls,
         cls,
-        listen_on="0.0.0.0:*",
         num_experts: int = None,
         num_experts: int = None,
         expert_uids: str = None,
         expert_uids: str = None,
         expert_pattern: str = None,
         expert_pattern: str = None,
@@ -222,7 +217,6 @@ class Server(threading.Thread):
         return cls(
         return cls(
             dht,
             dht,
             experts,
             experts,
-            listen_on=listen_on,
             num_connection_handlers=num_handlers,
             num_connection_handlers=num_handlers,
             device=device,
             device=device,
             checkpoint_dir=checkpoint_dir,
             checkpoint_dir=checkpoint_dir,
@@ -235,8 +229,7 @@ class Server(threading.Thread):
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         runs Runtime (self.runtime) to process incoming requests.
         """
         """
-        logger.info(f"Server started at {self.listen_on}")
-        logger.info(f"Got {len(self.experts)} experts:")
+        logger.info(f"Server started with {len(self.experts)} experts:")
         for expert_name, backend in self.experts.items():
         for expert_name, backend in self.experts.items():
             num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
             num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
             logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
             logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")

+ 6 - 3
hivemind/moe/server/connection_handler.py

@@ -5,11 +5,11 @@ from typing import Dict
 
 
 import torch
 import torch
 
 
-from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
-from hivemind.utils import get_logger, nested_flatten, MPFuture
+from hivemind.utils import MPFuture, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 
@@ -45,6 +45,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             except Exception as e:
             except Exception as e:
                 self.ready.set_exception(e)
                 self.ready.set_exception(e)
                 return
                 return
+
         self.ready.set_result(None)
         self.ready.set_result(None)
 
 
         try:
         try:
@@ -65,7 +66,9 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 
 
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
 
-    async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+    async def rpc_backward(
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> runtime_pb2.ExpertResponse:
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
         serialized_response = [
         serialized_response = [

+ 13 - 13
hivemind/moe/server/dht_handler.py

@@ -14,27 +14,27 @@ from hivemind.moe.server.expert_uid import (
     is_valid_uid,
     is_valid_uid,
     split_uid,
     split_uid,
 )
 )
-from hivemind.utils import Endpoint, get_dht_time, get_port
+from hivemind.p2p import PeerID, PeerInfo
+from hivemind.utils import get_dht_time
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5, **kwargs):
+    def __init__(self, experts, dht: DHT, peer_id: PeerID, update_period: int = 5, **kwargs):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
-        assert get_port(endpoint) is not None
-        self.endpoint = endpoint
+        self.peer_id = peer_id
         self.experts = experts
         self.experts = experts
         self.dht = dht
         self.dht = dht
         self.update_period = update_period
         self.update_period = update_period
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def run(self) -> None:
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys(), self.endpoint)
+        declare_experts(self.dht, self.experts.keys(), self.peer_id)
         while not self.stop.wait(self.update_period):
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys(), self.endpoint)
+            declare_experts(self.dht, self.experts.keys(), self.peer_id)
 
 
 
 
 def declare_experts(
 def declare_experts(
-    dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration = 300, wait: bool = True
+    dht: DHT, uids: Sequence[ExpertUID], peer_id: PeerID, expiration: DHTExpiration = 300, wait: bool = True
 ) -> Dict[ExpertUID, bool]:
 ) -> Dict[ExpertUID, bool]:
     """
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
     Make experts visible to all DHT peers; update timestamps if declared previously.
@@ -49,22 +49,22 @@ def declare_experts(
     for uid in uids:
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
     return dht.run_coroutine(
     return dht.run_coroutine(
-        partial(_declare_experts, uids=list(uids), endpoint=endpoint, expiration=expiration), return_future=not wait
+        partial(_declare_experts, uids=list(uids), peer_id=peer_id, expiration=expiration), return_future=not wait
     )
     )
 
 
 
 
 async def _declare_experts(
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
     for uid in uids:
-        data_to_store[uid, None] = endpoint
+        data_to_store[uid, None] = peer_id.to_base58()
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         for i in range(prefix.count(UID_DELIMITER) - 1):
         for i in range(prefix.count(UID_DELIMITER) - 1):
             prefix, last_coord = split_uid(prefix)
             prefix, last_coord = split_uid(prefix)
-            data_to_store[prefix, last_coord] = [uid, endpoint]
+            data_to_store[prefix, last_coord] = [uid, peer_id.to_base58()]
 
 
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
@@ -94,6 +94,6 @@ async def _get_experts(
 
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     for i, uid in enumerate(uids):
     for i, uid in enumerate(uids):
-        if found[uid] is not None and isinstance(found[uid].value, Endpoint):
-            experts[i] = RemoteExpert(uid, found[uid].value)
+        if found[uid] is not None and isinstance(found[uid].value, PeerID):
+            experts[i] = RemoteExpert(uid, PeerInfo(peer_id=found[uid].value, addrs=[]))
     return experts
     return experts