Browse Source

InferenceSession: Fix the case when failure happens while recovering
from another failure

Aleksandr Borzunov 2 years ago
parent
commit
226fe91f6f
1 changed files with 22 additions and 21 deletions
  1. 22 21
      src/client/inference_session.py

+ 22 - 21
src/client/inference_session.py

@@ -207,10 +207,13 @@ class InferenceSession:
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+
+        n_blocks = len(self._sequence_manager)
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
         else:
-            assert prompts.ndim == 4 and prompts.shape[0] == len(self._sequence_manager)
+            assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
+
         n_input_tokens = inputs.shape[1]
         if self._position + n_input_tokens > self._max_length:
             raise ValueError(
@@ -219,38 +222,36 @@ class InferenceSession:
 
         server_idx = 0
         block_idx = 0
-        recovery_end = None  # Recovery mode is disabled until a failure happens
-        while block_idx < len(self._sequence_manager):
+        recovery_until = -1  # Recovery mode is disabled until a failure happens
+        while block_idx < n_blocks:
             for attempt_no in itertools.count():
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
                 try:
                     if attempt_no >= 1:
                         self._sequence_manager.update_()
-                        if server_idx < len(self._chosen_spans):
-                            recovery_end = self._chosen_spans[server_idx].end
-                        else:
-                            recovery_end = len(self._sequence_manager)
-                        if attempt_no == 1:
+                    if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
+                        n_spans = len(self._chosen_spans)
+                        update_end = self._chosen_spans[server_idx].end if server_idx < n_spans else n_blocks
+                        if attempt_no == 1 and update_end > recovery_until:
                             logger.info(
-                                f"Entering recovery mode, remote attention caches "
-                                f"from block {block_idx} to {recovery_end} will be regenerated"
+                                f"Due to a server failure, remote attention caches "
+                                f"from block {block_idx} to {update_end} will be regenerated"
                             )
+                        recovery_until = max(recovery_until, update_end)
 
-                    if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
-                        backup_spans = self._sequence_manager.make_sequence(block_idx, recovery_end)
-                        if recovery_end is not None:
-                            # make_sequence() could return a longer sequence
-                            backup_spans[-1].end = min(backup_spans[-1].end, recovery_end)
+                        updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
+                        # make_sequence() could return a longer sequence
+                        updated_spans[-1].end = min(updated_spans[-1].end, update_end)
 
                         # The code below works even if server `server_idx` is not added yet
                         self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
-                        self._chosen_spans[server_idx : server_idx + 1] = backup_spans
-                        self._server_sessions[server_idx : server_idx + 1] = self._enter_server_sessions(backup_spans)
+                        self._chosen_spans[server_idx : server_idx + 1] = updated_spans
+                        self._server_sessions[server_idx : server_idx + 1] = self._enter_server_sessions(updated_spans)
                         if server_idx >= len(self._server_inputs):
                             self._server_inputs.append(None)
-                        self._server_inputs[server_idx + 1 : server_idx + 1] = [None] * (len(backup_spans) - 1)
+                        self._server_inputs[server_idx + 1 : server_idx + 1] = [None] * (len(updated_spans) - 1)
                         logger.debug(
-                            f"Found path from block {block_idx} to {recovery_end} via {len(backup_spans)} servers"
+                            f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
                         )
 
                     session = self._server_sessions[server_idx]
@@ -264,9 +265,9 @@ class InferenceSession:
                         )
                     assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens
 
-                    if recovery_end is not None and block_idx < recovery_end:
+                    if block_idx < recovery_until:
                         inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
-                    elif block_idx == recovery_end:
+                    elif block_idx == recovery_until:
                         inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
 
                     outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)