Ver Fonte

Rename RemoteTransformerBlockInferenceSession => _ServerInferenceSession

Aleksandr Borzunov há 2 anos atrás
pai
commit
a232f13869
2 ficheiros alterados com 7 adições e 9 exclusões
  1. 1 1
      src/client/__init__.py
  2. 6 8
      src/client/inference_session.py

+ 1 - 1
src/client/__init__.py

@@ -1,4 +1,4 @@
-from src.client.inference_session import RemoteSequentialInferenceSession, RemoteServerInferenceSession
+from src.client.inference_session import RemoteSequentialInferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager

+ 6 - 8
src/client/inference_session.py

@@ -31,11 +31,11 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class RemoteServerInferenceSession:
+class _ServerInferenceSession:
     """
     An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
 
-    :note: this inference session is *not* fault-tolerant out of the box
+    :note: This class is *not* fault-tolerant out of the box.
     """
 
     def __init__(
@@ -50,8 +50,6 @@ class RemoteServerInferenceSession:
     ):
         self.uid, self.rpc_info = uid, rpc_info
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
-        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
-        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
         self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
@@ -61,7 +59,7 @@ class RemoteServerInferenceSession:
     @classmethod
     async def create(
         cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
-    ) -> RemoteServerInferenceSession:
+    ) -> _ServerInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
         outputs_stream = await asyncio.wait_for(
@@ -176,13 +174,13 @@ class RemoteSequentialInferenceSession:
         self._position = 0
         self._metadata = metadata
 
-    def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[RemoteServerInferenceSession]:
+    def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
         server_sessions = []
         for span in chosen_spans:
             stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
             span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
             session = RemoteExpertWorker.run_coroutine(
-                RemoteServerInferenceSession.create(
+                _ServerInferenceSession.create(
                     stub, span_uids, rpc_info=self._sequence_manager.rpc_info, timeout=self._sequence_manager.timeout,
                     **self._metadata
                 )
@@ -191,7 +189,7 @@ class RemoteSequentialInferenceSession:
             session.__enter__()
         return server_sessions
 
-    def _exit_server_sessions(self, server_sessions: List[RemoteServerInferenceSession], *, verbose: bool) -> None:
+    def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession], *, verbose: bool) -> None:
         exc_loglevel = logging.WARNING if verbose else logging.DEBUG
         for session in reversed(server_sessions):
             try: