Browse Source

Move SequenceManagerConfig -> ClientConfig, petals.dht_utils -> petals.utils.dht (#463)

Alexander Borzunov 2 năm trước cách đây
mục cha
commit
063e94b4c8

+ 2 - 2
src/petals/client/__init__.py

@@ -1,4 +1,4 @@
+from petals.client.config import ClientConfig
 from petals.client.inference_session import InferenceSession
 from petals.client.remote_sequential import RemoteSequential
-from petals.client.routing.sequence_manager import RemoteSequenceManager
-from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
+from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase

+ 31 - 0
src/petals/client/config.py

@@ -0,0 +1,31 @@
+import dataclasses
+from typing import Optional, Sequence, Union
+
+from hivemind import PeerID
+
+from petals.constants import PUBLIC_INITIAL_PEERS
+
+
+@dataclasses.dataclass
+class ClientConfig:
+    initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS)  # a list of initial peers for hivemind DHT
+    dht_prefix: Optional[str] = None  # a prefix for all dht keys that correspond to this model (default: model name)
+    daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
+
+    show_route: Union[str, bool] = "inference"  # show chosen route through servers. one of [False, "inference", True]
+    allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
+    blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None  # if defined, do not use these servers
+    use_server_to_server: bool = True  # Use direct server-to-server communication
+
+    connect_timeout: float = 5  # timeout for opening a connection
+    request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
+    update_period: float = 60  # refresh DHT information once in this many seconds
+
+    max_retries: Optional[int] = None  # max number retries before the client raises an exception (default: inf)
+    min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
+    max_backoff: float = 60  # limit maximal sleep time between retries to this value
+    ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
+    active_adapter: Optional[str] = None  # name of active LoRA adapter (usually, Hugging Face repo)
+
+    max_pinged: int = 3  # max servers to ping from each sequence side, per update
+    ping_timeout: float = 2  # max time to wait for pings, per update

+ 4 - 3
src/petals/client/inference_session.py

@@ -13,7 +13,8 @@ from hivemind.p2p import P2P
 from hivemind.proto import runtime_pb2
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
-from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
+from petals.client.config import ClientConfig
+from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
 from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
