Prechádzať zdrojové kódy

Rename Remote{TransformerBlock => Server}InferenceSession

Aleksandr Borzunov 2 rokov pred
rodič
commit
bd10d15e6e
2 zmenil súbory, kde vykonal 6 pridanie a 6 odobranie
  1. 1 1
      src/client/__init__.py
  2. 5 5
      src/client/inference_session.py

+ 1 - 1
src/client/__init__.py

@@ -1,4 +1,4 @@
-from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
+from src.client.inference_session import RemoteSequentialInferenceSession, RemoteServerInferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager

+ 5 - 5
src/client/inference_session.py

@@ -28,9 +28,9 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class RemoteTransformerBlockInferenceSession:
+class RemoteServerInferenceSession:
     """
-    An interface to a single multi-step *inference* session for a specific remote module on a specific server
+    An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
 
     :note: this inference session is *not* fault-tolerant out of the box
     """
@@ -58,7 +58,7 @@ class RemoteTransformerBlockInferenceSession:
     @classmethod
     async def _create(
         cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
-    ) -> RemoteTransformerBlockInferenceSession:
+    ) -> RemoteServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
         outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
@@ -165,7 +165,7 @@ class RemoteSequentialInferenceSession:
         self.closed = False
         self.chosen_spans: List[RemoteSpanInfo] = []
         self.stack = contextlib.ExitStack()
-        self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
+        self.inference_sessions: List[RemoteServerInferenceSession] = []
         self.metadata = metadata
         self.timeout = timeout
 
@@ -179,7 +179,7 @@ class RemoteSequentialInferenceSession:
             stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
             span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
             inference_session = RemoteExpertWorker.run_coroutine(
-                RemoteTransformerBlockInferenceSession._create(
+                RemoteServerInferenceSession._create(
                     stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
                 )
             )