|
@@ -83,6 +83,17 @@ class _ServerInferenceSession:
|
|
if not next_input_message.uid and not next_input_message.tensors:
|
|
if not next_input_message.uid and not next_input_message.tensors:
|
|
break # this message means "done sending"
|
|
break # this message means "done sending"
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def position(self):
|
|
|
|
+ return self._position
|
|
|
|
+
|
|
|
|
+ @position.setter
|
|
|
|
+ def position(self, start_from_position: int):
|
|
|
|
+ assert start_from_position <= self._position
|
|
|
|
+ self._position = start_from_position
|
|
|
|
+ if self.history is not None and self.history.shape[1] >= start_from_position:
|
|
|
|
+ self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
|
|
|
|
+
|
|
def step(
|
|
def step(
|
|
self,
|
|
self,
|
|
inputs: torch.Tensor,
|
|
inputs: torch.Tensor,
|
|
@@ -90,7 +101,6 @@ class _ServerInferenceSession:
|
|
hypo_ids: torch.LongTensor,
|
|
hypo_ids: torch.LongTensor,
|
|
*,
|
|
*,
|
|
step_id: str,
|
|
step_id: str,
|
|
- start_from_position: int,
|
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
"""
|
|
"""
|
|
Inference step: send a chunk of input tensors and receive a chunk of outputs
|
|
Inference step: send a chunk of input tensors and receive a chunk of outputs
|
|
@@ -100,12 +110,6 @@ class _ServerInferenceSession:
|
|
if self.closed:
|
|
if self.closed:
|
|
raise Exception("Session is closed, cannot perform step")
|
|
raise Exception("Session is closed, cannot perform step")
|
|
|
|
|
|
- if start_from_position is not None:
|
|
|
|
- assert start_from_position <= self._position
|
|
|
|
- self._position = start_from_position
|
|
|
|
- if self.history is not None and self.history.shape[1] >= start_from_position:
|
|
|
|
- self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
|
|
|
|
-
|
|
|
|
n_input_tokens = inputs.shape[1]
|
|
n_input_tokens = inputs.shape[1]
|
|
if self.history is None:
|
|
if self.history is None:
|
|
self.history = inputs
|
|
self.history = inputs
|
|
@@ -127,8 +131,8 @@ class _ServerInferenceSession:
|
|
request_metadata = dict(session_id=self.session_id, step_id=step_id)
|
|
request_metadata = dict(session_id=self.session_id, step_id=step_id)
|
|
if not self.stepped:
|
|
if not self.stepped:
|
|
request_metadata.update(self.session_metadata)
|
|
request_metadata.update(self.session_metadata)
|
|
- if start_from_position is not None:
|
|
|
|
- request_metadata["start_from_position"] = start_from_position
|
|
|
|
|
|
+ if self._position is not None:
|
|
|
|
+ request_metadata["start_from_position"] = self._position
|
|
elif self.config.use_server_to_server:
|
|
elif self.config.use_server_to_server:
|
|
next_servers = self._collect_next_servers()
|
|
next_servers = self._collect_next_servers()
|
|
if next_servers:
|
|
if next_servers:
|
|
@@ -235,6 +239,13 @@ class InferenceSession:
|
|
def position(self) -> int:
|
|
def position(self) -> int:
|
|
return self._position
|
|
return self._position
|
|
|
|
|
|
|
|
+ @position.setter
|
|
|
|
+ def position(self, start_from_position: int) -> None:
|
|
|
|
+ self._position = start_from_position
|
|
|
|
+ for session in self._server_sessions:
|
|
|
|
+ assert isinstance(session, _ServerInferenceSession)
|
|
|
|
+ session.position = start_from_position
|
|
|
|
+
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
|
|
server_sessions = []
|
|
server_sessions = []
|
|
try:
|
|
try:
|
|
@@ -275,12 +286,8 @@ class InferenceSession:
|
|
inputs: torch.Tensor,
|
|
inputs: torch.Tensor,
|
|
prompts: Optional[torch.Tensor] = None,
|
|
prompts: Optional[torch.Tensor] = None,
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
- start_from_position: Optional[int] = None,
|
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
|
|
|
- if start_from_position is not None:
|
|
|
|
- self._position = start_from_position
|
|
|
|
-
|
|
|
|
assert not self._closed
|
|
assert not self._closed
|
|
if torch.is_grad_enabled():
|
|
if torch.is_grad_enabled():
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
@@ -329,7 +336,6 @@ class InferenceSession:
|
|
prompts[server_session.span.start : server_session.span.end],
|
|
prompts[server_session.span.start : server_session.span.end],
|
|
hypo_ids,
|
|
hypo_ids,
|
|
step_id=step_id,
|
|
step_id=step_id,
|
|
- start_from_position=start_from_position,
|
|
|
|
)
|
|
)
|
|
|
|
|
|
server_idx += 1
|
|
server_idx += 1
|