Ver Fonte

Fix imports

Aleksandr Borzunov há 2 anos atrás
pai
commit
4f366f834e

+ 3 - 3
src/petals/cli/convert_model.py

@@ -7,11 +7,11 @@ import torch.nn as nn
 import transformers
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from huggingface_hub import Repository
-from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
-from src.client import DistributedBloomConfig
 from tqdm.auto import tqdm
 
-from src import BloomModel
+from petals.bloom import BloomModel
+from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
+from petals.client import DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 12 - 12
src/petals/server/server.py

@@ -15,18 +15,18 @@ from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
-from src.constants import PUBLIC_INITIAL_PEERS
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
-from src.dht_utils import get_remote_module_infos
-from src.server import block_selection
-from src.server.backend import TransformerBackend
-from src.server.cache import MemoryCache
-from src.server.handler import TransformerConnectionHandler
-from src.server.throughput import get_host_throughput
-from src.utils.convert_8bit import replace_8bit_linear
-
-from src import BloomConfig, declare_active_modules
+
+from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
+from petals.bloom.model import BloomConfig
+from petals.constants import PUBLIC_INITIAL_PEERS
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
+from petals.dht_utils import declare_active_modules, get_remote_module_infos
+from petals.server import block_selection
+from petals.server.backend import TransformerBackend
+from petals.server.cache import MemoryCache
+from petals.server.handler import TransformerConnectionHandler
+from petals.server.throughput import get_host_throughput
+from petals.utils.convert_8bit import replace_8bit_linear
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 4 - 3
src/petals/server/throughput.py

@@ -10,9 +10,10 @@ from typing import Dict, Union
 
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from src.bloom.block import BloomBlock
-from src.bloom.model import BloomConfig
-from src.bloom.ops import build_alibi_tensor
+
+from petals.bloom.block import BloomBlock
+from petals.bloom.model import BloomConfig
+from petals.bloom.ops import build_alibi_tensor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)