|
@@ -212,39 +212,56 @@ class InferenceSession:
|
|
|
|
|
|
server_idx = 0
|
|
|
block_idx = 0
|
|
|
- recovery_mode = False
|
|
|
+ recovery_end = None # Recovery mode is disabled until a failure happens
|
|
|
while block_idx < len(self._sequence_manager):
|
|
|
for attempt_no in itertools.count():
|
|
|
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
|
|
|
try:
|
|
|
if attempt_no >= 1:
|
|
|
- self._exit_server_sessions(self._server_sessions[server_idx:])
|
|
|
- self._server_sessions[server_idx:] = []
|
|
|
- self._chosen_spans[server_idx:] = []
|
|
|
- self._server_inputs[server_idx + 1 :] = []
|
|
|
-
|
|
|
self._sequence_manager.update_()
|
|
|
- recovery_mode = True
|
|
|
+ if server_idx < len(self._chosen_spans):
|
|
|
+ recovery_end = self._chosen_spans[server_idx].end
|
|
|
+ else:
|
|
|
+ recovery_end = len(self._sequence_manager)
|
|
|
if attempt_no == 1:
|
|
|
- logger.info("Entering recovery mode, remote attention caches will be regenerated")
|
|
|
+ logger.info(
|
|
|
+ f"Entering recovery mode, remote attention caches "
|
|
|
+ f"from block {block_idx} to {recovery_end} will be regenerated"
|
|
|
+ )
|
|
|
|
|
|
if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
|
|
|
- backup_spans = self._sequence_manager.make_sequence(block_idx)
|
|
|
- self._chosen_spans.extend(backup_spans)
|
|
|
- self._server_sessions.extend(self._enter_server_sessions(backup_spans))
|
|
|
- logger.debug(f"Found path from block {block_idx} via {len(backup_spans)} servers")
|
|
|
+ backup_spans = self._sequence_manager.make_sequence(block_idx, recovery_end)
|
|
|
+ if recovery_end is not None:
|
|
|
+ # make_sequence() could return a longer sequence
|
|
|
+ backup_spans[-1].end = min(backup_spans[-1].end, recovery_end)
|
|
|
+
|
|
|
+ # The code below works even if server `server_idx` is not added yet
|
|
|
+ self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
|
|
+ self._chosen_spans[server_idx : server_idx + 1] = backup_spans
|
|
|
+ self._server_sessions[server_idx : server_idx + 1] = self._enter_server_sessions(backup_spans)
|
|
|
+ if server_idx >= len(self._server_inputs):
|
|
|
+ self._server_inputs.append(None)
|
|
|
+ self._server_inputs[server_idx + 1 : server_idx + 1] = [None] * (len(backup_spans) - 1)
|
|
|
+ logger.debug(
|
|
|
+ f"Found path from block {block_idx} to {recovery_end} via {len(backup_spans)} servers"
|
|
|
+ )
|
|
|
|
|
|
session = self._server_sessions[server_idx]
|
|
|
span = self._chosen_spans[server_idx]
|
|
|
|
|
|
- if server_idx == len(self._server_inputs):
|
|
|
- self._server_inputs.append(inputs)
|
|
|
+ if self._server_inputs[server_idx] is None:
|
|
|
+ self._server_inputs[server_idx] = inputs
|
|
|
elif self._server_inputs[server_idx].shape[1] == self._position:
|
|
|
- self._server_inputs[server_idx] = torch.cat([self._server_inputs[server_idx], inputs], dim=1)
|
|
|
+ self._server_inputs[server_idx] = torch.cat(
|
|
|
+ [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
|
|
|
+ )
|
|
|
assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens
|
|
|
|
|
|
- if recovery_mode:
|
|
|
- inputs = self._server_inputs[server_idx] # Take full inputs including prefix
|
|
|
+ if recovery_end is not None and block_idx < recovery_end:
|
|
|
+ inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
|
|
|
+ elif block_idx == recovery_end:
|
|
|
+ inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
+
|
|
|
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
|
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
|