Przeglądaj źródła

Return available cache size in rpc_info() (#191)

This PR makes servers return their free cache (in tokens * layers to make it compression-agnostic)

To be used when calling make_sequence(optimize="inference")
justheuristic 2 lat temu
rodzic
commit
5f58f00649

+ 5 - 0
src/petals/server/backend.py

@@ -1,6 +1,7 @@
 """Code for serving bloom blocks via hivemind-server"""
 from __future__ import annotations
 
+from collections import Counter
 from itertools import chain
 from typing import Any, Dict, Sequence, Tuple
 
@@ -64,6 +65,10 @@ class TransformerBackend(ModuleBackend):
             self.kwargs_schema,
         )
 
+        self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
+        for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
+            self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
+
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
         """Create tensor descriptors for attention cache tensors used during inference_step"""
         head_dim = self.config.hidden_size // self.config.n_head

+ 18 - 0
src/petals/server/handler.py

@@ -33,6 +33,8 @@ from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__file__)
 
+CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
+
 
 class TransformerConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
@@ -378,6 +380,22 @@ class TransformerConnectionHandler(ConnectionHandler):
         else:
             logger.warning(f"{message}: {warning}")
 
+    async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
+        """Return metadata about stored block uids and current load"""
+        rpc_info = {}
+        if request.uid:
+            backend = self.module_backends[request.uid]
+            rpc_info.update(self.module_backends[request.uid].get_info())
+        else:
+            backend = next(iter(self.module_backends.values()))
+            # not saving keys to rpc_info since user did not request any uid
+
+        cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes)
+        if CACHE_TOKENS_AVAILABLE in rpc_info:
+            raise RuntimeError(f"Block rpc_info dict has a reserved field {CACHE_TOKENS_AVAILABLE} : {rpc_info}")
+        rpc_info[CACHE_TOKENS_AVAILABLE] = cache_bytes_left // max(backend.cache_bytes_per_token.values())
+        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(rpc_info))
+
 
 async def _rpc_forward(
     *flat_tensors: torch.Tensor,

+ 39 - 0
tests/test_server_stats.py

@@ -0,0 +1,39 @@
+import time
+
+import hivemind
+import pytest
+import torch
+from test_utils import *
+
+from petals.client import DistributedBloomConfig
+from petals.data_structures import UID_DELIMITER
+from petals.dht_utils import get_remote_sequence
+from petals.server.handler import CACHE_TOKENS_AVAILABLE
+
+
+@pytest.mark.forked
+def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50):
+    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+
+    blocks1 = get_remote_sequence(dht, block_from, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}")
+    blocks2 = get_remote_sequence(dht, block_to - 1, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}")
+    info_before = blocks1.sequence_manager.rpc_info
+
+    with blocks1.inference_session(max_length=max_length) as sess:
+        sess.step(torch.randn(1, 1, config.hidden_size))
+        blocks1.sequence_manager._rpc_info = None  # invalidate cache
+        info_inside = blocks1.sequence_manager.rpc_info
+
+        with blocks2.inference_session(max_length=max_length2) as sess2:
+            sess2.step(torch.randn(1, 1, config.hidden_size))
+            blocks2.sequence_manager._rpc_info = None  # invalidate cache
+            info_inside2 = blocks2.sequence_manager.rpc_info
+
+    time.sleep(0.1)
+    blocks1.sequence_manager._rpc_info = None  # invalidate cache
+    info_after = blocks1.sequence_manager.rpc_info
+
+    assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]
+    assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1)
+    assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2)