Browse Source

Remove duplicate log entries, report aggregate runtime performance and parameter count (#135)

Max Ryabinin 4 years ago
parent
commit
d6ac1fbd8a

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.20'
+__version__ = '0.8.21'

+ 1 - 1
hivemind/client/averaging/__init__.py

@@ -23,7 +23,7 @@ from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
 
 
 INITIAL_GROUP_NBITS = 3
 INITIAL_GROUP_NBITS = 3
-logger = get_logger(__file__)
+logger = get_logger(__name__)
 
 
 
 
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
 class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):

+ 1 - 1
hivemind/client/averaging/matchmaking.py

@@ -20,7 +20,7 @@ from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
 from hivemind.utils.grpc import ChannelCache
 from hivemind.utils.grpc import ChannelCache
 
 
 
 
-logger = get_logger(__file__)
+logger = get_logger(__name__)
 
 
 
 
 class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
 class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):

+ 5 - 2
hivemind/dht/__init__.py

@@ -34,7 +34,7 @@ logger = get_logger(__name__)
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
 UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
 UidEndpoint = NamedTuple("UidEndpoint", [('uid', ExpertUID), ('endpoint', Endpoint)])
 UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
 UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
-FLAT_EXPERT = -1     # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
+FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
 UID_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$')  # e.g. ffn_expert.98.76.54 - prefix + some dims
 PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
 PREFIX_PATTERN = re.compile('^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$')  # e.g. expert. or ffn.45. (ends with ".")
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
@@ -160,7 +160,10 @@ class DHT(mp.Process):
                 method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
                 method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
                 asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
                 asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
 
 
-        loop.run_until_complete(_run())
+        try:
+            loop.run_until_complete(_run())
+        except KeyboardInterrupt:
+            logger.debug("Caught KeyboardInterrupt, shutting down")
 
 
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """

+ 28 - 41
hivemind/server/__init__.py

@@ -69,7 +69,7 @@ class Server(threading.Thread):
     @staticmethod
     @staticmethod
     def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
     def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
-               device=None, no_dht=False, initial_peers=(), dht_port=None, verbose=True,
+               device=None, no_dht=False, initial_peers=(), dht_port=None,
                compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
                compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
         """
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         Instantiate a server with several identical experts. See argparse comments below for details
@@ -91,31 +91,29 @@ class Server(threading.Thread):
         :param dht_port:  DHT node will listen on this port, default = find open port
         :param dht_port:  DHT node will listen on this port, default = find open port
            You can then use this node as initial peer for subsequent servers.
            You can then use this node as initial peer for subsequent servers.
 
 
-        :param verbose: whether to print server started / finished / terminated events
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
         :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
             hosted on this server. For a more fine-grained compression, start server in python and specify compression
             for each BatchTensorProto in ExpertBackend for the respective experts.
             for each BatchTensorProto in ExpertBackend for the respective experts.
 
 
         :param start: if True, starts server right away and returns when server is ready for requests
         :param start: if True, starts server right away and returns when server is ready for requests
         """
         """
-        if verbose and len(kwargs) != 0:
-            print("Ignored kwargs:", kwargs)
+        if len(kwargs) != 0:
+            logger.info("Ignored kwargs:", kwargs)
         assert expert_cls in name_to_block
         assert expert_cls in name_to_block
+        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
+            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
 
 
-        # initialize dht
-        dht = None
-        if not no_dht:
-            logger.info(f"Bootstrapping DHT node, initial peers = {initial_peers}")
+        if no_dht:
+            dht = None
+        else:
             dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
             dht_endpoint = replace_port(listen_on, dht_port or hivemind.find_open_port())
             dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
             dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
-            if verbose:
-                logger.info(f"Running dht node on port {dht.port}")
+            logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")
 
 
         # get expert uids
         # get expert uids
-        assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
-            "Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
         if expert_uids is None:
         if expert_uids is None:
             assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
             assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
+            logger.info(f"Generating expert uids from pattern {expert_pattern}")
             expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)
             expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)
 
 
         num_experts = len(expert_uids)
         num_experts = len(expert_uids)
