Ver código fonte

Make forward more fault-tolerant

Aleksandr Borzunov 2 anos atrás
pai
commit
87fd00ead9
2 arquivos alterados com 37 adições e 25 exclusões
  1. 11 13
      src/client/inference_session.py
  2. 26 12
      src/client/sequential_autograd.py

+ 11 - 13
src/client/inference_session.py

@@ -14,7 +14,6 @@ from hivemind import (
     get_logger,
     nested_flatten,
     serialize_torch_tensor,
-    use_hivemind_log_handler,
 )
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
@@ -26,7 +25,6 @@ from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCI
 from src.server.handler import TransformerConnectionHandler
 from src.utils.misc import DUMMY, is_dummy
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
@@ -219,18 +217,18 @@ class InferenceSession:
             for attempt_no in itertools.count():
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
                 try:
-                    if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
-                        if attempt_no >= 1:
-                            self._exit_server_sessions(self._server_sessions[server_idx:])
-                            self._server_sessions[server_idx:] = []
-                            self._chosen_spans[server_idx:] = []
-                            self._server_inputs[server_idx + 1 :] = []
+                    if attempt_no >= 1:
+                        self._exit_server_sessions(self._server_sessions[server_idx:])
+                        self._server_sessions[server_idx:] = []
+                        self._chosen_spans[server_idx:] = []
+                        self._server_inputs[server_idx + 1 :] = []
 
-                            self._sequence_manager.update_()
-                            recovery_mode = True
-                            if attempt_no == 1:
-                                logger.info("Entering recovery mode, remote attention caches will be regenerated")
+                        self._sequence_manager.update_()
+                        recovery_mode = True
+                        if attempt_no == 1:
+                            logger.info("Entering recovery mode, remote attention caches will be regenerated")
 
+                    if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
                         backup_spans = self._sequence_manager.make_sequence(block_idx)
                         self._chosen_spans.extend(backup_spans)
                         self._server_sessions.extend(self._enter_server_sessions(backup_spans))
@@ -249,8 +247,8 @@ class InferenceSession:
                         inputs = self._server_inputs[server_idx]  # Take full inputs including prefix
                     outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
                     assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
-                    inputs = outputs
 
+                    inputs = outputs
                     server_idx += 1
                     block_idx = span.end
                     break

+ 26 - 12
src/client/sequential_autograd.py

@@ -3,11 +3,12 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
 """
 import asyncio
 import itertools
-import logging
+from collections import deque
 from typing import List, Optional, Sequence, Tuple
 
 import torch
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.utils.logging import get_logger
 
 from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
 from src.client.sequence_manager import RemoteSequenceManager
@@ -15,6 +16,8 @@ from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
 from src.server.handler import TransformerConnectionHandler
 from src.utils.misc import DUMMY, is_dummy
 
+logger = get_logger(__file__)
+
 MAX_TOKENS_IN_BATCH = 1024
 
 
@@ -39,16 +42,25 @@ async def sequential_forward(
         sequence_manager.block_uids
     )  # should be n_layers - 1 but add extra prompts for convenience
 
-    sequences = sequence_manager.make_sequence(start_index, end_index)
+    sequences = deque()
     intermediate_inputs = []
     done_sequences = []
     outputs = inputs
 
-    while len(sequences) > 0:
+    block_idx = start_index
+    while block_idx < len(sequence_manager):
         for attempt_no in itertools.count():
-            span = sequences.pop(0)
-            span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+            logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
             try:
+                if attempt_no >= 1:
+                    sequence_manager.update_()
+                if not sequences or attempt_no >= 1:
+                    sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
+                    logger.debug(f"Found path from block {block_idx} via {len(sequences)} servers")
+
+                span = sequences.popleft()
+                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
@@ -64,14 +76,16 @@ async def sequential_forward(
                 done_sequences.append(span)
 
                 inputs = outputs
+                block_idx = span.end
                 break
             except Exception as e:
-                logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
-                await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
-
-                backup_sequences = sequence_manager.make_sequence(span.start)
-                assert backup_sequences[0].start == span.start
-                sequences = backup_sequences
+                delay = sequence_manager.min_backoff * 2**attempt_no
+                logger.warning(
+                    f"Caught exception when running forward from block {block_idx} "
+                    f"(retry in {delay:.0f} sec): {repr(e)}"
+                )
+                logger.debug("See detailed traceback below:", exc_info=True)
+                await asyncio.sleep(delay)
 
     return outputs, intermediate_inputs, done_sequences
 
@@ -110,7 +124,7 @@ async def sequential_backward(
                 grad_prompts_reversed.extend(span_grad_prompts)
                 break
             except Exception as e:
-                logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
+                logger.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
 
                 _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(