|
@@ -41,7 +41,7 @@ class RemoteSequential(nn.Sequential):
|
|
|
|
|
|
self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
|
|
|
logger.debug(f"Remote block uids: {self.block_uids}")
|
|
|
- self.remote_model_info = RemoteSequenceInfo(dht, self.block_uids)
|
|
|
+ self.remote_sequence_info = RemoteSequenceInfo(dht, self.block_uids)
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor):
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
|
|
@@ -63,7 +63,7 @@ class RemoteSequential(nn.Sequential):
|
|
|
|
|
|
def __getitem__(self, block_index: int):
|
|
|
assert 0 <= block_index < self.config.n_layer
|
|
|
- (module,) = _create_remote_modules_from_infos([self.block_infos[block_index]], self.p2p)
|
|
|
+ (module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
|
|
|
return module
|
|
|
|
|
|
def __iter__(self):
|
|
@@ -71,9 +71,8 @@ class RemoteSequential(nn.Sequential):
|
|
|
yield self[block_index]
|
|
|
|
|
|
def inference_session(self) -> RemoteSequentialInferenceSession:
|
|
|
- self.remote_model_info.update_()
|
|
|
- return RemoteSequentialInferenceSession(self.remote_model_info)
|
|
|
-
|
|
|
+ self.remote_sequence_info.update_()
|
|
|
+ return RemoteSequentialInferenceSession(self.remote_sequence_info)
|
|
|
|
|
|
|
|
|
|