Browse Source

Fix InferenceSession edge cases

Aleksandr Borzunov 2 years ago
parent
commit
292e359731
1 changed files with 42 additions and 30 deletions
  1. 42 30
      src/client/inference_session.py

+ 42 - 30
src/client/inference_session.py

@@ -175,22 +175,26 @@ class InferenceSession:
 
 
     def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
     def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
         server_sessions = []
         server_sessions = []
-        for span in chosen_spans:
-            stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
-            span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
-            session = RemoteExpertWorker.run_coroutine(
-                _ServerInferenceSession.create(
-                    stub,
-                    span_uids,
-                    rpc_info=self._sequence_manager.rpc_info,
-                    timeout=self._sequence_manager.timeout,
-                    max_length=self._max_length,
-                    **self._metadata,
+        try:
+            for span in chosen_spans:
+                stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
+                span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
+                session = RemoteExpertWorker.run_coroutine(
+                    _ServerInferenceSession.create(
+                        stub,
+                        span_uids,
+                        rpc_info=self._sequence_manager.rpc_info,
+                        timeout=self._sequence_manager.timeout,
+                        max_length=self._max_length,
+                        **self._metadata,
+                    )
                 )
                 )
-            )
-            server_sessions.append(session)
-            session.__enter__()
-        return server_sessions
+                server_sessions.append(session)
+                session.__enter__()
+            return server_sessions
+        except:
+            self._exit_server_sessions(server_sessions)
+            raise
 
 
     def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
     def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
         for session in reversed(server_sessions):
         for session in reversed(server_sessions):
@@ -230,9 +234,12 @@ class InferenceSession:
                     if attempt_no >= 1:
                     if attempt_no >= 1:
                         self._sequence_manager.update_()
                         self._sequence_manager.update_()
                     if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
                     if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
-                        n_spans = len(self._chosen_spans)
-                        update_end = self._chosen_spans[server_idx].end if server_idx < n_spans else n_blocks
-                        if attempt_no == 1 and update_end > recovery_until:
+                        # If there is a failed server session, this code closes it
+                        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
+
+                        n_prev_spans = len(self._chosen_spans)
+                        update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
+                        if attempt_no >= 1 and update_end > recovery_until:
                             logger.info(
                             logger.info(
                                 f"Due to a server failure, remote attention caches "
                                 f"Due to a server failure, remote attention caches "
                                 f"from block {block_idx} to {update_end} will be regenerated"
                                 f"from block {block_idx} to {update_end} will be regenerated"
@@ -242,18 +249,20 @@ class InferenceSession:
                         updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
                         updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
                         # make_sequence() could return a longer sequence
                         # make_sequence() could return a longer sequence
                         updated_spans[-1].end = min(updated_spans[-1].end, update_end)
                         updated_spans[-1].end = min(updated_spans[-1].end, update_end)
-
-                        # The code below works even if server `server_idx` is not added yet
-                        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
-                        self._chosen_spans[server_idx : server_idx + 1] = updated_spans
-                        self._server_sessions[server_idx : server_idx + 1] = self._enter_server_sessions(updated_spans)
-                        if server_idx >= len(self._server_inputs):
-                            self._server_inputs.append(None)
-                        self._server_inputs[server_idx + 1 : server_idx + 1] = [None] * (len(updated_spans) - 1)
+                        updated_sessions = self._enter_server_sessions(updated_spans)
                         logger.debug(
                         logger.debug(
                             f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
                             f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
                         )
                         )
 
 
+                        # If there is a failed span, this code replaces it, otherwise it just adds new ones
+                        self._chosen_spans[server_idx : server_idx + 1] = updated_spans
+                        self._server_sessions[server_idx : server_idx + 1] = updated_sessions
+                        recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
+                        self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (len(updated_spans) - 1)
+                        assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), \
+                            f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " \
+                            f"{len(self._server_inputs)} inputs"
+
                     session = self._server_sessions[server_idx]
                     session = self._server_sessions[server_idx]
                     span = self._chosen_spans[server_idx]
                     span = self._chosen_spans[server_idx]
 
 
@@ -263,15 +272,18 @@ class InferenceSession:
                         self._server_inputs[server_idx] = torch.cat(
                         self._server_inputs[server_idx] = torch.cat(
                             [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
                             [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
                         )
                         )
-                    assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens
+                    assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, \
+                        f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " \
+                        f"position={self._position} n_input_tokens={n_input_tokens}"
 
 
-                    if block_idx < recovery_until:
+                    if not session.stepped:
                         inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
                         inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
-                    elif block_idx == recovery_until:
+                    else:
                         inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
                         inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
 
 
                     outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
                     outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
-                    assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
+                    assert inputs.shape == outputs.shape, \
+                        f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
 
 
                     inputs = outputs
                     inputs = outputs
                     server_idx += 1
                     server_idx += 1