Jelajahi Sumber

Make inference, forward, and backward fully fault-tolerant (#91)

Alexander Borzunov 2 tahun lalu
induk
melakukan
11d6ba683c

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ accelerate==0.10.0
 huggingface-hub==0.7.0
 transformers==4.21.3
 protobuf>=3.12.2,<4.0.0
-git+https://github.com/learning-at-home/hivemind@94c985d2dc7a79a091e46c755e9f2f4469b164c7
+git+https://github.com/learning-at-home/hivemind@8f258b4b3688f671208bf323359cb967b25d640a
 humanfriendly

+ 1 - 1
src/client/__init__.py

@@ -1,4 +1,4 @@
-from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
+from src.client.inference_session import InferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager

+ 159 - 53
src/client/inference_session.py

@@ -1,7 +1,8 @@
 from __future__ import annotations
 
 import asyncio
-import contextlib
+import itertools
+import time
 from typing import AsyncIterator, List, Optional
 
 import torch
@@ -13,26 +14,25 @@ from hivemind import (
     get_logger,
     nested_flatten,
     serialize_torch_tensor,
-    use_hivemind_log_handler,
 )
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import aiter_with_timeout
 
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
 from src.utils.misc import DUMMY, is_dummy
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class RemoteTransformerBlockInferenceSession:
+class _ServerInferenceSession:
     """
-    An interface to a single multi-step *inference* session for a specific remote module on a specific server
+    An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
 
-    :note: this inference session is *not* fault-tolerant out of the box
+    :note: This class is *not* fault-tolerant out of the box.
     """
 
     def __init__(
@@ -42,32 +42,35 @@ class RemoteTransformerBlockInferenceSession:
         inputs_queue: asyncio.Queue,
         outputs_aiter: AsyncIterator,
         *,
+        timeout: float,
         max_length: int,
         points: int = 0,
     ):
         self.uid, self.rpc_info = uid, rpc_info
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
-        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
-        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.timeout = timeout
         self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
         self.stepped = False
         self.closed = False
 
     @classmethod
-    async def _create(
-        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
-    ) -> RemoteTransformerBlockInferenceSession:
+    async def create(
+        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
+    ) -> _ServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
-        outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
-        return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
+        outputs_stream = await asyncio.wait_for(
+            stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
+            timeout,
+        )
+        return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
 
     @staticmethod
-    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
+    async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
         while True:
-            next_input_message = await asyncio.wait_for(queue.get(), timeout)
+            next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
             yield next_input_message
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
@@ -77,7 +80,7 @@ class RemoteTransformerBlockInferenceSession:
         new_hidden_states: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
-    ):
+    ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tesors and receive a chunk of outputs
         :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
@@ -122,7 +125,7 @@ class RemoteTransformerBlockInferenceSession:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
         await self._inputs_queue.put(inputs_serialized)
         self.stepped = True
-        return await anext(self._outputs_stream)
+        return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
 
     def close(self):
         """Finish a given inference session, close the underlying connection"""
@@ -154,60 +157,163 @@ class RemoteTransformerBlockInferenceSession:
         self.close()
 
 
