Quellcode durchsuchen

Remove duplicate log entries, report aggregate runtime performance and parameter count (#135)

Max Ryabinin vor 4 Jahren
Ursprung
Commit
d6ac1fbd8a

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.20'
+__version__ = '0.8.21'

+ 1 - 1
hivemind/client/averaging/__init__.py

@@ -23,7 +23,7 @@ from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 
 INITIAL_GROUP_NBITS = 3
-logger = get_logger(__file__)
+logger = get_logger(__name__)
 
 
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):

+ 1 - 1
hivemind/client/averaging/matchmaking.py

@@ -20,7 +20,7 @@ from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 from hivemind.utils.grpc import ChannelCache
 
 
-logger = get_logger(__file__)
+logger = get_logger(__name__)
 
 
 class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):

+ 5 - 2
hivemind/dht/__init__.py

@@ -34,7 +34,7 @@ logger = get_logger(__name__)
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
 UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
-FLAT_EXPERT = -1     # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
+FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
 PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
@@ -160,7 +160,10 @@ class DHT(mp.Process):
                 method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
                 asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
 
-        loop.run_until_complete(_run())
+        try:
+            loop.run_until_complete(_run())
+        except KeyboardInterrupt:
+            logger.debug("Caught KeyboardInterrupt, shutting down")
 
     def run_in_background(self, await_ready=True, timeout=None):
         """

+ 28 - 41
hivemind/server/__init__.py

@@ -69,7 +69,7 @@ class Server(threading.Thread):
     @staticmethod
     def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
-               device=None, no_dht=False, initial_peers=(), dht_port=None, verbose=True,
+               device=None, no_dht=False, initial_peers=(), dht_port=None,
                compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
@@ -91,31 +91,29 @@ class Server(threading.Thread):
         :param dht_port:  DHT node will listen on this port, default = find open port
            You can then use this node as initial peer for subsequent servers.
 
-        :param verbose: whether to print server started / finished / terminated events
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
             for each BatchTensorProto in ExpertBackend for the respective experts.
 
         :param start: if True, starts server right away and returns when server is ready for requests
         """
-        if verbose and len(kwargs) != 0:
-            print("Ignored kwargs:", kwargs)
+        if len(kwargs) != 0:
+            logger.info("Ignored kwargs:", kwargs)
         assert expert_cls in name_to_block
+        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
+            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
 
-        # initialize dht
-        dht = None
-        if not no_dht:
-            logger.info(f"Bootstrapping DHT node, initial peers = {initial_peers}")
+        if no_dht:
+            dht = None
+        else:
             dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
             dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
-            if verbose:
-                logger.info(f"Running dht node on port {dht.port}")
+            logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
 
         # get expert uids
-        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
-            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
         if expert_uids is None:
             assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
+            logger.info(f"Generating expert uids from pattern {expert_pattern}")
             expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)
 
         num_experts = len(expert_uids)
@@ -130,7 +128,6 @@ class Server(threading.Thread):
             args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
         # initialize experts
-
         experts = {}
         for expert_uid in expert_uids:
             expert = name_to_block[expert_cls](hidden_dim)
@@ -139,18 +136,10 @@ class Server(threading.Thread):
                                                          outputs_schema=hivemind.BatchTensorDescriptor(
                                                              hidden_dim, compression=compression),
                                                          opt=optim_cls(expert.parameters()),
-                                                         max_batch_size=max_batch_size,
-                                                         )
-        # actually start server
-        server = Server(
-            dht, experts, listen_on=listen_on,
-            num_connection_handlers=num_handlers, device=device)
+                                                         max_batch_size=max_batch_size)
 
-        if start:
-            server.run_in_background(await_ready=True)
-            if verbose:
-                logger.info(f"Server started at {server.listen_on}")
-                logger.info(f"Got {len(experts)} active experts of type {expert_cls}: {list(experts.keys())}")
+        server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
+                        start=start)
         return server
 
     def run(self):
