Browse Source

Implement shortest-path routing for inference (#362)

This PR:

1. **Adds shortest path routing for inference.** We build a graph with client-server and server-server latencies and compute costs, as well as empirically measured overheads. For client-server latencies, we ping possible first and last servers in a sequence in `SequenceManager.update()`. We penalize servers who may not have enough cache for our request. This uses info added to DHT in #355, #356, #358.

2. **Makes a server ping neighboring servers in addition to next ones.** This is to get an opportunity to change the server even before we use all its blocks (e.g., because a neighboring server is faster). This feature is not enabled though, since it increases graph size for N servers to O(N^2) - but we may enable it if needed.

3. **Fixes a `SequenceManager` bug with the first `update()`.** Previously, this update was likely to produce incorrect information and cause to `MissingBlocksErrors` until the next update happens.
Alexander Borzunov 2 years ago
parent
commit
62d9ed5ce7

+ 1 - 0
setup.cfg

@@ -48,6 +48,7 @@ install_requires =
     sentencepiece>=0.1.99
     sentencepiece>=0.1.99
     peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
     peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
     safetensors>=0.3.1
     safetensors>=0.3.1
+    Dijkstar>=2.6.0
 
 
 [options.extras_require]
 [options.extras_require]
 dev =
 dev =

+ 1 - 1
src/petals/__init__.py

@@ -11,7 +11,7 @@ from petals.models import *
 from petals.utils import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 from petals.utils.logging import initialize_logs as _initialize_logs
 
 
-__version__ = "1.2.0.dev2"
+__version__ = "1.2.0.dev3"
 
 
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

+ 1 - 1
src/petals/cli/run_server.py

@@ -84,7 +84,7 @@ def main():
     parser.add_argument('--attn_cache_tokens', type=int, default=8192,
     parser.add_argument('--attn_cache_tokens', type=int, default=8192,
                         help='The number of past attention key/value pairs that will be stored between inference steps. '
                         help='The number of past attention key/value pairs that will be stored between inference steps. '
                              'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).')
                              'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).')
-    parser.add_argument('--alloc_timeout', type=float, default=60,
+    parser.add_argument('--alloc_timeout', type=float, default=5,
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                              'before rejecting the request')
                              'before rejecting the request')
     parser.add_argument('--revision', type=str, default=None,
     parser.add_argument('--revision', type=str, default=None,

+ 3 - 1
src/petals/client/inference_session.py

@@ -340,7 +340,9 @@ class InferenceSession:
                 f"from block {block_idx} to {update_end} will be regenerated"
                 f"from block {block_idx} to {update_end} will be regenerated"
             )
             )
 
 
-        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
+        updated_spans = self._sequence_manager.make_sequence(
+            block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length
+        )
         # make_sequence() could return a longer sequence
         # make_sequence() could return a longer sequence
         updated_spans[-1].end = min(updated_spans[-1].end, update_end)
         updated_spans[-1].end = min(updated_spans[-1].end, update_end)
         updated_sessions = self._enter_server_sessions(updated_spans)
         updated_sessions = self._enter_server_sessions(updated_spans)

+ 173 - 23
src/petals/client/routing/sequence_manager.py

@@ -10,6 +10,7 @@ import time
 from typing import Any, Collection, Dict, List, Optional, Sequence, Union
 from typing import Any, Collection, Dict, List, Optional, Sequence, Union
 from weakref import WeakMethod
 from weakref import WeakMethod
 
 
+import dijkstar
 import numpy as np
 import numpy as np
 from hivemind import DHT, P2P, MSGPackSerializer, PeerID
 from hivemind import DHT, P2P, MSGPackSerializer, PeerID
 from hivemind.dht.node import Blacklist
 from hivemind.dht.node import Blacklist
@@ -23,6 +24,8 @@ from petals.client.routing.spending_policy import NoSpendingPolicy
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
 from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.handler import TransformerConnectionHandler
+from petals.utils.ping import PingAggregator
+from petals.utils.random import sample_up_to
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -33,6 +36,7 @@ class SequenceManagerConfig:
     dht_prefix: Optional[str] = None  # a prefix for all dht keys that correspond to this model (default: model name)
     dht_prefix: Optional[str] = None  # a prefix for all dht keys that correspond to this model (default: model name)
     daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
     daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
 
 
+    show_route: Union[str, bool] = "inference"  # show chosen route through servers. one of [False, "inference", True]
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
     use_server_to_server: bool = True  # Use direct server-to-server communication
     use_server_to_server: bool = True  # Use direct server-to-server communication
 
 
@@ -43,7 +47,10 @@ class SequenceManagerConfig:
     min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
     min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
     max_backoff: float = 60  # limit maximal sleep time between retries to this value
     max_backoff: float = 60  # limit maximal sleep time between retries to this value
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
-    active_adapter: Optional[str] = None
+    active_adapter: Optional[str] = None  # name of active LoRA adapter (usually, Hugging Face repo)
+
+    max_pinged: int = 5  # max servers to ping from each sequence side, per update
+    ping_timeout: float = 2  # max time to wait for pings, per update
 
 
 
 
 @dataclasses.dataclass
 @dataclasses.dataclass
@@ -79,7 +86,6 @@ class RemoteSequenceManager:
         *,
         *,
         dht: Optional[DHT] = None,
         dht: Optional[DHT] = None,
         state: Optional[SequenceManagerState] = None,
         state: Optional[SequenceManagerState] = None,
-        active_adapter: Optional[str] = None,
     ):
     ):
         assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
         assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
