Aleksandr Borzunov 2 年之前
父節點
當前提交
384dc80919

+ 2 - 2
src/petals/cli/convert_model.py

@@ -7,11 +7,11 @@ import torch.nn as nn
 import transformers
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from huggingface_hub import Repository
+from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
+from src.client import DistributedBloomConfig
 from tqdm.auto import tqdm
 
 from src import BloomModel
-from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
-from src.client import DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 1 - 1
src/petals/dht_utils.py

@@ -12,7 +12,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
-from petals.client import DistributedBloomConfig, RemoteSequential, RemoteSequenceManager, RemoteTransformerBlock
+from petals.client import DistributedBloomConfig, RemoteSequenceManager, RemoteSequential, RemoteTransformerBlock
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
 use_hivemind_log_handler("in_root_logger")

+ 2 - 2
src/petals/server/server.py

@@ -15,8 +15,6 @@ from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-
-from src import BloomConfig, declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.constants import PUBLIC_INITIAL_PEERS
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
@@ -28,6 +26,8 @@ from src.server.handler import TransformerConnectionHandler
 from src.server.throughput import get_host_throughput
 from src.utils.convert_8bit import replace_8bit_linear
 
+from src import BloomConfig, declare_active_modules
+
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 

+ 0 - 1
src/petals/server/throughput.py

@@ -10,7 +10,6 @@ from typing import Dict, Union
 
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-
 from src.bloom.block import BloomBlock
 from src.bloom.model import BloomConfig
 from src.bloom.ops import build_alibi_tensor

+ 1 - 1
tests/test_block_exact_match.py

@@ -5,8 +5,8 @@ import pytest
 import torch
 from test_utils import *
 
-from petals.client import DistributedBloomConfig
 from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client import DistributedBloomConfig
 from petals.client.remote_sequential import RemoteTransformerBlock
 from petals.data_structures import UID_DELIMITER
 from petals.dht_utils import get_remote_module

+ 1 - 1
tests/test_remote_sequential.py

@@ -3,8 +3,8 @@ import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 from test_utils import *
 
-from petals.client import RemoteSequential
 from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client import RemoteSequential
 from petals.client.remote_model import DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")