@@ -158,6 +147,12 @@ class Server(threading.Thread):
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         """
+        logger.info(f"Server started at {self.listen_on}")
+        logger.info(f"Got {len(self.experts)} experts:")
+        for expert_name, backend in self.experts.items():
+            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
+
         if self.dht:
             if not self.dht.is_alive():
                 self.dht.run_in_background(await_ready=True)
@@ -172,8 +167,6 @@ class Server(threading.Thread):
         for process in self.conn_handlers:
             if not process.is_alive():
                 process.start()
-
-        for process in self.conn_handlers:
             process.ready.wait()
 
         self.runtime.run()
@@ -227,11 +220,10 @@ class Server(threading.Thread):
 
 
 @contextmanager
-def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
+def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     pipe, runners_pipe = mp.Pipe(duplex=True)
-    runner = mp.get_context("spawn").Process(
-        target=_server_runner, args=(runners_pipe, *args), kwargs=dict(verbose=verbose, **kwargs))
+    runner = mp.get_context("spawn").Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
 
     try:
         runner.start()
@@ -240,15 +232,13 @@ def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tupl
     finally:
         runner.join(timeout=shutdown_timeout)
         if runner.is_alive():
-            if verbose:
-                logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
+            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
             runner.kill()
-            if verbose:
-                logger.info("Server terminated.")
+            logger.info("Server terminated.")
 
 
-def _server_runner(pipe, *args, verbose, **kwargs):
-    server = Server.create(*args, verbose=verbose, start=True, **kwargs)
+def _server_runner(pipe, *args, **kwargs):
+    server = Server.create(*args, start=True, **kwargs)
     try:
         if server.dht is not None:
             dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
@@ -257,12 +247,10 @@ def _server_runner(pipe, *args, verbose, **kwargs):
         pipe.send((server.listen_on, dht_listen_on))
         pipe.recv()  # wait for shutdown signal
     finally:
-        if verbose:
-            logger.info("Shutting down server...")
+        logger.info("Shutting down server...")
         server.shutdown()
         server.join()
-        if verbose:
-            logger.info("Server shut down successfully.")
+        logger.info("Server shut down.")
 
 
 def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
@@ -277,7 +265,6 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
     :note: this method is not strictly process-safe. If several servers run it concurrently, they have
      a small chance of sampling duplicate expert uids.
     """
-    logger.info("Generating expert uids...")
     remaining_attempts = attempts_per_expert * num_experts
     found_uids, attempted_uids = list(), set()
 
@@ -298,7 +285,7 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
             except KeyboardInterrupt as e:
                 raise e
             except Exception as e:
-                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block} , {e}")
+                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
         return hivemind.dht.UID_DELIMITER.join(uid)
 
     while remaining_attempts > 0 and len(found_uids) < num_experts:

+ 62 - 4
hivemind/server/runtime.py

@@ -1,9 +1,13 @@
 import multiprocessing as mp
 import multiprocessing.pool
 import threading
+from collections import defaultdict
 from itertools import chain
+from queue import SimpleQueue
 from selectors import DefaultSelector, EVENT_READ
-from typing import Dict
+from statistics import mean
+from time import time
+from typing import Dict, NamedTuple
 
 import torch
 from prefetch_generator import BackgroundGenerator
@@ -31,14 +35,15 @@ class Runtime(threading.Thread):
 
     :param expert_backends: a dict [expert uid -> ExpertBackend]
     :param prefetch_batches: form up to this many batches in advance
-    :param start: start runtime immediately (at the end of __init__)
     :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
     :param device: if specified, moves all experts and data to this device via .to(device=device).
       If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
