|
@@ -8,7 +8,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from torch import nn
|
|
|
|
|
|
import src
|
|
|
-from src.client.inference_session import RemoteSequentialInferenceSession
|
|
|
+from src.client.inference_session import InferenceSession
|
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
|
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
|
|
|
from src.data_structures import UID_DELIMITER
|
|
@@ -80,9 +80,9 @@ class RemoteSequential(nn.Module):
|
|
|
def __len__(self):
|
|
|
return len(self.sequence_manager)
|
|
|
|
|
|
- def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession:
|
|
|
+ def inference_session(self, **kwargs) -> InferenceSession:
|
|
|
self.sequence_manager.update_()
|
|
|
- return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs)
|
|
|
+ return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
|
|
|
|
|
|
def extra_repr(self) -> str:
|
|
|
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
|