|
@@ -10,7 +10,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.moe.expert_uid import ExpertInfo
|
|
|
from torch import nn
|
|
|
|
|
|
-from src import DistributedBloomConfig, RemoteTransformerBlock
|
|
|
+import src
|
|
|
+from src.client.remote_block import RemoteTransformerBlock
|
|
|
from src.client.remote_sequence_info import RemoteSequenceInfo
|
|
|
from src.data_structures import UID_DELIMITER
|
|
|
from src.dht_utils import _create_remote_modules_from_infos
|
|
@@ -24,7 +25,7 @@ class RemoteSequential(nn.Module):
|
|
|
A sequence of transformer blocks hosted by the swarm.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
|
|
|
+ def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
|
|
|
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
|
|
|
if prefix.endswith(UID_DELIMITER):
|
|
|
logger.warning(
|