@@ -130,7 +128,6 @@ class Server(threading.Thread):
             args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
             args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
 
 
         # initialize experts
         # initialize experts
-
         experts = {}
         experts = {}
         for expert_uid in expert_uids:
         for expert_uid in expert_uids:
             expert = name_to_block[expert_cls](hidden_dim)
             expert = name_to_block[expert_cls](hidden_dim)
@@ -139,18 +136,10 @@ class Server(threading.Thread):
                                                          outputs_schema=hivemind.BatchTensorDescriptor(
                                                          outputs_schema=hivemind.BatchTensorDescriptor(
                                                              hidden_dim, compression=compression),
                                                              hidden_dim, compression=compression),
                                                          opt=optim_cls(expert.parameters()),
                                                          opt=optim_cls(expert.parameters()),
-                                                         max_batch_size=max_batch_size,
-                                                         )
-        # actually start server
-        server = Server(
-            dht, experts, listen_on=listen_on,
-            num_connection_handlers=num_handlers, device=device)
+                                                         max_batch_size=max_batch_size)
 
 
-        if start:
-            server.run_in_background(await_ready=True)
-            if verbose:
-                logger.info(f"Server started at {server.listen_on}")
-                logger.info(f"Got {len(experts)} active experts of type {expert_cls}: {list(experts.keys())}")
+        server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
+                        start=start)
         return server
         return server
 
 
     def run(self):
     def run(self):
@@ -158,6 +147,12 @@ class Server(threading.Thread):
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         runs Runtime (self.runtime) to process incoming requests.
         """
         """
+        logger.info(f"Server started at {self.listen_on}")
+        logger.info(f"Got {len(self.experts)} experts:")
+        for expert_name, backend in self.experts.items():
+            num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
+            logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")
+
         if self.dht:
         if self.dht:
             if not self.dht.is_alive():
             if not self.dht.is_alive():
                 self.dht.run_in_background(await_ready=True)
                 self.dht.run_in_background(await_ready=True)
@@ -172,8 +167,6 @@ class Server(threading.Thread):
         for process in self.conn_handlers:
         for process in self.conn_handlers:
             if not process.is_alive():
             if not process.is_alive():
                 process.start()
                 process.start()
-
-        for process in self.conn_handlers:
             process.ready.wait()
             process.ready.wait()
 
 
         self.runtime.run()
         self.runtime.run()
@@ -227,11 +220,10 @@ class Server(threading.Thread):
 
 
 
 
 @contextmanager
 @contextmanager
-def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
+def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     pipe, runners_pipe = mp.Pipe(duplex=True)
     pipe, runners_pipe = mp.Pipe(duplex=True)
-    runner = mp.get_context("spawn").Process(
-        target=_server_runner, args=(runners_pipe, *args), kwargs=dict(verbose=verbose, **kwargs))
+    runner = mp.get_context("spawn").Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
 
 
     try:
     try:
         runner.start()
         runner.start()
@@ -240,15 +232,13 @@ def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tupl
     finally:
     finally:
         runner.join(timeout=shutdown_timeout)
         runner.join(timeout=shutdown_timeout)
         if runner.is_alive():
         if runner.is_alive():
-            if verbose:
-                logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
+            logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
             runner.kill()
             runner.kill()
-            if verbose:
-                logger.info("Server terminated.")
+            logger.info("Server terminated.")
 
 
 
 
-def _server_runner(pipe, *args, verbose, **kwargs):
-    server = Server.create(*args, verbose=verbose, start=True, **kwargs)
+def _server_runner(pipe, *args, **kwargs):
+    server = Server.create(*args, start=True, **kwargs)
     try:
     try:
         if server.dht is not None:
         if server.dht is not None:
             dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
             dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
@@ -257,12 +247,10 @@ def _server_runner(pipe, *args, verbose, **kwargs):
         pipe.send((server.listen_on, dht_listen_on))
         pipe.send((server.listen_on, dht_listen_on))
         pipe.recv()  # wait for shutdown signal
         pipe.recv()  # wait for shutdown signal
     finally:
     finally:
-        if verbose:
-            logger.info("Shutting down server...")
+        logger.info("Shutting down server...")
         server.shutdown()
         server.shutdown()
         server.join()
         server.join()
