|
@@ -28,9 +28,9 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-class RemoteTransformerBlockInferenceSession:
|
|
|
+class RemoteServerInferenceSession:
|
|
|
"""
|
|
|
- An interface to a single multi-step *inference* session for a specific remote module on a specific server
|
|
|
+ An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
|
|
|
|
|
|
:note: this inference session is *not* fault-tolerant out of the box
|
|
|
"""
|
|
@@ -58,7 +58,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
@classmethod
|
|
|
async def _create(
|
|
|
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
|
|
|
- ) -> RemoteTransformerBlockInferenceSession:
|
|
|
+ ) -> RemoteServerInferenceSession:
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
|
|
@@ -165,7 +165,7 @@ class RemoteSequentialInferenceSession:
|
|
|
self.closed = False
|
|
|
self.chosen_spans: List[RemoteSpanInfo] = []
|
|
|
self.stack = contextlib.ExitStack()
|
|
|
- self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
|
|
|
+ self.inference_sessions: List[RemoteServerInferenceSession] = []
|
|
|
self.metadata = metadata
|
|
|
self.timeout = timeout
|
|
|
|
|
@@ -179,7 +179,7 @@ class RemoteSequentialInferenceSession:
|
|
|
stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
|
|
|
span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
|
|
|
inference_session = RemoteExpertWorker.run_coroutine(
|
|
|
- RemoteTransformerBlockInferenceSession._create(
|
|
|
+ RemoteServerInferenceSession._create(
|
|
|
stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
|
|
|
)
|
|
|
)
|