Parcourir la source

fix reference

justheuristic il y a 3 ans
Parent
commit
a4bdce32c1
1 fichiers modifiés avec 4 ajouts et 5 suppressions
  1. 4 5
      src/client/remote_sequential.py

+ 4 - 5
src/client/remote_sequential.py

@@ -41,7 +41,7 @@ class RemoteSequential(nn.Sequential):
 
         self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
         logger.debug(f"Remote block uids: {self.block_uids}")
-        self.remote_model_info = RemoteSequenceInfo(dht, self.block_uids)
+        self.remote_sequence_info = RemoteSequenceInfo(dht, self.block_uids)
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@@ -63,7 +63,7 @@ class RemoteSequential(nn.Sequential):
 
     def __getitem__(self, block_index: int):
         assert 0 <= block_index < self.config.n_layer
-        (module,) = _create_remote_modules_from_infos([self.block_infos[block_index]], self.p2p)
+        (module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
         return module
 
     def __iter__(self):
@@ -71,9 +71,8 @@ class RemoteSequential(nn.Sequential):
             yield self[block_index]
 
     def inference_session(self) -> RemoteSequentialInferenceSession:
-        self.remote_model_info.update_()
-        return RemoteSequentialInferenceSession(self.remote_model_info)
-
+        self.remote_sequence_info.update_()
+        return RemoteSequentialInferenceSession(self.remote_sequence_info)