|
@@ -24,6 +24,7 @@ from hivemind.utils.asyncio import amap_in_executor, anext
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
+import petals
|
|
|
from petals.data_structures import CHAIN_DELIMITER, InferenceMetadata, ModuleUID
|
|
|
from petals.server.backend import TransformerBackend
|
|
|
from petals.server.memory_cache import Handle
|
|
@@ -382,19 +383,23 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
|
|
|
"""Return metadata about stored block uids and current load"""
|
|
|
- rpc_info = {}
|
|
|
- if request.uid:
|
|
|
- backend = self.module_backends[request.uid]
|
|
|
- rpc_info.update(self.module_backends[request.uid].get_info())
|
|
|
- else:
|
|
|
- backend = next(iter(self.module_backends.values()))
|
|
|
- # not saving keys to rpc_info since user did not request any uid
|
|
|
|
|
|
+ backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values()))
|
|
|
cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes)
|
|
|
- if CACHE_TOKENS_AVAILABLE in rpc_info:
|
|
|
- raise RuntimeError(f"Block rpc_info dict has a reserved field {CACHE_TOKENS_AVAILABLE} : {rpc_info}")
|
|
|
- rpc_info[CACHE_TOKENS_AVAILABLE] = cache_bytes_left // max(backend.cache_bytes_per_token.values())
|
|
|
- return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(rpc_info))
|
|
|
+ result = {
|
|
|
+ "version": petals.__version__,
|
|
|
+ "dht_client_mode": self.dht.client_mode,
|
|
|
+ CACHE_TOKENS_AVAILABLE: cache_bytes_left // max(backend.cache_bytes_per_token.values()),
|
|
|
+ }
|
|
|
+
|
|
|
+ if request.uid:
|
|
|
+ block_info = self.module_backends[request.uid].get_info()
|
|
|
+ common_keys = set(result.keys()) & set(block_info.keys())
|
|
|
+ if common_keys:
|
|
|
+ raise RuntimeError(f"The block's rpc_info has keys reserved for the server's rpc_info: {common_keys}")
|
|
|
+ result.update(block_info)
|
|
|
+
|
|
|
+ return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
|
|
|
|
|
|
|
|
|
async def _rpc_forward(
|