Browse Source

Create model index in DHT (#491)

This PR creates an index of models hosted in the swarm - it is useful to know which custom models users run and display them at https://health.petals.dev as "not officially supported" models.
Alexander Borzunov 2 years ago
parent
commit
6ef6bf5fa2

+ 1 - 1
src/petals/client/inference_session.py

@@ -343,7 +343,7 @@ class InferenceSession:
         n_prev_spans = len(self._server_sessions)
         update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
         if attempt_no >= 1:
-            logger.info(
+            logger.debug(
                 f"Due to a server failure, remote attention caches "
                 f"from block {block_idx} to {update_end} will be regenerated"
             )

+ 13 - 0
src/petals/data_structures.py

@@ -20,6 +20,19 @@ class ServerState(Enum):
 RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
 
 
+@pydantic.dataclasses.dataclass
+class ModelInfo:
+    num_blocks: int
+    repository: Optional[str] = None
+
+    def to_dict(self) -> dict:
+        return dataclasses.asdict(self)
+
+    @classmethod
+    def from_dict(cls, source: dict):
+        return cls(**source)
+
+
 @pydantic.dataclasses.dataclass
 class ServerInfo:
     state: ServerState

+ 19 - 3
src/petals/server/server.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 import gc
 import math
 import multiprocessing as mp
+import os
 import random
 import threading
 import time
@@ -21,7 +22,7 @@ from transformers import PretrainedConfig
 
 import petals
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -259,6 +260,9 @@ class Server:
             using_relay=reachable_via_relay,
             **throughput_info,
         )
+        self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
+        if not os.path.isdir(converted_model_name_or_path):
+            self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
 
         self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
@@ -330,6 +334,7 @@ class Server:
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
                 server_info=self.server_info,
+                model_info=self.model_info,
                 block_indices=block_indices,
                 num_handlers=self.num_handlers,
                 min_batch_size=self.min_batch_size,
@@ -436,6 +441,7 @@ class ModuleContainer(threading.Thread):
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
         server_info: ServerInfo,
+        model_info: ModelInfo,
         block_indices: List[int],
         min_batch_size: int,
         max_batch_size: int,
@@ -463,6 +469,7 @@ class ModuleContainer(threading.Thread):
             module_uids,
             dht,
             server_info,
+            model_info,
             block_config=block_config,
             memory_cache=memory_cache,
             update_period=update_period,
@@ -671,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread):
         module_uids: List[str],
         dht: DHT,
         server_info: ServerInfo,
+        model_info: ModelInfo,
         *,
         block_config: PretrainedConfig,
         memory_cache: MemoryCache,
@@ -683,6 +691,7 @@ class ModuleAnnouncerThread(threading.Thread):
         self.module_uids = module_uids
         self.dht = dht
         self.server_info = server_info
+        self.model_info = model_info
         self.memory_cache = memory_cache
 
         self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
@@ -693,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread):
         self.trigger = threading.Event()
 
         self.max_pinged = max_pinged
-        dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
+        self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
         block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
         start_block, end_block = min(block_indices), max(block_indices) + 1
-        self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
+        self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
         self.ping_aggregator = PingAggregator(self.dht)
 
     def run(self) -> None:
@@ -720,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread):
             )
             if self.server_info.state == ServerState.OFFLINE:
                 break
+            if not self.dht_prefix.startswith("_"):  # Not private
+                self.dht.store(
+                    key="_petals.models",
+                    subkey=self.dht_prefix,
+                    value=self.model_info.to_dict(),
+                    expiration_time=get_dht_time() + self.expiration,
+                )
 
             delay = self.update_period - (time.perf_counter() - start_time)
             if delay < 0: