瀏覽代碼

Improve InferenceSession typing

Aleksandr Borzunov 2 年之前
父節點
當前提交
90654da4e1
共有 1 個文件被更改,包括 3 次插入3 次删除
  1. 3 3
      src/client/inference_session.py

+ 3 - 3
src/client/inference_session.py

@@ -80,7 +80,7 @@ class _ServerInferenceSession:
         new_hidden_states: torch.Tensor,
         new_hidden_states: torch.Tensor,
         prompts: Optional[torch.Tensor] = None,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
-    ):
+    ) -> torch.Tensor:
         """
         """
         Inference step: send a chunk of input tesors and receive a chunk of outputs
         Inference step: send a chunk of input tesors and receive a chunk of outputs
         :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
         :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
@@ -203,11 +203,11 @@ class InferenceSession:
             except Exception:
             except Exception:
                 logger.debug("Caught exception while closing connection to server:", exc_info=True)
                 logger.debug("Caught exception while closing connection to server:", exc_info=True)
 
 
-    def __enter__(self):
+    def __enter__(self) -> "InferenceSession":
         assert not self._closed and not self._chosen_spans
         assert not self._closed and not self._chosen_spans
         return self
         return self
 
 
-    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
+    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
         assert not self._closed
         assert not self._closed
         if torch.is_grad_enabled():
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")