|
@@ -3,20 +3,32 @@ import os
|
|
|
import sys
|
|
|
import threading
|
|
|
from enum import Enum
|
|
|
-from typing import Optional, Union
|
|
|
+from typing import Any, Optional, Union
|
|
|
|
|
|
-logging.addLevelName(logging.WARNING, "WARN")
|
|
|
|
|
|
+def in_ipython() -> bool:
|
|
|
+ """Check if the code is run in IPython, Jupyter, or Colab"""
|
|
|
+
|
|
|
+ try:
|
|
|
+ __IPYTHON__
|
|
|
+ return True
|
|
|
+ except NameError:
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+logging.addLevelName(logging.WARNING, "WARN")
|
|
|
loglevel = os.getenv("HIVEMIND_LOGLEVEL", "INFO")
|
|
|
|
|
|
+TRUE_CONSTANTS = ["TRUE", "1"]
|
|
|
+
|
|
|
_env_colors = os.getenv("HIVEMIND_COLORS")
|
|
|
if _env_colors is not None:
|
|
|
- use_colors = _env_colors.lower() == "true"
|
|
|
+ use_colors = _env_colors.upper() in TRUE_CONSTANTS
|
|
|
else:
|
|
|
- use_colors = sys.stderr.isatty()
|
|
|
+ use_colors = sys.stderr.isatty() or in_ipython()
|
|
|
|
|
|
-_env_log_caller = os.getenv("HIVEMIND_ALWAYS_LOG_CALLER")
|
|
|
-always_log_caller = _env_log_caller is not None and _env_log_caller.lower() == "true"
|
|
|
+_env_log_caller = os.getenv("HIVEMIND_ALWAYS_LOG_CALLER", "0")
|
|
|
+always_log_caller = _env_log_caller.upper() in TRUE_CONSTANTS
|
|
|
|
|
|
|
|
|
class HandlerMode(Enum):
|
|
@@ -30,7 +42,14 @@ _current_mode = HandlerMode.IN_HIVEMIND
|
|
|
_default_handler = None
|
|
|
|
|
|
|
|
|
-class TextStyle:
|
|
|
+class _DisableIfNoColors(type):
|
|
|
+ def __getattribute__(self, name: str) -> Any:
|
|
|
+ if name.isupper() and not use_colors:
|
|
|
+ return ""
|
|
|
+ return super().__getattribute__(name)
|
|
|
+
|
|
|
+
|
|
|
+class TextStyle(metaclass=_DisableIfNoColors):
|
|
|
"""
|
|
|
ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
|
|
|
"""
|
|
@@ -42,11 +61,6 @@ class TextStyle:
|
|
|
PURPLE = "\033[35m"
|
|
|
ORANGE = "\033[38;5;208m" # From 8-bit palette
|
|
|
|
|
|
- if not use_colors:
|
|
|
- # Set the constants above to empty strings
|
|
|
- _codes = locals()
|
|
|
- _codes.update({_name: "" for _name in list(_codes) if _name.isupper()})
|
|
|
-
|
|
|
|
|
|
class CustomFormatter(logging.Formatter):
|
|
|
"""
|
|
@@ -115,14 +129,21 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
|
|
|
return logging.getLogger(name)
|
|
|
|
|
|
|
|
|
-def _enable_default_handler(name: str) -> None:
|
|
|
+def _enable_default_handler(name: Optional[str]) -> None:
|
|
|
logger = get_logger(name)
|
|
|
+
|
|
|
+ # Remove the extra default handler in the Colab's default logger before adding a new one
|
|
|
+ if isinstance(logger, logging.RootLogger):
|
|
|
+ for handler in list(logger.handlers):
|
|
|
+ if isinstance(handler, logging.StreamHandler) and handler.stream is sys.stderr:
|
|
|
+ logger.removeHandler(handler)
|
|
|
+
|
|
|
logger.addHandler(_default_handler)
|
|
|
logger.propagate = False
|
|
|
logger.setLevel(loglevel)
|
|
|
|
|
|
|
|
|
-def _disable_default_handler(name: str) -> None:
|
|
|
+def _disable_default_handler(name: Optional[str]) -> None:
|
|
|
logger = get_logger(name)
|
|
|
logger.removeHandler(_default_handler)
|
|
|
logger.propagate = True
|