-class RemoteSequentialInferenceSession:
+class InferenceSession:
     """
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
-        self.sequence_manager = sequence_manager
-        self.p2p = p2p
-        self.closed = False
-        self.chosen_spans: List[RemoteSpanInfo] = []
-        self.stack = contextlib.ExitStack()
-        self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
-        self.metadata = metadata
-        self.timeout = timeout
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
+        self._sequence_manager = sequence_manager
+        self._p2p = p2p
+        self._closed = False
+        self._chosen_spans = []
+        self._server_sessions = []
+        self._server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
+        self._position = 0
+        self._max_length = max_length
+        self._metadata = metadata
 
-    def __enter__(self):
-        assert not self.closed and not self.chosen_spans
-        self.stack.__enter__()
-        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
-        self.chosen_spans.extend(self.sequence_manager.make_sequence())
-
-        for chosen_span in self.chosen_spans:
-            stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
-            span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
-            inference_session = RemoteExpertWorker.run_coroutine(
-                RemoteTransformerBlockInferenceSession._create(
-                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
+    def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
+        server_sessions = []
+        try:
+            for span in chosen_spans:
+                stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
+                span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
+                session = RemoteExpertWorker.run_coroutine(
+                    _ServerInferenceSession.create(
+                        stub,
+                        span_uids,
+                        rpc_info=self._sequence_manager.rpc_info,
+                        timeout=self._sequence_manager.timeout,
+                        max_length=self._max_length,
+                        **self._metadata,
+                    )
                 )
-            )
-            self.inference_sessions.append(inference_session)
-            self.stack.enter_context(inference_session)
+                server_sessions.append(session)
+                session.__enter__()
+            return server_sessions
+        except:
+            self._exit_server_sessions(server_sessions)
+            raise
+
+    def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
+        for session in reversed(server_sessions):
+            try:
+                session.__exit__(None, None, None)
+            except Exception:
+                logger.debug("Caught exception while closing connection to server:", exc_info=True)
 
+    def __enter__(self) -> "InferenceSession":
+        assert not self._closed and not self._chosen_spans
         return self
 
-    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
-        assert not self.closed
+    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
+        assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+
+        n_blocks = len(self._sequence_manager)
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
         else:
-            assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
-        for session in self.inference_sessions:
-            outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
-            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
-            inputs = outputs
+            assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
+
+        n_input_tokens = inputs.shape[1]
+        if self._position + n_input_tokens > self._max_length:
+            raise ValueError(
+                f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
+            )
+
+        server_idx = 0
+        block_idx = 0
+        recovery_until = -1  # Recovery mode is disabled until a failure happens
+        while block_idx < n_blocks:
+            for attempt_no in itertools.count():
+                logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
+                try:
+                    if attempt_no >= 1:
+                        self._sequence_manager.update_()
+                    if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
+                        # If there is a failed server session, this code closes it
+                        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
+
+                        n_prev_spans = len(self._chosen_spans)
+                        update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
+                        if attempt_no >= 1 and update_end > recovery_until:
+                            logger.info(
+                                f"Due to a server failure, remote attention caches "
+                                f"from block {block_idx} to {update_end} will be regenerated"
+                            )
+                        recovery_until = max(recovery_until, update_end)
+
+                        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
+                        # make_sequence() could return a longer sequence
+                        updated_spans[-1].end = min(updated_spans[-1].end, update_end)
+                        updated_sessions = self._enter_server_sessions(updated_spans)
+                        logger.debug(
+                            f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
+                        )
+
+                        # If there is a failed span, this code replaces it, otherwise it just adds new ones
+                        self._chosen_spans[server_idx : server_idx + 1] = updated_spans
+                        self._server_sessions[server_idx : server_idx + 1] = updated_sessions
+                        recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
+                        self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
+                            len(updated_spans) - 1
+                        )
+                        assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
+                            f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
+                            f"{len(self._server_inputs)} inputs"
+                        )
+
+                    session = self._server_sessions[server_idx]
+                    span = self._chosen_spans[server_idx]
+
+                    if self._server_inputs[server_idx] is None:
+                        self._server_inputs[server_idx] = inputs
+                    elif self._server_inputs[server_idx].shape[1] == self._position:
+                        self._server_inputs[server_idx] = torch.cat(
+                            [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
+                        )
+                    assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
+                        f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
+                        f"position={self._position} n_input_tokens={n_input_tokens}"
+                    )
+
+                    if not session.stepped:
+                        inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
+                    else:
+                        inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
+
+                    outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
+                    assert (
+                        inputs.shape == outputs.shape
+                    ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
+
+                    inputs = outputs
+                    server_idx += 1
+                    block_idx = span.end
+                    break
+                except Exception as e:
+                    delay = self._sequence_manager.get_retry_delay(attempt_no)
+                    logger.warning(
+                        f"Caught exception when running inference from block {block_idx} "
+                        f"(retry in {delay:.0f} sec): {repr(e)}"
+                    )
+                    logger.debug("See detailed traceback below:", exc_info=True)
+                    time.sleep(delay)
+
+        self._position += n_input_tokens
         return inputs
 
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""
-        if not self.closed:
-            self.stack.__exit__(*exc_details or (None, None, None))
-            self.inference_sessions.clear()
-            self.closed = True
+        if not self._closed:
+            self._server_inputs.clear()
+            self._exit_server_sessions(self._server_sessions)
+            self._server_sessions.clear()
+            self._chosen_spans.clear()
+            self._closed = True
 
     def __exit__(self, *exc_details):
         self.close(*exc_details)

+ 3 - 3
src/client/remote_sequential.py

@@ -8,7 +8,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
 import src
-from src.client.inference_session import RemoteSequentialInferenceSession
+from src.client.inference_session import InferenceSession
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
@@ -80,9 +80,9 @@ class RemoteSequential(nn.Module):
     def __len__(self):
         return len(self.sequence_manager)
 
-    def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession:
+    def inference_session(self, **kwargs) -> InferenceSession:
         self.sequence_manager.update_()
-        return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs)
+        return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
 
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

+ 5 - 0
src/client/sequence_manager.py

@@ -160,3 +160,8 @@ class RemoteSequenceManager:
                     else:
                         logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
         return self._rpc_info
+
+    def get_retry_delay(self, attempt_no: int) -> float:
+        if attempt_no == 0:
+            return 0
+        return self.min_backoff * 2 ** (attempt_no - 1)

+ 51 - 25
src/client/sequential_autograd.py

@@ -3,11 +3,12 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
 """
 import asyncio
 import itertools
