Aleksandr Borzunov 2 vuotta sitten
vanhempi
commit
8c50f65cf2
1 muutettua tiedostoa jossa 12 lisäystä ja 7 poistoa
  1. 12 7
      src/client/inference_session.py

+ 12 - 7
src/client/inference_session.py

@@ -258,10 +258,13 @@ class InferenceSession:
                         self._chosen_spans[server_idx : server_idx + 1] = updated_spans
                         self._chosen_spans[server_idx : server_idx + 1] = updated_spans
                         self._server_sessions[server_idx : server_idx + 1] = updated_sessions
                         self._server_sessions[server_idx : server_idx + 1] = updated_sessions
                         recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
                         recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
-                        self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (len(updated_spans) - 1)
-                        assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), \
-                            f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " \
+                        self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
+                            len(updated_spans) - 1
+                        )
+                        assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
+                            f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
                             f"{len(self._server_inputs)} inputs"
                             f"{len(self._server_inputs)} inputs"
+                        )
 
 
                     session = self._server_sessions[server_idx]
                     session = self._server_sessions[server_idx]
                     span = self._chosen_spans[server_idx]
                     span = self._chosen_spans[server_idx]
@@ -272,9 +275,10 @@ class InferenceSession:
                         self._server_inputs[server_idx] = torch.cat(
                         self._server_inputs[server_idx] = torch.cat(
                             [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
                             [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
                         )
                         )
-                    assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, \
-                        f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " \
+                    assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
+                        f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
                         f"position={self._position} n_input_tokens={n_input_tokens}"
                         f"position={self._position} n_input_tokens={n_input_tokens}"
+                    )
 
 
                     if not session.stepped:
                     if not session.stepped:
                         inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
                         inputs = self._server_inputs[server_idx]  # Pass full inputs including prefix
@@ -282,8 +286,9 @@ class InferenceSession:
                         inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
                         inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further
 
 
                     outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
                     outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
-                    assert inputs.shape == outputs.shape, \
-                        f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
+                    assert (
+                        inputs.shape == outputs.shape
+                    ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
 
 
                     inputs = outputs
                     inputs = outputs
                     server_idx += 1
                     server_idx += 1