Преглед изворни кода

Report server version and dht.client_mode in rpc_info(), check for updates on startup (#209)

This PR:

1. Shows the current Petals version and checks for updates on startup.
2. Reports the current version and DHT mode in `rpc_info()`, so it can be shown on http://health.petals.ml or used on clients for efficient routing.
Alexander Borzunov пре 2 година
родитељ
комит
6b12b0d050
5 измењених фајлова са 47 додато и 12 уклоњено
  1. 1 0
      setup.cfg
  2. 3 0
      src/petals/cli/run_server.py
  3. 16 11
      src/petals/server/handler.py
  4. 1 1
      src/petals/server/server.py
  5. 26 0
      src/petals/utils/version.py

+ 1 - 0
setup.cfg

@@ -42,6 +42,7 @@ install_requires =
     humanfriendly
     async-timeout>=4.0.2
     cpufeature>=0.2.0
+    packaging>=23.0
 
 [options.extras_require]
 dev =

+ 3 - 0
src/petals/cli/run_server.py

@@ -8,6 +8,7 @@ from humanfriendly import parse_size
 
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.server.server import Server
+from petals.utils.version import validate_version
 
 logger = get_logger(__file__)
 
@@ -193,6 +194,8 @@ def main():
     if load_in_8bit is not None:
         args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
 
+    validate_version()
+
     server = Server(
         **args,
         host_maddrs=host_maddrs,

+ 16 - 11
src/petals/server/handler.py

@@ -24,6 +24,7 @@ from hivemind.utils.asyncio import amap_in_executor, anext
 from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
+import petals
 from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID
 from petals.server.backend import TransformerBackend
 from petals.server.memory_cache import Handle
@@ -382,19 +383,23 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         """Return metadata about stored block uids and current load"""
-        rpc_info = {}
-        if request.uid:
-            backend = self.module_backends[request.uid]
-            rpc_info.update(self.module_backends[request.uid].get_info())
-        else:
-            backend = next(iter(self.module_backends.values()))
-            # not saving keys to rpc_info since user did not request any uid
 
+        backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values()))
         cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes)
-        if CACHE_TOKENS_AVAILABLE in rpc_info:
-            raise RuntimeError(f"Block rpc_info dict has a reserved field {CACHE_TOKENS_AVAILABLE} : {rpc_info}")
-        rpc_info[CACHE_TOKENS_AVAILABLE] = cache_bytes_left // max(backend.cache_bytes_per_token.values())
-        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(rpc_info))
+        result = {
+            "version": petals.__version__,
+            "dht_client_mode": self.dht.client_mode,
+            CACHE_TOKENS_AVAILABLE: cache_bytes_left // max(backend.cache_bytes_per_token.values()),
+        }
+
+        if request.uid:
+            block_info = self.module_backends[request.uid].get_info()
+            common_keys = set(result.keys()) & set(block_info.keys())
+            if common_keys:
+                raise RuntimeError(f"The block's rpc_info has keys reserved for the server's rpc_info: {common_keys}")
+            result.update(block_info)
+
+        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
 
 
 async def _rpc_forward(

+ 1 - 1
src/petals/server/server.py

@@ -102,7 +102,7 @@ class Server:
                 f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
                 f"Please specify --prefix manually when starting a server"
             )
-            logger.info(f"Automatic dht prefix: {prefix}")
+            logger.debug(f"Automatic dht prefix: {prefix}")
         self.prefix = prefix
 
         if expiration is None:

+ 26 - 0
src/petals/utils/version.py

@@ -0,0 +1,26 @@
+import requests
+from hivemind.utils.logging import TextStyle, get_logger
+from packaging.version import parse
+
+import petals
+
+logger = get_logger(__file__)
+
+
+def validate_version():
+    logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}")
+    try:
+        r = requests.get("https://pypi.python.org/pypi/petals/json")
+        r.raise_for_status()
+        response = r.json()
+
+        versions = [parse(ver) for ver in response.get("releases")]
+        latest = max(ver for ver in versions if not ver.is_prerelease)
+
+        if parse(petals.__version__) < latest:
+            logger.info(
+                f"A newer version {latest} is available. Please upgrade with: "
+                f"{TextStyle.BOLD}pip install --upgrade petals{TextStyle.RESET}"
+            )
+    except Exception as e:
+        logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True)