Quellcode durchsuchen

Fix logging: do not duplicate lines, enable colors in Colab (#156)

Alexander Borzunov vor 2 Jahren
Ursprung
Commit
668b736031

+ 4 - 0
src/petals/__init__.py

@@ -1 +1,5 @@
+import petals.utils.logging
+
 __version__ = "1.0alpha1"
+
+petals.utils.logging.initialize_logs()

+ 1 - 2
src/petals/bloom/from_pretrained.py

@@ -13,7 +13,7 @@ import time
 from typing import Optional, OrderedDict, Union
 
 import torch
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.logging import get_logger
 from transformers.modeling_utils import WEIGHTS_NAME
 from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.utils import get_file_from_repo
@@ -22,7 +22,6 @@ from petals.bloom.block import WrappedBloomBlock
 from petals.server.block_utils import get_block_size
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 CLIENT_BRANCH = "main"

+ 2 - 4
src/petals/bloom/modeling_utils.py

@@ -7,13 +7,11 @@ See commit history for authorship.
 import torch
 import torch.nn.functional as F
 import torch.utils.checkpoint
-from hivemind import use_hivemind_log_handler
+from hivemind import get_logger
 from torch import nn
 from transformers import BloomConfig
-from transformers.utils import logging
 
-use_hivemind_log_handler("in_root_logger")
-logger = logging.get_logger(__file__)
+logger = get_logger(__file__)
 
 
 class LMHead(nn.Module):

+ 1 - 2
src/petals/cli/convert_model.py

@@ -5,7 +5,7 @@ import psutil
 import torch.backends.quantized
 import torch.nn as nn
 import transformers
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.logging import get_logger
 from huggingface_hub import Repository
 from tqdm.auto import tqdm
 from transformers.models.bloom.modeling_bloom import BloomModel
@@ -13,7 +13,6 @@ from transformers.models.bloom.modeling_bloom import BloomModel
 from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
 from petals.client import DistributedBloomConfig
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 1 - 2
src/petals/cli/inference_one_block.py

@@ -1,14 +1,13 @@
 import argparse
 
 import torch
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.logging import get_logger
 from tqdm.auto import trange
 from transformers import BloomConfig
 from transformers.models.bloom.modeling_bloom import build_alibi_tensor
 
 from petals.bloom.block import BloomBlock
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 logger.warning("inference_one_block will soon be deprecated in favour of tests!")

+ 1 - 2
src/petals/cli/run_server.py

@@ -3,13 +3,12 @@ import argparse
 import configargparse
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.logging import get_logger
 from humanfriendly import parse_size
 
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.server.server import Server
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 1 - 6
src/petals/client/remote_model.py

@@ -5,7 +5,7 @@ from typing import List, Optional
 import hivemind
 import torch
 import torch.nn as nn
-from hivemind.utils.logging import get_logger, loglevel, use_hivemind_log_handler
+from hivemind.utils.logging import get_logger
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
 from transformers.models.bloom import (
     BloomConfig,
@@ -21,13 +21,8 @@ from petals.client.remote_sequential import RemoteSequential
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.utils.misc import DUMMY
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
-# We suppress asyncio error logs by default since they are mostly not relevant for the end user
-asyncio_loglevel = os.getenv("PETALS_ASYNCIO_LOGLEVEL", "FATAL" if loglevel != "DEBUG" else "DEBUG")
-get_logger("asyncio").setLevel(asyncio_loglevel)
-
 
 class DistributedBloomConfig(BloomConfig):
     """

+ 1 - 2
src/petals/client/remote_sequential.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 from typing import Optional, Union
 
 import torch
-from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
+from hivemind import DHT, P2P, get_logger
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
@@ -14,7 +14,6 @@ from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from petals.data_structures import UID_DELIMITER
 from petals.utils.misc import DUMMY
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 1 - 2
src/petals/client/routing/sequence_info.py

@@ -2,11 +2,10 @@ import dataclasses
 import time
 from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
 
-from hivemind import get_logger, use_hivemind_log_handler
+from hivemind import get_logger
 
 from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 1 - 2
src/petals/client/routing/sequence_manager.py

@@ -14,7 +14,7 @@ from hivemind.dht.node import Blacklist
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import P2PHandlerError
 from hivemind.proto import runtime_pb2
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.logging import get_logger
 
 import petals.dht_utils
 from petals.client.routing.sequence_info import RemoteSequenceInfo
@@ -22,7 +22,6 @@ from petals.client.routing.spending_policy import NoSpendingPolicy
 from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
 from petals.server.handler import TransformerConnectionHandler
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 1 - 2
src/petals/dht_utils.py

@@ -10,12 +10,11 @@ from typing import Dict, List, Optional, Sequence, Union
 from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import PeerID
-from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
 
 import petals.client
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 1 - 2
src/petals/server/backend.py

@@ -2,7 +2,7 @@
 from typing import Any, Dict, Sequence, Tuple
 
 import torch
-from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
+from hivemind import BatchTensorDescriptor
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 
@@ -11,7 +11,6 @@ from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.utils.misc import is_dummy
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 0 - 2
src/petals/server/memory_cache.py

@@ -14,10 +14,8 @@ from typing import AsyncContextManager, Dict, Optional, Union
 
 import hivemind
 import torch
-from hivemind import use_hivemind_log_handler
 from hivemind.utils import TensorDescriptor, get_logger
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 Handle = int

+ 1 - 2
src/petals/server/server.py

@@ -16,7 +16,7 @@ from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescripto
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.logging import get_logger
 from transformers import BloomConfig
 
 from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
@@ -32,7 +32,6 @@ from petals.server.throughput import get_host_throughput
 from petals.utils.convert_8bit import replace_8bit_linear
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 1 - 2
src/petals/server/task_pool.py

@@ -8,11 +8,10 @@ from queue import PriorityQueue
 from typing import Any, List, Optional, Sequence, Tuple
 
 import torch
-from hivemind import get_logger, use_hivemind_log_handler
+from hivemind import get_logger
 from hivemind.moe.server.task_pool import TaskPoolBase
 from hivemind.utils.mpfuture import ALL_STATES, MPFuture
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 1 - 2
src/petals/server/throughput.py

@@ -8,7 +8,7 @@ from pathlib import Path
 from typing import Optional, Union
 
 import torch
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.logging import get_logger
 from transformers import BloomConfig
 
 from petals.bloom.block import WrappedBloomBlock
@@ -16,7 +16,6 @@ from petals.server.block_utils import resolve_block_dtype
 from petals.utils.convert_8bit import replace_8bit_linear
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 34 - 0
src/petals/utils/logging.py

@@ -0,0 +1,34 @@
+import importlib
+import os
+
+from hivemind.utils import logging as hm_logging
+
+
+def in_jupyter() -> bool:
+    """Check if the code is run in Jupyter or Colab"""
+
+    try:
+        __IPYTHON__
+        return True
+    except NameError:
+        return False
+
+
+def initialize_logs():
+    """Initialize Petals logging tweaks. This function is called when you import the `petals` module."""
+
+    # Env var PETALS_LOGGING=False prohibits Petals do anything with logs
+    if os.getenv("PETALS_LOGGING", "True").lower() in ("false", "0"):
+        return
+
+    if in_jupyter():
+        os.environ["HIVEMIND_COLORS"] = "True"
+    importlib.reload(hm_logging)
+
+    hm_logging.get_logger().handlers.clear()  # Remove extra default handlers on Colab
+    hm_logging.use_hivemind_log_handler("in_root_logger")
+
+    # We suppress asyncio error logs by default since they are mostly not relevant for the end user,
+    # unless there is env var PETALS_ASYNCIO_LOGLEVEL
+    asyncio_loglevel = os.getenv("PETALS_ASYNCIO_LOGLEVEL", "FATAL" if hm_logging.loglevel != "DEBUG" else "DEBUG")
+    hm_logging.get_logger("asyncio").setLevel(asyncio_loglevel)

+ 1 - 2
tests/conftest.py

@@ -5,10 +5,9 @@ from contextlib import suppress
 import psutil
 import pytest
 from hivemind.utils.crypto import RSAPrivateKey
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.logging import get_logger
 from hivemind.utils.mpfuture import MPFuture
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 1 - 2
tests/test_full_model.py

@@ -1,14 +1,13 @@
 import pytest
 import torch
 import transformers
-from hivemind import get_logger, use_hivemind_log_handler
+from hivemind import get_logger
 from test_utils import *
 from transformers.generation import BeamSearchScorer
 from transformers.models.bloom import BloomForCausalLM
 
 from petals.client.remote_model import DistributedBloomForCausalLM
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 1 - 2
tests/test_remote_sequential.py

@@ -1,6 +1,6 @@
 import pytest
 import torch
-from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
+from hivemind import DHT, BatchTensorDescriptor, get_logger
 from hivemind.proto import runtime_pb2
 from test_utils import *
 
@@ -9,7 +9,6 @@ from petals.client import RemoteSequenceManager, RemoteSequential
 from petals.client.remote_model import DistributedBloomConfig
 from petals.data_structures import UID_DELIMITER
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 

+ 1 - 2
tests/test_sequence_manager.py

@@ -3,14 +3,13 @@ import time
 
 import pytest
 import torch
-from hivemind import DHT, get_logger, use_hivemind_log_handler
+from hivemind import DHT, get_logger
 from test_utils import *
 
 from petals.client import RemoteSequenceManager, RemoteSequential
 from petals.client.remote_model import DistributedBloomConfig
 from petals.data_structures import UID_DELIMITER
 
-use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)