Преглед изворни кода

Add option to rollback inference for a certain number of steps (#588)

* fix

* fix

* fix

* fix

* fix

* fix

* style
Anton Sinitsin пре 1 година
родитељ
комит
c0a4d2e3d5

+ 29 - 3
src/petals/client/inference_session.py

@@ -84,7 +84,13 @@ class _ServerInferenceSession:
                 break  # this message means "done sending"
 
     def step(
-        self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
+        self,
+        inputs: torch.Tensor,
+        prompts: torch.Tensor,
+        hypo_ids: torch.LongTensor,
+        *,
+        step_id: str,
+        start_from_position: int,
     ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -94,6 +100,12 @@ class _ServerInferenceSession:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
 
+        if start_from_position is not None:
+            assert start_from_position <= self._position
+            self._position = start_from_position
+            if self.history is not None and self.history.shape[1] >= start_from_position:
+                self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
+
         n_input_tokens = inputs.shape[1]
         if self.history is None:
             self.history = inputs
@@ -115,6 +127,8 @@ class _ServerInferenceSession:
         request_metadata = dict(session_id=self.session_id, step_id=step_id)
         if not self.stepped:
             request_metadata.update(self.session_metadata)
+        if start_from_position is not None:
+            request_metadata["start_from_position"] = start_from_position
         elif self.config.use_server_to_server:
             next_servers = self._collect_next_servers()
             if next_servers:
@@ -257,8 +271,16 @@ class InferenceSession:
         return self
 
     def step(
-        self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
+        self,
+        inputs: torch.Tensor,
+        prompts: Optional[torch.Tensor] = None,
+        hypo_ids: Optional[torch.Tensor] = None,
+        start_from_position: Optional[int] = None,
     ) -> torch.Tensor:
+
+        if start_from_position is not None:
+            self._position = start_from_position
+
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -303,7 +325,11 @@ class InferenceSession:
 
                     server_session = self._server_sessions[server_idx]
                     inputs = server_session.step(
-                        inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
+                        inputs,
+                        prompts[server_session.span.start : server_session.span.end],
+                        hypo_ids,
+                        step_id=step_id,
+                        start_from_position=start_from_position,
                     )
 
                     server_idx += 1

+ 7 - 0
src/petals/server/block_functions.py

@@ -160,6 +160,13 @@ async def iterate_rpc_inference(
     point_per_piece = points / max_length if max_length > 0 else 0.0
 
     async for request, step_metadata in input_iterator:
+        if "start_from_position" in step_metadata:
+            start_from_position = step_metadata["start_from_position"]
+            assert (
+                prefix_length >= start_from_position,
+            ), f"prefix_length={prefix_length}, start_from_position={start_from_position}"
+            prefix_length = start_from_position
+
         flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
         if args_structure is not None:
             # TODO: kwargs currently is unused, it can be used later for peft-like adaptation

+ 35 - 0
tests/test_speculative_generation.py

@@ -0,0 +1,35 @@
+import random
+
+import pytest
+import torch
+
+from petals import AutoDistributedConfig, RemoteSequential
+from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
+from petals.server.from_pretrained import load_pretrained_block
+from test_utils import *
+
+
+@pytest.mark.forked
+def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
+    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    remote_sequential = RemoteSequential(config)
+
+    block_index = random.randint(0, config.num_hidden_layers - 1)
+    remote_block = remote_sequential[block_index]
+
+    inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
+    short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
+    short_inputs[:, :2, :] = inputs[:, :2, :]
+
+    initial_outputs_inference = None
+    secondary_outputs_inference = None
+    with torch.inference_mode():
+        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
+            initial_outputs_inference = sess.step(inputs)
+            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
+            result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
+
+    ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
+    (outputs_local,) = ref_block(short_inputs)
+
+    assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)