Ver código fonte

fix import order

justheuristic 3 anos atrás
pai
commit
c792f50b9c

+ 4 - 4
src/__init__.py

@@ -1,5 +1,5 @@
-from .bloom import *
-from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
-from .dht_utils import declare_active_modules, get_remote_module
+from src.bloom import *
+from src.client import *
+from src.dht_utils import declare_active_modules, get_remote_module
 
-__version__ = "0.1"
+__version__ = "0.2"

+ 3 - 1
src/bloom/__init__.py

@@ -1 +1,3 @@
-from src.bloom.model import BloomBlock, BloomForYou, BloomModel, DistributedBloomConfig
+from src.bloom import *
+from src.client import *
+from src.server import *

+ 3 - 0
src/client/__init__.py

@@ -1 +1,4 @@
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
+from src.client.remote_model import DistributedBloomConfig, DistributedBloomModel, DistributedBloomForCausalLM
+from src.client.remote_sequence_info import RemoteSequenceInfo
+from src.client.remote_sequential import RemoteSequential

+ 3 - 2
src/client/remote_sequential.py

@@ -10,7 +10,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from torch import nn
 
-from src import DistributedBloomConfig, RemoteTransformerBlock
+import src
+from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_sequence_info import RemoteSequenceInfo
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
@@ -24,7 +25,7 @@ class RemoteSequential(nn.Module):
     A sequence of transformer blocks hosted by the swarm.
     """
 
-    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
+    def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
         if prefix.endswith(UID_DELIMITER):
             logger.warning(