Bladeren bron

Fix inference and rpc_info() fault tolerance (#131)

Alexander Borzunov 2 jaren geleden
bovenliggende
commit
f56edaa13f

+ 3 - 4
src/petals/client/inference_session.py

@@ -20,7 +20,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
 
-from petals.client.routing.sequence_manager import RemoteSequenceManager
+from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
 from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, is_dummy
@@ -307,12 +307,11 @@ class InferenceSession:
                         f"Caught exception when running inference from block {block_idx} "
                         f"(retry in {delay:.0f} sec): {repr(e)}"
                     )
-                    traceback_level = logging.DEBUG if str(e) else logging.WARNING
-                    logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
+                    maybe_log_traceback(e)
                     time.sleep(delay)
 
         self._position += n_input_tokens
-
+        inputs = inputs[:, -n_input_tokens:]
         outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
         return outputs
 

+ 19 - 6
src/petals/client/routing/sequence_manager.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import asyncio
 import itertools
 import logging
 import random
@@ -17,7 +18,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 import petals.dht_utils
 from petals.client.routing.sequence_info import RemoteSequenceInfo
 from petals.client.routing.spending_policy import NoSpendingPolicy
-from petals.data_structures import ModuleUID, RemoteSpanInfo
+from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
 from petals.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
@@ -169,8 +170,7 @@ class RemoteSequenceManager:
             except Exception as e:
                 delay = self.get_retry_delay(attempt_no)
                 logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
-                traceback_level = logging.DEBUG if str(e) else logging.WARNING
-                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
+                maybe_log_traceback(e)
                 time.sleep(delay)
 
     def on_request_failure(self, peer_id: PeerID):
@@ -215,7 +215,16 @@ class RemoteSequenceManager:
                 try:
                     if not self.ready.is_set():
                         self.update(wait=True)
-                    peer_id, _ = random.choice(list(self.sequence_info.block_infos[0].servers.items()))
+
+                    active_servers = [
+                        peer_id
+                        for peer_id, server in self.sequence_info.block_infos[0].servers.items()
+                        if server.state == ServerState.ONLINE
+                    ]
+                    if not active_servers:
+                        raise MissingBlocksError("no servers holding the first block are online")
+                    peer_id = random.choice(active_servers)
+
                     stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
                     outputs = RemoteExpertWorker.run_coroutine(
                         stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
@@ -231,8 +240,7 @@ class RemoteSequenceManager:
                         f"Caught exception when gathering information from peer {peer_id} "
                         f"(retry in {delay:.0f} sec): {repr(e)}"
                     )
-                    traceback_level = logging.DEBUG if str(e) else logging.WARNING
-                    logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
+                    maybe_log_traceback(e)
                     time.sleep(delay)
 
         return self._rpc_info
@@ -298,6 +306,11 @@ class _SequenceManagerUpdateThread(threading.Thread):
             self.shutdown()
 
 
+def maybe_log_traceback(exc: Exception):
+    traceback_level = logging.DEBUG if str(exc) or isinstance(exc, asyncio.TimeoutError) else logging.WARNING
+    logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
+
+
 class MissingBlocksError(Exception):
     def __repr__(self):
         return self.args[0]

+ 3 - 5
src/petals/client/sequential_autograd.py

@@ -13,7 +13,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.utils.logging import get_logger
 
 from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
-from petals.client.routing.sequence_manager import RemoteSequenceManager
+from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
 from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, is_dummy
@@ -100,8 +100,7 @@ async def sequential_forward(
                     f"Caught exception when running forward from block {block_idx} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                traceback_level = logging.DEBUG if str(e) else logging.WARNING
-                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
+                maybe_log_traceback(e)
                 await asyncio.sleep(delay)
 
     outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
@@ -178,8 +177,7 @@ async def sequential_backward(
                     f"Caught exception when running backward between blocks {span.start}-{span.end} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                traceback_level = logging.DEBUG if str(e) else logging.WARNING
-                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
+                maybe_log_traceback(e)
                 await asyncio.sleep(delay)
 
     # For now, we do not support mixed dummy and grad prompts

+ 1 - 1
src/petals/server/throughput.py

@@ -146,7 +146,7 @@ def measure_compute_rps(
 
 
 def get_device_name(device: torch.device) -> str:
-    return f"{torch.cuda.get_device_name(device)} GPU" if device == "cuda" else "CPU"
+    return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
 
 
 def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str: