|
@@ -162,7 +162,7 @@ class InferenceSession:
|
|
|
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, **metadata):
|
|
|
+ def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
|
|
|
self._sequence_manager = sequence_manager
|
|
|
self._p2p = p2p
|
|
|
self._closed = False
|
|
@@ -170,6 +170,7 @@ class InferenceSession:
|
|
|
self._server_sessions = []
|
|
|
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
|
|
self._position = 0
|
|
|
+ self._max_length = max_length
|
|
|
self._metadata = metadata
|
|
|
|
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
|
|
@@ -183,6 +184,7 @@ class InferenceSession:
|
|
|
span_uids,
|
|
|
rpc_info=self._sequence_manager.rpc_info,
|
|
|
timeout=self._sequence_manager.timeout,
|
|
|
+ max_length=self._max_length,
|
|
|
**self._metadata,
|
|
|
)
|
|
|
)
|
|
@@ -210,6 +212,10 @@ class InferenceSession:
|
|
|
else:
|
|
|
assert prompts.ndim == 4 and prompts.shape[0] == len(self._sequence_manager)
|
|
|
n_input_tokens = inputs.shape[1]
|
|
|
+ if self._position + n_input_tokens > self._max_length:
|
|
|
+ raise ValueError(
|
|
|
+ f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
|
|
|
+ )
|
|
|
|
|
|
server_idx = 0
|
|
|
block_idx = 0
|