@@ -31,7 +32,7 @@ class _ServerInferenceSession:
 
     def __init__(
         self,
-        config: SequenceManagerConfig,
+        config: ClientConfig,
         span: RemoteSpanInfo,
         uid: ModuleUID,
         rpc_info: RPCInfo,
@@ -58,7 +59,7 @@ class _ServerInferenceSession:
     @classmethod
     async def create(
         cls,
-        config: SequenceManagerConfig,
+        config: ClientConfig,
         p2p: P2P,
         span: RemoteSpanInfo,
         uid: ModuleUID,

+ 7 - 7
src/petals/client/remote_forward_backward.py

@@ -14,12 +14,12 @@ from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
 from hivemind.utils.streaming import split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
-from petals.client.routing.sequence_manager import SequenceManagerConfig
+from petals.client.config import ClientConfig
 from petals.data_structures import ModuleUID, RPCInfo
 
 
 async def _forward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
 ) -> List[torch.Tensor]:
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
         runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
@@ -29,7 +29,7 @@ async def _forward_unary(
 
 
 async def _backward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
 ) -> List[torch.Tensor]:
     grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
         runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
@@ -39,7 +39,7 @@ async def _backward_unary(
 
 
 async def _forward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
 ) -> List[torch.Tensor]:
     parts = (
         runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
@@ -52,7 +52,7 @@ async def _forward_stream(
 
 
 async def _backward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
 ) -> List[torch.Tensor]:
     parts = (
         runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
@@ -69,7 +69,7 @@ async def run_remote_forward(
     stub: StubBase,
     rpc_info: RPCInfo,
     *inputs: torch.Tensor,
-    config: SequenceManagerConfig,
+    config: ClientConfig,
     metadata: Optional[bytes] = None,
     **kwargs,
 ) -> Tuple[torch.Tensor, ...]:
@@ -115,7 +115,7 @@ async def run_remote_backward(
     stub: StubBase,
     rpc_info: RPCInfo,
     *inputs_and_grad_outputs: torch.Tensor,
-    config: SequenceManagerConfig,
+    config: ClientConfig,
     metadata: Optional[bytes] = None,
     **kwargs,
 ) -> Sequence[torch.Tensor]:

+ 3 - 2
src/petals/client/remote_sequential.py

@@ -6,8 +6,9 @@ import torch
 from hivemind import DHT, get_logger
 from torch import nn
 
+from petals.client.config import ClientConfig
 from petals.client.inference_session import InferenceSession
-from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig
+from petals.client.routing import RemoteSequenceManager
 from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from petals.data_structures import UID_DELIMITER
 from petals.utils.misc import DUMMY
@@ -22,7 +23,7 @@ class RemoteSequential(nn.Module):
 
     def __init__(
         self,
-        config: SequenceManagerConfig,
+        config: ClientConfig,
         *,
         sequence_manager: Optional[RemoteSequenceManager] = None,
         dht: Optional[DHT] = None,

+ 2 - 1
src/petals/client/routing/__init__.py

@@ -1 +1,2 @@
-"""Client-side functions responsible for choosing the best server, """
+from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
+from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 16 - 29
src/petals/client/routing/sequence_manager.py

@@ -7,7 +7,8 @@ import logging
 import random
 import threading
 import time
-from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Union
+import warnings
+from typing import Any, Dict, List, Optional, Sequence, Set, Union
 from weakref import WeakMethod
 
 import dijkstar
@@ -18,41 +19,27 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 
-import petals.dht_utils
+from petals.client.config import ClientConfig
 from petals.client.routing.sequence_info import RemoteSequenceInfo
 from petals.client.routing.spending_policy import NoSpendingPolicy
-from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
 from petals.server.handler import TransformerConnectionHandler
+from petals.utils.dht import get_remote_module_infos
 from petals.utils.ping import PingAggregator
 from petals.utils.random import sample_up_to
 
 logger = get_logger(__name__)
 
 
-@dataclasses.dataclass
-class SequenceManagerConfig:
-    initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS)  # a list of initial peers for hivemind DHT
-    dht_prefix: Optional[str] = None  # a prefix for all dht keys that correspond to this model (default: model name)
-    daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
-
-    show_route: Union[str, bool] = "inference"  # show chosen route through servers. one of [False, "inference", True]
-    allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
-    blocked_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, do not use these servers
-    use_server_to_server: bool = True  # Use direct server-to-server communication
-
-    connect_timeout: float = 5  # timeout for opening a connection
-    request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
-    update_period: float = 60  # refresh DHT information once in this many seconds
-
-    max_retries: Optional[int] = None  # max number retries before the client raises an exception (default: inf)
-    min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
-    max_backoff: float = 60  # limit maximal sleep time between retries to this value
-    ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
-    active_adapter: Optional[str] = None  # name of active LoRA adapter (usually, Hugging Face repo)
-
-    max_pinged: int = 3  # max servers to ping from each sequence side, per update
-    ping_timeout: float = 2  # max time to wait for pings, per update
+class SequenceManagerConfig(ClientConfig):
+    def __init__(self, *args, **kwargs):
+        warnings.warn(
+            "petals.client.routing.SequenceManagerConfig has been moved to petals.ClientConfig. "
+            "This alias will be removed in Petals 2.2.0+",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        super().__init__(*args, **kwargs)
 
 
 @dataclasses.dataclass
@@ -83,7 +70,7 @@ class RemoteSequenceManager:
 
     def __init__(
         self,
-        config: SequenceManagerConfig,
+        config: ClientConfig,
         block_uids: Sequence[ModuleUID],
         *,
         dht: Optional[DHT] = None,
@@ -133,7 +120,7 @@ class RemoteSequenceManager:
             self._need_latest_infos = True
 
     @staticmethod
-    def _peer_ids_to_set(peer_ids: Optional[Collection[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
+    def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
         if peer_ids is None:
             return None
 
@@ -354,7 +341,7 @@ class RemoteSequenceManager:
     def _update(self):
         """Perform an immediate and synchronous refresh, may take time"""
 
-        new_block_infos = petals.dht_utils.get_remote_module_infos(
+        new_block_infos = get_remote_module_infos(
             self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
         )
 

+ 1 - 1
src/petals/client/sequential_autograd.py

@@ -12,7 +12,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.utils.logging import get_logger
 
 from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
-from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
+from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
 from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, is_dummy

+ 2 - 2
src/petals/data_structures.py

@@ -6,8 +6,6 @@ import pydantic
 from hivemind import PeerID
 from hivemind.moe.expert_uid import ExpertUID
 
-from petals.server.memory_cache import Handle
-
 ModuleUID = str
 UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
@@ -78,6 +76,8 @@ class RemoteSpanInfo:
 
 RPCInfo = Dict[str, Any]
 
+Handle = int
+
 
 @dataclasses.dataclass(frozen=True)
 class InferenceMetadata:

+ 7 - 122
src/petals/dht_utils.py

@@ -1,124 +1,9 @@
-"""
-Utilities for declaring and retrieving active model layers using a shared DHT.
-"""
-from __future__ import annotations
+import warnings
 
-import math
-from functools import partial
-from typing import Dict, List, Optional, Sequence, Union
+warnings.warn(
+    "petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+",
+    DeprecationWarning,
+    stacklevel=2,
+)
 
-from hivemind.dht import DHT, DHTNode, DHTValue
-from hivemind.p2p import PeerID
-from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
-
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
-
-logger = get_logger(__name__)
-
-
-def declare_active_modules(
-    dht: DHT,
-    uids: Sequence[ModuleUID],
-    server_info: ServerInfo,
-    expiration_time: DHTExpiration,
-    wait: bool = True,
-) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
-    """
-    Declare that your node serves the specified modules; update timestamps if declared previously
-
-    :param uids: a list of module ids to declare
-    :param wait: if True, awaits for declaration to finish, otherwise runs in background
-    :param throughput: specify your performance in terms of compute throughput
-    :param expiration_time: declared modules will be visible for this many seconds
-    :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
-    """
-    if isinstance(uids, str):
-        uids = [uids]
-    if not isinstance(uids, list):
-        uids = list(uids)
-    for uid in uids:
-        assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
-
-    return dht.run_coroutine(
-        partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
-        return_future=not wait,
-    )
-
-
-async def _declare_active_modules(
-    dht: DHT,
-    node: DHTNode,
-    uids: List[ModuleUID],
-    server_info: ServerInfo,
-    expiration_time: DHTExpiration,
-) -> Dict[ModuleUID, bool]:
-    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
-    return await node.store_many(
-        keys=uids,
-        subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[server_info.to_tuple()] * len(uids),
-        expiration_time=expiration_time,
-        num_workers=num_workers,
-    )
-
-
-def get_remote_module_infos(
-    dht: DHT,
-    uids: Sequence[ModuleUID],
-    expiration_time: Optional[DHTExpiration] = None,
-    active_adapter: Optional[str] = None,
-    *,
-    latest: bool = False,
-    return_future: bool = False,
-) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
-    return dht.run_coroutine(
-        partial(
-            _get_remote_module_infos,
-            uids=uids,
-            active_adapter=active_adapter,
-            expiration_time=expiration_time,
-            latest=latest,
-        ),
-        return_future=return_future,
-    )
-
-
-async def _get_remote_module_infos(
-    dht: DHT,
-    node: DHTNode,
-    uids: List[ModuleUID],
-    active_adapter: Optional[str],
-    expiration_time: Optional[DHTExpiration],
-    latest: bool,
-) -> List[Optional[RemoteModuleInfo]]:
-    if latest:
-        assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
-        expiration_time = math.inf
-    elif expiration_time is None:
-        expiration_time = get_dht_time()
-    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
-    found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
-
-    modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
-    for i, uid in enumerate(uids):
-        metadata = found[uid]
-        if metadata is None or not isinstance(metadata.value, dict):
-            if metadata is not None:
-                logger.warning(f"Incorrect metadata for {uid}: {metadata}")
-            continue
-        servers = {}
-        for peer_id, server_info in metadata.value.items():
-            try:
-                peer_id = PeerID.from_base58(peer_id)
-                server_info = ServerInfo.from_tuple(server_info.value)
-
-                if active_adapter and active_adapter not in server_info.adapters:
-                    logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
-                    continue
-
-                servers[peer_id] = server_info
-            except (TypeError, ValueError) as e:
-                logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
-        if servers:
-            modules[i] = RemoteModuleInfo(uid, servers)
-    return modules
+from petals.utils.dht import *

+ 2 - 2
src/petals/models/bloom/config.py

@@ -5,15 +5,15 @@ from hivemind import get_logger
 from transformers.models.bloom import BloomConfig
 from transformers.models.bloom.modeling_bloom import BloomAttention
 
+from petals.client.config import ClientConfig
 from petals.client.lm_head import LMHeadConfig
 from petals.client.ptune import PTuneConfig
-from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.models.bloom.block import WrappedBloomBlock
 
 logger = get_logger(__name__)
 
 
-class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
+class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig):
     block_class = WrappedBloomBlock
     attn_class = BloomAttention
     block_prefix = "h"

+ 2 - 2
src/petals/models/llama/config.py

@@ -5,15 +5,15 @@ from hivemind import get_logger
 from transformers.models.llama import LlamaConfig
 from transformers.models.llama.modeling_llama import LlamaAttention
 
+from petals.client.config import ClientConfig
 from petals.client.lm_head import LMHeadConfig
 from petals.client.ptune import PTuneConfig
-from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.models.llama.block import WrappedLlamaBlock
 
 logger = get_logger(__name__)
 
 
-class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
+class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig):
     block_class = WrappedLlamaBlock
     attn_class = LlamaAttention
     block_prefix = "model.layers"

+ 1 - 2
src/petals/server/block_functions.py

@@ -12,9 +12,8 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_flatten
 
-from petals.data_structures import InferenceMetadata
+from petals.data_structures import Handle, InferenceMetadata
 from petals.server.backend import TransformerBackend
-from petals.server.memory_cache import Handle
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_prioritizer import TaskPrioritizerBase
 from petals.utils.convert_block import QuantType

+ 1 - 2
src/petals/server/handler.py

@@ -29,10 +29,9 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
 import petals
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, Handle, ModuleUID
 from petals.server.backend import TransformerBackend
 from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
-from petals.server.memory_cache import Handle
 from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
 from petals.utils.convert_block import QuantType
 

+ 1 - 2
src/petals/server/memory_cache.py

@@ -16,12 +16,11 @@ import hivemind
 import torch
 from hivemind.utils import TensorDescriptor, get_logger
 
+from petals.data_structures import Handle
 from petals.utils.asyncio import shield_and_wait
 
 logger = get_logger(__name__)
 
-Handle = int
-
 
 class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""

+ 1 - 1
src/petals/server/server.py

@@ -20,7 +20,6 @@ from transformers import PretrainedConfig
 import petals
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
-from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -31,6 +30,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
+from petals.utils.dht import declare_active_modules, get_remote_module_infos
 from petals.utils.ping import PingAggregator
 from petals.utils.random import sample_up_to
 from petals.utils.version import get_compatible_model_repo

+ 1 - 0
src/petals/utils/__init__.py

@@ -4,3 +4,4 @@ from petals.utils.auto_config import (
     AutoDistributedModelForCausalLM,
     AutoDistributedModelForSequenceClassification,
 )
+from petals.utils.dht import declare_active_modules, get_remote_module_infos

+ 124 - 0
src/petals/utils/dht.py

@@ -0,0 +1,124 @@
+"""
+Utilities for declaring and retrieving active model layers using a shared DHT.
+"""
+from __future__ import annotations
+
+import math
+from functools import partial
+from typing import Dict, List, Optional, Sequence, Union
+
+from hivemind.dht import DHT, DHTNode, DHTValue
+from hivemind.p2p import PeerID
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
+
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
+
+logger = get_logger(__name__)
+
+
+def declare_active_modules(
+    dht: DHT,
+    uids: Sequence[ModuleUID],
+    server_info: ServerInfo,
+    expiration_time: DHTExpiration,
+    wait: bool = True,
+) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
+    """
+    Declare that your node serves the specified modules; update timestamps if declared previously
+
+    :param uids: a list of module ids to declare
+    :param wait: if True, awaits for declaration to finish, otherwise runs in background
+    :param throughput: specify your performance in terms of compute throughput
+    :param expiration_time: declared modules will be visible for this many seconds
+    :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
+    """
+    if isinstance(uids, str):
+        uids = [uids]
+    if not isinstance(uids, list):
+        uids = list(uids)
+    for uid in uids:
+        assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
+
+    return dht.run_coroutine(
+        partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
+        return_future=not wait,
+    )
+
+
+async def _declare_active_modules(
+    dht: DHT,
+    node: DHTNode,
+    uids: List[ModuleUID],
+    server_info: ServerInfo,
+    expiration_time: DHTExpiration,
+) -> Dict[ModuleUID, bool]:
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+    return await node.store_many(
+        keys=uids,
+        subkeys=[dht.peer_id.to_base58()] * len(uids),
+        values=[server_info.to_tuple()] * len(uids),
+        expiration_time=expiration_time,
+        num_workers=num_workers,
+    )
+
+
+def get_remote_module_infos(
+    dht: DHT,
+    uids: Sequence[ModuleUID],
+    expiration_time: Optional[DHTExpiration] = None,
+    active_adapter: Optional[str] = None,
+    *,
+    latest: bool = False,
+    return_future: bool = False,
+) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
+    return dht.run_coroutine(
+        partial(
+            _get_remote_module_infos,
+            uids=uids,
+            active_adapter=active_adapter,
+            expiration_time=expiration_time,
+            latest=latest,
+        ),
+        return_future=return_future,
+    )
+
+
+async def _get_remote_module_infos(
+    dht: DHT,
+    node: DHTNode,
+    uids: List[ModuleUID],
+    active_adapter: Optional[str],
+    expiration_time: Optional[DHTExpiration],
+    latest: bool,
+) -> List[Optional[RemoteModuleInfo]]:
+    if latest:
+        assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
+        expiration_time = math.inf
+    elif expiration_time is None:
+        expiration_time = get_dht_time()
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+    found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
+
+    modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
+    for i, uid in enumerate(uids):
+        metadata = found[uid]
+        if metadata is None or not isinstance(metadata.value, dict):
+            if metadata is not None:
+                logger.warning(f"Incorrect metadata for {uid}: {metadata}")
+            continue
+        servers = {}
+        for peer_id, server_info in metadata.value.items():
+            try:
+                peer_id = PeerID.from_base58(peer_id)
+                server_info = ServerInfo.from_tuple(server_info.value)
+
+                if active_adapter and active_adapter not in server_info.adapters:
+                    logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
+                    continue
+
+                servers[peer_id] = server_info
+            except (TypeError, ValueError) as e:
+                logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
+        if servers:
+            modules[i] = RemoteModuleInfo(uid, servers)
+    return modules