瀏覽代碼

Store (start_block, end_block) in each DHT record for reliability (#510)

This PR fixes gaps in the DHT server info caused by unavailable DHT keys. Now, one DHT key is enough to get info about all blocks hosted by a server - so we'll see info until all keys are unavailable.

Also, this PR refactors `petals.client.routing` and `petals.server.block_selection` modules to use the common `compute_spans()` function (defined in `petals.utils.dht`) and `RemoteSpanInfo` class (defined in `petals.data_structures`).
Alexander Borzunov 1 年之前
父節點
當前提交
5ce4f1a159

+ 13 - 52
src/petals/client/routing/sequence_info.py

@@ -1,17 +1,15 @@
 import dataclasses
 import time
-from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
+from typing import Iterable, List, Optional, Tuple
 
 from hivemind import get_logger
 
 from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
+from petals.utils.dht import compute_spans
 
 logger = get_logger(__name__)
 
 
-T = TypeVar("T")
-
-
 @dataclasses.dataclass
 class RemoteSequenceInfo:
     """
@@ -30,7 +28,7 @@ class RemoteSequenceInfo:
     last_updated_time: Optional[float]
 
     @classmethod
-    def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
+    def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo":
         block_uids = tuple(block_uids)
         empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
         empty_spans = tuple([] for _ in range(len(block_uids)))
@@ -39,7 +37,7 @@ class RemoteSequenceInfo:
     def __getitem__(self, ix: slice):
         assert isinstance(ix, slice)
         block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
-        spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
+        spans_by_priority, spans_containing_block = self._sort_spans(block_infos)
         return RemoteSequenceInfo(
             block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
         )
@@ -47,60 +45,23 @@ class RemoteSequenceInfo:
     def __len__(self):
         return len(self.block_uids)
 
-    def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
+    def update_(self, new_block_infos: List[RemoteModuleInfo]):
         assert len(new_block_infos) == len(self.block_uids)
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
-            if info is None:
-                logger.debug(f"Found no block info for block {uid}")
-                continue
-            if not isinstance(info, RemoteModuleInfo):
-                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-                continue
-            if not info.servers:
-                logger.debug(f"Found no active peers for block {uid}")
-                continue
-            if info.uid != uid:
-                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-                continue
+            assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
             self.block_infos[block_index].servers = info.servers
 
-        self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+        self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)
         self.last_updated_time = time.perf_counter()
 
     @staticmethod
-    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
-        closed_spans = []
-        active_spans = {}
-        for block_index, info in enumerate(block_infos):
-            if info is not None:
-                for peer_id, server_info in info.servers.items():
-                    if server_info.state != ServerState.ONLINE:
-                        continue
-                    if peer_id not in active_spans:
-                        active_spans[peer_id] = RemoteSpanInfo(
-                            peer_id=peer_id,
-                            start=block_index,
-                            end=block_index + 1,
-                            server_info=server_info,
-                        )
-                    else:  # peer_id in active_spans
-                        active_spans[peer_id].end = block_index + 1
-
-            for peer_id in list(active_spans.keys()):
-                if (
-                    info is None
-                    or peer_id not in info.servers
-                    or info.servers[peer_id].state != ServerState.ONLINE
-                    or block_index == len(block_infos) - 1
-                ):
-                    closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans, f"spans: {active_spans}"
-
-        closed_spans.sort(key=lambda span: span.length, reverse=True)
+    def _sort_spans(block_infos: List[RemoteModuleInfo]):
+        spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())
+        spans_by_priority.sort(key=lambda span: span.length, reverse=True)
 
-        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
-        for span in closed_spans:
+        spans_containing_block = tuple([] for _ in range(len(block_infos)))
+        for span in spans_by_priority:
             for block_index in range(span.start, span.end):
                 spans_containing_block[block_index].append(span)
 
-        return closed_spans, spans_containing_block
+        return spans_by_priority, spans_containing_block

+ 0 - 4
src/petals/client/routing/sequence_manager.py

@@ -117,7 +117,6 @@ class RemoteSequenceManager:
         if state.sequence_info.last_updated_time is not None:
             assert block_uids == state.sequence_info.block_uids
             self._thread.ready.set()  # no need to await the first dht fetch
-            self._need_latest_infos = True
 
     @staticmethod
     def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
@@ -346,9 +345,6 @@ class RemoteSequenceManager:
         )
 
         for block_info in new_block_infos:
-            if not block_info:
-                continue
-
             # Apply allow and block lists
             block_info.servers = {
                 peer_id: server_info

+ 26 - 9
src/petals/data_structures.py

@@ -11,18 +11,15 @@ UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
 
 
-class ServerState(Enum):
-    OFFLINE = 0
-    JOINING = 1
-    ONLINE = 2
-
-
-RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
+    assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs"
+    dht_prefix, index = uid.split(UID_DELIMITER)
+    return dht_prefix, int(index)
 
 
 @pydantic.dataclasses.dataclass
 class ModelInfo:
-    num_blocks: int
+    num_blocks: pydantic.conint(ge=1, strict=True)
     repository: Optional[str] = None
 
     def to_dict(self) -> dict:
@@ -33,11 +30,23 @@ class ModelInfo:
         return cls(**source)
 
 
+class ServerState(Enum):
+    OFFLINE = 0
+    JOINING = 1
+    ONLINE = 2
+
+
+RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+
+
 @pydantic.dataclasses.dataclass
 class ServerInfo:
     state: ServerState
     throughput: RPS
 
+    start_block: Optional[pydantic.conint(ge=0, strict=True)] = None
+    end_block: Optional[pydantic.conint(ge=0, strict=True)] = None
+
     public_name: Optional[str] = None
     version: Optional[str] = None
 
@@ -83,9 +92,17 @@ class RemoteSpanInfo:
     server_info: ServerInfo
 
     @property
-    def length(self):
+    def length(self) -> int:
         return self.end - self.start
 
+    @property
+    def state(self) -> ServerState:
+        return self.server_info.state
+
+    @property
+    def throughput(self) -> float:
+        return self.server_info.throughput
+
 
 RPCInfo = Dict[str, Any]
 

+ 24 - 48
src/petals/server/block_selection.py

@@ -1,54 +1,23 @@
-from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List
 
 import numpy as np
 from hivemind import PeerID, get_logger
 
-from petals.data_structures import RemoteModuleInfo, ServerState
-
-__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
+from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState
+from petals.utils.dht import compute_spans
 
 logger = get_logger(__name__)
 
 
-@dataclass
-class Span:
-    start: int
-    end: int
-    throughput: float
-    state: ServerState
-
-    @property
-    def length(self):
-        return self.end - self.start
-
-    def move_to(self, new_start: int) -> None:
-        self.start, self.end = new_start, new_start + self.length
-
-
-def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
-    spans = {}
-    throughputs = np.zeros(len(module_infos))
-    for block, module in enumerate(module_infos):
-        if module is None:
-            continue
-
-        # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
-        # If the order were not defined, we would get slightly different values due to floating point errors,
-        # which may cause excess block replacements.
-        for peer_id, server in sorted(module.servers.items()):
-            if server.state == ServerState.OFFLINE:
-                continue
+def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray:
+    # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
+    # If the order were not defined, we would get slightly different values due to floating point errors,
+    # which may cause excess block replacements.
 
-            if peer_id in spans:
-                spans[peer_id].start = min(spans[peer_id].start, block)
-                spans[peer_id].end = max(spans[peer_id].start, block + 1)
-            else:
-                spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
-
-            throughputs[block] += server.throughput
-
-    return spans, throughputs
+    throughputs = np.zeros(total_blocks)
+    for span in sorted(spans.values(), key=lambda span: span.peer_id):
+        throughputs[span.start : span.end] += span.throughput
+    return throughputs
 
 
 def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
@@ -56,19 +25,26 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
     return min(options)[-1]
 
 
-def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
-    _, throughputs = compute_spans(module_infos)
+def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:
+    spans = compute_spans(module_infos, min_state=ServerState.JOINING)
+    throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
+
     start = _choose_best_start(throughputs, num_blocks)
     return list(range(start, start + num_blocks))
 
 
+def _move_span(span: RemoteSpanInfo, new_start: int):
+    span.start, span.end = new_start, new_start + span.length
+
+
 def should_choose_other_blocks(
-    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
+    local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float
 ) -> bool:
     if balance_quality > 1.0:
         return True  # Forces rebalancing on each check (may be used for debugging purposes)
 
-    spans, throughputs = compute_spans(module_infos)
+    spans = compute_spans(module_infos, min_state=ServerState.JOINING)
+    throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
     initial_throughput = throughputs.min()
     eps = 1e-3
 
@@ -88,7 +64,7 @@ def should_choose_other_blocks(
         return False  # This server is on its best place already
 
     throughputs[local_span.start : local_span.end] += local_span.throughput * eps
-    local_span.move_to(new_start)
+    _move_span(local_span, new_start)
     throughputs[local_span.start : local_span.end] += local_span.throughput
 
     moved = True
@@ -105,7 +81,7 @@ def should_choose_other_blocks(
 
             throughputs[span.start : span.end] += span.throughput * eps
             if span.start != new_start:
-                span.move_to(new_start)
+                _move_span(span, new_start)
                 moved = True
             throughputs[span.start : span.end] += span.throughput
 

+ 15 - 12
src/petals/server/server.py

@@ -23,7 +23,7 @@ from transformers import PretrainedConfig
 
 import petals
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -220,11 +220,10 @@ class Server:
             num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
         if block_indices is not None:
             try:
-                first_block_index, last_block_index = block_indices.split(":")
-                first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
+                start_block, end_block = [int(index.strip()) for index in block_indices.split(":")]
             except Exception as e:
                 raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
-            block_indices = range(first_block_index, last_block_index)
+            block_indices = range(start_block, end_block)
             num_blocks = len(block_indices)
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
 
@@ -703,11 +702,16 @@ class ModuleAnnouncerThread(threading.Thread):
         self.expiration = expiration
         self.trigger = threading.Event()
 
+        self.dht_prefix = parse_uid(module_uids[0])[0]
+        block_indices = [parse_uid(uid)[1] for uid in module_uids]
+        self.server_info.start_block = min(block_indices)
+        self.server_info.end_block = max(block_indices) + 1
+
         self.max_pinged = max_pinged
-        self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
-        block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
-        start_block, end_block = min(block_indices), max(block_indices) + 1
-        self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
+        self.next_uids = [
+            f"{self.dht_prefix}{UID_DELIMITER}{i}"
+            for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1)
+        ]
         self.ping_aggregator = PingAggregator(self.dht)
 
     def run(self) -> None:
@@ -755,12 +759,11 @@ class ModuleAnnouncerThread(threading.Thread):
 
     def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
         module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
-        middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
+        middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers}
         pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
         pinged_servers.discard(self.dht.peer_id)
-        if module_infos[-1] is not None:
-            # Sample servers hosting the block after the last one (most likely continuations) separately
-            pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
+        # Sample servers hosting the block after the last one (most likely continuations) separately
+        pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
         self.ping_aggregator.ping(list(pinged_servers))
 
 

+ 41 - 12
src/petals/utils/dht.py

@@ -11,7 +11,16 @@ from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
 
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
+from petals.data_structures import (
+    CHAIN_DELIMITER,
+    UID_DELIMITER,
+    ModuleUID,
+    RemoteModuleInfo,
+    RemoteSpanInfo,
+    ServerInfo,
+    ServerState,
+    parse_uid,
+)
 
 logger = get_logger(__name__)
 
@@ -70,7 +79,7 @@ def get_remote_module_infos(
     *,
     latest: bool = False,
     return_future: bool = False,
-) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
+) -> Union[List[RemoteModuleInfo], MPFuture]:
     return dht.run_coroutine(
         partial(
             _get_remote_module_infos,
@@ -90,7 +99,7 @@ async def _get_remote_module_infos(
     active_adapter: Optional[str],
     expiration_time: Optional[DHTExpiration],
     latest: bool,
-) -> List[Optional[RemoteModuleInfo]]:
+) -> List[RemoteModuleInfo]:
     if latest:
         assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
         expiration_time = math.inf
@@ -99,14 +108,14 @@ async def _get_remote_module_infos(
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
-    modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
-    for i, uid in enumerate(uids):
-        metadata = found[uid]
+    modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]
+    for module_info in modules:
+        metadata = found[module_info.uid]
         if metadata is None or not isinstance(metadata.value, dict):
             if metadata is not None:
-                logger.warning(f"Incorrect metadata for {uid}: {metadata}")
+                logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}")
             continue
-        servers = {}
+
         for peer_id, server_info in metadata.value.items():
             try:
                 peer_id = PeerID.from_base58(peer_id)
@@ -116,9 +125,29 @@ async def _get_remote_module_infos(
                     logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
                     continue
 
-                servers[peer_id] = server_info
+                module_info.servers[peer_id] = server_info
             except (TypeError, ValueError) as e:
-                logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
-        if servers:
-            modules[i] = RemoteModuleInfo(uid, servers)
+                logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
     return modules
+
+
+def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:
+    block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
+    num_blocks = len(module_infos)
+
+    spans = {}
+    for block_idx, module_info in enumerate(module_infos):
+        for peer_id, server_info in sorted(module_info.servers.items()):
+            if server_info.state.value < min_state.value:
+                continue
+
+            if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:
+                spans[peer_id] = RemoteSpanInfo(
+                    peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info
+                )
+                if server_info.start_block is not None and server_info.end_block is not None:
+                    spans[peer_id].start = max(server_info.start_block - block_offset, 0)
+                    spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
+            elif spans[peer_id].state == server_info.state:
+                spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)
+    return spans