فهرست منبع

Implement shortest-path routing for inference (#362)

This PR:

1. **Adds shortest path routing for inference.** We build a graph with client-server and server-server latencies and compute costs, as well as empirically measured overheads. For client-server latencies, we ping possible first and last servers in a sequence in `SequenceManager.update()`. We penalize servers who may not have enough cache for our request. This uses info added to DHT in #355, #356, #358.

2. **Makes a server ping neighboring servers in addition to next ones.** This is to get an opportunity to change the server even before we use all its blocks (e.g., because a neighboring server is faster). This feature is not enabled though, since it increases graph size for N servers to O(N^2) - but we may enable it if needed.

3. **Fixes a `SequenceManager` bug with the first `update()`.** Previously, this update was likely to produce incorrect information and cause to `MissingBlocksErrors` until the next update happens.
Alexander Borzunov 2 سال پیش
والد
کامیت
62d9ed5ce7

+ 1 - 0
setup.cfg

@@ -48,6 +48,7 @@ install_requires =
     sentencepiece>=0.1.99
     peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
     safetensors>=0.3.1
+    Dijkstar>=2.6.0
 
 [options.extras_require]
 dev =

+ 1 - 1
src/petals/__init__.py

@@ -11,7 +11,7 @@ from petals.models import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 
-__version__ = "1.2.0.dev2"
+__version__ = "1.2.0.dev3"
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

+ 1 - 1
src/petals/cli/run_server.py

@@ -84,7 +84,7 @@ def main():
     parser.add_argument('--attn_cache_tokens', type=int, default=8192,
                         help='The number of past attention key/value pairs that will be stored between inference steps. '
                              'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).')
-    parser.add_argument('--alloc_timeout', type=float, default=60,
+    parser.add_argument('--alloc_timeout', type=float, default=5,
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                              'before rejecting the request')
     parser.add_argument('--revision', type=str, default=None,

+ 3 - 1
src/petals/client/inference_session.py

@@ -340,7 +340,9 @@ class InferenceSession:
                 f"from block {block_idx} to {update_end} will be regenerated"
             )
 
-        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
+        updated_spans = self._sequence_manager.make_sequence(
+            block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length
+        )
         # make_sequence() could return a longer sequence
         updated_spans[-1].end = min(updated_spans[-1].end, update_end)
         updated_sessions = self._enter_server_sessions(updated_spans)

+ 173 - 23
src/petals/client/routing/sequence_manager.py

@@ -10,6 +10,7 @@ import time
 from typing import Any, Collection, Dict, List, Optional, Sequence, Union
 from weakref import WeakMethod
 
+import dijkstar
 import numpy as np
 from hivemind import DHT, P2P, MSGPackSerializer, PeerID
 from hivemind.dht.node import Blacklist
@@ -23,6 +24,8 @@ from petals.client.routing.spending_policy import NoSpendingPolicy
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
 from petals.server.handler import TransformerConnectionHandler
+from petals.utils.ping import PingAggregator
+from petals.utils.random import sample_up_to
 
 logger = get_logger(__name__)
 
@@ -33,6 +36,7 @@ class SequenceManagerConfig:
     dht_prefix: Optional[str] = None  # a prefix for all dht keys that correspond to this model (default: model name)
     daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
 
+    show_route: Union[str, bool] = "inference"  # show chosen route through servers. one of [False, "inference", True]
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
     use_server_to_server: bool = True  # Use direct server-to-server communication
 
@@ -43,7 +47,10 @@ class SequenceManagerConfig:
     min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
     max_backoff: float = 60  # limit maximal sleep time between retries to this value
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
-    active_adapter: Optional[str] = None
+    active_adapter: Optional[str] = None  # name of active LoRA adapter (usually, Hugging Face repo)
+
+    max_pinged: int = 5  # max servers to ping from each sequence side, per update
+    ping_timeout: float = 2  # max time to wait for pings, per update
 
 
 @dataclasses.dataclass
@@ -79,7 +86,6 @@ class RemoteSequenceManager:
         *,
         dht: Optional[DHT] = None,
         state: Optional[SequenceManagerState] = None,
-        active_adapter: Optional[str] = None,
     ):
         assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
@@ -94,7 +100,7 @@ class RemoteSequenceManager:
             dht = DHT(
                 initial_peers=config.initial_peers,
                 client_mode=True,
-                num_workers=config.num_hidden_layers,
+                num_workers=32,
                 startup_timeout=config.daemon_startup_timeout,
                 start=True,
             )
@@ -109,25 +115,25 @@ class RemoteSequenceManager:
         self._thread_start_lock = threading.Lock()
         self.policy = NoSpendingPolicy()
 
+        self.ping_aggregator = PingAggregator(dht)
+
         if state.banned_peers is None:
             state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
         if state.sequence_info is None:
             state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
 
-        if state.sequence_info.last_updated_time is None:
-            # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
-            # in the first _update() instead of the latest ones. This makes the first .update() faster.
-            petals.dht_utils.get_remote_module_infos(
-                self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True
-            )
-            self._need_latest_infos = False
-        else:
+        if state.sequence_info.last_updated_time is not None:
             assert block_uids == state.sequence_info.block_uids
             self._thread.ready.set()  # no need to await the first dht fetch
             self._need_latest_infos = True
 
     def make_sequence(
-        self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str
+        self,
+        start_index: int = 0,
+        end_index: Optional[int] = None,
+        *,
+        mode: str,
+        cache_tokens_needed: Optional[int] = None,
     ) -> List[RemoteSpanInfo]:
         """
         Form a sequence of remote servers that collectively serve all consecutive layers
@@ -143,6 +149,150 @@ class RemoteSequenceManager:
             self.update(wait=True)  # this will await an existing update or trigger a new one (if not updating)
 
         end_index = end_index if end_index is not None else len(self)
+
+        if mode == "min_latency":
+            span_sequence = self._make_sequence_with_min_latency(
+                start_index, end_index, cache_tokens_needed=cache_tokens_needed
+            )
+        elif mode == "max_throughput":
+            span_sequence = self._make_sequence_with_max_throughput(start_index, end_index)
+        else:
+            raise RuntimeError(f"Unexpected mode {mode}")
+
+        if self.config.show_route is True or (mode == "min_latency" and self.config.show_route == "inference"):
+            route_repr = " => ".join(
+                [f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence]
+            )
+            logger.info(f"Route found: {route_repr}")
+        return span_sequence
+
+    def _make_sequence_with_min_latency(
+        self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int]
+    ) -> List[RemoteSpanInfo]:
+        if start_index == end_index:
+            return []
+
+        with self.lock_changes:
+            missing_blocks = [
+                block_idx
+                for block_idx in range(start_index, end_index)
+                if not self.state.sequence_info.spans_containing_block[block_idx]
+            ]
+            if missing_blocks:
+                raise MissingBlocksError(missing_blocks)
+            server_infos = {
+                span.peer_id: span.server_info
+                for block_idx in range(start_index, end_index)
+                for span in self.state.sequence_info.spans_containing_block[block_idx]
+            }
+
+            graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed)
+
+        path = dijkstar.find_path(graph, "start", "end")
+        logger.debug(f"Path info: {path}")
+        if start_index == 0 and end_index == len(self):
+            logger.debug(f"Expected speed: {1 / path.total_cost:.1f} steps/sec")
+
+        span_sequence = []
+        for peer_id, block_idx in path.nodes[1:-1]:
+            if not span_sequence or span_sequence[-1].peer_id != peer_id:
+                span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id]))
+            else:
+                span_sequence[-1].end = block_idx
+
+        # Remove empty spans that can appear if we don't force to go to the end of each server and network delay
+        # don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors
+        span_sequence = [span for span in span_sequence if span.length > 0]
+
+        return span_sequence
+
+    def _build_inference_graph(
+        self,
+        start_index: int,
+        end_index: int,
+        *,
+        cache_tokens_needed: Optional[int],
+        overhead_coeff: float = 1.82,  # Backend overhead (empirically measured)
+        overhead_delay: float = 0.018,  # Serialization overhead (empirically measured)
+        default_inference_rps: float = 300,  # If inference RPS unknown
+        alloc_delay: float = 10,  # If not enough cache left, we penalize the edge
+    ) -> dijkstar.Graph:
+        missing_blocks = [
+            block_idx
+            for block_idx in range(start_index, end_index)
+            if not self.state.sequence_info.spans_containing_block[block_idx]
+        ]
+        if missing_blocks:
+            raise MissingBlocksError(missing_blocks)
+
+        client_server_rtts = self.ping_aggregator.to_dict()
+
+        graph = dijkstar.Graph()
+
+        # Clent -> server network delays
+        for span in self.state.sequence_info.spans_containing_block[start_index]:
+            delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
+            delay += overhead_delay
+            if not self._has_cache_for(span, cache_tokens_needed):
+                delay += alloc_delay
+            graph.add_edge("start", (span.peer_id, start_index), delay)
+
+        # Server -> client network delays
+        for span in self.state.sequence_info.spans_containing_block[end_index - 1]:
+            delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
+            graph.add_edge((span.peer_id, end_index), "end", delay)
+
+        # Server -> server network delays
+        for block_idx in range(start_index + 1, end_index):
+            for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]:
+                if cur_span.end != block_idx:
+                    # If we choose a server, we force to go to the end of it before switching to a new one
+                    # to avoid O(N^2) graphs for N servers
+                    continue
+
+                for next_span in self.state.sequence_info.spans_containing_block[block_idx]:
+                    rtt = None
+                    if cur_span.server_info.next_pings is not None:
+                        rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58())
+                    delay = self._rtt_to_delay(rtt)
+                    delay += overhead_delay
+                    if not self._has_cache_for(next_span, cache_tokens_needed):
+                        delay += alloc_delay
+                    graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay)
+
+        # Compute delays
+        for span in self.state.sequence_info.spans_by_priority:
+            for block_idx in range(max(span.start, start_index), min(span.end, end_index)):
+                inference_rps = span.server_info.inference_rps
+                if inference_rps is None:
+                    inference_rps = default_inference_rps
+                graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), overhead_coeff / inference_rps)
+
+        return graph
+
+    @staticmethod
+    def _rtt_to_delay(
+        rtt: float,
+        *,
+        default_delay: float = 0.15,  # If network delay unknown
+        max_delay: float = 5,  # If unreachable, we don't want to discard the edge completely
+    ) -> float:
+        if rtt is None:
+            return default_delay
+        return min(rtt / 2, max_delay)
+
+    @staticmethod
+    def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool:
+        if cache_tokens_needed is None or span.server_info.cache_tokens_left is None:
+            return True
+
+        # Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through
+        # this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage,
+        # so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate.
+        # This is okay since false positives are more costly than false negatives here.
+        return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
+
+    def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
         span_sequence = []
         current_index = start_index
         while current_index < end_index:
@@ -150,20 +300,12 @@ class RemoteSequenceManager:
             if not candidate_spans:
                 raise MissingBlocksError(current_index)
 
-            if mode == "max_throughput":
-                span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
-            elif mode == "min_latency":
-                span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
-            else:
-                raise RuntimeError(f"Unexpected mode {mode}")
+            span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
             chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
 
             assert chosen_span.start <= current_index < chosen_span.end
             span_sequence.append(dataclasses.replace(chosen_span, start=current_index))
             current_index = chosen_span.end
-
-        route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence])
-        logger.debug(f"Route found: {route_repr}")
         return span_sequence
 
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
@@ -182,10 +324,10 @@ class RemoteSequenceManager:
 
     def _update(self):
         """Perform an immediate and synchronous refresh, may take time"""
+
         new_block_infos = petals.dht_utils.get_remote_module_infos(
-            self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos
+            self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
         )
-        self._need_latest_infos = True  # All future _update() should use latest infos
 
         for block_info in new_block_infos:
             if not block_info:
@@ -217,6 +359,14 @@ class RemoteSequenceManager:
 
         with self.lock_changes:
             self.state.sequence_info.update_(new_block_infos)
+
+            first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
+            last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
+
+        pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
+        pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
+        self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
+
         self.ready.set()
 
     def on_request_failure(self, peer_id: Optional[PeerID]):

+ 16 - 15
src/petals/server/server.py

@@ -32,6 +32,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.ping import PingAggregator
+from petals.utils.random import sample_up_to
 from petals.utils.version import get_compatible_model_repo
 
 logger = get_logger(__name__)
@@ -61,7 +62,7 @@ class Server:
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
         attn_cache_tokens: int = 8192,
-        alloc_timeout: float = 60,
+        alloc_timeout: float = 5,
         device: Optional[Union[str, torch.device]] = None,
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
@@ -637,7 +638,6 @@ class ModuleAnnouncerThread(threading.Thread):
         update_period: float,
         expiration: float,
         max_pinged: int = 5,
-        max_reported: int = 10,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -650,10 +650,11 @@ class ModuleAnnouncerThread(threading.Thread):
         self.expiration = expiration
         self.trigger = threading.Event()
 
-        self.max_pinged, self.max_reported = max_pinged, max_reported
-        last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1]))
-        dht_prefix, block_index = last_uid.split(UID_DELIMITER)
-        self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}"
+        self.max_pinged = max_pinged
+        dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
+        block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
+        start_block, end_block = min(block_indices), max(block_indices) + 1
+        self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
         self.ping_aggregator = PingAggregator(self.dht)
 
     def run(self) -> None:
@@ -664,7 +665,7 @@ class ModuleAnnouncerThread(threading.Thread):
             if self.server_info.state != ServerState.OFFLINE:
                 self._ping_next_servers()
                 self.server_info.next_pings = {
-                    peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items()
+                    peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items()
                 }
             else:
                 self.server_info.next_pings = None  # No need to ping if we're disconnecting
@@ -691,14 +692,14 @@ class ModuleAnnouncerThread(threading.Thread):
             self.join()
 
     def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
-        [module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True)
-        if module_info is None:
-            return
-
-        next_servers = list(module_info.servers)
-        if len(next_servers) > self.max_pinged:
-            next_servers = random.sample(next_servers, self.max_pinged)
-        self.ping_aggregator.ping(next_servers)
+        module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
+        middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
+        pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
+        pinged_servers.discard(self.dht.peer_id)
+        if module_infos[-1] is not None:
+            # Sample servers hosting the block after the last one (most likely continuations) separately
+            pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
+        self.ping_aggregator.ping(list(pinged_servers))
 
 
 class RuntimeWithDeduplicatedPools(Runtime):

+ 15 - 14
src/petals/utils/ping.py

@@ -1,5 +1,6 @@
 import asyncio
 import math
+import threading
 import time
 from functools import partial
 from typing import Dict, Sequence
@@ -34,27 +35,27 @@ async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) ->
 
 
 class PingAggregator:
-    def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 3600):
+    def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300):
         self.dht = dht
         self.ema_alpha = ema_alpha
         self.expiration = expiration
         self.ping_emas = hivemind.TimedStorage()
+        self.lock = threading.Lock()
 
-    def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs):
+    def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None:
         current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
         logger.debug(f"Current RTTs: {current_rtts}")
 
-        expiration = hivemind.get_dht_time() + self.expiration
-        for peer_id, rtt in current_rtts.items():
-            prev_rtt = self.ping_emas.get(peer_id)
-            if prev_rtt is not None and prev_rtt.value != math.inf:
-                rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value  # Exponential smoothing
-            self.ping_emas.store(peer_id, rtt, expiration)
+        with self.lock:
+            expiration = hivemind.get_dht_time() + self.expiration
+            for peer_id, rtt in current_rtts.items():
+                prev_rtt = self.ping_emas.get(peer_id)
+                if prev_rtt is not None and prev_rtt.value != math.inf:
+                    rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value  # Exponential smoothing
+                self.ping_emas.store(peer_id, rtt, expiration)
 
-    def fastest(self, n_peers: int) -> Dict[hivemind.PeerID, float]:
-        with self.ping_emas.freeze():
+    def to_dict(self) -> Dict[hivemind.PeerID, float]:
+        with self.lock, self.ping_emas.freeze():
             smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
-        logger.debug(f"Smothed RTTs: {smoothed_rtts}")
-
-        fastest_rtts = sorted(smoothed_rtts.items(), key=lambda item: item[1])[:n_peers]
-        return dict(fastest_rtts)
+            logger.debug(f"Smothed RTTs: {smoothed_rtts}")
+            return smoothed_rtts

+ 12 - 0
src/petals/utils/random.py

@@ -0,0 +1,12 @@
+import random
+from typing import Collection, TypeVar
+
+T = TypeVar("T")
+
+
+def sample_up_to(population: Collection[T], k: int) -> T:
+    if not isinstance(population, list):
+        population = list(population)
+    if len(population) > k:
+        population = random.sample(population, k)
+    return population