|
@@ -167,24 +167,24 @@ class RemoteSequentialInferenceSession:
|
|
|
"""
|
|
|
|
|
|
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, **metadata):
|
|
|
- self.sequence_manager = sequence_manager
|
|
|
- self.p2p = p2p
|
|
|
- self.closed = False
|
|
|
- self.chosen_spans = []
|
|
|
- self.server_sessions = []
|
|
|
- self.server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
|
|
- self.position = 0
|
|
|
- self.metadata = metadata
|
|
|
+ self._sequence_manager = sequence_manager
|
|
|
+ self._p2p = p2p
|
|
|
+ self._closed = False
|
|
|
+ self._chosen_spans = []
|
|
|
+ self._server_sessions = []
|
|
|
+ self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
|
|
+ self._position = 0
|
|
|
+ self._metadata = metadata
|
|
|
|
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[RemoteServerInferenceSession]:
|
|
|
server_sessions = []
|
|
|
for span in chosen_spans:
|
|
|
- stub = TransformerConnectionHandler.get_stub(self.p2p, span.peer_id)
|
|
|
- span_uids = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[span.start : span.end])
|
|
|
+ stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
|
|
|
+ span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
|
|
|
session = RemoteExpertWorker.run_coroutine(
|
|
|
RemoteServerInferenceSession.create(
|
|
|
- stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.sequence_manager.timeout,
|
|
|
- **self.metadata
|
|
|
+ stub, span_uids, rpc_info=self._sequence_manager.rpc_info, timeout=self._sequence_manager.timeout,
|
|
|
+ **self._metadata
|
|
|
)
|
|
|
)
|
|
|
server_sessions.append(session)
|
|
@@ -200,54 +200,54 @@ class RemoteSequentialInferenceSession:
|
|
|
logger.log(exc_loglevel, "Caught exception while closing connection to server:", exc_info=True)
|
|
|
|
|
|
def __enter__(self):
|
|
|
- assert not self.closed and not self.chosen_spans
|
|
|
+ assert not self._closed and not self._chosen_spans
|
|
|
return self
|
|
|
|
|
|
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
|
|
|
- assert not self.closed
|
|
|
+ assert not self._closed
|
|
|
if torch.is_grad_enabled():
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
prompts = DUMMY
|
|
|
else:
|
|
|
- assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
|
|
|
+ assert prompts.ndim == 4 and prompts.shape[0] == len(self._sequence_manager)
|
|
|
n_input_tokens = inputs.shape[1]
|
|
|
|
|
|
server_idx = 0
|
|
|
block_idx = 0
|
|
|
recovery_mode = False
|
|
|
- while block_idx < len(self.sequence_manager):
|
|
|
+ while block_idx < len(self._sequence_manager):
|
|
|
for attempt_no in itertools.count():
|
|
|
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
|
|
|
try:
|
|
|
- if not self.chosen_spans or not self.server_sessions or attempt_no >= 1:
|
|
|
+ if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
|
|
|
if attempt_no >= 1:
|
|
|
- self._exit_server_sessions(self.server_sessions[server_idx:], verbose=False)
|
|
|
- self.server_sessions[server_idx:] = []
|
|
|
- self.chosen_spans[server_idx:] = []
|
|
|
- self.server_inputs[server_idx + 1:] = []
|
|
|
+ self._exit_server_sessions(self._server_sessions[server_idx:], verbose=False)
|
|
|
+ self._server_sessions[server_idx:] = []
|
|
|
+ self._chosen_spans[server_idx:] = []
|
|
|
+ self._server_inputs[server_idx + 1:] = []
|
|
|
|
|
|
- self.sequence_manager.update_()
|
|
|
+ self._sequence_manager.update_()
|
|
|
recovery_mode = True
|
|
|
if attempt_no == 1:
|
|
|
logger.info("Entering recovery mode, remote attention caches will be regenerated")
|
|
|
|
|
|
- backup_spans = self.sequence_manager.make_sequence(block_idx)
|
|
|
- self.chosen_spans.extend(backup_spans)
|
|
|
- self.server_sessions.extend(self._enter_server_sessions(backup_spans))
|
|
|
+ backup_spans = self._sequence_manager.make_sequence(block_idx)
|
|
|
+ self._chosen_spans.extend(backup_spans)
|
|
|
+ self._server_sessions.extend(self._enter_server_sessions(backup_spans))
|
|
|
logger.debug(f"Found path from block {block_idx} via {len(backup_spans)} servers")
|
|
|
|
|
|
- session = self.server_sessions[server_idx]
|
|
|
- span = self.chosen_spans[server_idx]
|
|
|
+ session = self._server_sessions[server_idx]
|
|
|
+ span = self._chosen_spans[server_idx]
|
|
|
|
|
|
- if server_idx == len(self.server_inputs):
|
|
|
- self.server_inputs.append(inputs)
|
|
|
- elif self.server_inputs[server_idx].shape[1] == self.position:
|
|
|
- self.server_inputs[server_idx] = torch.cat([self.server_inputs[server_idx], inputs], dim=1)
|
|
|
- assert self.server_inputs[server_idx].shape[1] == self.position + n_input_tokens
|
|
|
+ if server_idx == len(self._server_inputs):
|
|
|
+ self._server_inputs.append(inputs)
|
|
|
+ elif self._server_inputs[server_idx].shape[1] == self._position:
|
|
|
+ self._server_inputs[server_idx] = torch.cat([self._server_inputs[server_idx], inputs], dim=1)
|
|
|
+ assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens
|
|
|
|
|
|
if recovery_mode:
|
|
|
- inputs = self.server_inputs[server_idx] # Take full inputs including prefix
|
|
|
+ inputs = self._server_inputs[server_idx] # Take full inputs including prefix
|
|
|
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
|
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
inputs = outputs
|
|
@@ -256,25 +256,25 @@ class RemoteSequentialInferenceSession:
|
|
|
block_idx = span.end
|
|
|
break
|
|
|
except Exception as e:
|
|
|
- delay = self.sequence_manager.min_backoff * 2**attempt_no
|
|
|
+ delay = self._sequence_manager.min_backoff * 2**attempt_no
|
|
|
logger.warning(
|
|
|
f"Caught exception when running inference from block {block_idx} "
|
|
|
- f"(retry in {delay:.1f} sec): {repr(e)}"
|
|
|
+ f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
)
|
|
|
logger.debug("See detailed traceback below:", exc_info=True)
|
|
|
time.sleep(delay)
|
|
|
|
|
|
- self.position += n_input_tokens
|
|
|
+ self._position += n_input_tokens
|
|
|
return inputs
|
|
|
|
|
|
def close(self, *exc_details):
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
- if not self.closed:
|
|
|
- self.server_inputs.clear()
|
|
|
- self._exit_server_sessions(self.server_sessions, verbose=True)
|
|
|
- self.server_sessions.clear()
|
|
|
- self.chosen_spans.clear()
|
|
|
- self.closed = True
|
|
|
+ if not self._closed:
|
|
|
+ self._server_inputs.clear()
|
|
|
+ self._exit_server_sessions(self._server_sessions, verbose=True)
|
|
|
+ self._server_sessions.clear()
|
|
|
+ self._chosen_spans.clear()
|
|
|
+ self._closed = True
|
|
|
|
|
|
def __exit__(self, *exc_details):
|
|
|
self.close(*exc_details)
|