|
@@ -20,7 +20,7 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-class RemoteSequential(nn.Sequential):
|
|
|
+class RemoteSequential(nn.Module):
|
|
|
"""
|
|
|
A sequence of transformer blocks hosted by the swarm.
|
|
|
"""
|
|
@@ -41,9 +41,9 @@ class RemoteSequential(nn.Sequential):
|
|
|
self.max_retries = max_retries
|
|
|
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
|
|
|
|
|
- self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
|
|
|
- logger.debug(f"Remote block uids: {self.block_uids}")
|
|
|
- self.remote_sequence_info = RemoteSequenceInfo(dht, self.block_uids)
|
|
|
+ block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
|
|
|
+ logger.debug(f"Remote block uids: {block_uids}")
|
|
|
+ self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor):
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
|
|
@@ -72,6 +72,9 @@ class RemoteSequential(nn.Sequential):
|
|
|
for block_index in range(self.config.n_layer):
|
|
|
yield self[block_index]
|
|
|
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.remote_sequence_info)
|
|
|
+
|
|
|
def inference_session(self) -> RemoteSequentialInferenceSession:
|
|
|
self.remote_sequence_info.update_()
|
|
|
return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
|