Преглед изворни кода

refactor priority pool, copy-paste runtime from hivemind (but without bugs!)

justheuristic пре 2 година
родитељ
комит
65ad7a131e
6 измењених фајлова са 412 додато и 59 уклоњено
  1. 4 0
      cli/run_server.py
  2. 1 2
      src/server/backend.py
  3. 198 0
      src/server/runtime.py
  4. 147 57
      src/server/task_pool.py
  5. 0 0
      tests/test_point_system.py
  6. 62 0
      tests/test_priority_pool.py

+ 4 - 0
cli/run_server.py

@@ -34,6 +34,10 @@ def main():
                         help='Minimum required batch size for all expert operations')
     parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of tokens in the same batch will not exceed this value')
+    parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
+                        help='Pre-form this many subsequent batches while GPU is processing the current one')
+    parser.add_argument('--sender_threads', type=int, default=1, required=False,
+                        help='Use this many threads to pass results/exceptions from Runtime to Pools')
     parser.add_argument('--inference_max_length', type=int, default=16384,
                         help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
     parser.add_argument('--cache_dir', type=str, default=None, 

+ 1 - 2
src/server/backend.py

@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, Sequence, Tuple
 import torch
 from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
-from hivemind.moe.server.task_pool import TaskPool
 from hivemind.utils import get_logger
 
 from src.bloom.from_pretrained import BloomBlock
@@ -72,7 +71,7 @@ class TransformerBackend(ModuleBackend):
                 cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
                 return (hidden_states,)
 
-    def get_pools(self) -> Sequence[TaskPool]:
+    def get_pools(self) -> Sequence[PrioritizedTaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
 
     def get_info(self) -> Dict[str, Any]:

+ 198 - 0
src/server/runtime.py

@@ -0,0 +1,198 @@
+import multiprocessing as mp
+import multiprocessing.pool
+import threading
+from collections import defaultdict
+from itertools import chain
+from queue import SimpleQueue
+from selectors import EVENT_READ, DefaultSelector
+from statistics import mean
+from time import time
+from typing import Dict, NamedTuple, Optional
+
+import torch
+from hivemind.moe.server.module_backend import ModuleBackend
+from hivemind.utils import get_logger
+from prefetch_generator import BackgroundGenerator
+
+logger = get_logger(__name__)
+
+
+class Runtime(threading.Thread):
+    """
+    A group of processes that processes incoming requests for multiple module backends on a shared device.
+    Runtime is usually created and managed by Server, humans need not apply.
+
+    For debugging, you can start runtime manually with .start() or .run()
+
+    >>> module_backends = {'expert_name': ModuleBackend(**kwargs)}
+    >>> runtime = Runtime(module_backends)
+    >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
+    >>> runtime.ready.wait()  # await for runtime to load all experts on device and create request pools
+    >>> future = runtime.module_backends['expert_name'].forward_pool.submit_task(*module_inputs)
+    >>> print("Returned:", future.result())
+    >>> runtime.shutdown()
+
+    :param module_backends: a dict [expert uid -> ModuleBackend]
+    :param prefetch_batches: form up to this many batches in advance
+    :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
+    :param device: if specified, moves all experts and data to this device via .to(device=device).
+      If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
+
+    :param stats_report_interval: interval to collect and log statistics about runtime performance
+    """
+
+    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
+
+    def __init__(
+        self,
+        module_backends: Dict[str, ModuleBackend],
+        prefetch_batches: int = 1,
+        sender_threads: int = 1,
+        device: torch.device = None,
+        stats_report_interval: Optional[int] = None,
+    ):
+        super().__init__()
+        self.module_backends = module_backends
+        self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
+        self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
+        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
+        self.shutdown_trigger = mp.Event()
+        self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
+
+        self.stats_report_interval = stats_report_interval
+        if self.stats_report_interval is not None:
+            self.stats_reporter = StatsReporter(self.stats_report_interval)
+
+    def run(self):
+        for pool in self.pools:
+            if not pool.is_alive():
+                pool.start()
+        if self.device is not None:
+            for backend in self.module_backends.values():
+                backend.module.to(self.device)
+
+        with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
+            try:
+                self.ready.set()
+                if self.stats_report_interval is not None:
+                    self.stats_reporter.start()
+                logger.info("Started")
+
+                batch_iterator = self.iterate_minibatches_from_pools()
+                if self.prefetch_batches > 0:
+                    batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
+
+                for pool, batch_index, batch in batch_iterator:
+                    logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
+
+                    start = time()
+                    try:
+                        outputs = pool.process_func(*batch)
+                        output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
+
+                        batch_processing_time = time() - start
+
+                        batch_size = outputs[0].size(0)
+                        logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
+
+                        if self.stats_report_interval is not None:
+                            self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
+
+                    except KeyboardInterrupt:
+                        raise
+                    except BaseException as exception:
+                        logger.exception(f"Caught {exception}, attempting to recover")
+                        output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
+
+            finally:
+                if not self.shutdown_trigger.is_set():
+                    self.shutdown()
+
+    def shutdown(self):
+        """Gracefully terminate a running runtime."""
+        logger.info("Shutting down")
+        self.ready.clear()
+
+        if self.stats_report_interval is not None:
+            self.stats_reporter.stop.set()
+            self.stats_reporter.join()
+
+        logger.debug("Terminating pools")
+        for pool in self.pools:
+            if pool.is_alive():
+                pool.shutdown()
+        logger.debug("Pools terminated")
+
+        # trigger background thread to shutdown
+        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
+        self.shutdown_trigger.set()
+
+    def iterate_minibatches_from_pools(self, timeout=None):
+        """
+        Chooses pool according to priority, then copies exposed batch and frees the buffer
+        """
+        with DefaultSelector() as selector:
+            for pool in self.pools:
+                selector.register(pool.batch_receiver, EVENT_READ, pool)
+            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
+
+            while True:
+                # wait until at least one batch_receiver becomes available
+                logger.debug("Waiting for inputs from task pools")
+                ready_fds = selector.select()
+                ready_objects = {key.data for (key, events) in ready_fds}
+                if self.SHUTDOWN_TRIGGER in ready_objects:
+                    break  # someone asked us to shutdown, break from the loop
+
+                logger.debug("Choosing the pool with first priority")
+
+                pool = min(ready_objects, key=lambda pool: pool.priority)
+
+                logger.debug(f"Loading batch from {pool.name}")
+                batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
+                logger.debug(f"Loaded batch from {pool.name}")
+                yield pool, batch_index, batch_tensors
+
+
+BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
+
+
+class StatsReporter(threading.Thread):
+    def __init__(self, report_interval: int):
+        super().__init__()
+        self.report_interval = report_interval
+        self.stop = threading.Event()
+        self.stats_queue = SimpleQueue()
+
+    def run(self):
+        while not self.stop.wait(self.report_interval):
+            pool_batch_stats = defaultdict(list)
+            while not self.stats_queue.empty():
+                pool_uid, batch_stats = self.stats_queue.get()
+                pool_batch_stats[pool_uid].append(batch_stats)
+
+            total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
+            logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
+            for pool_uid, pool_stats in pool_batch_stats.items():
+                total_batches = len(pool_stats)
+                total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
+                avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
+                total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
+                batches_to_time = total_batches / total_time
+                batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
+
+                examples_to_time = total_examples / total_time
+                example_performance = f"{examples_to_time:.2f} " + (
+                    "examples/s" if examples_to_time > 1 else "s/example"
+                )
+
+                logger.info(
+                    f"{pool_uid}: "
+                    f"{total_batches} batches ({batch_performance}), "
+                    f"{total_examples} examples ({example_performance}), "
+                    f"avg batch size {avg_batch_size:.2f}"
+                )
+
+    def report_stats(self, pool_uid, batch_size, processing_time):
+        batch_stats = BatchStats(batch_size, processing_time)
+        self.stats_queue.put_nowait((pool_uid, batch_stats))

+ 147 - 57
src/server/task_pool.py

@@ -1,87 +1,177 @@
+import ctypes
 import multiprocessing as mp
 import os
 import threading
+import time
 from concurrent.futures import Future, InvalidStateError
 from dataclasses import dataclass, field
-from queue import PriorityQueue, Empty
-from typing import Sequence
+from queue import Empty, PriorityQueue
+from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
 
 import torch
-from hivemind import MPFuture, use_hivemind_log_handler, get_logger
+from hivemind import MPFuture, get_logger, use_hivemind_log_handler
 from hivemind.moe.server.task_pool import TaskPoolBase
 
-
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
 @dataclass(order=True, frozen=True)
-class PrioritizedTask:
+class Task:
     priority: float
+    time_submitted: float
     future: MPFuture = field(compare=False)
-    args: Sequence[torch.Tensor]  = field(compare=False)
+    args: Sequence[torch.Tensor] = field(compare=False)
+
+    @property
+    def uid(self) -> int:
+        return self.future._uid
 
 
 class PrioritizedTaskPool(TaskPoolBase):
     """
+    Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
+    returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
+    A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward)
+
+    :note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches.
+      This would require grouping requests of different length.
+
+    :param process_func: function to be applied to every formed batch; called by Runtime
+        Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
+    :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
+         Measured in the total number of tokens (i.e. batch size * sequence length)
 
+    :param name: pool name, used for logging
+    :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
+    :param start: if True, start automatically at the end of __init__
     """
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
 
-        assert self.min_batch_size == 1, "PriorityTaskPool supports no batching"
+    def __init__(
+        self,
+        process_func: callable,
+        max_batch_size: int,
+        name: str,
+        min_batch_size=1,
+        daemon=True,
+        start=False,
+    ):
+        super().__init__(process_func, daemon=daemon, name=name)
+        self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+
+        self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
+        self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
+
+        self._prioritizer_thread = threading.Thread(
+            name=self.name + "_prioritizer",
+            target=self._prioritize_tasks,
+            args=[self.submitted_tasks, self._ordered_tasks],
+            daemon=True,
+        )
+        self._dispatched_tasks = {}
+        self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
+        self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
+        self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
+        if start:
+            self.start()
+
+    @staticmethod
+    def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
+        """Read tasks from incoming queue and put them into a local priority queue"""
+        while True:
+            task = submitted_tasks.get()
+            if task is None:
+                logger.debug("Shutting down prioritizer thread")
+                break
 
-        self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
-        self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
+            ordered_tasks.put(task, block=True)
 
-    def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
-        f = super().submit_task(*args)
-        self.priority_queue.put(priority)
-        # TODO use a single queue here
-        return f
+    def start(self):
+        assert not self.is_alive() and not self._prioritizer_thread.is_alive()
+        self._prioritizer_thread.start()
+        super().start()
 
-    def _priortize_tasks(self):
-        """Infinite loop prioritizing incoming tasks"""
-        while True:
-            task = self.tasks.get(block=True)
-            priority = self.priority_queue.get(block=True)
-            self.prioritized_task_queue.put(PrioritizedTask(priority, task), block=True)
+    def shutdown(self, timeout: Optional[float] = None):
+        self.submitted_tasks.put(None)
+        self.terminate()
+        self._prioritizer_thread.join(timeout)
+
+    def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
+        """Add task to this pool's queue, return Future for its output"""
+        task = Task(priority, time.monotonic(), MPFuture(), args)
+        if self.get_task_size(task) > self.max_batch_size:
+            exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
+            task.future.set_exception(exc)
+        else:
+            self.submitted_tasks.put(task)
+            self.batch_sender.send(None)  # use this pipe to count the number of unfinished batches
+            if (task.priority, task.time_submitted) < self.priority:
+                self.priority = (task.priority, task.time_submitted)
+        return task.future
+
+    def get_task_size(self, task: Task) -> int:
+        """compute task processing complexity; defaults to the total number of tokens"""
+        if task.args and task.args[0].ndim >= 2:
+            return task.args[0].shape[0] * task.args[0].shape[1]
+        return 1
+
+    def load_batch_to_runtime(
+        self, timeout: Optional[float] = None, device: Optional[torch.device] = None
+    ) -> Tuple[Any, List[torch.Tensor]]:
+        """receive next batch of arrays"""
+        task = self._ordered_tasks.get(block=True, timeout=timeout)
+        batch_inputs = [
+            tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
+        ]
+        self._dispatched_tasks[task.uid] = task
+        self.batch_receiver.recv()  # reduce the number of active batches
+        if not self._ordered_tasks.empty():
+            first_remaining_task: Task = self._ordered_tasks.queue[0]
+            self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
+        return task.uid, batch_inputs
+
+    def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
+        """send results for a processed batch, previously loaded through load_batch_to_runtime"""
+        batch_outputs = [
+            tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
+            for tensor in batch_outputs
+        ]
+
+        task = self._dispatched_tasks.pop(uid, None)
+        if task is None:
+            logger.error(
+                f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
+            )
+        else:
+            task.future.set_result(batch_outputs)
+
+    def send_exception_from_runtime(self, uid: int, exception: BaseException):
+        task = self._dispatched_tasks.pop(uid, None)
+        if task is None:
+            logger.error(
+                f"Internal error: task task with index {uid} is missing from the dictionary; "
+                f"Could not set exception {exception}"
+            )
+        else:
+            task.future.set_exception(exception)
 
     def run(self, *args, **kwargs):
-        torch.set_num_threads(1)
-        logger.info(f"{self.name} starting, pid={os.getpid()}")
-        pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
+        mp.Event().wait()
 
-        output_thread = threading.Thread(
-            target=self._pool_output_loop, args=[pending_batches], name=f"{self.name}_output", daemon=True
-        )
-        priority_thread = threading.Thread(
-            target=self._priortize_tasks, args=[], name=f"{self.name}_priority", daemon=True
-        )
+    @property
+    def empty(self):
+        return not self.batch_receiver.poll()
 
-        try:
-            output_thread.start()
-            priority_thread.start()
-            self._pool_input_loop(pending_batches, *args, **kwargs)
-        except KeyboardInterrupt:
-            logger.debug("Caught KeyboardInterrupt, shutting down")
-        finally:
-            output_thread.join()
-            priority_thread.join()
-
-    # TODO: this is a copy-paste of the original method, except that we use different queue
-    def iterate_minibatches(self, *args, **kwargs):
-        """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
-        while True:
-            try:
-                logger.debug(f"{self.name} getting next task")
-                task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
-            except Empty:
-                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
-                continue
-
-            try:
-                if task.task.future.set_running_or_notify_cancel():
-                    yield [task.task]
-            except InvalidStateError as e:
-                logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")
+    @property
+    def priority(self) -> Tuple[float, float]:
+        """The priority of this pool equals the (priority, timestamp) of the most important task in it."""
+        return float(self._priority.value), float(self._oldest_undispatched_timestamp.value)
+
+    @priority.setter
+    def priority(self, item: Tuple[float, float]):
+        assert len(item) == 2
+        self._priority.value = float(item[0])
+        self._oldest_undispatched_timestamp.value = float(item[1])
+
+    def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
+        raise NotImplementedError()

+ 0 - 0
tests/test_dust_payment.py → tests/test_point_system.py


+ 62 - 0
tests/test_priority_pool.py

@@ -0,0 +1,62 @@
+import multiprocessing as mp
+import time
+
+import pytest
+import torch
+
+from src.server.runtime import Runtime
+from src.server.task_pool import PrioritizedTaskPool
+
+
+@pytest.mark.forked
+def test_priority_pools():
+    outputs_queue = mp.SimpleQueue()
+    results_valid = mp.Event()
+
+    def dummy_pool_func(x):
+        time.sleep(0.1)
+        y = x**2
+        outputs_queue.put((x, y))
+        return (y,)
+
+    class DummyBackend:
+        def __init__(self, pools):
+            self.pools = pools
+
+        def get_pools(self):
+            return self.pools
+
+    pools = (
+        PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1),
+        PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
+    )
+
+    runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
+    runtime.start()
+
+    def process_tasks():
+        futures = []
+        futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
+        futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
+        time.sleep(0.01)
+        futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
+        futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
+        futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
+        futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
+        futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
+        futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
+        futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
+        for i, f in enumerate(futures):
+            assert f.result()[0].item() == i**2
+        results_valid.set()
+
+    proc = mp.Process(target=process_tasks)
+    proc.start()
+    proc.join()
+    assert results_valid.is_set()
+
+    ordered_outputs = []
+    while not outputs_queue.empty():
+        ordered_outputs.append(outputs_queue.get()[0].item())
+
+    assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]