|
@@ -207,10 +207,13 @@ class InferenceSession:
|
|
|
assert not self._closed
|
|
|
if torch.is_grad_enabled():
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
|
+
|
|
|
+ n_blocks = len(self._sequence_manager)
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
prompts = DUMMY
|
|
|
else:
|
|
|
- assert prompts.ndim == 4 and prompts.shape[0] == len(self._sequence_manager)
|
|
|
+ assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
|
|
|
+
|
|
|
n_input_tokens = inputs.shape[1]
|
|
|
if self._position + n_input_tokens > self._max_length:
|
|
|
raise ValueError(
|
|
@@ -219,38 +222,36 @@ class InferenceSession:
|
|
|
|
|
|
server_idx = 0
|
|
|
block_idx = 0
|
|
|
- recovery_end = None # Recovery mode is disabled until a failure happens
|
|
|
- while block_idx < len(self._sequence_manager):
|
|
|
+ recovery_until = -1 # Recovery mode is disabled until a failure happens
|
|
|
+ while block_idx < n_blocks:
|
|
|
for attempt_no in itertools.count():
|
|
|
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
|
|
|
try:
|
|
|
if attempt_no >= 1:
|
|
|
self._sequence_manager.update_()
|
|
|
- if server_idx < len(self._chosen_spans):
|
|
|
- recovery_end = self._chosen_spans[server_idx].end
|
|
|
- else:
|
|
|
- recovery_end = len(self._sequence_manager)
|
|
|
- if attempt_no == 1:
|
|
|
+ if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
|
|
|
+ n_spans = len(self._chosen_spans)
|
|
|
+ update_end = self._chosen_spans[server_idx].end if server_idx < n_spans else n_blocks
|
|
|
+ if attempt_no == 1 and update_end > recovery_until:
|
|
|
logger.info(
|
|
|
- f"Entering recovery mode, remote attention caches "
|
|
|
- f"from block {block_idx} to {recovery_end} will be regenerated"
|
|
|
+ f"Due to a server failure, remote attention caches "
|
|
|
+ f"from block {block_idx} to {update_end} will be regenerated"
|
|
|
)
|
|
|
+ recovery_until = max(recovery_until, update_end)
|
|
|
|
|
|
- if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
|
|
|
- backup_spans = self._sequence_manager.make_sequence(block_idx, recovery_end)
|
|
|
- if recovery_end is not None:
|
|
|
- # make_sequence() could return a longer sequence
|
|
|
- backup_spans[-1].end = min(backup_spans[-1].end, recovery_end)
|
|
|
+ updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
|
|
|
+ # make_sequence() could return a longer sequence
|
|
|
+ updated_spans[-1].end = min(updated_spans[-1].end, update_end)
|
|
|
|
|
|
# The code below works even if server `server_idx` is not added yet
|
|
|
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
|
|
- self._chosen_spans[server_idx : server_idx + 1] = backup_spans
|
|
|
- self._server_sessions[server_idx : server_idx + 1] = self._enter_server_sessions(backup_spans)
|
|
|
+ self._chosen_spans[server_idx : server_idx + 1] = updated_spans
|
|
|
+ self._server_sessions[server_idx : server_idx + 1] = self._enter_server_sessions(updated_spans)
|
|
|
if server_idx >= len(self._server_inputs):
|
|
|
self._server_inputs.append(None)
|
|
|
- self._server_inputs[server_idx + 1 : server_idx + 1] = [None] * (len(backup_spans) - 1)
|
|
|
+ self._server_inputs[server_idx + 1 : server_idx + 1] = [None] * (len(updated_spans) - 1)
|
|
|
logger.debug(
|
|
|
- f"Found path from block {block_idx} to {recovery_end} via {len(backup_spans)} servers"
|
|
|
+ f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
|
|
|
)
|
|
|
|
|
|
session = self._server_sessions[server_idx]
|
|
@@ -264,9 +265,9 @@ class InferenceSession:
|
|
|
)
|
|
|
assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens
|
|
|
|
|
|
- if recovery_end is not None and block_idx < recovery_end:
|
|
|
+ if block_idx < recovery_until:
|
|
|
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
|
|
|
- elif block_idx == recovery_end:
|
|
|
+ elif block_idx == recovery_until:
|
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
|
|
|
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|