-        if verbose:
-            logger.info("Server shut down successfully.")
+        logger.info("Server shut down.")
 
 
 
 
 def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
 def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
@@ -277,7 +265,6 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
     :note: this method is not strictly process-safe. If several servers run it concurrently, they have
     :note: this method is not strictly process-safe. If several servers run it concurrently, they have
      a small chance of sampling duplicate expert uids.
      a small chance of sampling duplicate expert uids.
     """
     """
-    logger.info("Generating expert uids...")
     remaining_attempts = attempts_per_expert * num_experts
     remaining_attempts = attempts_per_expert * num_experts
     found_uids, attempted_uids = list(), set()
     found_uids, attempted_uids = list(), set()
 
 
@@ -298,7 +285,7 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
             except KeyboardInterrupt as e:
             except KeyboardInterrupt as e:
                 raise e
                 raise e
             except Exception as e:
             except Exception as e:
-                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block} , {e}")
+                raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
         return hivemind.dht.UID_DELIMITER.join(uid)
         return hivemind.dht.UID_DELIMITER.join(uid)
 
 
     while remaining_attempts > 0 and len(found_uids) < num_experts:
     while remaining_attempts > 0 and len(found_uids) < num_experts:

+ 62 - 4
hivemind/server/runtime.py

@@ -1,9 +1,13 @@
 import multiprocessing as mp
 import multiprocessing as mp
 import multiprocessing.pool
 import multiprocessing.pool
 import threading
 import threading
+from collections import defaultdict
 from itertools import chain
 from itertools import chain
+from queue import SimpleQueue
 from selectors import DefaultSelector, EVENT_READ
 from selectors import DefaultSelector, EVENT_READ
-from typing import Dict
+from statistics import mean
+from time import time
+from typing import Dict, NamedTuple
 
 
 import torch
 import torch
 from prefetch_generator import BackgroundGenerator
 from prefetch_generator import BackgroundGenerator
@@ -31,14 +35,15 @@ class Runtime(threading.Thread):
 
 
     :param expert_backends: a dict [expert uid -> ExpertBackend]
     :param expert_backends: a dict [expert uid -> ExpertBackend]
     :param prefetch_batches: form up to this many batches in advance
     :param prefetch_batches: form up to this many batches in advance
-    :param start: start runtime immediately (at the end of __init__)
     :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
     :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
     :param device: if specified, moves all experts and data to this device via .to(device=device).
     :param device: if specified, moves all experts and data to this device via .to(device=device).
       If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
       If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
+
+    :param stats_report_interval: interval to collect and log statistics about runtime performance
     """
     """
 
 
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
-                 device: torch.device = None):
+                 device: torch.device = None, stats_report_interval=30):
         super().__init__()
         super().__init__()
         self.expert_backends = expert_backends
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
@@ -46,6 +51,8 @@ class Runtime(threading.Thread):
         self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
 
 
+        self.stats_reporter = StatsReporter(stats_report_interval)
+
     def run(self):
     def run(self):
         for pool in self.pools:
         for pool in self.pools:
             if not pool.is_alive():
             if not pool.is_alive():
@@ -57,15 +64,25 @@ class Runtime(threading.Thread):
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:
             try:
                 self.ready.set()
                 self.ready.set()
