|
@@ -172,6 +172,8 @@ class RemoteSequentialInferenceSession:
|
|
|
self.closed = False
|
|
|
self.chosen_spans = []
|
|
|
self.server_sessions = []
|
|
|
+ self.server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
|
|
+ self.position = 0
|
|
|
self.metadata = metadata
|
|
|
|
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[RemoteServerInferenceSession]:
|
|
@@ -209,19 +211,27 @@ class RemoteSequentialInferenceSession:
|
|
|
prompts = DUMMY
|
|
|
else:
|
|
|
assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
|
|
|
+ n_input_tokens = inputs.shape[1]
|
|
|
|
|
|
server_idx = 0
|
|
|
block_idx = 0
|
|
|
+ recovery_mode = False
|
|
|
while block_idx < len(self.sequence_manager):
|
|
|
for attempt_no in itertools.count():
|
|
|
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
|
|
|
try:
|
|
|
if not self.chosen_spans or not self.server_sessions or attempt_no >= 1:
|
|
|
- self._exit_server_sessions(self.server_sessions[server_idx:], verbose=False)
|
|
|
- self.server_sessions[server_idx:] = []
|
|
|
- self.chosen_spans[server_idx:] = []
|
|
|
+ if attempt_no >= 1:
|
|
|
+ self._exit_server_sessions(self.server_sessions[server_idx:], verbose=False)
|
|
|
+ self.server_sessions[server_idx:] = []
|
|
|
+ self.chosen_spans[server_idx:] = []
|
|
|
+ self.server_inputs[server_idx + 1:] = []
|
|
|
+
|
|
|
+ self.sequence_manager.update_()
|
|
|
+ recovery_mode = True
|
|
|
+ if attempt_no == 1:
|
|
|
+ logger.info("Entering recovery mode, remote attention caches will be regenerated")
|
|
|
|
|
|
- self.sequence_manager.update_()
|
|
|
backup_spans = self.sequence_manager.make_sequence(block_idx)
|
|
|
self.chosen_spans.extend(backup_spans)
|
|
|
self.server_sessions.extend(self._enter_server_sessions(backup_spans))
|
|
@@ -230,6 +240,14 @@ class RemoteSequentialInferenceSession:
|
|
|
session = self.server_sessions[server_idx]
|
|
|
span = self.chosen_spans[server_idx]
|
|
|
|
|
|
+ if server_idx == len(self.server_inputs):
|
|
|
+ self.server_inputs.append(inputs)
|
|
|
+ elif self.server_inputs[server_idx].shape[1] == self.position:
|
|
|
+ self.server_inputs[server_idx] = torch.cat([self.server_inputs[server_idx], inputs], dim=1)
|
|
|
+ assert self.server_inputs[server_idx].shape[1] == self.position + n_input_tokens
|
|
|
+
|
|
|
+ if recovery_mode:
|
|
|
+ inputs = self.server_inputs[server_idx] # Take full inputs including prefix
|
|
|
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
|
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
inputs = outputs
|
|
@@ -241,17 +259,21 @@ class RemoteSequentialInferenceSession:
|
|
|
delay = self.sequence_manager.min_backoff * 2**attempt_no
|
|
|
logger.warning(
|
|
|
f"Caught exception when running inference from block {block_idx} "
|
|
|
- f"(retry in {delay:.2f} sec): {repr(e)}"
|
|
|
+ f"(retry in {delay:.1f} sec): {repr(e)}"
|
|
|
)
|
|
|
logger.debug("See detailed traceback below:", exc_info=True)
|
|
|
time.sleep(delay)
|
|
|
+
|
|
|
+ self.position += n_input_tokens
|
|
|
return inputs
|
|
|
|
|
|
def close(self, *exc_details):
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
if not self.closed:
|
|
|
+ self.server_inputs.clear()
|
|
|
self._exit_server_sessions(self.server_sessions, verbose=True)
|
|
|
self.server_sessions.clear()
|
|
|
+ self.chosen_spans.clear()
|
|
|
self.closed = True
|
|
|
|
|
|
def __exit__(self, *exc_details):
|