|
@@ -84,8 +84,13 @@ class _ServerInferenceSession:
|
|
break # this message means "done sending"
|
|
break # this message means "done sending"
|
|
|
|
|
|
def step(
|
|
def step(
|
|
- self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *,
|
|
|
|
- step_id: str, start_from_position: int
|
|
|
|
|
|
+ self,
|
|
|
|
+ inputs: torch.Tensor,
|
|
|
|
+ prompts: torch.Tensor,
|
|
|
|
+ hypo_ids: torch.LongTensor,
|
|
|
|
+ *,
|
|
|
|
+ step_id: str,
|
|
|
|
+ start_from_position: int,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
"""
|
|
"""
|
|
Inference step: send a chunk of input tensors and receive a chunk of outputs
|
|
Inference step: send a chunk of input tensors and receive a chunk of outputs
|
|
@@ -266,8 +271,11 @@ class InferenceSession:
|
|
return self
|
|
return self
|
|
|
|
|
|
def step(
|
|
def step(
|
|
- self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None,
|
|
|
|
- hypo_ids: Optional[torch.Tensor] = None, start_from_position: Optional[int] = None
|
|
|
|
|
|
+ self,
|
|
|
|
+ inputs: torch.Tensor,
|
|
|
|
+ prompts: Optional[torch.Tensor] = None,
|
|
|
|
+ hypo_ids: Optional[torch.Tensor] = None,
|
|
|
|
+ start_from_position: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
|
|
|
if start_from_position is not None:
|
|
if start_from_position is not None:
|
|
@@ -317,8 +325,11 @@ class InferenceSession:
|
|
|
|
|
|
server_session = self._server_sessions[server_idx]
|
|
server_session = self._server_sessions[server_idx]
|
|
inputs = server_session.step(
|
|
inputs = server_session.step(
|
|
- inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids,
|
|
|
|
- step_id=step_id, start_from_position=start_from_position
|
|
|
|
|
|
+ inputs,
|
|
|
|
+ prompts[server_session.span.start : server_session.span.end],
|
|
|
|
+ hypo_ids,
|
|
|
|
+ step_id=step_id,
|
|
|
|
+ start_from_position=start_from_position,
|
|
)
|
|
)
|
|
|
|
|
|
server_idx += 1
|
|
server_idx += 1
|