|
@@ -31,11 +31,11 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-class RemoteServerInferenceSession:
|
|
|
+class _ServerInferenceSession:
|
|
|
"""
|
|
|
An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
|
|
|
|
|
|
- :note: this inference session is *not* fault-tolerant out of the box
|
|
|
+ :note: This class is *not* fault-tolerant out of the box.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
@@ -50,8 +50,6 @@ class RemoteServerInferenceSession:
|
|
|
):
|
|
|
self.uid, self.rpc_info = uid, rpc_info
|
|
|
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
|
|
|
- # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
|
|
|
- # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
|
|
@@ -61,7 +59,7 @@ class RemoteServerInferenceSession:
|
|
|
@classmethod
|
|
|
async def create(
|
|
|
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
|
|
|
- ) -> RemoteServerInferenceSession:
|
|
|
+ ) -> _ServerInferenceSession:
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
outputs_stream = await asyncio.wait_for(
|
|
@@ -176,13 +174,13 @@ class RemoteSequentialInferenceSession:
|
|
|
self._position = 0
|
|
|
self._metadata = metadata
|
|
|
|
|
|
- def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[RemoteServerInferenceSession]:
|
|
|
+ def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
|
|
|
server_sessions = []
|
|
|
for span in chosen_spans:
|
|
|
stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
|
|
|
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
|
|
|
session = RemoteExpertWorker.run_coroutine(
|
|
|
- RemoteServerInferenceSession.create(
|
|
|
+ _ServerInferenceSession.create(
|
|
|
stub, span_uids, rpc_info=self._sequence_manager.rpc_info, timeout=self._sequence_manager.timeout,
|
|
|
**self._metadata
|
|
|
)
|
|
@@ -191,7 +189,7 @@ class RemoteSequentialInferenceSession:
|
|
|
session.__enter__()
|
|
|
return server_sessions
|
|
|
|
|
|
- def _exit_server_sessions(self, server_sessions: List[RemoteServerInferenceSession], *, verbose: bool) -> None:
|
|
|
+ def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession], *, verbose: bool) -> None:
|
|
|
exc_loglevel = logging.WARNING if verbose else logging.DEBUG
|
|
|
for session in reversed(server_sessions):
|
|
|
try:
|