xtinkt před 1 rokem
rodič
revize
4285ddbd7b

+ 18 - 3
src/petals/client/inference_session.py

@@ -84,7 +84,8 @@ class _ServerInferenceSession:
                 break  # this message means "done sending"
 
     def step(
-        self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
+        self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *,
+        step_id: str, last_validated_position: int
     ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -94,6 +95,12 @@ class _ServerInferenceSession:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
 
+        if last_validated_position is not None:
+            assert last_validated_position <= self._position
+            self._position = last_validated_position
+            if self.history is not None and self.history.shape[1] >= last_validated_position:
+                self.history = self.history[:, :last_validated_position, :] if last_validated_position > 0 else None
+
         n_input_tokens = inputs.shape[1]
         if self.history is None:
             self.history = inputs
@@ -115,6 +122,8 @@ class _ServerInferenceSession:
         request_metadata = dict(session_id=self.session_id, step_id=step_id)
         if not self.stepped:
             request_metadata.update(self.session_metadata)
+        if last_validated_position is not None:
+            request_metadata["last_validated_position"] = last_validated_position
         elif self.config.use_server_to_server:
             next_servers = self._collect_next_servers()
             if next_servers:
@@ -257,8 +266,13 @@ class InferenceSession:
         return self
 
     def step(
-        self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
+        self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None,
+        hypo_ids: Optional[torch.Tensor] = None, last_validated_position: Optional[int] = None
     ) -> torch.Tensor:
+
+        if last_validated_position is not None:
+            self._position = last_validated_position
+
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -303,7 +317,8 @@ class InferenceSession:
 
                     server_session = self._server_sessions[server_idx]
                     inputs = server_session.step(
-                        inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
+                        inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids,
+                        step_id=step_id, last_validated_position=last_validated_position
                     )
 
                     server_idx += 1

binární
src/petals/server/.handler.py.swp


+ 5 - 0
src/petals/server/block_functions.py

@@ -160,6 +160,11 @@ async def iterate_rpc_inference(
     point_per_piece = points / max_length if max_length > 0 else 0.0
 
     async for request, step_metadata in input_iterator:
+        if "last_validated_position" in step_metadata:
+            last_validated_position = min(step_metadata["last_validated_position"], prefix_length)
+            assert prefix_length >= last_validated_position, f"prefix_length={prefix_length}, last_validated_position={last_validated_position}"
+            prefix_length = last_validated_position
+
         flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
         if args_structure is not None:
             # TODO: kwargs currently is unused, it can be used later for peft-like adaptation

+ 1 - 0
src/petals/server/handler.py

@@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 max_length = metadata.get("max_length")
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
+                last_validated_position = metadata.get("last_validated_position", None)
                 alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
                 args_structure = metadata.get("args_structure")
                 if not requested_uids:

+ 35 - 0
tests/test_speculative_generation.py

@@ -0,0 +1,35 @@
+import random
+
+import pytest
+import torch
+
+from petals import AutoDistributedConfig, RemoteSequential
+from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
+from petals.server.from_pretrained import load_pretrained_block
+from test_utils import *
+
+
+@pytest.mark.forked
+def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
+    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    remote_sequential = RemoteSequential(config)
+
+    block_index = random.randint(0, config.num_hidden_layers - 1)
+    remote_block = remote_sequential[block_index]
+
+    inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
+    short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
+    short_inputs[:, :2, :] = inputs[:, :2, :]
+
+    initial_outputs_inference = None
+    secondary_outputs_inference = None
+    with torch.inference_mode():
+        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
+            initial_outputs_inference = sess.step(inputs)
+            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], last_validated_position=2)
+            result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
+
+    ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
+    (outputs_local,) = ref_block(short_inputs)
+
+    assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)