Parcourir la source

Fix max_length

Aleksandr Borzunov il y a 2 ans
Parent
commit
01cffeba5d
2 fichiers modifiés avec 8 ajouts et 2 suppressions
  1. 7 1
      src/client/inference_session.py
  2. 1 1
      tests/test_block_exact_match.py

+ 7 - 1
src/client/inference_session.py

@@ -162,7 +162,7 @@ class InferenceSession:
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, **metadata):
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
         self._sequence_manager = sequence_manager
         self._p2p = p2p
         self._closed = False
@@ -170,6 +170,7 @@ class InferenceSession:
         self._server_sessions = []
         self._server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
         self._position = 0
+        self._max_length = max_length
         self._metadata = metadata
 
     def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
@@ -183,6 +184,7 @@ class InferenceSession:
                     span_uids,
                     rpc_info=self._sequence_manager.rpc_info,
                     timeout=self._sequence_manager.timeout,
+                    max_length=self._max_length,
                     **self._metadata,
                 )
             )
@@ -210,6 +212,10 @@ class InferenceSession:
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == len(self._sequence_manager)
         n_input_tokens = inputs.shape[1]
+        if self._position + n_input_tokens > self._max_length:
+            raise ValueError(
+                f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
+            )
 
         server_idx = 0
         block_idx = 0

+ 1 - 1
tests/test_block_exact_match.py

@@ -33,7 +33,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
                 outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
 
             # test that max length is respected
-            with pytest.raises(P2PHandlerError) as exc_info:
+            with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
                 sess.step(inputs[:, -1:, :])
             assert "Maximum length exceeded" in repr(exc_info.value)