소스 검색

Improve InferenceSession typing

Aleksandr Borzunov 2 년 전
부모
커밋
90654da4e1
1개의 변경된 파일3개의 추가작업 그리고 3개의 파일을 삭제
  1. 3 3
      src/client/inference_session.py

+ 3 - 3
src/client/inference_session.py

@@ -80,7 +80,7 @@ class _ServerInferenceSession:
         new_hidden_states: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
-    ):
+    ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tesors and receive a chunk of outputs
         :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
@@ -203,11 +203,11 @@ class InferenceSession:
             except Exception:
                 logger.debug("Caught exception while closing connection to server:", exc_info=True)
 
-    def __enter__(self):
+    def __enter__(self) -> "InferenceSession":
         assert not self._closed and not self._chosen_spans
         return self
 
-    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
+    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")