Переглянути джерело

Don't import the client into top-level petals package

Aleksandr Borzunov 2 роки тому
батько
коміт
43e13e1a12

+ 1 - 1
README.md

@@ -140,7 +140,7 @@ Once your have enough servers, you can use them to train and/or inference the mo
 ```python
 import torch
 import torch.nn.functional as F
-from petals import BloomTokenizerFast, DistributedBloomForCausalLM
+from petals.client import BloomTokenizerFast, DistributedBloomForCausalLM
 
 initial_peers = [TODO_put_one_or_more_server_addresses_here]  # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
 tokenizer = BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")

+ 0 - 4
src/petals/__init__.py

@@ -1,5 +1 @@
-from petals.bloom import *
-from petals.client import *
-from petals.dht_utils import declare_active_modules, get_remote_module
-
 __version__ = "1.0alpha1"

+ 1 - 1
src/petals/server/server.py

@@ -16,8 +16,8 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from petals import BloomConfig, declare_active_modules
 from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
+from petals.client import BloomConfig, declare_active_modules
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.dht_utils import get_remote_module_infos

+ 1 - 2
tests/test_block_exact_match.py

@@ -7,8 +7,7 @@ import transformers
 from hivemind import P2PHandlerError
 from test_utils import *
 
-import petals
-from petals import DistributedBloomConfig
+from petals.client import DistributedBloomConfig
 from petals.bloom.from_pretrained import load_pretrained_block
 from petals.client.remote_sequential import RemoteTransformerBlock
 from petals.data_structures import UID_DELIMITER

+ 1 - 1
tests/test_remote_sequential.py

@@ -3,8 +3,8 @@ import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 from test_utils import *
 
-from petals import RemoteSequential
 from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client import RemoteSequential
 from petals.client.remote_model import DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")