Browse Source

Rename RemoteSequentialInferenceSession => InferenceSession

Aleksandr Borzunov 2 năm trước cách đây
mục cha
commit
8d47e38251

+ 1 - 1
src/client/__init__.py

@@ -1,4 +1,4 @@
-from src.client.inference_session import RemoteSequentialInferenceSession
+from src.client.inference_session import InferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager

+ 1 - 1
src/client/inference_session.py

@@ -159,7 +159,7 @@ class _ServerInferenceSession:
         self.close()
 
 
-class RemoteSequentialInferenceSession:
+class InferenceSession:
     """
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """

+ 3 - 3
src/client/remote_sequential.py

@@ -8,7 +8,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
 import src
-from src.client.inference_session import RemoteSequentialInferenceSession
+from src.client.inference_session import InferenceSession
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
@@ -80,9 +80,9 @@ class RemoteSequential(nn.Module):
     def __len__(self):
         return len(self.sequence_manager)
 
-    def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession:
+    def inference_session(self, **kwargs) -> InferenceSession:
         self.sequence_manager.update_()
-        return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs)
+        return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
 
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"