Quellcode durchsuchen

Merge branch 'main' into test_set_position

justheuristic vor 1 Jahr
Ursprung
Commit
5e4d884fa2
2 geänderte Dateien mit 7 neuen und 1 gelöschten Zeilen
  1. 6 1
      src/petals/client/inference_session.py
  2. 1 0
      tests/test_speculative_generation.py

+ 6 - 1
src/petals/client/inference_session.py

@@ -110,6 +110,12 @@ class _ServerInferenceSession:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
 
+        if start_from_position is not None:
+            assert start_from_position <= self._position
+            self._position = start_from_position
+            if self.history is not None and self.history.shape[1] >= start_from_position:
+                self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
+
         n_input_tokens = inputs.shape[1]
         if self.history is None:
             self.history = inputs
@@ -287,7 +293,6 @@ class InferenceSession:
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
-
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")

+ 1 - 0
tests/test_speculative_generation.py

@@ -29,6 +29,7 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
 
             sess.position = 2
             secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
+            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
             result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
 
     ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)