Bläddra i källkod

Merge remote-tracking branch 'origin/master' into Checkpointer

Michael Diskin 4 år sedan
förälder
incheckning
ba3906d498

+ 1 - 1
benchmarks/benchmark_averaging.py

@@ -80,7 +80,7 @@ def benchmark_averaging(
             with lock_stats:
                 successful_steps += int(success)
                 total_steps += 1
-            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
+            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}")
         logger.info(f"Averager {index}: done.")
 
     threads = []

+ 11 - 21
examples/albert/run_trainer.py

@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 
-import logging
 import os
 import pickle
 from dataclasses import asdict
@@ -18,32 +17,22 @@ from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 
 import hivemind
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
 from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 
-logger = logging.getLogger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
 
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 
-def setup_logging(training_args):
-    logging.basicConfig(
-        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
-        datefmt="%m/%d/%Y %H:%M:%S",
-        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
-    )
 
-    # Log on each process the small summary:
-    logger.warning(
-        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
-        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
-    )
-    # Set the verbosity to info of the Transformers logger (on main process only):
-    if is_main_process(training_args.local_rank):
+def setup_transformers_logging(process_rank: int):
+    if is_main_process(process_rank):
         transformers.utils.logging.set_verbosity_info()
-        transformers.utils.logging.enable_default_handler()
-        transformers.utils.logging.enable_explicit_format()
-    logger.info("Training/evaluation parameters %s", training_args)
+        transformers.utils.logging.disable_default_handler()
+        transformers.utils.logging.enable_propagation()
 
 
 def get_model(training_args, config, tokenizer):
@@ -149,7 +138,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
                     loss=self.loss,
                     mini_steps=self.steps,
                 )
-                logger.info(f"Step {self.collaborative_optimizer.local_step}")
+                logger.info(f"Step #{self.collaborative_optimizer.local_step}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 logger.info(f"Performance: {samples_per_second} samples per second.")
                 if self.steps:
@@ -220,7 +209,8 @@ def main():
     if len(collaboration_args.initial_peers) == 0:
         raise ValueError("Please specify at least one network endpoint in initial peers.")
 
-    setup_logging(training_args)
+    setup_transformers_logging(training_args.local_rank)
+    logger.info(f"Training/evaluation parameters:\n{training_args}")
 
     # Set seed before initializing model.
     set_seed(training_args.seed)

+ 4 - 3
examples/albert/run_training_monitor.py

@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 
-import logging
 import time
 from dataclasses import asdict, dataclass, field
 from ipaddress import ip_address
@@ -13,11 +12,13 @@ from torch_optimizer import Lamb
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 import hivemind
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
 from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 
-logger = logging.getLogger(__name__)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
 
 
 @dataclass
@@ -139,7 +140,7 @@ class CheckpointHandler:
         self.model.push_to_hub(
             repo_name=self.repo_path,
             repo_url=self.repo_url,
-            commit_message=f"Step {current_step}, loss {current_loss:.3f}",
+            commit_message=f"Step #{current_step}, loss {current_loss:.3f}",
         )
         logger.info("Finished uploading to Model Hub")
 

+ 10 - 10
hivemind/optim/collaborative.py

@@ -153,7 +153,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
-        self.collaboration_state = self.fetch_collaboration_state()
+        self.collaboration_state = self._fetch_state()
         self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
         self.lock_local_progress, self.should_report_progress = Lock(), Event()
         self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
@@ -248,8 +248,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self.fetch_collaboration_state()
+        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
+        self.collaboration_state = self._fetch_state()
         self.collaboration_state_updated.set()
 
         if not self.is_synchronized:
@@ -299,8 +299,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self.fetch_collaboration_state()
+        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
+        self.collaboration_state = self._fetch_state()
         self.collaboration_state_updated.set()
 
         with self.lock_collaboration_state:
@@ -403,9 +403,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 continue  # if state was updated externally, reset timer
 
             with self.lock_collaboration_state:
-                self.collaboration_state = self.fetch_collaboration_state()
+                self.collaboration_state = self._fetch_state()
 
-    def fetch_collaboration_state(self) -> CollaborationState:
+    def _fetch_state(self) -> CollaborationState:
         """Read performance statistics reported by peers, estimate progress towards next batch"""
         response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         current_time = get_dht_time()
@@ -463,9 +463,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         )
         logger.log(
             self.status_loglevel,
-            f"Collaboration accumulated {total_samples_accumulated} samples from "
-            f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
-            f"(refresh in {time_to_next_fetch:.2f}s.)",
+            f"{self.prefix} accumulated {total_samples_accumulated} samples from "
+            f"{num_peers} peers for step #{global_optimizer_step}. "
+            f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
         )
         return CollaborationState(
             global_optimizer_step,

+ 9 - 2
hivemind/p2p/p2p_daemon.py

@@ -15,7 +15,7 @@ from multiaddr import Multiaddr
 
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
-from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import as_aiter, asingle
@@ -98,6 +98,7 @@ class P2P:
         use_relay: bool = True,
         use_relay_hop: bool = False,
         use_relay_discovery: bool = False,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
@@ -168,6 +169,7 @@ class P2P:
             relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
             tls=tls,
+            persistentConnMaxMsgSize=persistent_conn_max_msg_size,
             **process_kwargs,
         )
 