+
+    :param stats_report_interval: interval to collect and log statistics about runtime performance
     """
 
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
-                 device: torch.device = None):
+                 device: torch.device = None, stats_report_interval=30):
         super().__init__()
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
@@ -46,6 +51,8 @@ class Runtime(threading.Thread):
         self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
 
+        self.stats_reporter = StatsReporter(stats_report_interval)
+
     def run(self):
         for pool in self.pools:
             if not pool.is_alive():
@@ -57,15 +64,25 @@ class Runtime(threading.Thread):
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:
                 self.ready.set()
+                self.stats_reporter.start()
                 logger.info("Started")
                 for pool, batch_index, batch in BackgroundGenerator(
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
                     logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
+
+                    start = time()
                     outputs = pool.process_func(*batch)
-                    logger.info(f"Pool {pool.uid}: batch {batch_index} processed, size {outputs[0].size(0)}")
+                    batch_processing_time = time() - start
+
+                    batch_size = outputs[0].size(0)
+                    logger.debug(f"Pool {pool.uid}: batch {batch_index} processed, size {batch_size}")
+                    self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)
+
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
             finally:
                 logger.info("Shutting down")
+                self.stats_reporter.stop.set()
+                self.stats_reporter.join()
                 self.shutdown()
 
     SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
@@ -103,3 +120,44 @@ class Runtime(threading.Thread):
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
                 logger.debug(f"Loaded batch from {pool.uid}")
                 yield pool, batch_index, batch_tensors
+
+
+BatchStats = NamedTuple('BatchStats', (('batch_size', int), ('processing_time', float)))
+
+
+class StatsReporter(threading.Thread):
+    def __init__(self, report_interval: int):
+        super().__init__()
+        self.report_interval = report_interval
+        self.stop = threading.Event()
+        self.stats_queue = SimpleQueue()
+
+    def run(self):
+        while not self.stop.wait(self.report_interval):
+            pool_batch_stats = defaultdict(list)
+            while not self.stats_queue.empty():
+                pool_uid, batch_stats = self.stats_queue.get()
+                pool_batch_stats[pool_uid].append(batch_stats)
+
+            total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
+            logger.info(f'Processed {total_processed_batches} batches in last {self.report_interval} seconds:')
+            for pool_uid, pool_stats in pool_batch_stats.items():
+                total_batches = len(pool_stats)
+                total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
+                avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
+                total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
+                batches_to_time = total_batches / total_time
+                batch_performance = f'{batches_to_time:.2f} ' + ('batches/s' if batches_to_time > 1 else 's/batch')
+
+                examples_to_time = total_examples / total_time
+                example_performance = f'{examples_to_time:.2f} ' + (
+                    'examples/s' if examples_to_time > 1 else 's/example')
+
+                logger.info(f'{pool_uid}: '
+                            f'{total_batches} batches ({batch_performance}), '
+                            f'{total_examples} examples ({example_performance}), '
+                            f'avg batch size {avg_batch_size:.2f}')
+
+    def report_stats(self, pool_uid, batch_size, processing_time):
+        batch_stats = BatchStats(batch_size, processing_time)
+        self.stats_queue.put_nowait((pool_uid, batch_stats))

+ 52 - 45
hivemind/server/task_pool.py

@@ -142,55 +142,62 @@ class TaskPool(TaskPoolBase):
 
     def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
         """ Infinite loop: aggregate tasks into batches and send them to runtime """
-        prev_num_tasks = 0  # number of tasks currently in shared buffer
-        batch_index = max(pending_batches.keys(), default=0)
-        batch_iterator = self.iterate_minibatches(*args, **kwargs)
-
-        while True:
-            # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
-            # assumes that tasks are processed in the same order as they are created
-            for skip_i in range(prev_num_tasks):
-                finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
-                if skip_i == prev_num_tasks - 1:
-                    self.priority = finished_task_timestamp
-
-            logger.debug(f"{self.uid} getting next batch")
-            batch_tasks = next(batch_iterator)
-            # save batch futures, _output_loop will deliver on them later
-            pending_batches[batch_index] = batch_tasks
-
-            logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
-            # find or create shared arrays for current batch size
-            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
-            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
-
-            logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
-            self.batch_sender.send((batch_index, batch_inputs))
-            logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
-            prev_num_tasks = len(batch_tasks)
-            batch_index += 1
+        try:
+            prev_num_tasks = 0  # number of tasks currently in shared buffer
+            batch_index = max(pending_batches.keys(), default=0)
+            batch_iterator = self.iterate_minibatches(*args, **kwargs)
+
+            while True:
+                # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
+                # assumes that tasks are processed in the same order as they are created
+                for skip_i in range(prev_num_tasks):
+                    finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
+                    if skip_i == prev_num_tasks - 1:
+                        self.priority = finished_task_timestamp
+
+                logger.debug(f"{self.uid} getting next batch")
+                batch_tasks = next(batch_iterator)
+                # save batch futures, _output_loop will deliver on them later
+                pending_batches[batch_index] = batch_tasks
+
+                logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
+                # find or create shared arrays for current batch size
+                batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
+                                range(len(batch_tasks[0].args))]
+                batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
+
+                logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
+                self.batch_sender.send((batch_index, batch_inputs))
+                logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
+                prev_num_tasks = len(batch_tasks)
+                batch_index += 1
+        except KeyboardInterrupt:
+            logger.debug('Caught KeyboardInterrupt, shutting down')
 
     def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
 
-        while True:
-            logger.debug(f"{self.uid} waiting for results from runtime")
-            payload = self.outputs_receiver.recv()
-            if isinstance(payload, BaseException):
-                raise payload
-            else:
-                batch_index, batch_outputs = payload
-            logger.debug(f"{self.uid}, batch {batch_index}: got results")
-
-            # split batch into partitions for individual tasks
-            batch_tasks = pending_batches.pop(batch_index)
-            task_sizes = [self.get_task_size(task) for task in batch_tasks]
-            outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
-            logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
-
-            # dispatch results to futures
-            for task, task_outputs in zip(batch_tasks, outputs_per_task):
-                task.future.set_result(tuple(task_outputs))
+        try:
+            while True:
+                logger.debug(f"{self.uid} waiting for results from runtime")
+                payload = self.outputs_receiver.recv()
+                if isinstance(payload, BaseException):
+                    raise payload
+                else:
+                    batch_index, batch_outputs = payload
+                logger.debug(f"{self.uid}, batch {batch_index}: got results")
+
+                # split batch into partitions for individual tasks
+                batch_tasks = pending_batches.pop(batch_index)
+                task_sizes = [self.get_task_size(task) for task in batch_tasks]
+                outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
+                logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
+
+                # dispatch results to futures
+                for task, task_outputs in zip(batch_tasks, outputs_per_task):
+                    task.future.set_result(tuple(task_outputs))
+        except KeyboardInterrupt:
+            logger.debug(f"Caught KeyboardInterrupt, shutting down")
 
     @property
     def empty(self):

+ 1 - 1
hivemind/utils/grpc.py

@@ -16,7 +16,7 @@ from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithEx
 from hivemind.utils.networking import Endpoint
 from hivemind.utils.logging import get_logger
 
-logger = get_logger(__file__)
+logger = get_logger(__name__)
 
 Stub = TypeVar("Stub")
 

+ 1 - 0
hivemind/utils/logging.py

@@ -15,4 +15,5 @@ def get_logger(module_name: str) -> logging.Logger:
     logger = logging.getLogger(name_without_prefix)
     logger.setLevel(loglevel)
     logger.addHandler(handler)
+    logger.propagate = False
     return logger

+ 7 - 1
scripts/run_server.py

@@ -6,6 +6,9 @@ import torch
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.server import Server
 from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
 
 
 def main():
@@ -62,9 +65,12 @@ def main():
     else:
         compression = getattr(CompressionType, compression_type)
 
+    server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression)
+
     try:
-        server = Server.create(**args, optim_cls=optim_cls, start=True, verbose=True, compression=compression)
         server.join()
+    except KeyboardInterrupt:
+        logger.info("Caught KeyboardInterrupt, shutting down")
     finally:
         server.shutdown()
 

+ 1 - 1
tests/benchmark_dht.py

@@ -7,7 +7,7 @@ from tqdm import trange
 import hivemind
 from hivemind.utils.threading import increase_file_limit
 
-logger = hivemind.get_logger(__file__)
+logger = hivemind.get_logger(__name__)
 
 
 def random_endpoint() -> hivemind.Endpoint: