xtinkt 1 سال پیش
والد
کامیت
269028d0e6
3فایلهای تغییر یافته به همراه17 افزوده شده و 17 حذف شده
  1. 12 12
      src/petals/client/inference_session.py
  2. 4 4
      src/petals/server/block_functions.py
  3. 1 1
      tests/test_speculative_generation.py

+ 12 - 12
src/petals/client/inference_session.py

@@ -85,7 +85,7 @@ class _ServerInferenceSession:
 
     def step(
         self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *,
-        step_id: str, last_validated_position: int
+        step_id: str, start_from_position: int
     ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -95,11 +95,11 @@ class _ServerInferenceSession:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
 
-        if last_validated_position is not None:
-            assert last_validated_position <= self._position
-            self._position = last_validated_position
-            if self.history is not None and self.history.shape[1] >= last_validated_position:
-                self.history = self.history[:, :last_validated_position, :] if last_validated_position > 0 else None
+        if start_from_position is not None:
+            assert start_from_position <= self._position
+            self._position = start_from_position
+            if self.history is not None and self.history.shape[1] >= start_from_position:
+                self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
 
         n_input_tokens = inputs.shape[1]
         if self.history is None:
@@ -122,8 +122,8 @@ class _ServerInferenceSession:
         request_metadata = dict(session_id=self.session_id, step_id=step_id)
         if not self.stepped:
             request_metadata.update(self.session_metadata)
-        if last_validated_position is not None:
-            request_metadata["last_validated_position"] = last_validated_position
+        if start_from_position is not None:
+            request_metadata["start_from_position"] = start_from_position
         elif self.config.use_server_to_server:
             next_servers = self._collect_next_servers()
             if next_servers:
@@ -267,11 +267,11 @@ class InferenceSession:
 
     def step(
         self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None,
-        hypo_ids: Optional[torch.Tensor] = None, last_validated_position: Optional[int] = None
+        hypo_ids: Optional[torch.Tensor] = None, start_from_position: Optional[int] = None
     ) -> torch.Tensor:
 
-        if last_validated_position is not None:
-            self._position = last_validated_position
+        if start_from_position is not None:
+            self._position = start_from_position
 
         assert not self._closed
         if torch.is_grad_enabled():
@@ -318,7 +318,7 @@ class InferenceSession:
                     server_session = self._server_sessions[server_idx]
                     inputs = server_session.step(
                         inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids,
-                        step_id=step_id, last_validated_position=last_validated_position
+                        step_id=step_id, start_from_position=start_from_position
                     )
 
                     server_idx += 1

+ 4 - 4
src/petals/server/block_functions.py

@@ -160,10 +160,10 @@ async def iterate_rpc_inference(
     point_per_piece = points / max_length if max_length > 0 else 0.0
 
     async for request, step_metadata in input_iterator:
-        if "last_validated_position" in step_metadata:
-            last_validated_position = step_metadata["last_validated_position"]
-            assert prefix_length >= last_validated_position, f"prefix_length={prefix_length}, last_validated_position={last_validated_position}"
-            prefix_length = last_validated_position
+        if "start_from_position" in step_metadata:
+            start_from_position = step_metadata["start_from_position"]
+            assert prefix_length >= start_from_position, f"prefix_length={prefix_length}, start_from_position={start_from_position}"
+            prefix_length = start_from_position
 
         flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
         if args_structure is not None:

+ 1 - 1
tests/test_speculative_generation.py

@@ -26,7 +26,7 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
     with torch.inference_mode():
         with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
             initial_outputs_inference = sess.step(inputs)
-            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], last_validated_position=2)
+            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
             result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
 
     ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)