+                self.stats_reporter.start()
                 logger.info("Started")
                 logger.info("Started")
                 for pool, batch_index, batch in BackgroundGenerator(
                 for pool, batch_index, batch in BackgroundGenerator(
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
                     logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
                     logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
+
+                    start = time()
                     outputs = pool.process_func(*batch)
                     outputs = pool.process_func(*batch)
-                    logger.info(f"Pool {pool.uid}: batch {batch_index} processed, size {outputs[0].size(0)}")
+                    batch_processing_time = time() - start
+
+                    batch_size = outputs[0].size(0)
+                    logger.debug(f"Pool {pool.uid}: batch {batch_index} processed, size {batch_size}")
+                    self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)
+
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
             finally:
             finally:
                 logger.info("Shutting down")
                 logger.info("Shutting down")
+                self.stats_reporter.stop.set()
+                self.stats_reporter.join()
                 self.shutdown()
                 self.shutdown()
 
 
     SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
     SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
@@ -103,3 +120,44 @@ class Runtime(threading.Thread):
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
                 logger.debug(f"Loaded batch from {pool.uid}")
                 logger.debug(f"Loaded batch from {pool.uid}")
                 yield pool, batch_index, batch_tensors
                 yield pool, batch_index, batch_tensors
+
+
+BatchStats = NamedTuple('BatchStats', (('batch_size', int), ('processing_time', float)))
+
+
+class StatsReporter(threading.Thread):
+    def __init__(self, report_interval: int):
+        super().__init__()
+        self.report_interval = report_interval
+        self.stop = threading.Event()
+        self.stats_queue = SimpleQueue()
+
+    def run(self):
+        while not self.stop.wait(self.report_interval):
+            pool_batch_stats = defaultdict(list)
+            while not self.stats_queue.empty():
+                pool_uid, batch_stats = self.stats_queue.get()
+                pool_batch_stats[pool_uid].append(batch_stats)
+
+            total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
+            logger.info(f'Processed {total_processed_batches} batches in last {self.report_interval} seconds:')
+            for pool_uid, pool_stats in pool_batch_stats.items():
+                total_batches = len(pool_stats)
+                total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
+                avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
+                total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
+                batches_to_time = total_batches / total_time
+                batch_performance = f'{batches_to_time:.2f} ' + ('batches/s' if batches_to_time > 1 else 's/batch')
+
+                examples_to_time = total_examples / total_time
+                example_performance = f'{examples_to_time:.2f} ' + (
+                    'examples/s' if examples_to_time > 1 else 's/example')
+
+                logger.info(f'{pool_uid}: '
+                            f'{total_batches} batches ({batch_performance}), '
+                            f'{total_examples} examples ({example_performance}), '
+                            f'avg batch size {avg_batch_size:.2f}')
+
+    def report_stats(self, pool_uid, batch_size, processing_time):
+        batch_stats = BatchStats(batch_size, processing_time)
+        self.stats_queue.put_nowait((pool_uid, batch_stats))

+ 52 - 45
hivemind/server/task_pool.py

@@ -142,55 +142,62 @@ class TaskPool(TaskPoolBase):
 
 
     def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
     def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
         """ Infinite loop: aggregate tasks into batches and send them to runtime """
         """ Infinite loop: aggregate tasks into batches and send them to runtime """
-        prev_num_tasks = 0  # number of tasks currently in shared buffer
-        batch_index = max(pending_batches.keys(), default=0)
-        batch_iterator = self.iterate_minibatches(*args, **kwargs)
-
-        while True:
-            # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
-            # assumes that tasks are processed in the same order as they are created
-            for skip_i in range(prev_num_tasks):
-                finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
-                if skip_i == prev_num_tasks - 1:
-                    self.priority = finished_task_timestamp
-
-            logger.debug(f"{self.uid} getting next batch")
-            batch_tasks = next(batch_iterator)
-            # save batch futures, _output_loop will deliver on them later
-            pending_batches[batch_index] = batch_tasks
-
-            logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
-            # find or create shared arrays for current batch size
-            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
-            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
-
-            logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
-            self.batch_sender.send((batch_index, batch_inputs))
-            logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
-            prev_num_tasks = len(batch_tasks)
-            batch_index += 1
+        try:
+            prev_num_tasks = 0  # number of tasks currently in shared buffer
+            batch_index = max(pending_batches.keys(), default=0)
+            batch_iterator = self.iterate_minibatches(*args, **kwargs)
+
+            while True:
+                # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
+                # assumes that tasks are processed in the same order as they are created
+                for skip_i in range(prev_num_tasks):
+                    finished_task_timestamp = self.undispatched_task_timestamps.get()  # earlier timestamp = higher priority
+                    if skip_i == prev_num_tasks - 1:
+                        self.priority = finished_task_timestamp
+
+                logger.debug(f"{self.uid} getting next batch")
+                batch_tasks = next(batch_iterator)
+                # save batch futures, _output_loop will deliver on them later
+                pending_batches[batch_index] = batch_tasks
+
+                logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
+                # find or create shared arrays for current batch size
+                batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
+                                range(len(batch_tasks[0].args))]
+                batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
+
+                logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
+                self.batch_sender.send((batch_index, batch_inputs))
+                logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
+                prev_num_tasks = len(batch_tasks)
+                batch_index += 1
+        except KeyboardInterrupt:
+            logger.debug('Caught KeyboardInterrupt, shutting down')
 
 
     def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
     def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
 
 
-        while True:
-            logger.debug(f"{self.uid} waiting for results from runtime")
-            payload = self.outputs_receiver.recv()
-            if isinstance(payload, BaseException):
-                raise payload
-            else:
-                batch_index, batch_outputs = payload
-            logger.debug(f"{self.uid}, batch {batch_index}: got results")
-
-            # split batch into partitions for individual tasks
-            batch_tasks = pending_batches.pop(batch_index)
-            task_sizes = [self.get_task_size(task) for task in batch_tasks]
-            outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
-            logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
-
-            # dispatch results to futures
-            for task, task_outputs in zip(batch_tasks, outputs_per_task):
-                task.future.set_result(tuple(task_outputs))
+        try:
+            while True:
+                logger.debug(f"{self.uid} waiting for results from runtime")
+                payload = self.outputs_receiver.recv()
+                if isinstance(payload, BaseException):
+                    raise payload
+                else:
+                    batch_index, batch_outputs = payload
+                logger.debug(f"{self.uid}, batch {batch_index}: got results")
+
+                # split batch into partitions for individual tasks
+                batch_tasks = pending_batches.pop(batch_index)
+                task_sizes = [self.get_task_size(task) for task in batch_tasks]
+                outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
+                logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
+
+                # dispatch results to futures
+                for task, task_outputs in zip(batch_tasks, outputs_per_task):
+                    task.future.set_result(tuple(task_outputs))
+        except KeyboardInterrupt:
+            logger.debug(f"Caught KeyboardInterrupt, shutting down")
 
 
     @property
     @property
     def empty(self):
     def empty(self):

+ 1 - 1
hivemind/utils/grpc.py

@@ -16,7 +16,7 @@ from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithEx
 from hivemind.utils.networking import Endpoint
 from hivemind.utils.networking import Endpoint
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-logger = get_logger(__file__)
+logger = get_logger(__name__)
 
 
 Stub = TypeVar("Stub")
 Stub = TypeVar("Stub")
 
 

+ 1 - 0
hivemind/utils/logging.py

@@ -15,4 +15,5 @@ def get_logger(module_name: str) -> logging.Logger:
     logger = logging.getLogger(name_without_prefix)
     logger = logging.getLogger(name_without_prefix)
     logger.setLevel(loglevel)
     logger.setLevel(loglevel)
     logger.addHandler(handler)
     logger.addHandler(handler)
+    logger.propagate = False
     return logger
     return logger

+ 7 - 1
scripts/run_server.py

@@ -6,6 +6,9 @@ import torch
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.server import Server
 from hivemind.server import Server
 from hivemind.utils.threading import increase_file_limit
 from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
 
 
 
 
 def main():
 def main():
@@ -62,9 +65,12 @@ def main():
     else:
     else:
         compression = getattr(CompressionType, compression_type)
         compression = getattr(CompressionType, compression_type)
 
 
+    server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression)
+
     try:
     try:
-        server = Server.create(**args, optim_cls=optim_cls, start=True, verbose=True, compression=compression)
         server.join()
         server.join()
+    except KeyboardInterrupt:
+        logger.info("Caught KeyboardInterrupt, shutting down")
     finally:
     finally:
         server.shutdown()
         server.shutdown()
 
 

+ 1 - 1
tests/benchmark_dht.py

@@ -7,7 +7,7 @@ from tqdm import trange
 import hivemind
 import hivemind
 from hivemind.utils.threading import increase_file_limit
 from hivemind.utils.threading import increase_file_limit
 
 
-logger = hivemind.get_logger(__file__)
+logger = hivemind.get_logger(__name__)
 
 
 
 
 def random_endpoint() -> hivemind.Endpoint:
 def random_endpoint() -> hivemind.Endpoint: