Selaa lähdekoodia

InferenceSession: Replace only a segment of spans instead of everything
until the end

Aleksandr Borzunov 2 vuotta sitten
vanhempi
commit
3bc06f0002
1 muutettua tiedostoa jossa 34 lisäystä ja 17 poistoa
  1. 34 17
      src/client/inference_session.py

+ 34 - 17
src/client/inference_session.py

@@ -212,39 +212,56 @@ class InferenceSession:
 
         server_idx = 0
         block_idx = 0
-        recovery_mode = False
+        recovery_end = None  # Recovery mode is disabled until a failure happens
         while block_idx < len(self._sequence_manager):
             for attempt_no in itertools.count():
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
                 try:
                     if attempt_no >= 1:
-                        self._exit_server_sessions(self._server_sessions[server_idx:])
-                        self._server_sessions[server_idx:] = []
-                        self._chosen_spans[server_idx:] = []
-                        self._server_inputs[server_idx + 1 :] = []
-
                         self._sequence_manager.update_()
-                        recovery_mode = True
+                        if server_idx < len(self._chosen_spans):
+                            recovery_end = self._chosen_spans[server_idx].end
+                        else:
+                            recovery_end = len(self._sequence_manager)
                         if attempt_no == 1:
-                            logger.info("Entering recovery mode, remote attention caches will be regenerated")
+                            logger.info(
+                                f"Entering recovery mode, remote attention caches "
+                                f"from block {block_idx} to {recovery_end} will be regenerated"
+                            )
 
                     if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
-                        backup_spans = self._sequence_manager.make_sequence(block_idx)
-                        self._chosen_spans.extend(backup_spans)
-                        self._server_sessions.extend(self._enter_server_sessions(backup_spans))
-                        logger.debug(f"Found path from block {block_idx} via {len(backup_spans)} servers")
+                        backup_spans = self._sequence_manager.make_sequence(block_idx, recovery_end)
+                        if recovery_end is not None:
+                            # make_sequence() could return a longer sequence
+                            backup_spans[-1].end = min(backup_spans[-1].end, recovery_end)
+
+                        # The code below works even if server `server_idx` is not added yet
+                        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
+                        self._chosen_spans[server_idx : server_idx + 1] = backup_spans
+                        self._server_sessions[server_idx : server_idx + 1] = self._enter_server_sessions(backup_spans)
+                        if server_idx >= len(self._server_inputs):
+                            self._server_inputs.append(None)
+                        self._server_inputs[server_idx + 1 : server_idx + 1] = [None] * (len(backup_spans) - 1)
+                        logger.debug(
+                            f"Found path from block {block_idx} to {recovery_end} via {len(backup_spans)} servers"
+                        )
 
                     session = self._server_sessions[server_idx]
                     span = self._chosen_spans[server_idx]
 
-                    if server_idx == len(self._server_inputs):
-                        self._server_inputs.append(inputs)
+                    if self._server_inputs[server_idx] is None:
+                        self._server_inputs[server_idx] = inputs
                     elif self._server_inputs[server_idx].shape[1] == self._position:
-                        self._server_inputs[server_idx] = torch.cat([self._server_inputs[server_idx], inputs], dim=1)
+                        self._server_inputs[server_idx] = torch.cat(
+                            [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
+                        )
                     assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens
 
-                    if recovery_mode:
-                        inputs = self._server_inputs[server_idx]  # Take full inputs including prefix
+                    if recovery_end is not None and block_idx < recovery_end:
+                        inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
+                    elif block_idx == recovery_end:
+                        inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
+
                     outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
                     assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"