فهرست منبع

Use length-weighted sampling in routing for inference (#204)

This pull-request implements a simple (1) greedy (2) latency-agnostic routing optimization that should speed up both our use cases.

Why this exists: our effort to merge full routing (ping-aware, throughut-aware, dijkstra) is in a sorry state between several branches; merging it into main would take many days.

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
justheuristic 2 سال پیش
والد
کامیت
012f840f7e

+ 1 - 1
src/petals/client/inference_session.py

@@ -255,7 +255,7 @@ class InferenceSession:
                             )
                         recovery_until = max(recovery_until, update_end)
 
-                        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
+                        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="fastest")
                         # make_sequence() could return a longer sequence
                         updated_spans[-1].end = min(updated_spans[-1].end, update_end)
                         updated_sessions = self._enter_server_sessions(updated_spans)

+ 13 - 2
src/petals/client/routing/sequence_manager.py

@@ -9,6 +9,7 @@ import time
 from typing import Any, Dict, List, Optional, Sequence, Union
 from weakref import WeakMethod
 
+import numpy as np
 from hivemind import DHT, P2P, MSGPackSerializer, PeerID
 from hivemind.dht.node import Blacklist
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
@@ -92,12 +93,15 @@ class RemoteSequenceManager:
         if await_ready:
             self._thread.ready.wait(timeout)
 
-    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
+    def make_sequence(
+        self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random"
+    ) -> List[RemoteSpanInfo]:
         """
         Form a sequence of remote servers that collectively serve all consecutive layers
 
         :param start_index: optional index of the first module in a sequence, default = the first of block_uids
         :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
+        :param mode: either random or fastest
         """
         if not self.is_alive():
             logger.error("Using a sequence manager that is not running: it has either crashed or never started")
@@ -110,7 +114,14 @@ class RemoteSequenceManager:
         current_index = start_index
         while current_index < end_index:
             candidate_spans = self.sequence_info.spans_containing_block[current_index]
-            chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
+            if mode == "random":
+                chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
+            elif mode == "fastest":
+                # note: this too is a heuristic that will be replaced once we integrate fastest wall time routing
+                span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
+                chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
+            else:
+                raise RuntimeError(f"Unexpected mode {mode}")
 
             assert chosen_span.start <= current_index < chosen_span.end
             span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))

+ 1 - 1
src/petals/client/sequential_autograd.py

@@ -60,7 +60,7 @@ async def sequential_forward(
             span = None
             try:
                 if not sequences or attempt_no >= 1:
-                    sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
+                    sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="random"))
                     # make_sequence() could return a longer sequence
                     sequences[-1].end = min(sequences[-1].end, end_index)
                     logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")

+ 3 - 2
tests/test_sequence_manager.py

@@ -14,7 +14,8 @@ logger = get_logger(__file__)
 
 
 @pytest.mark.forked
-def test_sequence_manager_shutdown():
+@pytest.mark.parametrize("mode", ["fastest", "random"])
+def test_sequence_manager_basics(mode: str):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
     sequential = RemoteSequential(config, dht)
@@ -28,7 +29,7 @@ def test_sequence_manager_shutdown():
         sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True),
     )
 
-    sequence = sequential.sequence_manager.make_sequence()
+    sequence = sequential.sequence_manager.make_sequence(mode=mode)
     assert all(sequence[i].peer_id != sequence[i + 1].peer_id for i in range(len(sequence) - 1))
 
     assert sequential.sequence_manager.is_alive()