xtinkt 1 年之前
父节点
当前提交
9aecb3f39e
共有 2 个文件被更改,包括 20 次插入7 次删除
  1. 17 6
      src/petals/client/inference_session.py
  2. 3 1
      src/petals/server/block_functions.py

+ 17 - 6
src/petals/client/inference_session.py

@@ -84,8 +84,13 @@ class _ServerInferenceSession:
                 break  # this message means "done sending"
 
     def step(
-        self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *,
-        step_id: str, start_from_position: int
+        self,
+        inputs: torch.Tensor,
+        prompts: torch.Tensor,
+        hypo_ids: torch.LongTensor,
+        *,
+        step_id: str,
+        start_from_position: int,
     ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -266,8 +271,11 @@ class InferenceSession:
         return self
 
     def step(
-        self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None,
-        hypo_ids: Optional[torch.Tensor] = None, start_from_position: Optional[int] = None
+        self,
+        inputs: torch.Tensor,
+        prompts: Optional[torch.Tensor] = None,
+        hypo_ids: Optional[torch.Tensor] = None,
+        start_from_position: Optional[int] = None,
     ) -> torch.Tensor:
 
         if start_from_position is not None:
@@ -317,8 +325,11 @@ class InferenceSession:
 
                     server_session = self._server_sessions[server_idx]
                     inputs = server_session.step(
-                        inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids,
-                        step_id=step_id, start_from_position=start_from_position
+                        inputs,
+                        prompts[server_session.span.start : server_session.span.end],
+                        hypo_ids,
+                        step_id=step_id,
+                        start_from_position=start_from_position,
                     )
 
                     server_idx += 1

+ 3 - 1
src/petals/server/block_functions.py

@@ -162,7 +162,9 @@ async def iterate_rpc_inference(
     async for request, step_metadata in input_iterator:
         if "start_from_position" in step_metadata:
             start_from_position = step_metadata["start_from_position"]
-            assert prefix_length >= start_from_position, f"prefix_length={prefix_length}, start_from_position={start_from_position}"
+            assert (
+                prefix_length >= start_from_position,
+            ), f"prefix_length={prefix_length}, start_from_position={start_from_position}"
             prefix_length = start_from_position
 
         flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)