@@ -189,7 +191,12 @@ class P2P:
             await self.shutdown()
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
-        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(
+            control_maddr=self._daemon_listen_maddr,
+            listen_maddr=self._client_listen_maddr,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
+
         await self._ping_daemon()
         return self
 

+ 26 - 6
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -26,6 +26,8 @@ SUPPORT_CONN_PROTOCOLS = (
 SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
 logger = get_logger(__name__)
 
+DEFAULT_MAX_MSG_SIZE = 4 * 1024 ** 2
+
 
 def parse_conn_protocol(maddr: Multiaddr) -> int:
     proto_codes = set(proto.code for proto in maddr.protocols())
@@ -84,10 +86,13 @@ class ControlClient:
         daemon_connector: DaemonConnector,
         listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
         *,
-        _initialized_with_create=False,
+        _initialized_with_create: bool = False,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
     ) -> None:
         assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
 
+        self.persistent_conn_max_msg_size = persistent_conn_max_msg_size
+
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
@@ -107,8 +112,14 @@ class ControlClient:
         daemon_connector: DaemonConnector,
         listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
         use_persistent_conn: bool = True,
+        persistent_conn_max_msg_size=2 << 22,
     ) -> "ControlClient":
-        control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
+        control = cls(
+            daemon_connector,
+            listen_maddr,
+            _initialized_with_create=True,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
 
         if use_persistent_conn:
             await control._ensure_persistent_conn()
@@ -207,12 +218,18 @@ class ControlClient:
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
 
-        await self._pending_messages.put(
-            p2pd_pb.PersistentConnectionRequest(
+        payload = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, unaryResponse=response)
+        if payload.ByteSize() <= self.persistent_conn_max_msg_size:
+            await self._pending_messages.put(payload)
+        else:
+            error_msg = p2pd_pb.PersistentConnectionRequest(
                 callId=call_id.bytes,
-                unaryResponse=response,
+                callUnaryResponse=p2pd_pb.CallUnaryResponse(
+                    error=b"response size exceeds message size limit",
+                ),
             )
-        )
+            await self._pending_messages.put(error_msg)
+
         self._handler_tasks.pop(call_id)
 
     async def _cancel_unary_call(self, call_id: UUID):
@@ -255,6 +272,9 @@ class ControlClient:
             callUnary=call_unary_req,
         )
 
+        if req.ByteSize() > self.persistent_conn_max_msg_size:
+            raise P2PDaemonError(f"Message size exceeds set limit {self.persistent_conn_max_msg_size}")
+
         try:
             self._pending_calls[call_id] = asyncio.Future()
             await self._pending_messages.put(req)

+ 19 - 3
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,7 +10,13 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 from multiaddr import Multiaddr
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler, TUnaryHandler
+from hivemind.p2p.p2p_daemon_bindings.control import (
+    DEFAULT_MAX_MSG_SIZE,
+    ControlClient,
+    DaemonConnector,
+    StreamHandler,
+    TUnaryHandler,
+)
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
@@ -22,11 +28,21 @@ class Client:
         self.control = None
 
     @classmethod
-    async def create(cls, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> "Client":
+    async def create(
+        cls,
+        control_maddr: Multiaddr = None,
+        listen_maddr: Multiaddr = None,
+        *,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
+    ) -> "Client":
         client = cls(_initialized_with_create=True)
 
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        client.control = await ControlClient.create(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(
+            daemon_connector=daemon_connector,
+            listen_maddr=listen_maddr,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
 
         return client
 

+ 90 - 15
hivemind/utils/logging.py

@@ -1,6 +1,11 @@
 import logging
 import os
 import sys
+import threading
+from enum import Enum
+from typing import Optional, Union
+
+logging.addLevelName(logging.WARNING, "WARN")
 
 loglevel = os.getenv("LOGLEVEL", "INFO")
 
@@ -11,6 +16,17 @@ else:
     use_colors = sys.stderr.isatty()
 
 
+class HandlerMode(Enum):
+    NOWHERE = 0
+    IN_HIVEMIND = 1
+    IN_ROOT_LOGGER = 2
+
+
+_init_lock = threading.RLock()
+_current_mode = HandlerMode.IN_HIVEMIND
+_default_handler = None
+
+
 class TextStyle:
     """
     ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
@@ -60,23 +76,82 @@ class CustomFormatter(logging.Formatter):
         return super().format(record)
 
 
-def get_logger(module_name: str) -> logging.Logger:
-    # trim package name
-    name_without_prefix = ".".join(module_name.split(".")[1:])
+def _initialize_if_necessary():
+    global _current_mode, _default_handler
 
-    logging.addLevelName(logging.WARNING, "WARN")
-    formatter = CustomFormatter(
-        fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
-        style="{",
-        datefmt="%b %d %H:%M:%S",
-    )
-    handler = logging.StreamHandler()
-    handler.setFormatter(formatter)
-    logger = logging.getLogger(name_without_prefix)
-    logger.setLevel(loglevel)
-    logger.addHandler(handler)
+    with _init_lock:
+        if _default_handler is not None:
+            return
+
+        formatter = CustomFormatter(
+            fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
+            style="{",
+            datefmt="%b %d %H:%M:%S",
+        )
+        _default_handler = logging.StreamHandler()
+        _default_handler.setFormatter(formatter)
+
+        _enable_default_handler("hivemind")
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+    """
+    Same as ``logging.getLogger()`` but ensures that the default log handler is initialized.
+    """
+
+    _initialize_if_necessary()
+    return logging.getLogger(name)
+
+
+def _enable_default_handler(name: str) -> None:
+    logger = get_logger(name)
+    logger.addHandler(_default_handler)
     logger.propagate = False
-    return logger
+    logger.setLevel(loglevel)
+
+
+def _disable_default_handler(name: str) -> None:
+    logger = get_logger(name)
+    logger.removeHandler(_default_handler)
+    logger.propagate = True
+    logger.setLevel(logging.NOTSET)
+
+
+def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
+    """
+    Choose loggers where the default hivemind log handler is applied. Options for the ``where`` argument are:
+
+    * "in_hivemind" (default): Use the hivemind log handler in the loggers of the ``hivemind`` package.
+                               Don't propagate their messages to the root logger.
+    * "nowhere": Don't use the hivemind log handler anywhere.
+                 Propagate the ``hivemind`` messages to the root logger.
+    * "in_root_logger": Use the hivemind log handler in the root logger
+                        (that is, in all application loggers until they disable propagation to the root logger).
+                        Propagate the ``hivemind`` messages to the root logger.
+
+    The options may be defined as strings (case-insensitive) or values from the HandlerMode enum.
+    """
+
+    global _current_mode
+
+    if isinstance(where, str):
+        # We allow `where` to be a string, so a developer does not have to import the enum for one usage
+        where = HandlerMode[where.upper()]
+
+    if where == _current_mode:
+        return
+
+    if _current_mode == HandlerMode.IN_HIVEMIND:
+        _disable_default_handler("hivemind")
+    elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
+        _disable_default_handler(None)
+
+    _current_mode = where
+
+    if _current_mode == HandlerMode.IN_HIVEMIND:
+        _enable_default_handler("hivemind")
+    elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
+        _enable_default_handler(None)
 
 
 def golog_level_to_python(level: str) -> int:

+ 5 - 5
setup.py

@@ -14,9 +14,10 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.5"
-P2PD_CHECKSUM = "affea8ec63dbe2423ef7453718b5798d"
+P2PD_VERSION = "v0.3.6"
+P2PD_CHECKSUM = "627d0c3b475a29331fdfd1667e828f6d"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
+P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
 
 here = os.path.abspath(os.path.dirname(__file__))
 
@@ -85,11 +86,10 @@ def download_p2p_daemon():
     binary_path = os.path.join(install_path, "p2pd")
     if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
         print("Downloading Peer to Peer Daemon")
-        url = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
-        urllib.request.urlretrieve(url, binary_path)
+        urllib.request.urlretrieve(P2PD_BINARY_URL, binary_path)
         os.chmod(binary_path, 0o777)
         if md5(binary_path) != P2PD_CHECKSUM:
-            raise RuntimeError(f"Downloaded p2pd binary from {url} does not match with md5 checksum")
+            raise RuntimeError(f"Downloaded p2pd binary from {P2PD_BINARY_URL} does not match with md5 checksum")
 
 
 class BuildPy(build_py):