瀏覽代碼

Replace .make_sequence(..., mode="random") with mode="max_throughput" (#313)

We need to sample the next server using its throughput as the weight to actually achieve max throughput for fine-tuning.

As an example, imagine a situation where we have 3 servers with throughputs [1000, 500, 1] hosting the same blocks, then compare the uniform and weighted sampling strategies.
Alexander Borzunov 2 年之前
父節點
當前提交
6137b1b4b0

+ 1 - 1
src/petals/client/inference_session.py

@@ -253,7 +253,7 @@ class InferenceSession:
                             )
                         recovery_until = max(recovery_until, update_end)
 
-                        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="fastest")
+                        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end, mode="min_latency")
                         # make_sequence() could return a longer sequence
                         updated_spans[-1].end = min(updated_spans[-1].end, update_end)
                         updated_sessions = self._enter_server_sessions(updated_spans)

+ 4 - 2
src/petals/client/routing/sequence_info.py

@@ -77,7 +77,9 @@ class RemoteSequenceInfo:
                     if server.state != ServerState.ONLINE:
                         continue
                     if peer_id not in active_spans:
-                        active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
+                        active_spans[peer_id] = RemoteSpanInfo(
+                            peer_id=peer_id, start=block_index, end=block_index + 1, throughput=server.throughput
+                        )
                     else:  # peer_id in active_spans
                         active_spans[peer_id].end = block_index + 1
 
@@ -91,7 +93,7 @@ class RemoteSequenceInfo:
                     closed_spans.append(active_spans.pop(peer_id))
         assert not active_spans, f"spans: {active_spans}"
 
-        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
+        closed_spans.sort(key=lambda span: span.length, reverse=True)
 
         spans_containing_block = tuple(list() for _ in range(len(block_infos)))
         for span in closed_spans:

+ 8 - 8
src/petals/client/routing/sequence_manager.py

@@ -115,14 +115,14 @@ class RemoteSequenceManager:
             self._need_latest_infos = True
 
     def make_sequence(
-        self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random"
+        self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str
     ) -> List[RemoteSpanInfo]:
         """
         Form a sequence of remote servers that collectively serve all consecutive layers
 
         :param start_index: optional index of the first module in a sequence, default = the first of block_uids
         :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
-        :param mode: either random or fastest
+        :param mode: one of ["max_throughput", "min_latency"]
         """
         with self._thread_start_lock:
             if not self.is_alive():
@@ -137,17 +137,17 @@ class RemoteSequenceManager:
             candidate_spans = self.state.sequence_info.spans_containing_block[current_index]
             if not candidate_spans:
                 raise MissingBlocksError(current_index)
-            if mode == "random":
-                chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
-            elif mode == "fastest":
-                # note: this too is a heuristic that will be replaced once we integrate fastest wall time routing
+
+            if mode == "max_throughput":
+                span_weights = np.array([span.throughput for span in candidate_spans], dtype=np.float64)
+            elif mode == "min_latency":
                 span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
-                chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
             else:
                 raise RuntimeError(f"Unexpected mode {mode}")
+            chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
 
             assert chosen_span.start <= current_index < chosen_span.end
-            span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
+            span_sequence.append(dataclasses.replace(chosen_span, start=current_index))
             current_index = chosen_span.end
 
         route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence])

+ 1 - 1
src/petals/client/sequential_autograd.py

@@ -60,7 +60,7 @@ async def sequential_forward(
             span = None
             try:
                 if not sequences or attempt_no >= 1:
-                    sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="random"))
+                    sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="max_throughput"))
                     # make_sequence() could return a longer sequence
                     sequences[-1].end = min(sequences[-1].end, end_index)
                     logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")

+ 6 - 1
src/petals/data_structures.py

@@ -39,9 +39,14 @@ class RemoteModuleInfo:
 class RemoteSpanInfo:
     """A chain of remote blocks served by one specific remote peer"""
 
+    peer_id: PeerID
     start: int
     end: int
-    peer_id: PeerID
+    throughput: float
+
+    @property
+    def length(self):
+        return self.end - self.start
 
 
 RPCInfo = Dict[str, Any]

+ 1 - 1
tests/test_sequence_manager.py

@@ -14,7 +14,7 @@ logger = get_logger(__name__)
 
 
 @pytest.mark.forked
-@pytest.mark.parametrize("mode", ["fastest", "random"])
+@pytest.mark.parametrize("mode", ["max_throughput", "min_latency"])
 def test_sequence_manager_basics(mode: str):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)