|
@@ -14,8 +14,9 @@ from torch import nn
|
|
|
import src
|
|
|
from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
|
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
|
-from src.data_structures import UID_DELIMITER, RemoteSpanInfo
|
|
|
+from src.data_structures import UID_DELIMITER, RemoteSpanInfo, CHAIN_DELIMITER
|
|
|
from src.dht_utils import _create_remote_modules_from_infos
|
|
|
+from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -31,7 +32,6 @@ class RemoteSequential(nn.Module):
|
|
|
config: src.DistributedBloomConfig,
|
|
|
dht: DHT,
|
|
|
prefix: str,
|
|
|
- max_retries: int = 3,
|
|
|
p2p: Optional[P2P] = None,
|
|
|
sequence_manager: Optional[RemoteSequenceManager] = None,
|
|
|
):
|
|
@@ -47,13 +47,12 @@ class RemoteSequential(nn.Module):
|
|
|
self.config = config
|
|
|
self.dht = dht
|
|
|
self.prefix = prefix
|
|
|
- self.max_retries = max_retries
|
|
|
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
|
|
|
|
|
|
block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
|
|
|
if sequence_manager is None:
|
|
|
logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
|
|
|
- self.sequence_manager = RemoteSequenceManager(dht, block_uids)
|
|
|
+ self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
|
|
|
self.is_subsequence = False
|
|
|
else:
|
|
|
assert isinstance(sequence_manager.block_uids, list)
|
|
@@ -63,7 +62,7 @@ class RemoteSequential(nn.Module):
|
|
|
def forward(self, inputs: torch.Tensor):
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
|
|
|
for block_index in range(self.config.n_layer):
|
|
|
- for retry_index in range(self.max_retries):
|
|
|
+ for retry_index in range(self.sequence_manager.max_retries):
|
|
|
try:
|
|
|
block = self[block_index]
|
|
|
(outputs,) = block(inputs)
|
|
@@ -72,7 +71,7 @@ class RemoteSequential(nn.Module):
|
|
|
inputs = outputs
|
|
|
break
|
|
|
except Exception as e:
|
|
|
- if retry_index == self.max_retries - 1:
|
|
|
+ if retry_index == self.sequence_manager.max_retries - 1:
|
|
|
raise e
|
|
|
else:
|
|
|
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
|
|
@@ -109,13 +108,14 @@ class RemoteSequential(nn.Module):
|
|
|
class RemoteSequentialInferenceSession:
|
|
|
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
|
|
|
|
|
- def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P):
|
|
|
+ def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float]=None):
|
|
|
self.sequence_manager = sequence_manager
|
|
|
self.p2p = p2p
|
|
|
self.closed = False
|
|
|
self.chosen_spans: List[RemoteSpanInfo] = []
|
|
|
self.stack = contextlib.ExitStack()
|
|
|
self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
|
|
|
+ self.timeout = timeout
|
|
|
|
|
|
def __enter__(self):
|
|
|
assert not self.closed and not self.chosen_spans
|
|
@@ -124,17 +124,13 @@ class RemoteSequentialInferenceSession:
|
|
|
self.chosen_spans.extend(self.sequence_manager.make_sequence())
|
|
|
|
|
|
for chosen_span in self.chosen_spans:
|
|
|
- TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
|
|
-
|
|
|
- # TODO begin throwaway prototype code
|
|
|
- remote = RemoteTransformerBlock(self.sequence_manager.block_infos[current_block], self.p2p)
|
|
|
- _ = remote.info # TODO fix
|
|
|
- span_uids = self.sequence_manager.block_uids[current_block : chosen_span.end]
|
|
|
- remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
|
|
|
- self.inference_sessions.append(remote.inference_session())
|
|
|
- self.stack.enter_context(self.inference_sessions[-1])
|
|
|
- current_block = chosen_span.end
|
|
|
- # TODO end throwaway prototype code
|
|
|
+ stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
|
|
|
+ span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start: chosen_span.end])
|
|
|
+ inference_session = RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(
|
|
|
+ stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
|
|
|
+ ))
|
|
|
+ self.inference_sessions.append(inference_session)
|
|
|
+ self.stack.enter_context(inference_session)
|
|
|
|
|
|
return self
|
|
|
|