Browse Source

Merge remote-tracking branch 'origin/master' into Checkpointer

Michael Diskin 3 năm trước cách đây
mục cha
commit
9a35201e25

+ 4 - 1
benchmarks/benchmark_averaging.py

@@ -7,8 +7,11 @@ import torch
 
 import hivemind
 from hivemind.proto import runtime_pb2
-from hivemind.utils import LOCALHOST, get_logger, increase_file_limit
+from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import LOCALHOST
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 3 - 1
benchmarks/benchmark_dht.py

@@ -7,8 +7,10 @@ from tqdm import trange
 import hivemind
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-logger = hivemind.get_logger(__name__)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
 
 
 def random_endpoint() -> hivemind.Endpoint:

+ 2 - 1
benchmarks/benchmark_tensor_compression.py

@@ -5,8 +5,9 @@ import torch
 
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 2 - 1
benchmarks/benchmark_throughput.py

@@ -10,8 +10,9 @@ import hivemind
 from hivemind import get_free_port
 from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 1 - 1
examples/albert/run_trainer.py

@@ -23,7 +23,7 @@ import utils
 from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger()
+logger = get_logger(__name__)
 
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 

+ 1 - 1
examples/albert/run_training_monitor.py

@@ -18,7 +18,7 @@ import utils
 from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger()
+logger = get_logger(__name__)
 
 
 @dataclass

+ 2 - 1
hivemind/hivemind_cli/run_server.py

@@ -8,8 +8,9 @@ from hivemind.moe.server import Server
 from hivemind.moe.server.layers import schedule_name_to_scheduler
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 

+ 1 - 1
hivemind/utils/__init__.py

@@ -1,7 +1,7 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *

+ 9 - 2
hivemind/utils/logging.py

@@ -7,7 +7,7 @@ from typing import Optional, Union
 
 logging.addLevelName(logging.WARNING, "WARN")
 
-loglevel = os.getenv("LOGLEVEL", "INFO")
+loglevel = os.getenv("HIVEMIND_LOGLEVEL", "INFO")
 
 _env_colors = os.getenv("HIVEMIND_COLORS")
 if _env_colors is not None:
@@ -96,7 +96,12 @@ def _initialize_if_necessary():
 
 def get_logger(name: Optional[str] = None) -> logging.Logger:
     """
-    Same as ``logging.getLogger()`` but ensures that the default log handler is initialized.
+    Same as ``logging.getLogger()`` but ensures that the default hivemind log handler is initialized.
+
+    :note: By default, the hivemind log handler (that reads the ``HIVEMIND_LOGLEVEL`` env variable and uses
+           the colored log formatter) is only applied to messages logged inside the hivemind package.
+           If you want to extend this handler to other loggers in your application, call
+           ``use_hivemind_log_handler("in_root_logger")``.
     """
 
     _initialize_if_necessary()
@@ -138,6 +143,8 @@ def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
         # We allow `where` to be a string, so a developer does not have to import the enum for one usage
         where = HandlerMode[where.upper()]
 
+    _initialize_if_necessary()
+
     if where == _current_mode:
         return
 

+ 2 - 1
tests/conftest.py

@@ -6,9 +6,10 @@ from contextlib import suppress
 import psutil
 import pytest
 
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import MPFuture, SharedBytes
 
+use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)