@@ -94,7 +100,7 @@ class RemoteSequenceManager:
             dht = DHT(
             dht = DHT(
                 initial_peers=config.initial_peers,
                 initial_peers=config.initial_peers,
                 client_mode=True,
                 client_mode=True,
-                num_workers=config.num_hidden_layers,
+                num_workers=32,
                 startup_timeout=config.daemon_startup_timeout,
                 startup_timeout=config.daemon_startup_timeout,
                 start=True,
                 start=True,
             )
             )
@@ -109,25 +115,25 @@ class RemoteSequenceManager:
         self._thread_start_lock = threading.Lock()
         self._thread_start_lock = threading.Lock()
         self.policy = NoSpendingPolicy()
         self.policy = NoSpendingPolicy()
 
 
+        self.ping_aggregator = PingAggregator(dht)
+
         if state.banned_peers is None:
         if state.banned_peers is None:
             state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
             state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
         if state.sequence_info is None:
         if state.sequence_info is None:
             state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
             state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
 
 
-        if state.sequence_info.last_updated_time is None:
-            # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
-            # in the first _update() instead of the latest ones. This makes the first .update() faster.
-            petals.dht_utils.get_remote_module_infos(
-                self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True
-            )
-            self._need_latest_infos = False
-        else:
+        if state.sequence_info.last_updated_time is not None:
             assert block_uids == state.sequence_info.block_uids
             assert block_uids == state.sequence_info.block_uids
             self._thread.ready.set()  # no need to await the first dht fetch
             self._thread.ready.set()  # no need to await the first dht fetch
             self._need_latest_infos = True
             self._need_latest_infos = True
 
 
     def make_sequence(
     def make_sequence(
-        self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str
+        self,
+        start_index: int = 0,
+        end_index: Optional[int] = None,
+        *,
+        mode: str,
+        cache_tokens_needed: Optional[int] = None,
     ) -> List[RemoteSpanInfo]:
     ) -> List[RemoteSpanInfo]:
         """
         """
         Form a sequence of remote servers that collectively serve all consecutive layers
         Form a sequence of remote servers that collectively serve all consecutive layers
@@ -143,6 +149,150 @@ class RemoteSequenceManager:
             self.update(wait=True)  # this will await an existing update or trigger a new one (if not updating)
             self.update(wait=True)  # this will await an existing update or trigger a new one (if not updating)
 
 
         end_index = end_index if end_index is not None else len(self)
         end_index = end_index if end_index is not None else len(self)
