Browse Source

Share more info about a server in DHT (#355)

Alexander Borzunov 2 years ago
parent
commit
2c8959e713

+ 1 - 1
setup.cfg

@@ -38,7 +38,7 @@ install_requires =
     tokenizers>=0.13.3
     tokenizers>=0.13.3
     transformers>=4.30.1,<5.0.0
     transformers>=4.30.1,<5.0.0
     speedtest-cli==2.1.3
     speedtest-cli==2.1.3
-    pydantic>=1.8.1,<2.0  # 2.0 is incompatible with hivemind==1.1.8
+    pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind==1.1.8
     hivemind==1.1.8
     hivemind==1.1.8
     tensor_parallel==1.0.23
     tensor_parallel==1.0.23
     humanfriendly
     humanfriendly

+ 1 - 1
src/petals/__init__.py

@@ -11,7 +11,7 @@ from petals.models import *
 from petals.utils import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 from petals.utils.logging import initialize_logs as _initialize_logs
 
 
-__version__ = "1.2.0.dev1"
+__version__ = "1.2.0.dev2"
 
 
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

+ 3 - 2
src/petals/cli/run_server.py

@@ -146,8 +146,9 @@ def main():
                         help="Skip checking this server's reachability via health.petals.ml "
                         help="Skip checking this server's reachability via health.petals.ml "
                              "when connecting to the public swarm. If you connect to a private swarm, "
                              "when connecting to the public swarm. If you connect to a private swarm, "
                              "the check is skipped by default. Use this option only if you know what you are doing")
                              "the check is skipped by default. Use this option only if you know what you are doing")
-    
-    parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.")
+
+    parser.add_argument("--adapters", nargs='+', default=(),
+                        help="List of pre-loaded LoRA adapters that can be used for inference or training")
 
 
     # fmt:on
     # fmt:on
     args = vars(parser.parse_args())
     args = vars(parser.parse_args())

+ 25 - 8
src/petals/data_structures.py

@@ -1,10 +1,8 @@
-from __future__ import annotations
-
 import dataclasses
 import dataclasses
-from dataclasses import dataclass
 from enum import Enum
 from enum import Enum
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, Optional, Sequence, Tuple
 
 
+import pydantic
 from hivemind import PeerID
 from hivemind import PeerID
 from hivemind.moe.expert_uid import ExpertUID
 from hivemind.moe.expert_uid import ExpertUID
 
 
@@ -21,13 +19,32 @@ class ServerState(Enum):
     ONLINE = 2
     ONLINE = 2
 
 
 
 
-@dataclass
+@pydantic.dataclasses.dataclass
 class ServerInfo:
 class ServerInfo:
     state: ServerState
     state: ServerState
-    throughput: float
+    throughput: pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+
+    adapters: Sequence[str] = ()
+    version: Optional[str] = None
+    torch_dtype: Optional[str] = None
+    quant_type: Optional[str] = None
+    using_relay: Optional[bool] = None
+    cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
+
+    def to_tuple(self) -> Tuple[int, float, dict]:
+        extra_info = dataclasses.asdict(self)
+        del extra_info["state"], extra_info["throughput"]
+        return (self.state.value, self.throughput, extra_info)
+
+    @classmethod
+    def from_tuple(cls, source: tuple):
+        state, throughput = source[:2]
+        extra_info = source[2] if len(source) > 2 else {}
+        # pydantic will validate existing fields and ignore extra ones
+        return cls(state=ServerState(state), throughput=throughput, **extra_info)
 
 
 
 
-@dataclass
+@dataclasses.dataclass
 class RemoteModuleInfo:
 class RemoteModuleInfo:
     """A remote module that is served by one or more servers"""
     """A remote module that is served by one or more servers"""
 
 
@@ -35,7 +52,7 @@ class RemoteModuleInfo:
     servers: Dict[PeerID, ServerInfo]
     servers: Dict[PeerID, ServerInfo]
 
 
 
 
-@dataclass
+@dataclasses.dataclass
 class RemoteSpanInfo:
 class RemoteSpanInfo:
     """A chain of remote blocks served by one specific remote peer"""
     """A chain of remote blocks served by one specific remote peer"""
 
 

+ 11 - 30
src/petals/dht_utils.py

@@ -11,7 +11,7 @@ from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
 
 
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -19,10 +19,8 @@ logger = get_logger(__name__)
 def declare_active_modules(
 def declare_active_modules(
     dht: DHT,
     dht: DHT,
     uids: Sequence[ModuleUID],
     uids: Sequence[ModuleUID],
+    server_info: ServerInfo,
     expiration_time: DHTExpiration,
     expiration_time: DHTExpiration,
-    state: ServerState,
-    throughput: float,
-    adapters: Optional[Sequence[str]] = None,
     wait: bool = True,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
     """
     """
@@ -42,14 +40,7 @@ def declare_active_modules(
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
 
 
     return dht.run_coroutine(
     return dht.run_coroutine(
-        partial(
-            _declare_active_modules,
-            uids=uids,
-            expiration_time=expiration_time,
-            state=state,
-            throughput=throughput,
-            adapters=list(adapters or []),
-        ),
+        partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
         return_future=not wait,
         return_future=not wait,
     )
     )
 
 
@@ -58,16 +49,14 @@ async def _declare_active_modules(
     dht: DHT,
     dht: DHT,
     node: DHTNode,
     node: DHTNode,
     uids: List[ModuleUID],
     uids: List[ModuleUID],
+    server_info: ServerInfo,
     expiration_time: DHTExpiration,
     expiration_time: DHTExpiration,
-    state: ServerState,
-    throughput: float,
-    adapters: List[str],
 ) -> Dict[ModuleUID, bool]:
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
     return await node.store_many(
         keys=uids,
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[(state.value, throughput, dict(adapters=adapters))] * len(uids),
+        values=[server_info.to_tuple()] * len(uids),
         expiration_time=expiration_time,
         expiration_time=expiration_time,
         num_workers=num_workers,
         num_workers=num_workers,
     )
     )
@@ -115,29 +104,21 @@ async def _get_remote_module_infos(
         metadata = found[uid]
         metadata = found[uid]
         if metadata is None or not isinstance(metadata.value, dict):
         if metadata is None or not isinstance(metadata.value, dict):
             if metadata is not None:
             if metadata is not None:
-                logger.error(f"Incorrect metadata for {uid}: {metadata}")
+                logger.warning(f"Incorrect metadata for {uid}: {metadata}")
             continue
             continue
         servers = {}
         servers = {}
         for peer_id, server_info in metadata.value.items():
         for peer_id, server_info in metadata.value.items():
             try:
             try:
                 peer_id = PeerID.from_base58(peer_id)
                 peer_id = PeerID.from_base58(peer_id)
-                state, throughput = server_info.value[:2]
-                extra_info = server_info.value[2] if len(server_info.value) > 2 else {}
-                adapters = extra_info.get("adapters", [])
-                if bool(active_adapter) and active_adapter not in adapters:
+                server_info = ServerInfo.from_tuple(server_info.value)
+
+                if active_adapter and active_adapter not in server_info.adapters:
                     logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
                     logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
                     continue
                     continue
 
 
-                if not (
-                    isinstance(state, int)
-                    and isinstance(throughput, float)
-                    and math.isfinite(throughput)
-                    and throughput >= 0.0
-                ):
-                    raise ValueError(f"Invalid server info: {server_info}")
-                servers[peer_id] = ServerInfo(ServerState(state), throughput)
+                servers[peer_id] = server_info
             except (TypeError, ValueError) as e:
             except (TypeError, ValueError) as e:
-                logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
+                logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
         if servers:
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)
             modules[i] = RemoteModuleInfo(uid, servers)
     return modules
     return modules

+ 0 - 2
src/petals/models/bloom/config.py

@@ -9,8 +9,6 @@ from petals.client.lm_head import LMHeadConfig
 from petals.client.ptune import PTuneConfig
 from petals.client.ptune import PTuneConfig
 from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.models.bloom.block import WrappedBloomBlock
 from petals.models.bloom.block import WrappedBloomBlock
-from petals.utils.auto_config import AutoDistributedConfig
-from petals.utils.version import get_compatible_model_repo
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 3
src/petals/models/llama/config.py

@@ -9,7 +9,6 @@ from petals.client.lm_head import LMHeadConfig
 from petals.client.ptune import PTuneConfig
 from petals.client.ptune import PTuneConfig
 from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.models.llama.block import WrappedLlamaBlock
 from petals.models.llama.block import WrappedLlamaBlock
-from petals.utils.auto_config import AutoDistributedConfig
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -31,8 +30,7 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         if loading_from_repo and dht_prefix is None:
         if loading_from_repo and dht_prefix is None:
             dht_prefix = str(model_name_or_path)
             dht_prefix = str(model_name_or_path)
-            if "/" in dht_prefix:  # If present, strip repository name to merge blocks hosted by different accounts
-                dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :]
+            dht_prefix = dht_prefix.split("/")[-1]  # Use only repo name to merge blocks hosted by different accounts
             if not dht_prefix.endswith("-hf"):
             if not dht_prefix.endswith("-hf"):
                 dht_prefix += "-hf"
                 dht_prefix += "-hf"
             logger.info(f"Using DHT prefix: {dht_prefix}")
             logger.info(f"Using DHT prefix: {dht_prefix}")

+ 1 - 2
src/petals/server/handler.py

@@ -562,11 +562,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         """Return metadata about stored block uids and current load"""
         """Return metadata about stored block uids and current load"""
 
 
         backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values()))
         backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values()))
-        cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes)
         result = {
         result = {
             "version": petals.__version__,
             "version": petals.__version__,
             "dht_client_mode": self.dht.client_mode,
             "dht_client_mode": self.dht.client_mode,
-            CACHE_TOKENS_AVAILABLE: cache_bytes_left // max(backend.cache_bytes_per_token.values()),
+            CACHE_TOKENS_AVAILABLE: backend.memory_cache.bytes_left // max(backend.cache_bytes_per_token.values()),
         }
         }
 
 
         if request.uid:
         if request.uid:

+ 4 - 0
src/petals/server/memory_cache.py

@@ -47,6 +47,10 @@ class MemoryCache:
     def current_size_bytes(self, value: int):
     def current_size_bytes(self, value: int):
         self._current_size.value = value
         self._current_size.value = value
 
 
+    @property
+    def bytes_left(self) -> int:
+        return self.max_size_bytes - self.current_size_bytes
+
     @property
     @property
     def handle_counter(self) -> int:
     def handle_counter(self) -> int:
         return self._handle_counter.value
         return self._handle_counter.value

+ 46 - 36
src/petals/server/server.py

@@ -16,8 +16,9 @@ from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
+import petals
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
@@ -29,7 +30,6 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
-from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 from petals.utils.version import get_compatible_model_repo
 from petals.utils.version import get_compatible_model_repo
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -81,7 +81,7 @@ class Server:
         dht_client_mode: Optional[bool] = None,
         dht_client_mode: Optional[bool] = None,
         use_relay: bool = True,
         use_relay: bool = True,
         use_auto_relay: bool = True,
         use_auto_relay: bool = True,
-        adapters: Optional[List[str]] = None,
+        adapters: Sequence[str] = (),
         **kwargs,
         **kwargs,
     ):
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@@ -215,7 +215,15 @@ class Server:
                 force_eval=(throughput == "eval"),
                 force_eval=(throughput == "eval"),
                 cache_dir=cache_dir,
                 cache_dir=cache_dir,
             )
             )
-        self.throughput = throughput
+        self.server_info = ServerInfo(
+            state=ServerState.JOINING,
+            throughput=throughput,
+            adapters=tuple(adapters),
+            version=petals.__version__,
+            torch_dtype=str(torch_dtype).lstrip("torch."),
+            quant_type=quant_type.name.lower(),
+            using_relay=self.dht.client_mode,
+        )
 
 
         self.balance_quality = balance_quality
         self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_balance_check_period = mean_balance_check_period
@@ -283,7 +291,7 @@ class Server:
                 block_config=self.block_config,
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
                 attn_cache_bytes=self.attn_cache_bytes,
                 alloc_timeout=self.alloc_timeout,
                 alloc_timeout=self.alloc_timeout,
-                throughput=self.throughput,
+                server_info=self.server_info,
                 block_indices=block_indices,
                 block_indices=block_indices,
                 num_handlers=self.num_handlers,
                 num_handlers=self.num_handlers,
                 min_batch_size=self.min_batch_size,
                 min_batch_size=self.min_batch_size,
@@ -307,7 +315,6 @@ class Server:
                 quant_type=self.quant_type,
                 quant_type=self.quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 should_validate_reachability=self.should_validate_reachability,
                 should_validate_reachability=self.should_validate_reachability,
-                adapters=self.adapters,
                 start=True,
                 start=True,
             )
             )
             try:
             try:
@@ -385,7 +392,7 @@ class ModuleContainer(threading.Thread):
         block_config: PretrainedConfig,
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
         attn_cache_bytes: int,
         alloc_timeout: float,
         alloc_timeout: float,
-        throughput: float,
+        server_info: ServerInfo,
         block_indices: List[int],
         block_indices: List[int],
         min_batch_size: int,
         min_batch_size: int,
         max_batch_size: int,
         max_batch_size: int,
@@ -401,16 +408,18 @@ class ModuleContainer(threading.Thread):
         quant_type: QuantType,
         quant_type: QuantType,
         tensor_parallel_devices: Sequence[torch.device],
         tensor_parallel_devices: Sequence[torch.device],
         should_validate_reachability: bool,
         should_validate_reachability: bool,
-        adapters: Optional[List[str]] = None,
         **kwargs,
         **kwargs,
     ) -> ModuleContainer:
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
+        memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
+
+        server_info.state = ServerState.JOINING
         joining_announcer = ModuleAnnouncerThread(
         joining_announcer = ModuleAnnouncerThread(
             module_uids,
             module_uids,
             dht,
             dht,
-            ServerState.JOINING,
-            adapters=adapters,
-            throughput=throughput,
+            server_info,
+            block_config=block_config,
+            memory_cache=memory_cache,
             update_period=update_period,
             update_period=update_period,
             expiration=expiration,
             expiration=expiration,
             daemon=True,
             daemon=True,
@@ -420,7 +429,6 @@ class ModuleContainer(threading.Thread):
 
 
         assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
         assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
 
 
-        memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
         blocks = {}
         blocks = {}
         try:
         try:
             for module_uid, block_index in zip(module_uids, block_indices):
             for module_uid, block_index in zip(module_uids, block_indices):
@@ -441,7 +449,7 @@ class ModuleContainer(threading.Thread):
                     tensor_parallel_devices,
                     tensor_parallel_devices,
                     device,
                     device,
                     quant_type,
                     quant_type,
-                    adapters=adapters,
+                    adapters=server_info.adapters,
                     freeze=True,
                     freeze=True,
                     use_auth_token=use_auth_token,
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
                     cache_dir=cache_dir,
@@ -477,13 +485,12 @@ class ModuleContainer(threading.Thread):
 
 
             joining_announcer.stop.set()
             joining_announcer.stop.set()
             joining_announcer.join()
             joining_announcer.join()
+            server_info.state = ServerState.OFFLINE
             declare_active_modules(
             declare_active_modules(
                 dht,
                 dht,
                 module_uids,
                 module_uids,
+                server_info,
                 expiration_time=get_dht_time() + expiration,
                 expiration_time=get_dht_time() + expiration,
-                state=ServerState.OFFLINE,
-                throughput=throughput,
-                adapters=adapters,
             )
             )
             logger.info(f"Announced that blocks {module_uids} are offline")
             logger.info(f"Announced that blocks {module_uids} are offline")
             raise
             raise
@@ -497,8 +504,9 @@ class ModuleContainer(threading.Thread):
             dht,
             dht,
             dht_prefix,
             dht_prefix,
             blocks,
             blocks,
-            adapters=adapters,
-            throughput=throughput,
+            block_config=block_config,
+            memory_cache=memory_cache,
+            server_info=server_info,
             update_period=update_period,
             update_period=update_period,
             expiration=expiration,
             expiration=expiration,
             **kwargs,
             **kwargs,
@@ -510,10 +518,11 @@ class ModuleContainer(threading.Thread):
         dht_prefix: str,
         dht_prefix: str,
         module_backends: Dict[str, TransformerBackend],
         module_backends: Dict[str, TransformerBackend],
         *,
         *,
+        block_config: PretrainedConfig,
+        memory_cache: MemoryCache,
         inference_max_length: int,
         inference_max_length: int,
         num_handlers: int,
         num_handlers: int,
-        throughput: float,
-        adapters: Optional[Sequence[str]],
+        server_info: ServerInfo,
         update_period: float,
         update_period: float,
         expiration: Optional[float] = None,
         expiration: Optional[float] = None,
         request_timeout: float,
         request_timeout: float,
@@ -525,7 +534,7 @@ class ModuleContainer(threading.Thread):
         super().__init__()
         super().__init__()
 
 
         self.dht, self.module_backends = dht, module_backends
         self.dht, self.module_backends = dht, module_backends
-        self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
+        self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
 
 
         self.push_manager = mp.Manager()
         self.push_manager = mp.Manager()
         self.push_manager.__enter__()
         self.push_manager.__enter__()
@@ -534,7 +543,7 @@ class ModuleContainer(threading.Thread):
             TransformerConnectionHandler(
             TransformerConnectionHandler(
                 dht,
                 dht,
                 self.module_backends,
                 self.module_backends,
-                adapters=adapters,
+                adapters=server_info.adapters,
                 dht_prefix=dht_prefix,
                 dht_prefix=dht_prefix,
                 push_manager=self.push_manager,
                 push_manager=self.push_manager,
                 session_queues=session_queues,
                 session_queues=session_queues,
@@ -548,12 +557,14 @@ class ModuleContainer(threading.Thread):
 
 
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
+
+        self.server_info.state = ServerState.ONLINE
         self.online_announcer = ModuleAnnouncerThread(
         self.online_announcer = ModuleAnnouncerThread(
             list(self.module_backends.keys()),
             list(self.module_backends.keys()),
             dht,
             dht,
-            ServerState.ONLINE,
-            adapters=adapters,
-            throughput=throughput,
+            self.server_info,
+            block_config=block_config,
+            memory_cache=memory_cache,
             update_period=update_period,
             update_period=update_period,
             expiration=expiration,
             expiration=expiration,
             daemon=True,
             daemon=True,
@@ -613,12 +624,12 @@ class ModuleContainer(threading.Thread):
         self.online_announcer.stop.set()
         self.online_announcer.stop.set()
         self.online_announcer.join()
         self.online_announcer.join()
 
 
+        self.server_info.state = ServerState.OFFLINE
         declare_active_modules(
         declare_active_modules(
             self.dht,
             self.dht,
             self.module_backends.keys(),
             self.module_backends.keys(),
+            self.server_info,
             expiration_time=get_dht_time() + self.expiration,
             expiration_time=get_dht_time() + self.expiration,
-            state=ServerState.OFFLINE,
-            throughput=self.throughput,
         )
         )
         logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
         logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
 
 
@@ -651,10 +662,10 @@ class ModuleAnnouncerThread(threading.Thread):
         self,
         self,
         module_uids: List[str],
         module_uids: List[str],
         dht: DHT,
         dht: DHT,
-        state: ServerState,
-        adapters: Optional[Sequence[str]],
+        server_info: ServerInfo,
         *,
         *,
-        throughput: float,
+        block_config: PretrainedConfig,
+        memory_cache: MemoryCache,
         update_period: float = 30,
         update_period: float = 30,
         expiration: float,
         expiration: float,
         **kwargs,
         **kwargs,
@@ -662,22 +673,21 @@ class ModuleAnnouncerThread(threading.Thread):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
         self.module_uids = module_uids
         self.module_uids = module_uids
         self.dht = dht
         self.dht = dht
-        self.state = state
-        self.adapters = adapters
-        self.throughput = throughput
+        self.server_info = server_info
+        self.memory_cache = memory_cache
+        self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
         self.update_period = update_period
         self.update_period = update_period
         self.expiration = expiration
         self.expiration = expiration
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def run(self) -> None:
     def run(self) -> None:
         while True:
         while True:
+            self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
             declare_active_modules(
             declare_active_modules(
                 self.dht,
                 self.dht,
                 self.module_uids,
                 self.module_uids,
+                self.server_info,
                 expiration_time=get_dht_time() + self.expiration,
                 expiration_time=get_dht_time() + self.expiration,
-                state=self.state,
-                throughput=self.throughput,
-                adapters=self.adapters,
             )
             )
             if self.stop.wait(self.update_period):
             if self.stop.wait(self.update_period):
                 break
                 break

+ 2 - 2
src/petals/utils/convert_block.py

@@ -2,7 +2,7 @@
 Tools for converting transformer blocks, applying quantization and/or tensor parallelism
 Tools for converting transformer blocks, applying quantization and/or tensor parallelism
 """
 """
 import re
 import re
-from typing import List, Optional, Sequence
+from typing import Optional, Sequence
 
 
 import tensor_parallel as tp
 import tensor_parallel as tp
 import torch
 import torch
@@ -25,7 +25,7 @@ def convert_block(
     output_device: torch.device,
     output_device: torch.device,
     quant_type: QuantType,
     quant_type: QuantType,
     freeze: bool = True,
     freeze: bool = True,
-    adapters: Optional[List[str]] = None,
+    adapters: Optional[Sequence[str]] = None,
     **kwargs,
     **kwargs,
 ) -> tp.TensorParallel:
 ) -> tp.TensorParallel:
     """
     """