Parcourir la source

Regenerate attn caches when necessary

Aleksandr Borzunov il y a 2 ans
Parent
commit
55bea823c0
1 fichiers modifiés avec 27 ajouts et 5 suppressions
  1. 27 5
      src/client/inference_session.py

+ 27 - 5
src/client/inference_session.py

@@ -172,6 +172,8 @@ class RemoteSequentialInferenceSession:
         self.closed = False
         self.chosen_spans = []
         self.server_sessions = []
+        self.server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
+        self.position = 0
         self.metadata = metadata
 
     def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[RemoteServerInferenceSession]:
@@ -209,19 +211,27 @@ class RemoteSequentialInferenceSession:
             prompts = DUMMY
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
+        n_input_tokens = inputs.shape[1]
 
         server_idx = 0
         block_idx = 0
+        recovery_mode = False
         while block_idx < len(self.sequence_manager):
             for attempt_no in itertools.count():
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
                 try:
                     if not self.chosen_spans or not self.server_sessions or attempt_no >= 1:
-                        self._exit_server_sessions(self.server_sessions[server_idx:], verbose=False)
-                        self.server_sessions[server_idx:] = []
-                        self.chosen_spans[server_idx:] = []
+                        if attempt_no >= 1:
+                            self._exit_server_sessions(self.server_sessions[server_idx:], verbose=False)
+                            self.server_sessions[server_idx:] = []
+                            self.chosen_spans[server_idx:] = []
+                            self.server_inputs[server_idx + 1:] = []
+
+                            self.sequence_manager.update_()
+                            recovery_mode = True
+                            if attempt_no == 1:
+                                logger.info("Entering recovery mode, remote attention caches will be regenerated")
 
-                        self.sequence_manager.update_()
                         backup_spans = self.sequence_manager.make_sequence(block_idx)
                         self.chosen_spans.extend(backup_spans)
                         self.server_sessions.extend(self._enter_server_sessions(backup_spans))
@@ -230,6 +240,14 @@ class RemoteSequentialInferenceSession:
                     session = self.server_sessions[server_idx]
                     span = self.chosen_spans[server_idx]
 
+                    if server_idx == len(self.server_inputs):
+                        self.server_inputs.append(inputs)
+                    elif self.server_inputs[server_idx].shape[1] == self.position:
+                        self.server_inputs[server_idx] = torch.cat([self.server_inputs[server_idx], inputs], dim=1)
+                    assert self.server_inputs[server_idx].shape[1] == self.position + n_input_tokens
+
+                    if recovery_mode:
+                        inputs = self.server_inputs[server_idx]  # Take full inputs including prefix
                     outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
                     assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
                     inputs = outputs
@@ -241,17 +259,21 @@ class RemoteSequentialInferenceSession:
                     delay = self.sequence_manager.min_backoff * 2**attempt_no
                     logger.warning(
                         f"Caught exception when running inference from block {block_idx} "
-                        f"(retry in {delay:.2f} sec): {repr(e)}"
+                        f"(retry in {delay:.1f} sec): {repr(e)}"
                     )
                     logger.debug("See detailed traceback below:", exc_info=True)
                     time.sleep(delay)
+
+        self.position += n_input_tokens
         return inputs
 
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""
         if not self.closed:
+            self.server_inputs.clear()
             self._exit_server_sessions(self.server_sessions, verbose=True)
             self.server_sessions.clear()
+            self.chosen_spans.clear()
             self.closed = True
 
     def __exit__(self, *exc_details):