Browse Source

Make log handlers configurable, shorten entries (#378)

1. Fix bugs: make `get_logger()` idempotent and don't trim the actual logger name.
2. Allow a developer to choose where the default hivemind log handler is enabled (in hivemind/in the root logger/nowhere).
3. Enable the `in_root_logger` mode in `examples/albert`, so that all messages (from `__main__`, `transformers`, and `hivemind` itself) consistently follow the hivemind style.
4. Change some log messages to improve their presentation.

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Alexander Borzunov 4 năm trước cách đây
mục cha
commit
b84f62bc08

+ 1 - 1
benchmarks/benchmark_averaging.py

@@ -80,7 +80,7 @@ def benchmark_averaging(
             with lock_stats:
                 successful_steps += int(success)
                 total_steps += 1
-            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
+            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}")
         logger.info(f"Averager {index}: done.")
 
     threads = []

+ 11 - 21
examples/albert/run_trainer.py

@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 
-import logging
 import os
 import pickle
 from dataclasses import asdict
@@ -18,32 +17,22 @@ from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 
 import hivemind
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
 from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 
-logger = logging.getLogger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
 
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 
-def setup_logging(training_args):
-    logging.basicConfig(
-        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
-        datefmt="%m/%d/%Y %H:%M:%S",
-        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
-    )
 
-    # Log on each process the small summary:
-    logger.warning(
-        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
-        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
-    )
-    # Set the verbosity to info of the Transformers logger (on main process only):
-    if is_main_process(training_args.local_rank):
+def setup_transformers_logging(process_rank: int):
+    if is_main_process(process_rank):
         transformers.utils.logging.set_verbosity_info()
-        transformers.utils.logging.enable_default_handler()
-        transformers.utils.logging.enable_explicit_format()
-    logger.info("Training/evaluation parameters %s", training_args)
+        transformers.utils.logging.disable_default_handler()
+        transformers.utils.logging.enable_propagation()
 
 
 def get_model(training_args, config, tokenizer):
@@ -149,7 +138,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
                     loss=self.loss,
                     mini_steps=self.steps,
                 )
