Aleksandr Borzunov пре 2 година
родитељ
комит
ef1d73477e

+ 2 - 2
src/petals/bloom/__init__.py

@@ -1,2 +1,2 @@
-from src.bloom.block import BloomBlock
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
+from petals.bloom.block import BloomBlock
+from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel

+ 1 - 1
src/petals/bloom/block.py

@@ -9,7 +9,7 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (
+from petals.bloom.ops import (
     BloomGelu,
     BloomScaledSoftmax,
     attention_mask_func,

+ 1 - 1
src/petals/bloom/from_pretrained.py

@@ -15,7 +15,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from transformers.modeling_utils import WEIGHTS_NAME
 from transformers.utils.hub import cached_path, hf_bucket_url
 
-from src.bloom import BloomBlock, BloomConfig
+from petals.bloom import BloomBlock, BloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 1 - 1
src/petals/bloom/model.py

@@ -26,7 +26,7 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
 from transformers.utils import logging
 
-from src.bloom.block import BloomBlock
+from petals.bloom.block import BloomBlock
 
 use_hivemind_log_handler("in_root_logger")
 logger = logging.get_logger(__file__)

+ 3 - 3
src/petals/cli/inference_one_block.py

@@ -4,9 +4,9 @@ import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tqdm.auto import trange
 
-from src.bloom.block import BloomBlock
-from src.bloom.model import BloomConfig
-from src.bloom.ops import build_alibi_tensor
+from petals.bloom.block import BloomBlock
+from petals.bloom.model import BloomConfig
+from petals.bloom.ops import build_alibi_tensor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 2 - 2
src/petals/cli/run_server.py

@@ -6,8 +6,8 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from humanfriendly import parse_size
 
-from src.constants import PUBLIC_INITIAL_PEERS
-from src.server.server import Server
+from petals.constants import PUBLIC_INITIAL_PEERS
+from petals.server.server import Server
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 5 - 5
src/petals/client/__init__.py

@@ -1,5 +1,5 @@
-from src.client.inference_session import InferenceSession
-from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
-from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
-from src.client.sequence_manager import RemoteSequenceManager
-from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
+from petals.client.inference_session import InferenceSession
+from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
+from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 4 - 4
src/petals/client/inference_session.py

@@ -20,10 +20,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
 
-from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
-from src.server.handler import TransformerConnectionHandler
-from src.utils.misc import DUMMY, is_dummy
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from petals.server.handler import TransformerConnectionHandler
+from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__file__)
 

+ 1 - 1
src/petals/client/remote_forward_backward.py

@@ -13,7 +13,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
 from hivemind.utils.streaming import split_for_streaming
 
-from src.data_structures import ModuleUID, RPCInfo
+from petals.data_structures import ModuleUID, RPCInfo
 
 
 async def _forward_unary(

+ 2 - 2
src/petals/client/remote_generation.py

@@ -3,14 +3,14 @@ from typing import List, Optional
 import torch
 from hivemind.utils.logging import get_logger
 
-from src.utils.generation_algorithms import (
+from petals.utils.generation_algorithms import (
     BeamSearchAlgorithm,
     DecodingAlgorithm,
     GreedyAlgorithm,
     NucleusAlgorithm,
     TopKAlgorithm,
 )
-from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
+from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
 
 logger = get_logger(__file__)
 

+ 5 - 5
src/petals/client/remote_model.py

@@ -7,7 +7,7 @@ import torch.nn as nn
 from hivemind import get_logger, use_hivemind_log_handler
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
 
-from src.bloom.model import (
+from petals.bloom.model import (
     BloomConfig,
     BloomForCausalLM,
     BloomForSequenceClassification,
@@ -15,10 +15,10 @@ from src.bloom.model import (
     BloomPreTrainedModel,
     LMHead,
 )
-from src.client.remote_generation import RemoteGenerationMixin
-from src.client.remote_sequential import RemoteSequential
-from src.constants import PUBLIC_INITIAL_PEERS
-from src.utils.misc import DUMMY
+from petals.client.remote_generation import RemoteGenerationMixin
+from petals.client.remote_sequential import RemoteSequential
+from petals.constants import PUBLIC_INITIAL_PEERS
+from petals.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 7 - 7
src/petals/client/remote_sequential.py

@@ -7,12 +7,12 @@ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
-import src
-from src.client.inference_session import InferenceSession
-from src.client.sequence_manager import RemoteSequenceManager
-from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
-from src.data_structures import UID_DELIMITER
-from src.utils.misc import DUMMY
+from petals.client.inference_session import InferenceSession
+from petals.client.remote_model import DistributedBloomConfig
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
+from petals.data_structures import UID_DELIMITER
+from petals.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -25,7 +25,7 @@ class RemoteSequential(nn.Module):
 
     def __init__(
         self,
-        config: src.DistributedBloomConfig,
+        config: DistributedBloomConfig,
         dht: DHT,
         dht_prefix: Optional[str] = None,
         p2p: Optional[P2P] = None,

+ 4 - 4
src/petals/client/sequence_manager.py

@@ -9,10 +9,10 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src.client.spending_policy import NoSpendingPolicy
-from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
-from src.dht_utils import get_remote_module_infos
-from src.server.handler import TransformerConnectionHandler
+from petals.client.spending_policy import NoSpendingPolicy
+from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
+from petals.dht_utils import get_remote_module_infos
+from petals.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 5 - 5
src/petals/client/sequential_autograd.py

@@ -11,11 +11,11 @@ import torch
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.utils.logging import get_logger
 
-from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
-from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
-from src.server.handler import TransformerConnectionHandler
-from src.utils.misc import DUMMY, is_dummy
+from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
+from petals.client.sequence_manager import RemoteSequenceManager
+from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
+from petals.server.handler import TransformerConnectionHandler
+from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__file__)
 

+ 14 - 14
src/petals/dht_utils.py

@@ -12,8 +12,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
-import src
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
+from petals.client import DistributedBloomConfig, RemoteSequential, RemoteSequenceManager, RemoteTransformerBlock
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -76,10 +76,10 @@ def get_remote_sequence(
     dht: DHT,
     start: int,
     stop: int,
-    config: src.DistributedBloomConfig,
+    config: DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[src.RemoteSequential, MPFuture]:
+) -> Union[RemoteSequential, MPFuture]:
     return RemoteExpertWorker.run_coroutine(
         _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
     )
@@ -89,22 +89,22 @@ async def _get_remote_sequence(
     dht: DHT,
     start: int,
     stop: int,
-    config: src.DistributedBloomConfig,
+    config: DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
-) -> src.RemoteSequential:
+) -> RemoteSequential:
     uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
     p2p = await dht.replicate_p2p()
-    manager = src.RemoteSequenceManager(dht, uids, p2p)
-    return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
+    manager = RemoteSequenceManager(dht, uids, p2p)
+    return RemoteSequential(config, dht, dht_prefix, p2p, manager)
 
 
 def get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: src.DistributedBloomConfig,
+    config: DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
+) -> Union[Union[RemoteTransformerBlock, List[RemoteTransformerBlock]], MPFuture]:
     """
     :param uid_or_uids: find one or more modules with these ids from across the DHT
     :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
@@ -119,15 +119,15 @@ def get_remote_module(
 async def _get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: src.DistributedBloomConfig,
+    config: DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
-) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
+) -> Union[RemoteTransformerBlock, List[RemoteTransformerBlock]]:
     single_uid = isinstance(uid_or_uids, ModuleUID)
     uids = [uid_or_uids] if single_uid else uid_or_uids
     p2p = await dht.replicate_p2p()
-    managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+    managers = (RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
     modules = [
-        src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
+        RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
     ]
     return modules[0] if single_uid else modules
 

+ 4 - 4
src/petals/server/backend.py

@@ -6,10 +6,10 @@ from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 
-from src.bloom.from_pretrained import BloomBlock
-from src.server.cache import MemoryCache
-from src.server.task_pool import PrioritizedTaskPool
-from src.utils.misc import is_dummy
+from petals.bloom.from_pretrained import BloomBlock
+from petals.server.cache import MemoryCache
+from petals.server.task_pool import PrioritizedTaskPool
+from petals.utils.misc import is_dummy
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 1 - 1
src/petals/server/block_selection.py

@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple
 import numpy as np
 from hivemind import PeerID, get_logger
 
-from src.data_structures import RemoteModuleInfo, ServerState
+from petals.data_structures import RemoteModuleInfo, ServerState
 
 __all__ = ["choose_best_blocks", "should_choose_other_blocks"]
 

+ 5 - 5
src/petals/server/handler.py

@@ -21,11 +21,11 @@ from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
 from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
-from src.data_structures import CHAIN_DELIMITER, ModuleUID
-from src.server.backend import TransformerBackend
-from src.server.task_pool import PrioritizedTaskPool
-from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
-from src.utils.misc import DUMMY, is_dummy
+from petals.data_structures import CHAIN_DELIMITER, ModuleUID
+from petals.server.backend import TransformerBackend
+from petals.server.task_pool import PrioritizedTaskPool
+from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
+from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__file__)
 

+ 5 - 8
tests/test_block_exact_match.py

@@ -3,16 +3,13 @@ import random
 import hivemind
 import pytest
 import torch
-import transformers
-from hivemind import P2PHandlerError
 from test_utils import *
 
-import src
-from src import DistributedBloomConfig
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_sequential import RemoteTransformerBlock
-from src.data_structures import UID_DELIMITER
-from src.dht_utils import get_remote_module
+from petals.client import DistributedBloomConfig
+from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client.remote_sequential import RemoteTransformerBlock
+from petals.data_structures import UID_DELIMITER
+from petals.dht_utils import get_remote_module
 
 
 @pytest.mark.forked

+ 6 - 6
tests/test_chained_calls.py

@@ -9,16 +9,16 @@ import pytest
 import torch
 from test_utils import *
 
-import src
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_sequential import RemoteSequential
-from src.dht_utils import get_remote_sequence
+from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client import DistributedBloomConfig
+from petals.client.remote_sequential import RemoteSequential
+from petals.dht_utils import get_remote_sequence
 
 
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
     remote_blocks = get_remote_sequence(dht, 3, 6, config)
     assert isinstance(remote_blocks, RemoteSequential)
 
@@ -47,7 +47,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 @pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
     remote_blocks = get_remote_sequence(dht, 3, 5, config)
     assert isinstance(remote_blocks, RemoteSequential)
 

+ 2 - 2
tests/test_full_model.py

@@ -5,8 +5,8 @@ from hivemind import get_logger, use_hivemind_log_handler
 from test_utils import *
 from transformers.generation_utils import BeamSearchScorer
 
-from src.bloom.model import BloomForCausalLM
-from src.client.remote_model import DistributedBloomForCausalLM
+from petals.bloom.model import BloomForCausalLM
+from petals.client.remote_model import DistributedBloomForCausalLM
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 2 - 2
tests/test_priority_pool.py

@@ -4,8 +4,8 @@ import time
 import pytest
 import torch
 
-from src.server.runtime import Runtime
-from src.server.task_pool import PrioritizedTaskPool
+from petals.server.runtime import Runtime
+from petals.server.task_pool import PrioritizedTaskPool
 
 
 @pytest.mark.forked

+ 3 - 3
tests/test_remote_sequential.py

@@ -3,9 +3,9 @@ import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 from test_utils import *
 
-from src import RemoteSequential
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_model import DistributedBloomConfig
+from petals.client import RemoteSequential
+from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client.remote_model import DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)