5
0
Эх сурвалжийг харах

Implement fault-tolerant inference

Aleksandr Borzunov 2 жил өмнө
parent
commit
f6622bcff7

+ 76 - 31
src/client/inference_session.py

@@ -1,7 +1,9 @@
 from __future__ import annotations
 
 import asyncio
-import contextlib
+import itertools
+import logging
+import time
 from typing import AsyncIterator, List, Optional
 
 import torch
@@ -18,6 +20,7 @@ from hivemind import (
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import aiter_with_timeout
 
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
@@ -56,18 +59,22 @@ class RemoteServerInferenceSession:
         self.closed = False
 
     @classmethod
-    async def _create(
-        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
+    async def create(
+        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
     ) -> RemoteServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
-        outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
+        outputs_stream = await asyncio.wait_for(
+            stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
+            timeout,
+        )
+        outputs_stream = aiter_with_timeout(outputs_stream, timeout)
         return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
 
     @staticmethod
-    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
+    async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
         while True:
-            next_input_message = await asyncio.wait_for(queue.get(), timeout)
+            next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
             yield next_input_message
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
@@ -159,33 +166,39 @@ class RemoteSequentialInferenceSession:
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, **metadata):
         self.sequence_manager = sequence_manager
         self.p2p = p2p
         self.closed = False
-        self.chosen_spans: List[RemoteSpanInfo] = []
-        self.stack = contextlib.ExitStack()
-        self.inference_sessions: List[RemoteServerInferenceSession] = []
+        self.chosen_spans = []
+        self.server_sessions = []
         self.metadata = metadata
-        self.timeout = timeout
 
-    def __enter__(self):
-        assert not self.closed and not self.chosen_spans
-        self.stack.__enter__()
-        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
-        self.chosen_spans.extend(self.sequence_manager.make_sequence())
-
-        for chosen_span in self.chosen_spans:
-            stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
-            span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
-            inference_session = RemoteExpertWorker.run_coroutine(
-                RemoteServerInferenceSession._create(
-                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
+    def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[RemoteServerInferenceSession]:
+        server_sessions = []
+        for span in chosen_spans:
+            stub = TransformerConnectionHandler.get_stub(self.p2p, span.peer_id)
+            span_uids = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[span.start : span.end])
+            session = RemoteExpertWorker.run_coroutine(
+                RemoteServerInferenceSession.create(
+                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.sequence_manager.timeout,
+                    **self.metadata
                 )
             )
-            self.inference_sessions.append(inference_session)
-            self.stack.enter_context(inference_session)
+            server_sessions.append(session)
+            session.__enter__()
+        return server_sessions
 
+    def _exit_server_sessions(self, server_sessions: List[RemoteServerInferenceSession], *, verbose: bool) -> None:
+        exc_loglevel = logging.WARNING if verbose else logging.DEBUG
+        for session in reversed(server_sessions):
+            try:
+                session.__exit__(None, None, None)
+            except Exception:
+                logger.log(exc_loglevel, "Caught exception while closing connection to server:", exc_info=True)
+
+    def __enter__(self):
+        assert not self.closed and not self.chosen_spans
         return self
 
     def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
@@ -196,17 +209,49 @@ class RemoteSequentialInferenceSession:
             prompts = DUMMY
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
-        for session in self.inference_sessions:
-            outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
-            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
-            inputs = outputs
+
+        server_idx = 0
+        block_idx = 0
+        while block_idx < len(self.sequence_manager):
+            for attempt_no in itertools.count():
+                logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
+                try:
+                    if not self.chosen_spans or not self.server_sessions or attempt_no >= 1:
+                        self._exit_server_sessions(self.server_sessions[server_idx:], verbose=False)
+                        self.server_sessions[server_idx:] = []
+                        self.chosen_spans[server_idx:] = []
+
+                        self.sequence_manager.update_()
+                        backup_spans = self.sequence_manager.make_sequence(block_idx)
+                        self.chosen_spans.extend(backup_spans)
+                        self.server_sessions.extend(self._enter_server_sessions(backup_spans))
+                        logger.debug(f"Found path from block {block_idx} via {len(backup_spans)} servers")
+
+                    session = self.server_sessions[server_idx]
+                    span = self.chosen_spans[server_idx]
+
+                    outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
+                    assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
+                    inputs = outputs
+
+                    server_idx += 1
+                    block_idx = span.end
+                    break
+                except Exception as e:
+                    delay = self.sequence_manager.min_backoff * 2**attempt_no
+                    logger.warning(
+                        f"Caught exception when running inference from block {block_idx} "
+                        f"(retry in {delay:.2f} sec): {repr(e)}"
+                    )
+                    logger.debug("See detailed traceback below:", exc_info=True)
+                    time.sleep(delay)
         return inputs
 
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""
         if not self.closed:
-            self.stack.__exit__(*exc_details or (None, None, None))
-            self.inference_sessions.clear()
+            self._exit_server_sessions(self.server_sessions, verbose=True)
+            self.server_sessions.clear()
             self.closed = True
 
     def __exit__(self, *exc_details):