Răsfoiți Sursa

Share more info about a server in DHT (#355)

Alexander Borzunov 2 ani în urmă
părinte
comite
2c8959e713

+ 1 - 1
setup.cfg

@@ -38,7 +38,7 @@ install_requires =
     tokenizers>=0.13.3
     transformers>=4.30.1,<5.0.0
     speedtest-cli==2.1.3
-    pydantic>=1.8.1,<2.0  # 2.0 is incompatible with hivemind==1.1.8
+    pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind==1.1.8
     hivemind==1.1.8
     tensor_parallel==1.0.23
     humanfriendly

+ 1 - 1
src/petals/__init__.py

@@ -11,7 +11,7 @@ from petals.models import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 
-__version__ = "1.2.0.dev1"
+__version__ = "1.2.0.dev2"
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

+ 3 - 2
src/petals/cli/run_server.py

@@ -146,8 +146,9 @@ def main():
                         help="Skip checking this server's reachability via health.petals.ml "
                              "when connecting to the public swarm. If you connect to a private swarm, "
                              "the check is skipped by default. Use this option only if you know what you are doing")
-    
-    parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.")
+
+    parser.add_argument("--adapters", nargs='+', default=(),
+                        help="List of pre-loaded LoRA adapters that can be used for inference or training")
 
     # fmt:on
     args = vars(parser.parse_args())

+ 25 - 8
src/petals/data_structures.py

@@ -1,10 +1,8 @@
-from __future__ import annotations
-
 import dataclasses
-from dataclasses import dataclass
 from enum import Enum
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, Optional, Sequence, Tuple
 
+import pydantic
 from hivemind import PeerID
 from hivemind.moe.expert_uid import ExpertUID
 
@@ -21,13 +19,32 @@ class ServerState(Enum):
     ONLINE = 2
 
 
-@dataclass
+@pydantic.dataclasses.dataclass
 class ServerInfo:
     state: ServerState
-    throughput: float
+    throughput: pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+
+    adapters: Sequence[str] = ()
+    version: Optional[str] = None
+    torch_dtype: Optional[str] = None
+    quant_type: Optional[str] = None
+    using_relay: Optional[bool] = None
+    cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
+
+    def to_tuple(self) -> Tuple[int, float, dict]:
+        extra_info = dataclasses.asdict(self)
+        del extra_info["state"], extra_info["throughput"]
+        return (self.state.value, self.throughput, extra_info)
+
+    @classmethod
+    def from_tuple(cls, source: tuple):
+        state, throughput = source[:2]
+        extra_info = source[2] if len(source) > 2 else {}
+        # pydantic will validate existing fields and ignore extra ones
+        return cls(state=ServerState(state), throughput=throughput, **extra_info)
 
 
-@dataclass
+@dataclasses.dataclass
 class RemoteModuleInfo:
     """A remote module that is served by one or more servers"""
 
@@ -35,7 +52,7 @@ class RemoteModuleInfo:
     servers: Dict[PeerID, ServerInfo]
 
 
-@dataclass
+@dataclasses.dataclass
 class RemoteSpanInfo:
     """A chain of remote blocks served by one specific remote peer"""
 

+ 11 - 30
src/petals/dht_utils.py

@@ -11,7 +11,7 @@ from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
 
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
 
 logger = get_logger(__name__)
 
@@ -19,10 +19,8 @@ logger = get_logger(__name__)
 def declare_active_modules(
     dht: DHT,
     uids: Sequence[ModuleUID],
+    server_info: ServerInfo,
     expiration_time: DHTExpiration,
-    state: ServerState,
-    throughput: float,
-    adapters: Optional[Sequence[str]] = None,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
     """
@@ -42,14 +40,7 @@ def declare_active_modules(
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
 
     return dht.run_coroutine(
-        partial(
-            _declare_active_modules,
-            uids=uids,
-            expiration_time=expiration_time,
-            state=state,
-            throughput=throughput,
-            adapters=list(adapters or []),
-        ),
+        partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
         return_future=not wait,
     )
 
@@ -58,16 +49,14 @@ async def _declare_active_modules(
     dht: DHT,
     node: DHTNode,
     uids: List[ModuleUID],
+    server_info: ServerInfo,
     expiration_time: DHTExpiration,
-    state: ServerState,
-    throughput: float,
-    adapters: List[str],
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[(state.value, throughput, dict(adapters=adapters))] * len(uids),
+        values=[server_info.to_tuple()] * len(uids),
         expiration_time=expiration_time,
         num_workers=num_workers,
     )
@@ -115,29 +104,21 @@ async def _get_remote_module_infos(
         metadata = found[uid]
         if metadata is None or not isinstance(metadata.value, dict):
             if metadata is not None:
-                logger.error(f"Incorrect metadata for {uid}: {metadata}")
+                logger.warning(f"Incorrect metadata for {uid}: {metadata}")
             continue
         servers = {}
         for peer_id, server_info in metadata.value.items():
             try:
                 peer_id = PeerID.from_base58(peer_id)
-                state, throughput = server_info.value[:2]
-                extra_info = server_info.value[2] if len(server_info.value) > 2 else {}
-                adapters = extra_info.get("adapters", [])
-                if bool(active_adapter) and active_adapter not in adapters:
+                server_info = ServerInfo.from_tuple(server_info.value)
+
+                if active_adapter and active_adapter not in server_info.adapters:
                     logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
                     continue
 
-                if not (
-                    isinstance(state, int)
-                    and isinstance(throughput, float)
-                    and math.isfinite(throughput)
-                    and throughput >= 0.0
-                ):
-                    raise ValueError(f"Invalid server info: {server_info}")
-                servers[peer_id] = ServerInfo(ServerState(state), throughput)
+                servers[peer_id] = server_info
             except (TypeError, ValueError) as e:
-                logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
+                logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)
     return modules

+ 0 - 2
src/petals/models/bloom/config.py

@@ -9,8 +9,6 @@ from petals.client.lm_head import LMHeadConfig
 from petals.client.ptune import PTuneConfig
 from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.models.bloom.block import WrappedBloomBlock
-from petals.utils.auto_config import AutoDistributedConfig
-from petals.utils.version import get_compatible_model_repo
 
 logger = get_logger(__name__)
 

+ 1 - 3
src/petals/models/llama/config.py

@@ -9,7 +9,6 @@ from petals.client.lm_head import LMHeadConfig
 from petals.client.ptune import PTuneConfig
 from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.models.llama.block import WrappedLlamaBlock
-from petals.utils.auto_config import AutoDistributedConfig
 
 logger = get_logger(__name__)
 
@@ -31,8 +30,7 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         if loading_from_repo and dht_prefix is None:
             dht_prefix = str(model_name_or_path)
-            if "/" in dht_prefix:  # If present, strip repository name to merge blocks hosted by different accounts
-                dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :]
+            dht_prefix = dht_prefix.split("/")[-1]  # Use only repo name to merge blocks hosted by different accounts
             if not dht_prefix.endswith("-hf"):
                 dht_prefix += "-hf"
             logger.info(f"Using DHT prefix: {dht_prefix}")

+ 1 - 2
src/petals/server/handler.py

@@ -562,11 +562,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         """Return metadata about stored block uids and current load"""
 
         backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values()))
-        cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes)
         result = {
             "version": petals.__version__,
             "dht_client_mode": self.dht.client_mode,
-            CACHE_TOKENS_AVAILABLE: cache_bytes_left // max(backend.cache_bytes_per_token.values()),
+            CACHE_TOKENS_AVAILABLE: backend.memory_cache.bytes_left // max(backend.cache_bytes_per_token.values()),
         }
 
         if request.uid:

+ 4 - 0
src/petals/server/memory_cache.py

@@ -47,6 +47,10 @@ class MemoryCache:
     def current_size_bytes(self, value: int):
         self._current_size.value = value
 
+    @property
+    def bytes_left(self) -> int:
+        return self.max_size_bytes - self.current_size_bytes
+
     @property
     def handle_counter(self) -> int:
         return self._handle_counter.value

+ 46 - 36
src/petals/server/server.py

@@ -16,8 +16,9 @@ from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
 
+import petals
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
@@ -29,7 +30,6 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
-from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 from petals.utils.version import get_compatible_model_repo
 
 logger = get_logger(__name__)
@@ -81,7 +81,7 @@ class Server:
         dht_client_mode: Optional[bool] = None,
         use_relay: bool = True,
         use_auto_relay: bool = True,
-        adapters: Optional[List[str]] = None,
+        adapters: Sequence[str] = (),
         **kwargs,
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@@ -215,7 +215,15 @@ class Server:
                 force_eval=(throughput == "eval"),
                 cache_dir=cache_dir,
             )
-        self.throughput = throughput
+        self.server_info = ServerInfo(
+            state=ServerState.JOINING,
+            throughput=throughput,
+            adapters=tuple(adapters),
+            version=petals.__version__,
+            torch_dtype=str(torch_dtype).lstrip("torch."),
+            quant_type=quant_type.name.lower(),
+            using_relay=self.dht.client_mode,
+        )
 
         self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
@@ -283,7 +291,7 @@ class Server:
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
                 alloc_timeout=self.alloc_timeout,
-                throughput=self.throughput,
+                server_info=self.server_info,
                 block_indices=block_indices,
                 num_handlers=self.num_handlers,
                 min_batch_size=self.min_batch_size,
@@ -307,7 +315,6 @@ class Server:
                 quant_type=self.quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 should_validate_reachability=self.should_validate_reachability,
-                adapters=self.adapters,
                 start=True,
             )
             try:
@@ -385,7 +392,7 @@ class ModuleContainer(threading.Thread):
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
         alloc_timeout: float,
-        throughput: float,
+        server_info: ServerInfo,
         block_indices: List[int],
         min_batch_size: int,
         max_batch_size: int,
@@ -401,16 +408,18 @@ class ModuleContainer(threading.Thread):
         quant_type: QuantType,
         tensor_parallel_devices: Sequence[torch.device],
         should_validate_reachability: bool,
-        adapters: Optional[List[str]] = None,
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
+        memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
+
+        server_info.state = ServerState.JOINING
         joining_announcer = ModuleAnnouncerThread(
             module_uids,
             dht,
-            ServerState.JOINING,
-            adapters=adapters,
-            throughput=throughput,
+            server_info,
+            block_config=block_config,
+            memory_cache=memory_cache,
             update_period=update_period,
             expiration=expiration,
             daemon=True,
@@ -420,7 +429,6 @@ class ModuleContainer(threading.Thread):
 
         assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
 
-        memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
         blocks = {}
         try:
             for module_uid, block_index in zip(module_uids, block_indices):
@@ -441,7 +449,7 @@ class ModuleContainer(threading.Thread):
                     tensor_parallel_devices,
                     device,
                     quant_type,
-                    adapters=adapters,
+                    adapters=server_info.adapters,
                     freeze=True,
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
@@ -477,13 +485,12 @@ class ModuleContainer(threading.Thread):
 
             joining_announcer.stop.set()
             joining_announcer.join()
+            server_info.state = ServerState.OFFLINE
             declare_active_modules(
                 dht,
                 module_uids,
+                server_info,
                 expiration_time=get_dht_time() + expiration,
-                state=ServerState.OFFLINE,
-                throughput=throughput,
-                adapters=adapters,
             )
             logger.info(f"Announced that blocks {module_uids} are offline")
             raise
@@ -497,8 +504,9 @@ class ModuleContainer(threading.Thread):
             dht,
             dht_prefix,
             blocks,
-            adapters=adapters,
-            throughput=throughput,
+            block_config=block_config,
+            memory_cache=memory_cache,
+            server_info=server_info,
             update_period=update_period,
             expiration=expiration,
             **kwargs,
@@ -510,10 +518,11 @@ class ModuleContainer(threading.Thread):
         dht_prefix: str,
         module_backends: Dict[str, TransformerBackend],
         *,
+        block_config: PretrainedConfig,
+        memory_cache: MemoryCache,
         inference_max_length: int,
         num_handlers: int,
-        throughput: float,
-        adapters: Optional[Sequence[str]],
+        server_info: ServerInfo,
         update_period: float,
         expiration: Optional[float] = None,
         request_timeout: float,
@@ -525,7 +534,7 @@ class ModuleContainer(threading.Thread):
         super().__init__()
 
         self.dht, self.module_backends = dht, module_backends
-        self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
+        self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
 
         self.push_manager = mp.Manager()
         self.push_manager.__enter__()
@@ -534,7 +543,7 @@ class ModuleContainer(threading.Thread):
             TransformerConnectionHandler(
                 dht,
                 self.module_backends,
-                adapters=adapters,
+                adapters=server_info.adapters,
                 dht_prefix=dht_prefix,
                 push_manager=self.push_manager,
                 session_queues=session_queues,
@@ -548,12 +557,14 @@ class ModuleContainer(threading.Thread):
 
         self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
+
+        self.server_info.state = ServerState.ONLINE
         self.online_announcer = ModuleAnnouncerThread(
             list(self.module_backends.keys()),
             dht,
-            ServerState.ONLINE,
-            adapters=adapters,
-            throughput=throughput,
+            self.server_info,
+            block_config=block_config,
+            memory_cache=memory_cache,
             update_period=update_period,
             expiration=expiration,
             daemon=True,
@@ -613,12 +624,12 @@ class ModuleContainer(threading.Thread):
         self.online_announcer.stop.set()
         self.online_announcer.join()
 
+        self.server_info.state = ServerState.OFFLINE
         declare_active_modules(
             self.dht,
             self.module_backends.keys(),
+            self.server_info,
             expiration_time=get_dht_time() + self.expiration,
-            state=ServerState.OFFLINE,
-            throughput=self.throughput,
         )
         logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
 
@@ -651,10 +662,10 @@ class ModuleAnnouncerThread(threading.Thread):
         self,
         module_uids: List[str],
         dht: DHT,
-        state: ServerState,
-        adapters: Optional[Sequence[str]],
+        server_info: ServerInfo,
         *,
-        throughput: float,
+        block_config: PretrainedConfig,
+        memory_cache: MemoryCache,
         update_period: float = 30,
         expiration: float,
         **kwargs,
@@ -662,22 +673,21 @@ class ModuleAnnouncerThread(threading.Thread):
         super().__init__(**kwargs)
         self.module_uids = module_uids
         self.dht = dht
-        self.state = state
-        self.adapters = adapters
-        self.throughput = throughput
+        self.server_info = server_info
+        self.memory_cache = memory_cache
+        self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
         self.update_period = update_period
         self.expiration = expiration
         self.stop = threading.Event()
 
     def run(self) -> None:
         while True:
+            self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
             declare_active_modules(
                 self.dht,
                 self.module_uids,
+                self.server_info,
                 expiration_time=get_dht_time() + self.expiration,
-                state=self.state,
-                throughput=self.throughput,
-                adapters=self.adapters,
             )
             if self.stop.wait(self.update_period):
                 break

+ 2 - 2
src/petals/utils/convert_block.py

@@ -2,7 +2,7 @@
 Tools for converting transformer blocks, applying quantization and/or tensor parallelism
 """
 import re
-from typing import List, Optional, Sequence
+from typing import Optional, Sequence
 
 import tensor_parallel as tp
 import torch
@@ -25,7 +25,7 @@ def convert_block(
     output_device: torch.device,
     quant_type: QuantType,
     freeze: bool = True,
-    adapters: Optional[List[str]] = None,
+    adapters: Optional[Sequence[str]] = None,
     **kwargs,
 ) -> tp.TensorParallel:
     """