@@ -22,7 +22,7 @@ from petals.bloom.block import WrappedBloomBlock
from petals.server.block_utils import get_block_size
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
-logger = get_logger(__file__)
+logger = get_logger(__name__)
CLIENT_BRANCH = "main"
BLOCK_BRANCH_PREFIX = "block_"
@@ -13,7 +13,7 @@ from hivemind import get_logger
from torch import nn
from transformers import BloomConfig
class LMHead(nn.Module):
@@ -13,7 +13,7 @@ from transformers.models.bloom.modeling_bloom import BloomModel
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
from petals.client import DistributedBloomConfig
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
@@ -8,7 +8,7 @@ from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from petals.bloom.block import BloomBlock
logger.warning("inference_one_block will soon be deprecated in favour of tests!")
@@ -10,7 +10,7 @@ from petals.constants import PUBLIC_INITIAL_PEERS
from petals.server.server import Server
from petals.utils.version import validate_version
def main():
@@ -25,7 +25,7 @@ from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, R
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
class _ServerInferenceSession:
@@ -15,7 +15,7 @@ from petals.utils.generation_algorithms import (
)
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
class RemoteGenerationMixin:
@@ -21,7 +21,7 @@ from petals.client.remote_sequential import RemoteSequential
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.utils.misc import DUMMY
class DistributedBloomConfig(BloomConfig):
@@ -14,7 +14,7 @@ from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
class RemoteSequential(nn.Module):
@@ -6,7 +6,7 @@ from hivemind import get_logger
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
T = TypeVar("T")
@@ -23,7 +23,7 @@ from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
class RemoteSequenceManager:
@@ -18,7 +18,7 @@ from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
MAX_TOKENS_IN_BATCH = 1024
@@ -15,7 +15,7 @@ from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
import petals.client
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
def declare_active_modules(
@@ -20,7 +20,7 @@ from petals.server.memory_cache import Handle, MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
class TransformerBackend(ModuleBackend):
@@ -8,7 +8,7 @@ from petals.data_structures import RemoteModuleInfo, ServerState
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
@dataclass
@@ -32,7 +32,7 @@ from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
@@ -18,7 +18,7 @@ from hivemind.utils import TensorDescriptor, get_logger
from petals.utils.asyncio import shield_and_wait
Handle = int
@@ -31,7 +31,7 @@ from petals.server.throughput import get_dtype_name, get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
class Server:
@@ -12,7 +12,7 @@ from hivemind import get_logger
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
@dataclass(order=True, frozen=True)
@@ -16,7 +16,7 @@ from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import convert_block
try:
import speedtest
@@ -15,7 +15,7 @@ from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.bloom.block import WrappedBloomBlock
use_hivemind_log_handler("in_root_logger")
def convert_block(
@@ -8,7 +8,7 @@ from typing import Optional
import huggingface_hub
from hivemind.utils.logging import get_logger
DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))
@@ -4,7 +4,7 @@ from packaging.version import parse
import petals
def validate_version():
@@ -8,7 +8,7 @@ from transformers.models.bloom import BloomForCausalLM
from petals.client.remote_model import DistributedBloomForCausalLM
@pytest.mark.forked
@@ -10,7 +10,7 @@ from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig