|
@@ -80,7 +80,7 @@ class _ServerInferenceSession:
|
|
|
new_hidden_states: torch.Tensor,
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
- ):
|
|
|
+ ) -> torch.Tensor:
|
|
|
"""
|
|
|
Inference step: send a chunk of input tesors and receive a chunk of outputs
|
|
|
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
|
@@ -203,11 +203,11 @@ class InferenceSession:
|
|
|
except Exception:
|
|
|
logger.debug("Caught exception while closing connection to server:", exc_info=True)
|
|
|
|
|
|
- def __enter__(self):
|
|
|
+ def __enter__(self) -> "InferenceSession":
|
|
|
assert not self._closed and not self._chosen_spans
|
|
|
return self
|
|
|
|
|
|
- def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
|
|
|
+ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
|
|
assert not self._closed
|
|
|
if torch.is_grad_enabled():
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|