|
@@ -110,6 +110,12 @@ class _ServerInferenceSession:
|
|
if self.closed:
|
|
if self.closed:
|
|
raise Exception("Session is closed, cannot perform step")
|
|
raise Exception("Session is closed, cannot perform step")
|
|
|
|
|
|
|
|
+ if start_from_position is not None:
|
|
|
|
+ assert start_from_position <= self._position
|
|
|
|
+ self._position = start_from_position
|
|
|
|
+ if self.history is not None and self.history.shape[1] >= start_from_position:
|
|
|
|
+ self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
|
|
|
|
+
|
|
n_input_tokens = inputs.shape[1]
|
|
n_input_tokens = inputs.shape[1]
|
|
if self.history is None:
|
|
if self.history is None:
|
|
self.history = inputs
|
|
self.history = inputs
|
|
@@ -287,7 +293,6 @@ class InferenceSession:
|
|
prompts: Optional[torch.Tensor] = None,
|
|
prompts: Optional[torch.Tensor] = None,
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
-
|
|
|
|
assert not self._closed
|
|
assert not self._closed
|
|
if torch.is_grad_enabled():
|
|
if torch.is_grad_enabled():
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|