|
@@ -258,10 +258,13 @@ class InferenceSession:
|
|
self._chosen_spans[server_idx : server_idx + 1] = updated_spans
|
|
self._chosen_spans[server_idx : server_idx + 1] = updated_spans
|
|
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
|
|
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
|
|
recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
|
|
recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
|
|
- self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (len(updated_spans) - 1)
|
|
|
|
- assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), \
|
|
|
|
- f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " \
|
|
|
|
|
|
+ self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
|
|
|
|
+ len(updated_spans) - 1
|
|
|
|
+ )
|
|
|
|
+ assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
|
|
|
|
+ f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
|
|
f"{len(self._server_inputs)} inputs"
|
|
f"{len(self._server_inputs)} inputs"
|
|
|
|
+ )
|
|
|
|
|
|
session = self._server_sessions[server_idx]
|
|
session = self._server_sessions[server_idx]
|
|
span = self._chosen_spans[server_idx]
|
|
span = self._chosen_spans[server_idx]
|
|
@@ -272,9 +275,10 @@ class InferenceSession:
|
|
self._server_inputs[server_idx] = torch.cat(
|
|
self._server_inputs[server_idx] = torch.cat(
|
|
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
|
|
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
|
|
)
|
|
)
|
|
- assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, \
|
|
|
|
- f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " \
|
|
|
|
|
|
+ assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
|
|
|
|
+ f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
|
|
f"position={self._position} n_input_tokens={n_input_tokens}"
|
|
f"position={self._position} n_input_tokens={n_input_tokens}"
|
|
|
|
+ )
|
|
|
|
|
|
if not session.stepped:
|
|
if not session.stepped:
|
|
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
|
|
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
|
|
@@ -282,8 +286,9 @@ class InferenceSession:
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
|
|
|
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
- assert inputs.shape == outputs.shape, \
|
|
|
|
- f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
|
|
|
|
|
|
+ assert (
|
|
|
|
+ inputs.shape == outputs.shape
|
|
|
|
+ ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
|
|
|
|
|
|
inputs = outputs
|
|
inputs = outputs
|
|
server_idx += 1
|
|
server_idx += 1
|