+
+        if mode == "min_latency":
+            span_sequence = self._make_sequence_with_min_latency(
+                start_index, end_index, cache_tokens_needed=cache_tokens_needed
+            )
+        elif mode == "max_throughput":
+            span_sequence = self._make_sequence_with_max_throughput(start_index, end_index)
+        else:
+            raise RuntimeError(f"Unexpected mode {mode}")
+
+        if self.config.show_route is True or (mode == "min_latency" and self.config.show_route == "inference"):
+            route_repr = " => ".join(
+                [f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence]
+            )
+            logger.info(f"Route found: {route_repr}")
+        return span_sequence
+
+    def _make_sequence_with_min_latency(
+        self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int]
+    ) -> List[RemoteSpanInfo]:
+        if start_index == end_index:
+            return []
+
+        with self.lock_changes:
+            missing_blocks = [
+                block_idx
+                for block_idx in range(start_index, end_index)
+                if not self.state.sequence_info.spans_containing_block[block_idx]
+            ]
+            if missing_blocks:
+                raise MissingBlocksError(missing_blocks)
+            server_infos = {
+                span.peer_id: span.server_info
+                for block_idx in range(start_index, end_index)
+                for span in self.state.sequence_info.spans_containing_block[block_idx]
+            }
+
+            graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed)
+
+        path = dijkstar.find_path(graph, "start", "end")
+        logger.debug(f"Path info: {path}")
+        if start_index == 0 and end_index == len(self):
+            logger.debug(f"Expected speed: {1 / path.total_cost:.1f} steps/sec")
+
+        span_sequence = []
+        for peer_id, block_idx in path.nodes[1:-1]:
+            if not span_sequence or span_sequence[-1].peer_id != peer_id:
+                span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id]))
+            else:
+                span_sequence[-1].end = block_idx
+
+        # Remove empty spans that can appear if we don't force to go to the end of each server and network delay
+        # don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors
+        span_sequence = [span for span in span_sequence if span.length > 0]
+
+        return span_sequence
+
+    def _build_inference_graph(
+        self,
+        start_index: int,
+        end_index: int,
+        *,
+        cache_tokens_needed: Optional[int],
+        overhead_coeff: float = 1.82,  # Backend overhead (empirically measured)
+        overhead_delay: float = 0.018,  # Serialization overhead (empirically measured)
+        default_inference_rps: float = 300,  # If inference RPS unknown
+        alloc_delay: float = 10,  # If not enough cache left, we penalize the edge
+    ) -> dijkstar.Graph:
+        missing_blocks = [
+            block_idx
+            for block_idx in range(start_index, end_index)
+            if not self.state.sequence_info.spans_containing_block[block_idx]
+        ]
+        if missing_blocks:
+            raise MissingBlocksError(missing_blocks)
+
+        client_server_rtts = self.ping_aggregator.to_dict()
+
+        graph = dijkstar.Graph()
+
+        # Clent -> server network delays
+        for span in self.state.sequence_info.spans_containing_block[start_index]:
+            delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
+            delay += overhead_delay
+            if not self._has_cache_for(span, cache_tokens_needed):
+                delay += alloc_delay
+            graph.add_edge("start", (span.peer_id, start_index), delay)
+
+        # Server -> client network delays
+        for span in self.state.sequence_info.spans_containing_block[end_index - 1]:
+            delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
+            graph.add_edge((span.peer_id, end_index), "end", delay)
+
+        # Server -> server network delays
+        for block_idx in range(start_index + 1, end_index):
+            for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]:
+                if cur_span.end != block_idx:
+                    # If we choose a server, we force to go to the end of it before switching to a new one
+                    # to avoid O(N^2) graphs for N servers
+                    continue
+
+                for next_span in self.state.sequence_info.spans_containing_block[block_idx]:
+                    rtt = None
+                    if cur_span.server_info.next_pings is not None:
+                        rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58())
+                    delay = self._rtt_to_delay(rtt)
+                    delay += overhead_delay
+                    if not self._has_cache_for(next_span, cache_tokens_needed):
+                        delay += alloc_delay
+                    graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay)
+
+        # Compute delays
+        for span in self.state.sequence_info.spans_by_priority:
+            for block_idx in range(max(span.start, start_index), min(span.end, end_index)):
+                inference_rps = span.server_info.inference_rps
+                if inference_rps is None:
+                    inference_rps = default_inference_rps
+                graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), overhead_coeff / inference_rps)
+
+        return graph
+
+    @staticmethod
+    def _rtt_to_delay(
+        rtt: float,
+        *,
+        default_delay: float = 0.15,  # If network delay unknown
+        max_delay: float = 5,  # If unreachable, we don't want to discard the edge completely
+    ) -> float:
+        if rtt is None:
+            return default_delay
+        return min(rtt / 2, max_delay)
+
+    @staticmethod
+    def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool:
+        if cache_tokens_needed is None or span.server_info.cache_tokens_left is None:
+            return True
+
+        # Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through
+        # this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage,
+        # so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate.
+        # This is okay since false positives are more costly than false negatives here.
+        return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
+
+    def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
         span_sequence = []
         span_sequence = []
         current_index = start_index
         current_index = start_index
         while current_index < end_index:
         while current_index < end_index:
@@ -150,20 +300,12 @@ class RemoteSequenceManager:
             if not candidate_spans:
             if not candidate_spans:
                 raise MissingBlocksError(current_index)
                 raise MissingBlocksError(current_index)
 
 
-            if mode == "max_throughput":
-                span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
-            elif mode == "min_latency":
-                span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
-            else:
-                raise RuntimeError(f"Unexpected mode {mode}")
+            span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
             chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
             chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
 
 
             assert chosen_span.start <= current_index < chosen_span.end
             assert chosen_span.start <= current_index < chosen_span.end
             span_sequence.append(dataclasses.replace(chosen_span, start=current_index))
             span_sequence.append(dataclasses.replace(chosen_span, start=current_index))
             current_index = chosen_span.end
             current_index = chosen_span.end
-
-        route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence])
-        logger.debug(f"Route found: {route_repr}")
         return span_sequence
         return span_sequence
 
 
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
@@ -182,10 +324,10 @@ class RemoteSequenceManager:
 
 
     def _update(self):
     def _update(self):
         """Perform an immediate and synchronous refresh, may take time"""
         """Perform an immediate and synchronous refresh, may take time"""
+
         new_block_infos = petals.dht_utils.get_remote_module_infos(
         new_block_infos = petals.dht_utils.get_remote_module_infos(
-            self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos
+            self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
         )
         )
-        self._need_latest_infos = True  # All future _update() should use latest infos
 
 
         for block_info in new_block_infos:
         for block_info in new_block_infos:
             if not block_info:
             if not block_info:
@@ -217,6 +359,14 @@ class RemoteSequenceManager:
 
 
         with self.lock_changes:
         with self.lock_changes:
             self.state.sequence_info.update_(new_block_infos)
             self.state.sequence_info.update_(new_block_infos)
+
+            first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
+            last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
+
+        pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
+        pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
+        self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
+
         self.ready.set()
         self.ready.set()
 
 
     def on_request_failure(self, peer_id: Optional[PeerID]):
     def on_request_failure(self, peer_id: Optional[PeerID]):

+ 16 - 15
src/petals/server/server.py

@@ -32,6 +32,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.ping import PingAggregator
 from petals.utils.ping import PingAggregator
+from petals.utils.random import sample_up_to
 from petals.utils.version import get_compatible_model_repo
 from petals.utils.version import get_compatible_model_repo
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -61,7 +62,7 @@ class Server:
         cache_dir: Optional[str] = None,
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
         max_disk_space: Optional[int] = None,
         attn_cache_tokens: int = 8192,
         attn_cache_tokens: int = 8192,
-        alloc_timeout: float = 60,
+        alloc_timeout: float = 5,
         device: Optional[Union[str, torch.device]] = None,
         device: Optional[Union[str, torch.device]] = None,
         compression=CompressionType.NONE,
         compression=CompressionType.NONE,
         stats_report_interval: Optional[int] = None,
         stats_report_interval: Optional[int] = None,
@@ -637,7 +638,6 @@ class ModuleAnnouncerThread(threading.Thread):
         update_period: float,
         update_period: float,
         expiration: float,
         expiration: float,
         max_pinged: int = 5,
         max_pinged: int = 5,
-        max_reported: int = 10,
         **kwargs,
         **kwargs,
     ):
     ):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
@@ -650,10 +650,11 @@ class ModuleAnnouncerThread(threading.Thread):
         self.expiration = expiration
         self.expiration = expiration
         self.trigger = threading.Event()
         self.trigger = threading.Event()
 
 
-        self.max_pinged, self.max_reported = max_pinged, max_reported
-        last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1]))
-        dht_prefix, block_index = last_uid.split(UID_DELIMITER)
-        self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}"
+        self.max_pinged = max_pinged
+        dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
+        block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
+        start_block, end_block = min(block_indices), max(block_indices) + 1
+        self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
         self.ping_aggregator = PingAggregator(self.dht)
         self.ping_aggregator = PingAggregator(self.dht)
 
 
     def run(self) -> None:
     def run(self) -> None:
