浏览代码

Avoid synchronous updates, ban peers based on request outcome (#127)

- sequence_manager now takes care for its own updated-ness - no need to manually update it
- if a peer fails a request, sequence manager will ban this peer temporarily. Ban times increase with failure streaks



Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
justheuristic 2 年之前
父节点
当前提交
68c85e7492

+ 4 - 2
src/petals/client/inference_session.py

@@ -235,9 +235,8 @@ class InferenceSession:
         while block_idx < n_blocks:
             for attempt_no in itertools.count():
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
+                span = None
                 try:
-                    if attempt_no >= 1:
-                        self._sequence_manager.update(wait=True)
                     if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
                         # If there is a failed server session, this code closes it
                         self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
@@ -298,8 +297,11 @@ class InferenceSession:
                     inputs = outputs
                     server_idx += 1
                     block_idx = span.end
+                    self._sequence_manager.on_request_success(span.peer_id)
                     break
                 except Exception as e:
+                    if span is not None:
+                        self._sequence_manager.on_request_failure(span.peer_id)
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     logger.warning(
                         f"Caught exception when running inference from block {block_idx} "

+ 42 - 4
src/petals/client/routing/sequence_manager.py

@@ -8,7 +8,8 @@ import time
 from typing import Any, Dict, List, Optional, Sequence, Union
 from weakref import WeakMethod
 
-from hivemind import DHT, P2P, MSGPackSerializer
+from hivemind import DHT, P2P, MSGPackSerializer, PeerID
+from hivemind.dht.node import Blacklist
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -39,6 +40,7 @@ class RemoteSequenceManager:
     :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
     :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
     :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
+    :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
     :param start: start the background thread (see the note below). If false, you will need to start it manually.
     :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
       running redundant sequence managers for the same set of layers.
@@ -53,18 +55,21 @@ class RemoteSequenceManager:
         update_period: float = 30,
         request_timeout: float = 30,
         min_backoff: float = 1,
+        ban_timeout: float = 60,
         sequence_info: Optional[RemoteSequenceInfo] = None,
         rpc_info: Optional[dict] = None,
+        banned_peers: Optional[Blacklist] = None,
         *,  # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
         start: bool,
     ):
         assert len(block_uids) > 0, "Sequences must contain at least one block"
         self.dht, self.p2p = dht, p2p
-        self.request_timeout, self.min_backoff = request_timeout, min_backoff
+        self.request_timeout, self.ban_timeout, self.min_backoff = request_timeout, ban_timeout, min_backoff
         self.lock_changes = threading.Lock()
         self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
         self.policy = NoSpendingPolicy()
         self._rpc_info = rpc_info
+        self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers
 
         if sequence_info is None:
             self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
@@ -97,7 +102,7 @@ class RemoteSequenceManager:
             logger.error("Using a sequence manager that is not running: it has either crashed or never started")
         if not self.ready.is_set():
             logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready")
-            self.ready.wait()
+            self.update(wait=True)  # this will await an existing update or trigger a new one (if not updating)
 
         end_index = end_index if end_index is not None else len(self)
         span_sequence = []
@@ -123,9 +128,11 @@ class RemoteSequenceManager:
             self.p2p,
             update_period=self._thread.update_period,
             request_timeout=self.request_timeout,
+            ban_timeout=self.ban_timeout,
             min_backoff=self.min_backoff,
             sequence_info=self.sequence_info[ix],
             rpc_info=self._rpc_info,
+            banned_peers=self.banned_peers,
             start=True,
         )
 
@@ -143,6 +150,14 @@ class RemoteSequenceManager:
                 new_block_infos = petals.dht_utils.get_remote_module_infos(
                     self.dht, self.block_uids, expiration_time=float("inf")
                 )
+                for block_info in new_block_infos:
+                    if not block_info:
+                        continue
+                    for peer_id in tuple(block_info.servers.keys()):
+                        if peer_id in self.banned_peers:
+                            logger.debug(f"Ignoring banned {peer_id} for block {block_info.uid}")
+                            block_info.servers.pop(peer_id)
+
                 with self.lock_changes:
                     self.sequence_info.update_(new_block_infos)
                 missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]]
@@ -158,6 +173,24 @@ class RemoteSequenceManager:
                 logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                 time.sleep(delay)
 
+    def on_request_failure(self, peer_id: PeerID):
+        """remove a given peer from the routing table. If the routing is no longer possible, trigger an update"""
+        logger.info(f"Peer {peer_id} did not respond, banning it temporarily")
+        self.banned_peers.register_failure(peer_id)
+        with self.lock_changes:
+            should_update = False
+            for info in self.sequence_info.block_infos:
+                info.servers.pop(peer_id, None)
+                if not info.servers:
+                    should_update = True
+            if should_update:
+                self.ready.clear()
+                self.update(wait=False)
+
+    def on_request_success(self, peer_id: PeerID):
+        """if peer has a failure streak, clear that streak"""
+        self.banned_peers.register_success(peer_id)
+
     def __len__(self):
         return len(self.block_uids)
 
@@ -178,16 +211,21 @@ class RemoteSequenceManager:
         """Return the rpc_info queried from one of the servers that hold the first block"""
         if self._rpc_info is None:
             for attempt_no in itertools.count():
+                peer_id = None
                 try:
-                    self._update()
+                    if not self.ready.is_set():
+                        self.update(wait=True)
                     peer_id, _ = random.choice(list(self.sequence_info.block_infos[0].servers.items()))
                     stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
                     outputs = RemoteExpertWorker.run_coroutine(
                         stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
                     )
                     self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+                    self.on_request_success(peer_id)
                     break
                 except Exception as e:
+                    if peer_id is not None:
+                        self.on_request_failure(peer_id)
                     delay = self.get_retry_delay(attempt_no)
                     logger.warning(
                         f"Caught exception when gathering information from peer {peer_id} "

+ 7 - 3
src/petals/client/sequential_autograd.py

@@ -57,9 +57,8 @@ async def sequential_forward(
     while block_idx < end_index:
         for attempt_no in itertools.count():
             logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
+            span = None
             try:
-                if attempt_no >= 1:
-                    sequence_manager._update()
                 if not sequences or attempt_no >= 1:
                     sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
                     # make_sequence() could return a longer sequence
@@ -91,8 +90,11 @@ async def sequential_forward(
 
                 inputs = outputs
                 block_idx = span.end
+                sequence_manager.on_request_success(span.peer_id)
                 break
             except Exception as e:
+                if span is not None:
+                    sequence_manager.on_request_failure(span.peer_id)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
                     f"Caught exception when running forward from block {block_idx} "
@@ -137,7 +139,6 @@ async def sequential_backward(
             logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
             try:
                 if attempt_no >= 1:
-                    sequence_manager.update(wait=True)
                     _, backup_inputs, backup_sequences = await sequential_forward(
                         inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
                     )
@@ -167,8 +168,11 @@ async def sequential_backward(
                 )
                 grad_outputs = [grad_outputs]
                 grad_prompts_reversed.extend(span_grad_prompts)
+                sequence_manager.on_request_success(span.peer_id)
                 break
             except Exception as e:
+                if span is not None:
+                    sequence_manager.on_request_failure(span.peer_id)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
                     f"Caught exception when running backward between blocks {span.start}-{span.end} "