|
@@ -10,6 +10,7 @@ import time
|
|
|
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
|
|
|
from weakref import WeakMethod
|
|
|
|
|
|
+import dijkstar
|
|
|
import numpy as np
|
|
|
from hivemind import DHT, P2P, MSGPackSerializer, PeerID
|
|
|
from hivemind.dht.node import Blacklist
|
|
@@ -23,6 +24,8 @@ from petals.client.routing.spending_policy import NoSpendingPolicy
|
|
|
from petals.constants import PUBLIC_INITIAL_PEERS
|
|
|
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
+from petals.utils.ping import PingAggregator
|
|
|
+from petals.utils.random import sample_up_to
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -33,6 +36,7 @@ class SequenceManagerConfig:
|
|
|
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
|
|
|
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
|
|
|
|
|
|
+ show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
|
|
|
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
|
|
|
use_server_to_server: bool = True # Use direct server-to-server communication
|
|
|
|
|
@@ -43,7 +47,10 @@ class SequenceManagerConfig:
|
|
|
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
|
|
|
max_backoff: float = 60 # limit maximal sleep time between retries to this value
|
|
|
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
|
|
- active_adapter: Optional[str] = None
|
|
|
+ active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
|
|
|
+
|
|
|
+ max_pinged: int = 5 # max servers to ping from each sequence side, per update
|
|
|
+ ping_timeout: float = 2 # max time to wait for pings, per update
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
@@ -79,7 +86,6 @@ class RemoteSequenceManager:
|
|
|
*,
|
|
|
dht: Optional[DHT] = None,
|
|
|
state: Optional[SequenceManagerState] = None,
|
|
|
- active_adapter: Optional[str] = None,
|
|
|
):
|
|
|
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
|
|
|
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
|
@@ -94,7 +100,7 @@ class RemoteSequenceManager:
|
|
|
dht = DHT(
|
|
|
initial_peers=config.initial_peers,
|
|
|
client_mode=True,
|
|
|
- num_workers=config.num_hidden_layers,
|
|
|
+ num_workers=32,
|
|
|
startup_timeout=config.daemon_startup_timeout,
|
|
|
start=True,
|
|
|
)
|
|
@@ -109,25 +115,25 @@ class RemoteSequenceManager:
|
|
|
self._thread_start_lock = threading.Lock()
|
|
|
self.policy = NoSpendingPolicy()
|
|
|
|
|
|
+ self.ping_aggregator = PingAggregator(dht)
|
|
|
+
|
|
|
if state.banned_peers is None:
|
|
|
state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
|
|
|
if state.sequence_info is None:
|
|
|
state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
|
|
|
|
|
|
- if state.sequence_info.last_updated_time is None:
|
|
|
- # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
|
|
|
- # in the first _update() instead of the latest ones. This makes the first .update() faster.
|
|
|
- petals.dht_utils.get_remote_module_infos(
|
|
|
- self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True
|
|
|
- )
|
|
|
- self._need_latest_infos = False
|
|
|
- else:
|
|
|
+ if state.sequence_info.last_updated_time is not None:
|
|
|
assert block_uids == state.sequence_info.block_uids
|
|
|
self._thread.ready.set() # no need to await the first dht fetch
|
|
|
self._need_latest_infos = True
|
|
|
|
|
|
def make_sequence(
|
|
|
- self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str
|
|
|
+ self,
|
|
|
+ start_index: int = 0,
|
|
|
+ end_index: Optional[int] = None,
|
|
|
+ *,
|
|
|
+ mode: str,
|
|
|
+ cache_tokens_needed: Optional[int] = None,
|
|
|
) -> List[RemoteSpanInfo]:
|
|
|
"""
|
|
|
Form a sequence of remote servers that collectively serve all consecutive layers
|
|
@@ -143,6 +149,150 @@ class RemoteSequenceManager:
|
|
|
self.update(wait=True) # this will await an existing update or trigger a new one (if not updating)
|
|
|
|
|
|
end_index = end_index if end_index is not None else len(self)
|
|
|
+
|
|
|
+ if mode == "min_latency":
|
|
|
+ span_sequence = self._make_sequence_with_min_latency(
|
|
|
+ start_index, end_index, cache_tokens_needed=cache_tokens_needed
|
|
|
+ )
|
|
|
+ elif mode == "max_throughput":
|
|
|
+ span_sequence = self._make_sequence_with_max_throughput(start_index, end_index)
|
|
|
+ else:
|
|
|
+ raise RuntimeError(f"Unexpected mode {mode}")
|
|
|
+
|
|
|
+ if self.config.show_route is True or (mode == "min_latency" and self.config.show_route == "inference"):
|
|
|
+ route_repr = " => ".join(
|
|
|
+ [f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence]
|
|
|
+ )
|
|
|
+ logger.info(f"Route found: {route_repr}")
|
|
|
+ return span_sequence
|
|
|
+
|
|
|
+ def _make_sequence_with_min_latency(
|
|
|
+ self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int]
|
|
|
+ ) -> List[RemoteSpanInfo]:
|
|
|
+ if start_index == end_index:
|
|
|
+ return []
|
|
|
+
|
|
|
+ with self.lock_changes:
|
|
|
+ missing_blocks = [
|
|
|
+ block_idx
|
|
|
+ for block_idx in range(start_index, end_index)
|
|
|
+ if not self.state.sequence_info.spans_containing_block[block_idx]
|
|
|
+ ]
|
|
|
+ if missing_blocks:
|
|
|
+ raise MissingBlocksError(missing_blocks)
|
|
|
+ server_infos = {
|
|
|
+ span.peer_id: span.server_info
|
|
|
+ for block_idx in range(start_index, end_index)
|
|
|
+ for span in self.state.sequence_info.spans_containing_block[block_idx]
|
|
|
+ }
|
|
|
+
|
|
|
+ graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed)
|
|
|
+
|
|
|
+ path = dijkstar.find_path(graph, "start", "end")
|
|
|
+ logger.debug(f"Path info: {path}")
|
|
|
+ if start_index == 0 and end_index == len(self):
|
|
|
+ logger.debug(f"Expected speed: {1 / path.total_cost:.1f} steps/sec")
|
|
|
+
|
|
|
+ span_sequence = []
|
|
|
+ for peer_id, block_idx in path.nodes[1:-1]:
|
|
|
+ if not span_sequence or span_sequence[-1].peer_id != peer_id:
|
|
|
+ span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id]))
|
|
|
+ else:
|
|
|
+ span_sequence[-1].end = block_idx
|
|
|
+
|
|
|
+ # Remove empty spans that can appear if we don't force to go to the end of each server and network delay
|
|
|
+ # don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors
|
|
|
+ span_sequence = [span for span in span_sequence if span.length > 0]
|
|
|
+
|
|
|
+ return span_sequence
|
|
|
+
|
|
|
+ def _build_inference_graph(
|
|
|
+ self,
|
|
|
+ start_index: int,
|
|
|
+ end_index: int,
|
|
|
+ *,
|
|
|
+ cache_tokens_needed: Optional[int],
|
|
|
+ overhead_coeff: float = 1.82, # Backend overhead (empirically measured)
|
|
|
+ overhead_delay: float = 0.018, # Serialization overhead (empirically measured)
|
|
|
+ default_inference_rps: float = 300, # If inference RPS unknown
|
|
|
+ alloc_delay: float = 10, # If not enough cache left, we penalize the edge
|
|
|
+ ) -> dijkstar.Graph:
|
|
|
+ missing_blocks = [
|
|
|
+ block_idx
|
|
|
+ for block_idx in range(start_index, end_index)
|
|
|
+ if not self.state.sequence_info.spans_containing_block[block_idx]
|
|
|
+ ]
|
|
|
+ if missing_blocks:
|
|
|
+ raise MissingBlocksError(missing_blocks)
|
|
|
+
|
|
|
+ client_server_rtts = self.ping_aggregator.to_dict()
|
|
|
+
|
|
|
+ graph = dijkstar.Graph()
|
|
|
+
|
|
|
+ # Clent -> server network delays
|
|
|
+ for span in self.state.sequence_info.spans_containing_block[start_index]:
|
|
|
+ delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
|
|
|
+ delay += overhead_delay
|
|
|
+ if not self._has_cache_for(span, cache_tokens_needed):
|
|
|
+ delay += alloc_delay
|
|
|
+ graph.add_edge("start", (span.peer_id, start_index), delay)
|
|
|
+
|
|
|
+ # Server -> client network delays
|
|
|
+ for span in self.state.sequence_info.spans_containing_block[end_index - 1]:
|
|
|
+ delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
|
|
|
+ graph.add_edge((span.peer_id, end_index), "end", delay)
|
|
|
+
|
|
|
+ # Server -> server network delays
|
|
|
+ for block_idx in range(start_index + 1, end_index):
|
|
|
+ for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]:
|
|
|
+ if cur_span.end != block_idx:
|
|
|
+ # If we choose a server, we force to go to the end of it before switching to a new one
|
|
|
+ # to avoid O(N^2) graphs for N servers
|
|
|
+ continue
|
|
|
+
|
|
|
+ for next_span in self.state.sequence_info.spans_containing_block[block_idx]:
|
|
|
+ rtt = None
|
|
|
+ if cur_span.server_info.next_pings is not None:
|
|
|
+ rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58())
|
|
|
+ delay = self._rtt_to_delay(rtt)
|
|
|
+ delay += overhead_delay
|
|
|
+ if not self._has_cache_for(next_span, cache_tokens_needed):
|
|
|
+ delay += alloc_delay
|
|
|
+ graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay)
|
|
|
+
|
|
|
+ # Compute delays
|
|
|
+ for span in self.state.sequence_info.spans_by_priority:
|
|
|
+ for block_idx in range(max(span.start, start_index), min(span.end, end_index)):
|
|
|
+ inference_rps = span.server_info.inference_rps
|
|
|
+ if inference_rps is None:
|
|
|
+ inference_rps = default_inference_rps
|
|
|
+ graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), overhead_coeff / inference_rps)
|
|
|
+
|
|
|
+ return graph
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _rtt_to_delay(
|
|
|
+ rtt: float,
|
|
|
+ *,
|
|
|
+ default_delay: float = 0.15, # If network delay unknown
|
|
|
+ max_delay: float = 5, # If unreachable, we don't want to discard the edge completely
|
|
|
+ ) -> float:
|
|
|
+ if rtt is None:
|
|
|
+ return default_delay
|
|
|
+ return min(rtt / 2, max_delay)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool:
|
|
|
+ if cache_tokens_needed is None or span.server_info.cache_tokens_left is None:
|
|
|
+ return True
|
|
|
+
|
|
|
+ # Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through
|
|
|
+ # this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage,
|
|
|
+ # so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate.
|
|
|
+ # This is okay since false positives are more costly than false negatives here.
|
|
|
+ return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
|
|
|
+
|
|
|
+ def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
|
|
|
span_sequence = []
|
|
|
current_index = start_index
|
|
|
while current_index < end_index:
|
|
@@ -150,20 +300,12 @@ class RemoteSequenceManager:
|
|
|
if not candidate_spans:
|
|
|
raise MissingBlocksError(current_index)
|
|
|
|
|
|
- if mode == "max_throughput":
|
|
|
- span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
|
|
|
- elif mode == "min_latency":
|
|
|
- span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
|
|
|
- else:
|
|
|
- raise RuntimeError(f"Unexpected mode {mode}")
|
|
|
+ span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
|
|
|
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
|
|
|
|
|
|
assert chosen_span.start <= current_index < chosen_span.end
|
|
|
span_sequence.append(dataclasses.replace(chosen_span, start=current_index))
|
|
|
current_index = chosen_span.end
|
|
|
-
|
|
|
- route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence])
|
|
|
- logger.debug(f"Route found: {route_repr}")
|
|
|
return span_sequence
|
|
|
|
|
|
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
|
|
@@ -182,10 +324,10 @@ class RemoteSequenceManager:
|
|
|
|
|
|
def _update(self):
|
|
|
"""Perform an immediate and synchronous refresh, may take time"""
|
|
|
+
|
|
|
new_block_infos = petals.dht_utils.get_remote_module_infos(
|
|
|
- self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos
|
|
|
+ self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
|
|
|
)
|
|
|
- self._need_latest_infos = True # All future _update() should use latest infos
|
|
|
|
|
|
for block_info in new_block_infos:
|
|
|
if not block_info:
|
|
@@ -217,6 +359,14 @@ class RemoteSequenceManager:
|
|
|
|
|
|
with self.lock_changes:
|
|
|
self.state.sequence_info.update_(new_block_infos)
|
|
|
+
|
|
|
+ first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
|
|
|
+ last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
|
|
|
+
|
|
|
+ pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
|
|
|
+ pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
|
|
|
+ self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
|
|
|
+
|
|
|
self.ready.set()
|
|
|
|
|
|
def on_request_failure(self, peer_id: Optional[PeerID]):
|