-import logging
+from collections import deque
 from typing import List, Optional, Sequence, Tuple
 
 import torch
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.utils.logging import get_logger
 
 from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
 from src.client.sequence_manager import RemoteSequenceManager
@@ -15,6 +16,8 @@ from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
 from src.server.handler import TransformerConnectionHandler
 from src.utils.misc import DUMMY, is_dummy
 
+logger = get_logger(__file__)
+
 MAX_TOKENS_IN_BATCH = 1024
 
 
@@ -39,19 +42,30 @@ async def sequential_forward(
         sequence_manager.block_uids
     )  # should be n_layers - 1 but add extra prompts for convenience
 
-    sequences = sequence_manager.make_sequence(start_index, end_index)
+    sequences = deque()
     intermediate_inputs = []
     done_sequences = []
     outputs = inputs
 
-    while len(sequences) > 0:
+    block_idx = start_index
+    while block_idx < end_index:
         for attempt_no in itertools.count():
-            span = sequences.pop(0)
-            span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+            logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
             try:
+                if attempt_no >= 1:
+                    sequence_manager.update_()
+                if not sequences or attempt_no >= 1:
+                    sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
+                    # make_sequence() could return a longer sequence
+                    sequences[-1].end = min(sequences[-1].end, end_index)
+                    logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")
+
+                span = sequences.popleft()
+
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
+                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 (outputs,) = await run_remote_forward(
                     span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
                 )
@@ -64,14 +78,16 @@ async def sequential_forward(
                 done_sequences.append(span)
 
                 inputs = outputs
+                block_idx = span.end
                 break
             except Exception as e:
-                logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
-                await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
-
-                backup_sequences = sequence_manager.make_sequence(span.start)
-                assert backup_sequences[0].start == span.start
-                sequences = backup_sequences
+                delay = sequence_manager.get_retry_delay(attempt_no)
+                logger.warning(
+                    f"Caught exception when running forward from block {block_idx} "
+                    f"(retry in {delay:.0f} sec): {repr(e)}"
+                )
+                logger.debug("See detailed traceback below:", exc_info=True)
+                await asyncio.sleep(delay)
 
     return outputs, intermediate_inputs, done_sequences
 
@@ -91,11 +107,26 @@ async def sequential_backward(
 
     grad_prompts_reversed = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
+        inputs = intermediate_inputs.pop()
+        span = forward_sequences.pop()
         for attempt_no in itertools.count():
-            inputs = intermediate_inputs.pop(-1)
-            span = forward_sequences.pop(-1)
-            span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+            logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
             try:
+                if attempt_no >= 1:
+                    sequence_manager.update_()
+                    _, backup_inputs, backup_sequences = await sequential_forward(
+                        inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
+                    )
+                    assert len(backup_inputs) == len(backup_sequences)
+                    assert backup_sequences[0].start == span.start
+                    assert backup_sequences[-1].end == span.end
+
+                    intermediate_inputs.extend(backup_inputs)
+                    forward_sequences.extend(backup_sequences)
+                    inputs = intermediate_inputs.pop()
+                    span = forward_sequences.pop()
+
+                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
                     span_uids,
@@ -110,18 +141,13 @@ async def sequential_backward(
                 grad_prompts_reversed.extend(span_grad_prompts)
                 break
             except Exception as e:
-                logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
-                await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
-
-                _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
-                    inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
+                delay = sequence_manager.get_retry_delay(attempt_no)
+                logger.warning(
+                    f"Caught exception when running backward between blocks {span.start}-{span.end} "
+                    f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                assert len(intermediate_inputs) == len(forward_sequences)
-                assert backup_forward_sequences[0].start == span.start
-                assert backup_forward_sequences[-1].end == span.end
-
-                forward_sequences.extend(backup_forward_sequences)
-                intermediate_inputs.extend(backup_intermediate_inputs)
+                logger.debug("See detailed traceback below:", exc_info=True)
+                await asyncio.sleep(delay)
 
     # For now, we do not support mixed dummy and grad prompts
     # Concat in num_layer dimension

+ 1 - 1
tests/test_block_exact_match.py

@@ -33,7 +33,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
                 outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
 
             # test that max length is respected
-            with pytest.raises(P2PHandlerError) as exc_info:
+            with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
                 sess.step(inputs[:, -1:, :])
             assert "Maximum length exceeded" in repr(exc_info.value)