@@ -664,7 +665,7 @@ class ModuleAnnouncerThread(threading.Thread):
             if self.server_info.state != ServerState.OFFLINE:
             if self.server_info.state != ServerState.OFFLINE:
                 self._ping_next_servers()
                 self._ping_next_servers()
                 self.server_info.next_pings = {
                 self.server_info.next_pings = {
-                    peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.fastest(self.max_reported).items()
+                    peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items()
                 }
                 }
             else:
             else:
                 self.server_info.next_pings = None  # No need to ping if we're disconnecting
                 self.server_info.next_pings = None  # No need to ping if we're disconnecting
@@ -691,14 +692,14 @@ class ModuleAnnouncerThread(threading.Thread):
             self.join()
             self.join()
 
 
     def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
     def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
-        [module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True)
-        if module_info is None:
-            return
-
-        next_servers = list(module_info.servers)
-        if len(next_servers) > self.max_pinged:
-            next_servers = random.sample(next_servers, self.max_pinged)
-        self.ping_aggregator.ping(next_servers)
+        module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
+        middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
+        pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
+        pinged_servers.discard(self.dht.peer_id)
+        if module_infos[-1] is not None:
+            # Sample servers hosting the block after the last one (most likely continuations) separately
+            pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
+        self.ping_aggregator.ping(list(pinged_servers))
 
 
 
 
 class RuntimeWithDeduplicatedPools(Runtime):
 class RuntimeWithDeduplicatedPools(Runtime):

+ 15 - 14
src/petals/utils/ping.py

@@ -1,5 +1,6 @@
 import asyncio
 import asyncio
 import math
 import math
+import threading
 import time
 import time
 from functools import partial
 from functools import partial
 from typing import Dict, Sequence
 from typing import Dict, Sequence
@@ -34,27 +35,27 @@ async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) ->
 
 
 
 
 class PingAggregator:
 class PingAggregator:
-    def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 3600):
+    def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300):
         self.dht = dht
         self.dht = dht
         self.ema_alpha = ema_alpha
         self.ema_alpha = ema_alpha
         self.expiration = expiration
         self.expiration = expiration
         self.ping_emas = hivemind.TimedStorage()
         self.ping_emas = hivemind.TimedStorage()
+        self.lock = threading.Lock()
 
 
-    def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs):
+    def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None:
         current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
         current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))
         logger.debug(f"Current RTTs: {current_rtts}")
         logger.debug(f"Current RTTs: {current_rtts}")
 
 
-        expiration = hivemind.get_dht_time() + self.expiration
-        for peer_id, rtt in current_rtts.items():
-            prev_rtt = self.ping_emas.get(peer_id)
-            if prev_rtt is not None and prev_rtt.value != math.inf:
-                rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value  # Exponential smoothing
-            self.ping_emas.store(peer_id, rtt, expiration)
+        with self.lock:
+            expiration = hivemind.get_dht_time() + self.expiration
+            for peer_id, rtt in current_rtts.items():
+                prev_rtt = self.ping_emas.get(peer_id)
+                if prev_rtt is not None and prev_rtt.value != math.inf:
+                    rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value  # Exponential smoothing
+                self.ping_emas.store(peer_id, rtt, expiration)
 
 
-    def fastest(self, n_peers: int) -> Dict[hivemind.PeerID, float]:
-        with self.ping_emas.freeze():
+    def to_dict(self) -> Dict[hivemind.PeerID, float]:
+        with self.lock, self.ping_emas.freeze():
             smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
             smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}
-        logger.debug(f"Smothed RTTs: {smoothed_rtts}")
-
-        fastest_rtts = sorted(smoothed_rtts.items(), key=lambda item: item[1])[:n_peers]
-        return dict(fastest_rtts)
+            logger.debug(f"Smothed RTTs: {smoothed_rtts}")
+            return smoothed_rtts

+ 12 - 0
src/petals/utils/random.py

@@ -0,0 +1,12 @@
+import random
+from typing import Collection, TypeVar
+
+T = TypeVar("T")
+
+
+def sample_up_to(population: Collection[T], k: int) -> T:
+    if not isinstance(population, list):
+        population = list(population)
+    if len(population) > k:
+        population = random.sample(population, k)
+    return population