|
@@ -1,7 +1,9 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
-import contextlib
|
|
|
+import itertools
|
|
|
+import logging
|
|
|
+import time
|
|
|
from typing import AsyncIterator, List, Optional
|
|
|
|
|
|
import torch
|
|
@@ -18,6 +20,7 @@ from hivemind import (
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.p2p import StubBase
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
+from hivemind.utils.asyncio import aiter_with_timeout
|
|
|
|
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
|
@@ -56,18 +59,22 @@ class RemoteServerInferenceSession:
|
|
|
self.closed = False
|
|
|
|
|
|
@classmethod
|
|
|
- async def _create(
|
|
|
- cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
|
|
|
+ async def create(
|
|
|
+ cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
|
|
|
) -> RemoteServerInferenceSession:
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
- outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
|
|
|
+ outputs_stream = await asyncio.wait_for(
|
|
|
+ stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
|
|
+ timeout,
|
|
|
+ )
|
|
|
+ outputs_stream = aiter_with_timeout(outputs_stream, timeout)
|
|
|
return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
|
|
|
|
|
|
@staticmethod
|
|
|
- async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
|
|
|
+ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
|
|
while True:
|
|
|
- next_input_message = await asyncio.wait_for(queue.get(), timeout)
|
|
|
+ next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
|
|
|
yield next_input_message
|
|
|
if not next_input_message.uid and not next_input_message.tensors:
|
|
|
break # this message means "done sending"
|
|
@@ -159,33 +166,39 @@ class RemoteSequentialInferenceSession:
|
|
|
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
|
|
|
+ def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, **metadata):
|
|
|
self.sequence_manager = sequence_manager
|
|
|
self.p2p = p2p
|
|
|
self.closed = False
|
|
|
- self.chosen_spans: List[RemoteSpanInfo] = []
|
|
|
- self.stack = contextlib.ExitStack()
|
|
|
- self.inference_sessions: List[RemoteServerInferenceSession] = []
|
|
|
+ self.chosen_spans = []
|
|
|
+ self.server_sessions = []
|
|
|
self.metadata = metadata
|
|
|
- self.timeout = timeout
|
|
|
|
|
|
- def __enter__(self):
|
|
|
- assert not self.closed and not self.chosen_spans
|
|
|
- self.stack.__enter__()
|
|
|
- # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
|
|
|
- self.chosen_spans.extend(self.sequence_manager.make_sequence())
|
|
|
-
|
|
|
- for chosen_span in self.chosen_spans:
|
|
|
- stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
|
|
|
- span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
|
|
|
- inference_session = RemoteExpertWorker.run_coroutine(
|
|
|
- RemoteServerInferenceSession._create(
|
|
|
- stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
|
|
|
+ def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[RemoteServerInferenceSession]:
|
|
|
+ server_sessions = []
|
|
|
+ for span in chosen_spans:
|
|
|
+ stub = TransformerConnectionHandler.get_stub(self.p2p, span.peer_id)
|
|
|
+ span_uids = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[span.start : span.end])
|
|
|
+ session = RemoteExpertWorker.run_coroutine(
|
|
|
+ RemoteServerInferenceSession.create(
|
|
|
+ stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.sequence_manager.timeout,
|
|
|
+ **self.metadata
|
|
|
)
|
|
|
)
|
|
|
- self.inference_sessions.append(inference_session)
|
|
|
- self.stack.enter_context(inference_session)
|
|
|
+ server_sessions.append(session)
|
|
|
+ session.__enter__()
|
|
|
+ return server_sessions
|
|
|
|
|
|
+ def _exit_server_sessions(self, server_sessions: List[RemoteServerInferenceSession], *, verbose: bool) -> None:
|
|
|
+ exc_loglevel = logging.WARNING if verbose else logging.DEBUG
|
|
|
+ for session in reversed(server_sessions):
|
|
|
+ try:
|
|
|
+ session.__exit__(None, None, None)
|
|
|
+ except Exception:
|
|
|
+ logger.log(exc_loglevel, "Caught exception while closing connection to server:", exc_info=True)
|
|
|
+
|
|
|
+ def __enter__(self):
|
|
|
+ assert not self.closed and not self.chosen_spans
|
|
|
return self
|
|
|
|
|
|
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
|
|
@@ -196,17 +209,49 @@ class RemoteSequentialInferenceSession:
|
|
|
prompts = DUMMY
|
|
|
else:
|
|
|
assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
|
|
|
- for session in self.inference_sessions:
|
|
|
- outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
|
|
|
- assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
- inputs = outputs
|
|
|
+
|
|
|
+ server_idx = 0
|
|
|
+ block_idx = 0
|
|
|
+ while block_idx < len(self.sequence_manager):
|
|
|
+ for attempt_no in itertools.count():
|
|
|
+ logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
|
|
|
+ try:
|
|
|
+ if not self.chosen_spans or not self.server_sessions or attempt_no >= 1:
|
|
|
+ self._exit_server_sessions(self.server_sessions[server_idx:], verbose=False)
|
|
|
+ self.server_sessions[server_idx:] = []
|
|
|
+ self.chosen_spans[server_idx:] = []
|
|
|
+
|
|
|
+ self.sequence_manager.update_()
|
|
|
+ backup_spans = self.sequence_manager.make_sequence(block_idx)
|
|
|
+ self.chosen_spans.extend(backup_spans)
|
|
|
+ self.server_sessions.extend(self._enter_server_sessions(backup_spans))
|
|
|
+ logger.debug(f"Found path from block {block_idx} via {len(backup_spans)} servers")
|
|
|
+
|
|
|
+ session = self.server_sessions[server_idx]
|
|
|
+ span = self.chosen_spans[server_idx]
|
|
|
+
|
|
|
+ outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
|
+ assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
+ inputs = outputs
|
|
|
+
|
|
|
+ server_idx += 1
|
|
|
+ block_idx = span.end
|
|
|
+ break
|
|
|
+ except Exception as e:
|
|
|
+ delay = self.sequence_manager.min_backoff * 2**attempt_no
|
|
|
+ logger.warning(
|
|
|
+ f"Caught exception when running inference from block {block_idx} "
|
|
|
+ f"(retry in {delay:.2f} sec): {repr(e)}"
|
|
|
+ )
|
|
|
+ logger.debug("See detailed traceback below:", exc_info=True)
|
|
|
+ time.sleep(delay)
|
|
|
return inputs
|
|
|
|
|
|
def close(self, *exc_details):
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
if not self.closed:
|
|
|
- self.stack.__exit__(*exc_details or (None, None, None))
|
|
|
- self.inference_sessions.clear()
|
|
|
+ self._exit_server_sessions(self.server_sessions, verbose=True)
|
|
|
+ self.server_sessions.clear()
|
|
|
self.closed = True
|
|
|
|
|
|
def __exit__(self, *exc_details):
|