Jelajahi Sumber

Fix imports

Aleksandr Borzunov 2 tahun lalu
induk
melakukan
59db85174e

+ 2 - 3
README.md

@@ -140,11 +140,10 @@ Once your have enough servers, you can use them to train and/or inference the mo
 ```python
 import torch
 import torch.nn.functional as F
-from petals.ansformers
-from src import DistributedBloomForCausalLM
+from petals import BloomTokenizerFast, DistributedBloomForCausalLM
 
 initial_peers = [TODO_put_one_or_more_server_addresses_here]  # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
-tokenizer = transformers.BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
+tokenizer = BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
 model = DistributedBloomForCausalLM.from_pretrained(
   "bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
 )  # this model has only embeddings / logits, all transformer blocks rely on remote servers

+ 1 - 1
cli/convert_model.py

@@ -9,7 +9,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from huggingface_hub import Repository
 from tqdm.auto import tqdm
 
-from petals.import BloomModel
+from petals import BloomModel
 from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
 from petals.client import DistributedBloomConfig
 

+ 2 - 2
petals/client/remote_sequential.py

@@ -7,7 +7,7 @@ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
-import src
+import petals
 from petals.client.inference_session import InferenceSession
 from petals.client.sequence_manager import RemoteSequenceManager
 from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
@@ -25,7 +25,7 @@ class RemoteSequential(nn.Module):
 
     def __init__(
         self,
-        config: src.DistributedBloomConfig,
+        config: petals.DistributedBloomConfig,
         dht: DHT,
         dht_prefix: Optional[str] = None,
         p2p: Optional[P2P] = None,

+ 13 - 13
petals/dht_utils.py

@@ -12,7 +12,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
-import src
+import petals
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
 use_hivemind_log_handler("in_root_logger")
@@ -76,10 +76,10 @@ def get_remote_sequence(
     dht: DHT,
     start: int,
     stop: int,
-    config: src.DistributedBloomConfig,
+    config: petals.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[src.RemoteSequential, MPFuture]:
+) -> Union[petals.RemoteSequential, MPFuture]:
     return RemoteExpertWorker.run_coroutine(
         _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
     )
@@ -89,22 +89,22 @@ async def _get_remote_sequence(
     dht: DHT,
     start: int,
     stop: int,
-    config: src.DistributedBloomConfig,
+    config: petals.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
-) -> src.RemoteSequential:
+) -> petals.RemoteSequential:
     uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
     p2p = await dht.replicate_p2p()
-    manager = src.RemoteSequenceManager(dht, uids, p2p)
-    return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
+    manager = petals.RemoteSequenceManager(dht, uids, p2p)
+    return petals.RemoteSequential(config, dht, dht_prefix, p2p, manager)
 
 
 def get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: src.DistributedBloomConfig,
+    config: petals.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
+) -> Union[Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]], MPFuture]:
     """
     :param uid_or_uids: find one or more modules with these ids from across the DHT
     :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
@@ -119,15 +119,15 @@ def get_remote_module(
 async def _get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: src.DistributedBloomConfig,
+    config: petals.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
-) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
+) -> Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]]:
     single_uid = isinstance(uid_or_uids, ModuleUID)
     uids = [uid_or_uids] if single_uid else uid_or_uids
     p2p = await dht.replicate_p2p()
-    managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+    managers = (petals.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
     modules = [
-        src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
+        petals.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
     ]
     return modules[0] if single_uid else modules
 

+ 1 - 1
petals/server/server.py

@@ -16,7 +16,7 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from petals.import BloomConfig, declare_active_modules
+from petals import BloomConfig, declare_active_modules
 from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState

+ 1 - 1
petals/server/throughput.py

@@ -11,7 +11,7 @@ from typing import Dict, Union
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from petals.import project_name
+from petals import project_name
 from petals.bloom.block import BloomBlock
 from petals.bloom.model import BloomConfig
 from petals.bloom.ops import build_alibi_tensor

+ 3 - 3
petals/src/__init__.py

@@ -1,6 +1,6 @@
-from src.bloom import *
-from src.client import *
-from src.dht_utils import declare_active_modules, get_remote_module
+from petals.bloom import *
+from petals.client import *
+from petals.dht_utils import declare_active_modules, get_remote_module
 
 project_name = "bloomd"
 __version__ = "0.2"

+ 2 - 2
petals/src/bloom/__init__.py

@@ -1,2 +1,2 @@
-from src.bloom.block import BloomBlock
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
+from petals.bloom.block import BloomBlock
+from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel

+ 1 - 1
petals/src/bloom/block.py

@@ -9,7 +9,7 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (
+from petals.bloom.ops import (
     BloomGelu,
     BloomScaledSoftmax,
     attention_mask_func,

+ 1 - 1
petals/src/bloom/from_pretrained.py

@@ -15,7 +15,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from transformers.modeling_utils import WEIGHTS_NAME
 from transformers.utils.hub import cached_path, hf_bucket_url
 
-from src.bloom import BloomBlock, BloomConfig
+from petals.bloom import BloomBlock, BloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 1 - 1
petals/src/bloom/model.py

@@ -26,7 +26,7 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
 from transformers.utils import logging
 
-from src.bloom.block import BloomBlock
+from petals.bloom.block import BloomBlock
 
 use_hivemind_log_handler("in_root_logger")
 logger = logging.get_logger(__file__)

+ 5 - 5
petals/src/client/__init__.py

@@ -1,5 +1,5 @@
-from src.client.inference_session import InferenceSession
-from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
-from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
-from src.client.sequence_manager import RemoteSequenceManager
-from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
+from petals.client.inference_session import InferenceSession
+from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
+from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 4 - 4
petals/src/client/inference_session.py

@@ -20,10 +20,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
 
-from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
-from src.server.handler import TransformerConnectionHandler
-from src.utils.misc import DUMMY, is_dummy
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from petals.server.handler import TransformerConnectionHandler
+from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__file__)
 

+ 1 - 1
petals/src/client/remote_forward_backward.py

@@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
 from hivemind.utils.streaming import split_for_streaming
 
-from src.data_structures import ModuleUID, RPCInfo
+from petals.data_structures import ModuleUID, RPCInfo
 
 
 async def _forward_unary(

+ 2 - 2
petals/src/client/remote_generation.py

@@ -3,14 +3,14 @@ from typing import List, Optional
 import torch
 from hivemind.utils.logging import get_logger
 
-from src.utils.generation_algorithms import (
+from petals.utils.generation_algorithms import (
     BeamSearchAlgorithm,
     DecodingAlgorithm,
     GreedyAlgorithm,
     NucleusAlgorithm,
     TopKAlgorithm,
 )
-from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
+from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
 
 logger = get_logger(__file__)
 

+ 5 - 5
petals/src/client/remote_model.py

@@ -7,7 +7,7 @@ import torch.nn as nn
 from hivemind import get_logger, use_hivemind_log_handler
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
 
-from src.bloom.model import (
+from petals.bloom.model import (
     BloomConfig,
     BloomForCausalLM,
     BloomForSequenceClassification,
@@ -15,10 +15,10 @@ from src.bloom.model import (
     BloomPreTrainedModel,
     LMHead,
 )
-from src.client.remote_generation import RemoteGenerationMixin
-from src.client.remote_sequential import RemoteSequential
-from src.constants import PUBLIC_INITIAL_PEERS
-from src.utils.misc import DUMMY
+from petals.client.remote_generation import RemoteGenerationMixin
+from petals.client.remote_sequential import RemoteSequential
+from petals.constants import PUBLIC_INITIAL_PEERS
+from petals.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 7 - 7
petals/src/client/remote_sequential.py

@@ -7,12 +7,12 @@ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
-import src
-from src.client.inference_session import InferenceSession
-from src.client.sequence_manager import RemoteSequenceManager
-from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
-from src.data_structures import UID_DELIMITER
-from src.utils.misc import DUMMY
+import petals
+from petals.client.inference_session import InferenceSession
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
+from petals.data_structures import UID_DELIMITER
+from petals.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -25,7 +25,7 @@ class RemoteSequential(nn.Module):
 
     def __init__(
         self,
-        config: src.DistributedBloomConfig,
+        config: petals.DistributedBloomConfig,
         dht: DHT,
         dht_prefix: Optional[str] = None,
         p2p: Optional[P2P] = None,

+ 4 - 4
petals/src/client/sequence_manager.py

@@ -9,10 +9,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src.client.spending_policy import NoSpendingPolicy
-from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
-from src.dht_utils import get_remote_module_infos
-from src.server.handler import TransformerConnectionHandler
+from petals.client.spending_policy import NoSpendingPolicy
+from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
+from petals.dht_utils import get_remote_module_infos
+from petals.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 5 - 5
petals/src/client/sequential_autograd.py

@@ -11,11 +11,11 @@ import torch
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.utils.logging import get_logger
 
-from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
-from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
-from src.server.handler import TransformerConnectionHandler
-from src.utils.misc import DUMMY, is_dummy
+from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
+from petals.server.handler import TransformerConnectionHandler
+from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__file__)
 

+ 14 - 14
petals/src/dht_utils.py

@@ -12,8 +12,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
-import src
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
+import petals
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -76,10 +76,10 @@ def get_remote_sequence(
     dht: DHT,
     start: int,
     stop: int,
-    config: src.DistributedBloomConfig,
+    config: petals.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[src.RemoteSequential, MPFuture]:
+) -> Union[petals.RemoteSequential, MPFuture]:
     return RemoteExpertWorker.run_coroutine(
         _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
     )
@@ -89,22 +89,22 @@ async def _get_remote_sequence(
     dht: DHT,
     start: int,
     stop: int,
-    config: src.DistributedBloomConfig,
+    config: petals.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
-) -> src.RemoteSequential:
+) -> petals.RemoteSequential:
     uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
     p2p = await dht.replicate_p2p()
-    manager = src.RemoteSequenceManager(dht, uids, p2p)
-    return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
+    manager = petals.RemoteSequenceManager(dht, uids, p2p)
+    return petals.RemoteSequential(config, dht, dht_prefix, p2p, manager)
 
 
 def get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: src.DistributedBloomConfig,
+    config: petals.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
+) -> Union[Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]], MPFuture]:
     """
     :param uid_or_uids: find one or more modules with these ids from across the DHT
     :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
@@ -119,15 +119,15 @@ def get_remote_module(
 async def _get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: src.DistributedBloomConfig,
+    config: petals.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
-) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
+) -> Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]]:
     single_uid = isinstance(uid_or_uids, ModuleUID)
     uids = [uid_or_uids] if single_uid else uid_or_uids
     p2p = await dht.replicate_p2p()
-    managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+    managers = (petals.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
     modules = [
-        src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
+        petals.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
     ]
     return modules[0] if single_uid else modules
 

+ 4 - 4
petals/src/server/backend.py

@@ -6,10 +6,10 @@ from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 
-from src.bloom.from_pretrained import BloomBlock
-from src.server.cache import MemoryCache
-from src.server.task_pool import PrioritizedTaskPool
-from src.utils.misc import is_dummy
+from petals.bloom.from_pretrained import BloomBlock
+from petals.server.cache import MemoryCache
+from petals.server.task_pool import PrioritizedTaskPool
+from petals.utils.misc import is_dummy
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 1 - 1
petals/src/server/block_selection.py

@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple
 import numpy as np
 from hivemind import PeerID, get_logger
 
-from src.data_structures import RemoteModuleInfo, ServerState
+from petals.data_structures import RemoteModuleInfo, ServerState
 
 __all__ = ["choose_best_blocks", "should_choose_other_blocks"]
 

+ 5 - 5
petals/src/server/handler.py

@@ -21,11 +21,11 @@ from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
 from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
-from src.data_structures import CHAIN_DELIMITER, ModuleUID
-from src.server.backend import TransformerBackend
-from src.server.task_pool import PrioritizedTaskPool
-from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
-from src.utils.misc import DUMMY, is_dummy
+from petals.data_structures import CHAIN_DELIMITER, ModuleUID
+from petals.server.backend import TransformerBackend
+from petals.server.task_pool import PrioritizedTaskPool
+from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
+from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__file__)
 

+ 11 - 11
petals/src/server/server.py

@@ -16,17 +16,17 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src import BloomConfig, declare_active_modules
-from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
-from src.constants import PUBLIC_INITIAL_PEERS
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
-from src.dht_utils import get_remote_module_infos
-from src.server import block_selection
-from src.server.backend import TransformerBackend
-from src.server.cache import MemoryCache
-from src.server.handler import TransformerConnectionHandler
-from src.server.throughput import get_host_throughput
-from src.utils.convert_8bit import replace_8bit_linear
+from petals import BloomConfig, declare_active_modules
+from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
+from petals.constants import PUBLIC_INITIAL_PEERS
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
+from petals.dht_utils import get_remote_module_infos
+from petals.server import block_selection
+from petals.server.backend import TransformerBackend
+from petals.server.cache import MemoryCache
+from petals.server.handler import TransformerConnectionHandler
+from petals.server.throughput import get_host_throughput
+from petals.utils.convert_8bit import replace_8bit_linear
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 4 - 4
petals/src/server/throughput.py

@@ -11,10 +11,10 @@ from typing import Dict, Union
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src import project_name
-from src.bloom.block import BloomBlock
-from src.bloom.model import BloomConfig
-from src.bloom.ops import build_alibi_tensor
+from petals import project_name
+from petals.bloom.block import BloomBlock
+from petals.bloom.model import BloomConfig
+from petals.bloom.ops import build_alibi_tensor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 2 - 2
tests/test_block_exact_match.py

@@ -7,8 +7,8 @@ import transformers
 from hivemind import P2PHandlerError
 from test_utils import *
 
-import src
-from petals.import DistributedBloomConfig
+import petals
+from petals import DistributedBloomConfig
 from petals.bloom.from_pretrained import load_pretrained_block
 from petals.client.remote_sequential import RemoteTransformerBlock
 from petals.data_structures import UID_DELIMITER

+ 3 - 3
tests/test_chained_calls.py

@@ -9,7 +9,7 @@ import pytest
 import torch
 from test_utils import *
 
-import src
+import petals
 from petals.bloom.from_pretrained import load_pretrained_block
 from petals.client.remote_sequential import RemoteSequential
 from petals.dht_utils import get_remote_sequence
@@ -18,7 +18,7 @@ from petals.dht_utils import get_remote_sequence
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    config = petals.DistributedBloomConfig.from_pretrained(MODEL_NAME)
     remote_blocks = get_remote_sequence(dht, 3, 6, config)
     assert isinstance(remote_blocks, RemoteSequential)
 
@@ -47,7 +47,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 @pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    config = petals.DistributedBloomConfig.from_pretrained(MODEL_NAME)
     remote_blocks = get_remote_sequence(dht, 3, 5, config)
     assert isinstance(remote_blocks, RemoteSequential)
 

+ 1 - 1
tests/test_remote_sequential.py

@@ -3,7 +3,7 @@ import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 from test_utils import *
 
-from petals.import RemoteSequential
+from petals import RemoteSequential
 from petals.bloom.from_pretrained import load_pretrained_block
 from petals.client.remote_model import DistributedBloomConfig