-                logger.info(f"Step {self.collaborative_optimizer.local_step}")
+                logger.info(f"Step #{self.collaborative_optimizer.local_step}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 logger.info(f"Performance: {samples_per_second} samples per second.")
                 if self.steps:
@@ -220,7 +209,8 @@ def main():
     if len(collaboration_args.initial_peers) == 0:
         raise ValueError("Please specify at least one network endpoint in initial peers.")
 
-    setup_logging(training_args)
+    setup_transformers_logging(training_args.local_rank)
+    logger.info(f"Training/evaluation parameters:\n{training_args}")
 
     # Set seed before initializing model.
     set_seed(training_args.seed)

+ 4 - 3
examples/albert/run_training_monitor.py

@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 
-import logging
 import time
 from dataclasses import asdict, dataclass, field
 from ipaddress import ip_address
@@ -13,11 +12,13 @@ from torch_optimizer import Lamb
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 import hivemind
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
 from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 
-logger = logging.getLogger(__name__)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
 
 
 @dataclass
@@ -139,7 +140,7 @@ class CheckpointHandler:
         self.model.push_to_hub(
             repo_name=self.repo_path,
             repo_url=self.repo_url,
-            commit_message=f"Step {current_step}, loss {current_loss:.3f}",
+            commit_message=f"Step #{current_step}, loss {current_loss:.3f}",
         )
         logger.info("Finished uploading to Model Hub")
 

+ 10 - 10
hivemind/optim/collaborative.py

@@ -153,7 +153,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
-        self.collaboration_state = self.fetch_collaboration_state()
+        self.collaboration_state = self._fetch_state()
         self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
         self.lock_local_progress, self.should_report_progress = Lock(), Event()
         self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
@@ -237,8 +237,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self.fetch_collaboration_state()
+        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
+        self.collaboration_state = self._fetch_state()
         self.collaboration_state_updated.set()
 
         if not self.is_synchronized:
@@ -288,8 +288,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self.fetch_collaboration_state()
+        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
+        self.collaboration_state = self._fetch_state()
         self.collaboration_state_updated.set()
 
         with self.lock_collaboration_state:
@@ -392,9 +392,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 continue  # if state was updated externally, reset timer
 
             with self.lock_collaboration_state:
-                self.collaboration_state = self.fetch_collaboration_state()
+                self.collaboration_state = self._fetch_state()
 
-    def fetch_collaboration_state(self) -> CollaborationState:
+    def _fetch_state(self) -> CollaborationState:
         """Read performance statistics reported by peers, estimate progress towards next batch"""
         response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         current_time = get_dht_time()
@@ -452,9 +452,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         )
         logger.log(
             self.status_loglevel,
-            f"Collaboration accumulated {total_samples_accumulated} samples from "
-            f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
-            f"(refresh in {time_to_next_fetch:.2f}s.)",
+            f"{self.prefix} accumulated {total_samples_accumulated} samples from "
+            f"{num_peers} peers for step #{global_optimizer_step}. "
+            f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
         )
         return CollaborationState(
             global_optimizer_step,

+ 90 - 15
hivemind/utils/logging.py

@@ -1,6 +1,11 @@
 import logging
 import os
 import sys
+import threading
+from enum import Enum
+from typing import Optional, Union
+
+logging.addLevelName(logging.WARNING, "WARN")
 
 loglevel = os.getenv("LOGLEVEL", "INFO")
 
@@ -11,6 +16,17 @@ else:
     use_colors = sys.stderr.isatty()
 
 
+class HandlerMode(Enum):
+    NOWHERE = 0
+    IN_HIVEMIND = 1
+    IN_ROOT_LOGGER = 2
+
+
+_init_lock = threading.RLock()
+_current_mode = HandlerMode.IN_HIVEMIND
+_default_handler = None
+
+
 class TextStyle:
     """
     ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
@@ -60,23 +76,82 @@ class CustomFormatter(logging.Formatter):
         return super().format(record)
 
 
-def get_logger(module_name: str) -> logging.Logger:
-    # trim package name
-    name_without_prefix = ".".join(module_name.split(".")[1:])
+def _initialize_if_necessary():
+    global _current_mode, _default_handler
 
-    logging.addLevelName(logging.WARNING, "WARN")
-    formatter = CustomFormatter(
-        fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
-        style="{",
-        datefmt="%b %d %H:%M:%S",
-    )
-    handler = logging.StreamHandler()
-    handler.setFormatter(formatter)
-    logger = logging.getLogger(name_without_prefix)
-    logger.setLevel(loglevel)
-    logger.addHandler(handler)
+    with _init_lock:
+        if _default_handler is not None:
+            return
+
+        formatter = CustomFormatter(
+            fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
+            style="{",
+            datefmt="%b %d %H:%M:%S",
+        )
+        _default_handler = logging.StreamHandler()
+        _default_handler.setFormatter(formatter)
+
+        _enable_default_handler("hivemind")
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+    """
+    Same as ``logging.getLogger()`` but ensures that the default log handler is initialized.
+    """
+
+    _initialize_if_necessary()
+    return logging.getLogger(name)
+
+
+def _enable_default_handler(name: str) -> None:
+    logger = get_logger(name)
+    logger.addHandler(_default_handler)
     logger.propagate = False
-    return logger
+    logger.setLevel(loglevel)
+
+
+def _disable_default_handler(name: str) -> None:
+    logger = get_logger(name)
+    logger.removeHandler(_default_handler)
+    logger.propagate = True
+    logger.setLevel(logging.NOTSET)
+
+
+def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
+    """
+    Choose loggers where the default hivemind log handler is applied. Options for the ``where`` argument are:
+
+    * "in_hivemind" (default): Use the hivemind log handler in the loggers of the ``hivemind`` package.
+                               Don't propagate their messages to the root logger.
+    * "nowhere": Don't use the hivemind log handler anywhere.
+                 Propagate the ``hivemind`` messages to the root logger.
+    * "in_root_logger": Use the hivemind log handler in the root logger
+                        (that is, in all application loggers until they disable propagation to the root logger).
+                        Propagate the ``hivemind`` messages to the root logger.
+
+    The options may be defined as strings (case-insensitive) or values from the HandlerMode enum.
+    """
+
+    global _current_mode
+
+    if isinstance(where, str):
+        # We allow `where` to be a string, so a developer does not have to import the enum for one usage
+        where = HandlerMode[where.upper()]
+
+    if where == _current_mode:
+        return
+
+    if _current_mode == HandlerMode.IN_HIVEMIND:
+        _disable_default_handler("hivemind")
+    elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
+        _disable_default_handler(None)
+
+    _current_mode = where
+
+    if _current_mode == HandlerMode.IN_HIVEMIND:
+        _enable_default_handler("hivemind")
+    elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
+        _enable_default_handler(None)
 
 
 def golog_level_to_python(level: str) -> int: