Просмотр исходного кода

test running inference session with position getter/setter

Your Name 1 год назад
Родитель
Сommit
e45532793d
2 измененных файлов с 23 добавлено и 15 удалено
  1. 20 14
      src/petals/client/inference_session.py
  2. 3 1
      tests/test_speculative_generation.py

+ 20 - 14
src/petals/client/inference_session.py

@@ -83,6 +83,17 @@ class _ServerInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
+    @property
+    def position(self):
+        return self._position
+
+    @position.setter
+    def position(self, start_from_position: int):
+        assert start_from_position <= self._position
+        self._position = start_from_position
+        if self.history is not None and self.history.shape[1] >= start_from_position:
+            self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
+
     def step(
         self,
         inputs: torch.Tensor,
@@ -90,7 +101,6 @@ class _ServerInferenceSession:
         hypo_ids: torch.LongTensor,
         *,
         step_id: str,
-        start_from_position: int,
     ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -100,12 +110,6 @@ class _ServerInferenceSession:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
 
-        if start_from_position is not None:
-            assert start_from_position <= self._position
-            self._position = start_from_position
-            if self.history is not None and self.history.shape[1] >= start_from_position:
-                self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
-
         n_input_tokens = inputs.shape[1]
         if self.history is None:
             self.history = inputs
@@ -127,8 +131,8 @@ class _ServerInferenceSession:
         request_metadata = dict(session_id=self.session_id, step_id=step_id)
         if not self.stepped:
             request_metadata.update(self.session_metadata)
-        if start_from_position is not None:
-            request_metadata["start_from_position"] = start_from_position
+        if self._position is not None:
+            request_metadata["start_from_position"] = self._position
         elif self.config.use_server_to_server:
             next_servers = self._collect_next_servers()
             if next_servers:
@@ -235,6 +239,13 @@ class InferenceSession:
     def position(self) -> int:
         return self._position
 
+    @position.setter
+    def position(self, start_from_position: int) -> None:
+        self._position = start_from_position
+        for session in self._server_sessions:
+            assert isinstance(session, _ServerInferenceSession)
+            session.position = start_from_position
+
     def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
         server_sessions = []
         try:
@@ -275,12 +286,8 @@ class InferenceSession:
         inputs: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
-        start_from_position: Optional[int] = None,
     ) -> torch.Tensor:
 
-        if start_from_position is not None:
-            self._position = start_from_position
-
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -329,7 +336,6 @@ class InferenceSession:
                         prompts[server_session.span.start : server_session.span.end],
                         hypo_ids,
                         step_id=step_id,
-                        start_from_position=start_from_position,
                     )
 
                     server_idx += 1

+ 3 - 1
tests/test_speculative_generation.py

@@ -26,7 +26,9 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
     with torch.inference_mode():
         with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
             initial_outputs_inference = sess.step(inputs)
-            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
+
+            sess.position = 2
+            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
             result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
 
     ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)