|
@@ -12,7 +12,7 @@ from torch import nn
|
|
|
|
|
|
import src
|
|
import src
|
|
from src.client.remote_block import RemoteTransformerBlock
|
|
from src.client.remote_block import RemoteTransformerBlock
|
|
-from src.client.remote_sequence_info import RemoteSequenceInfo
|
|
|
|
|
|
+from src.client.sequence_manager import RemoteSequenceManager
|
|
from src.data_structures import UID_DELIMITER
|
|
from src.data_structures import UID_DELIMITER
|
|
from src.dht_utils import _create_remote_modules_from_infos
|
|
from src.dht_utils import _create_remote_modules_from_infos
|
|
|
|
|
|
@@ -44,7 +44,7 @@ class RemoteSequential(nn.Module):
|
|
block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
|
|
block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
|
|
|
|
|
|
logger.debug(f"Remote block uids: {block_uids}")
|
|
logger.debug(f"Remote block uids: {block_uids}")
|
|
- self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
|
|
|
|
|
|
+ self.remote_sequence_info = RemoteSequenceManager(dht, block_uids)
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor):
|
|
def forward(self, inputs: torch.Tensor):
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
|
|
@@ -84,7 +84,7 @@ class RemoteSequential(nn.Module):
|
|
class RemoteSequentialInferenceSession:
|
|
class RemoteSequentialInferenceSession:
|
|
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
|
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
|
|
|
|
|
- def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
|
|
|
|
|
|
+ def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
|
|
self.remote_sequence_info = remote_sequence_info
|
|
self.remote_sequence_info = remote_sequence_info
|
|
self.p2p = p2p
|
|
self.p2p = p2p
|
|
self.closed = False
|
|
self.closed = False
|