|
@@ -175,22 +175,26 @@ class InferenceSession:
|
|
|
|
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
|
|
server_sessions = []
|
|
server_sessions = []
|
|
- for span in chosen_spans:
|
|
|
|
- stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
|
|
|
|
- span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
|
|
|
|
- session = RemoteExpertWorker.run_coroutine(
|
|
|
|
- _ServerInferenceSession.create(
|
|
|
|
- stub,
|
|
|
|
- span_uids,
|
|
|
|
- rpc_info=self._sequence_manager.rpc_info,
|
|
|
|
- timeout=self._sequence_manager.timeout,
|
|
|
|
- max_length=self._max_length,
|
|
|
|
- **self._metadata,
|
|
|
|
|
|
+ try:
|
|
|
|
+ for span in chosen_spans:
|
|
|
|
+ stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
|
|
|
|
+ span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
|
|
|
|
+ session = RemoteExpertWorker.run_coroutine(
|
|
|
|
+ _ServerInferenceSession.create(
|
|
|
|
+ stub,
|
|
|
|
+ span_uids,
|
|
|
|
+ rpc_info=self._sequence_manager.rpc_info,
|
|
|
|
+ timeout=self._sequence_manager.timeout,
|
|
|
|
+ max_length=self._max_length,
|
|
|
|
+ **self._metadata,
|
|
|
|
+ )
|
|
)
|
|
)
|
|
- )
|
|
|
|
- server_sessions.append(session)
|
|
|
|
- session.__enter__()
|
|
|
|
- return server_sessions
|
|
|
|
|
|
+ server_sessions.append(session)
|
|
|
|
+ session.__enter__()
|
|
|
|
+ return server_sessions
|
|
|
|
+ except:
|
|
|
|
+ self._exit_server_sessions(server_sessions)
|
|
|
|
+ raise
|
|
|
|
|
|
def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
|
|
def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
|
|
for session in reversed(server_sessions):
|
|
for session in reversed(server_sessions):
|
|
@@ -230,9 +234,12 @@ class InferenceSession:
|
|
if attempt_no >= 1:
|
|
if attempt_no >= 1:
|
|
self._sequence_manager.update_()
|
|
self._sequence_manager.update_()
|
|
if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
|
|
if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
|
|
- n_spans = len(self._chosen_spans)
|
|
|
|
- update_end = self._chosen_spans[server_idx].end if server_idx < n_spans else n_blocks
|
|
|
|
- if attempt_no == 1 and update_end > recovery_until:
|
|
|
|
|
|
+ # If there is a failed server session, this code closes it
|
|
|
|
+ self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
|
|
|
+
|
|
|
|
+ n_prev_spans = len(self._chosen_spans)
|
|
|
|
+ update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
|
|
|
|
+ if attempt_no >= 1 and update_end > recovery_until:
|
|
logger.info(
|
|
logger.info(
|
|
f"Due to a server failure, remote attention caches "
|
|
f"Due to a server failure, remote attention caches "
|
|
f"from block {block_idx} to {update_end} will be regenerated"
|
|
f"from block {block_idx} to {update_end} will be regenerated"
|
|
@@ -242,18 +249,20 @@ class InferenceSession:
|
|
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
|
|
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
|
|
# make_sequence() could return a longer sequence
|
|
# make_sequence() could return a longer sequence
|
|
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
|
|
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
|
|
-
|
|
|
|
- # The code below works even if server `server_idx` is not added yet
|
|
|
|
- self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
|
|
|
- self._chosen_spans[server_idx : server_idx + 1] = updated_spans
|
|
|
|
- self._server_sessions[server_idx : server_idx + 1] = self._enter_server_sessions(updated_spans)
|
|
|
|
- if server_idx >= len(self._server_inputs):
|
|
|
|
- self._server_inputs.append(None)
|
|
|
|
- self._server_inputs[server_idx + 1 : server_idx + 1] = [None] * (len(updated_spans) - 1)
|
|
|
|
|
|
+ updated_sessions = self._enter_server_sessions(updated_spans)
|
|
logger.debug(
|
|
logger.debug(
|
|
f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
|
|
f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ # If there is a failed span, this code replaces it, otherwise it just adds new ones
|
|
|
|
+ self._chosen_spans[server_idx : server_idx + 1] = updated_spans
|
|
|
|
+ self._server_sessions[server_idx : server_idx + 1] = updated_sessions
|
|
|
|
+ recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
|
|
|
|
+ self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (len(updated_spans) - 1)
|
|
|
|
+ assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), \
|
|
|
|
+ f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " \
|
|
|
|
+ f"{len(self._server_inputs)} inputs"
|
|
|
|
+
|
|
session = self._server_sessions[server_idx]
|
|
session = self._server_sessions[server_idx]
|
|
span = self._chosen_spans[server_idx]
|
|
span = self._chosen_spans[server_idx]
|
|
|
|
|
|
@@ -263,15 +272,18 @@ class InferenceSession:
|
|
self._server_inputs[server_idx] = torch.cat(
|
|
self._server_inputs[server_idx] = torch.cat(
|
|
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
|
|
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
|
|
)
|
|
)
|
|
- assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens
|
|
|
|
|
|
+ assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, \
|
|
|
|
+ f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " \
|
|
|
|
+ f"position={self._position} n_input_tokens={n_input_tokens}"
|
|
|
|
|
|
- if block_idx < recovery_until:
|
|
|
|
|
|
+ if not session.stepped:
|
|
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
|
|
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
|
|
- elif block_idx == recovery_until:
|
|
|
|
|
|
+ else:
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
|
|
|
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
- assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
|
|
|
+ assert inputs.shape == outputs.shape, \
|
|
|
|
+ f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
|
|
|
|
|
|
inputs = outputs
|
|
inputs = outputs
|
|
server_idx += 1
|
|
server_idx += 1
|