瀏覽代碼

Enable log handler in benchmarks and run_server (#380)

Currently, messages from the root logger are not displayed using hivemind logger (in particular, LOGLEVEL is ignored for the root logger, so the info messages are not visible). This PR fixes it.

Also, it renames LOGLEVEL to HIVEMIND_LOGLEVEL.

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Alexander Borzunov 4 年之前
父節點
當前提交
fcad8d0d63

+ 4 - 1
benchmarks/benchmark_averaging.py

@@ -7,8 +7,11 @@ import torch
 
 import hivemind
 from hivemind.proto import runtime_pb2
-from hivemind.utils import LOCALHOST, get_logger, increase_file_limit
+from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import LOCALHOST
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 3 - 1
benchmarks/benchmark_dht.py

@@ -7,8 +7,10 @@ from tqdm import trange
 import hivemind
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-logger = hivemind.get_logger(__name__)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
 
 
 def random_endpoint() -> hivemind.Endpoint:

+ 2 - 1
benchmarks/benchmark_tensor_compression.py

@@ -5,8 +5,9 @@ import torch
 
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 2 - 1
benchmarks/benchmark_throughput.py

@@ -10,8 +10,9 @@ import hivemind
 from hivemind import get_free_port
 from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 1 - 1
examples/albert/run_trainer.py

@@ -23,7 +23,7 @@ import utils
 from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger()
+logger = get_logger(__name__)
 
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 

+ 1 - 1
examples/albert/run_training_monitor.py

@@ -18,7 +18,7 @@ import utils
 from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger()
+logger = get_logger(__name__)
 
 
 @dataclass

+ 2 - 1
hivemind/hivemind_cli/run_server.py

@@ -8,8 +8,9 @@ from hivemind.moe.server import Server
 from hivemind.moe.server.layers import schedule_name_to_scheduler
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 1 - 1
hivemind/utils/__init__.py

@@ -1,7 +1,7 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *

+ 9 - 2
hivemind/utils/logging.py

@@ -7,7 +7,7 @@ from typing import Optional, Union
 
 logging.addLevelName(logging.WARNING, "WARN")
 
-loglevel = os.getenv("LOGLEVEL", "INFO")
+loglevel = os.getenv("HIVEMIND_LOGLEVEL", "INFO")
 
 _env_colors = os.getenv("HIVEMIND_COLORS")
 if _env_colors is not None:
@@ -96,7 +96,12 @@ def _initialize_if_necessary():
 
 def get_logger(name: Optional[str] = None) -> logging.Logger:
     """
-    Same as ``logging.getLogger()`` but ensures that the default log handler is initialized.
+    Same as ``logging.getLogger()`` but ensures that the default hivemind log handler is initialized.
+
+    :note: By default, the hivemind log handler (that reads the ``HIVEMIND_LOGLEVEL`` env variable and uses
+           the colored log formatter) is only applied to messages logged inside the hivemind package.
+           If you want to extend this handler to other loggers in your application, call
+           ``use_hivemind_log_handler("in_root_logger")``.
     """
 
     _initialize_if_necessary()
@@ -138,6 +143,8 @@ def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
         # We allow `where` to be a string, so a developer does not have to import the enum for one usage
         where = HandlerMode[where.upper()]
 
+    _initialize_if_necessary()
+
     if where == _current_mode:
         return
 

+ 2 - 1
tests/conftest.py

@@ -6,9 +6,10 @@ from contextlib import suppress
 import psutil
 import pytest
 
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import MPFuture, SharedBytes
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)