task_pool.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. """
  2. Task pool is responsible for receiving tasks and grouping them together for processing (but not processing itself)
  3. """
  4. import ctypes
  5. import multiprocessing as mp
  6. import os
  7. import threading
  8. import time
  9. from abc import ABCMeta, abstractmethod
  10. from collections import namedtuple
  11. from concurrent.futures import Future
  12. from queue import Empty
  13. from typing import Any, Callable, Dict, Generator, List, Tuple
  14. import torch
  15. from hivemind.utils import get_logger
  16. from hivemind.utils.mpfuture import InvalidStateError, MPFuture
  17. logger = get_logger(__name__)
  18. Task = namedtuple("Task", ("future", "args"))
  19. class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
  20. """A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime"""
  21. def __init__(self, process_func: Callable, daemon=True, **kwargs):
  22. super().__init__(daemon=daemon, **kwargs)
  23. self.process_func = process_func
  24. self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool
  25. @abstractmethod
  26. def run(self):
  27. pass
  28. @abstractmethod
  29. def submit_task(self, *args: torch.Tensor) -> Future:
  30. pass
  31. @abstractmethod
  32. def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
  33. pass
  34. @property
  35. def priority(self):
  36. return self._priority.value
  37. @priority.setter
  38. def priority(self, value):
  39. self._priority.value = float(value)
  40. @property
  41. @abstractmethod
  42. def empty(self):
  43. pass
  44. class TaskPool(TaskPoolBase):
  45. """
  46. Request aggregator that accepts processing requests, groups them into batches, waits for Runtime
  47. to process these batches and dispatches results back to request sources. Operates as a background process.
  48. :param process_func: function to be applied to every formed batch; called by Runtime
  49. Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
  50. :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
  51. :param name: pool name
  52. :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
  53. :param timeout: wait for a subsequent task for at most this many seconds
  54. :param pool_size: store at most this many unprocessed tasks in a queue
  55. :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
  56. :param start: if True, start automatically at the end of __init__
  57. """
  58. def __init__(
  59. self,
  60. process_func: Callable,
  61. max_batch_size: int,
  62. name: str,
  63. min_batch_size=1,
  64. timeout=None,
  65. pool_size=None,
  66. prefetch_batches=1,
  67. daemon=True,
  68. start=False,
  69. ):
  70. super().__init__(process_func, daemon=daemon, name=name)
  71. self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
  72. self.prefetch_batches = prefetch_batches
  73. # interaction with ConnectionHandlers
  74. self.tasks = mp.Queue(maxsize=pool_size or 0)
  75. self.undispatched_task_timestamps = mp.SimpleQueue()
  76. # interaction with Runtime
  77. self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) # send/recv arrays that contain batch inputs
  78. self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False) # send/recv arrays that contain outputs
  79. if start:
  80. self.start()
  81. def submit_task(self, *args: torch.Tensor) -> Future:
  82. """Add task to this pool's queue, return Future for its output"""
  83. task = Task(MPFuture(), args)
  84. if self.get_task_size(task) > self.max_batch_size:
  85. exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
  86. task.future.set_exception(exc)
  87. else:
  88. self.tasks.put(task)
  89. self.undispatched_task_timestamps.put(time.time())
  90. return task.future
  91. def iterate_minibatches(self, *args, **kwargs):
  92. """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
  93. batch = []
  94. total_size = 0
  95. while True:
  96. if total_size >= self.min_batch_size and self.tasks.empty():
  97. yield batch
  98. batch = []
  99. total_size = 0
  100. try:
  101. logger.debug(f"{self.name} getting next task")
  102. task = self.tasks.get(timeout=self.timeout)
  103. except Empty:
  104. logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
  105. continue
  106. task_size = self.get_task_size(task)
  107. if total_size + task_size > self.max_batch_size:
  108. yield batch
  109. batch = []
  110. total_size = 0
  111. try:
  112. if task.future.set_running_or_notify_cancel():
  113. batch.append(task)
  114. total_size += task_size
  115. except InvalidStateError as e:
  116. logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
  117. def run(self, *args, **kwargs):
  118. torch.set_num_threads(1)
  119. logger.info(f"{self.name} starting, pid={os.getpid()}")
  120. pending_batches = {} # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
  121. output_thread = threading.Thread(
  122. target=self._pool_output_loop, args=[pending_batches], name=f"{self.name}_output", daemon=True
  123. )
  124. try:
  125. output_thread.start()
  126. self._pool_input_loop(pending_batches, *args, **kwargs)
  127. except KeyboardInterrupt:
  128. logger.debug("Caught KeyboardInterrupt, shutting down")
  129. finally:
  130. output_thread.join()
  131. def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
  132. """Infinite loop: aggregate tasks into batches and send them to runtime"""
  133. prev_num_tasks = 0 # number of tasks currently in shared buffer
  134. batch_index = max(pending_batches.keys(), default=0)
  135. batch_iterator = self.iterate_minibatches(*args, **kwargs)
  136. while True:
  137. # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
  138. # assumes that tasks are processed in the same order as they are created
  139. for skip_i in range(prev_num_tasks):
  140. finished_task_timestamp = (
  141. self.undispatched_task_timestamps.get()
  142. ) # earlier timestamp = higher priority
  143. if skip_i == prev_num_tasks - 1:
  144. self.priority = finished_task_timestamp
  145. logger.debug(f"{self.name} getting next batch")
  146. batch_tasks = next(batch_iterator)
  147. # save batch futures, _output_loop will deliver on them later
  148. pending_batches[batch_index] = batch_tasks
  149. logger.debug(f"{self.name}, batch {batch_index}: aggregating inputs")
  150. # find or create shared arrays for current batch size
  151. batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
  152. batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
  153. logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
  154. self.batch_sender.send((batch_index, batch_inputs))
  155. logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
  156. prev_num_tasks = len(batch_tasks)
  157. batch_index += 1
  158. def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
  159. """Infinite loop: receive results from runtime and dispatch them to task Futures"""
  160. while True:
  161. logger.debug(f"{self.name} waiting for results from runtime")
  162. batch_index, batch_outputs = self.outputs_receiver.recv()
  163. logger.debug(f"{self.name}, batch {batch_index}: got results")
  164. # split batch into partitions for individual tasks
  165. batch_tasks = pending_batches.pop(batch_index)
  166. task_sizes = [self.get_task_size(task) for task in batch_tasks]
  167. outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
  168. logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
  169. # dispatch results to futures
  170. for task, task_outputs in zip(batch_tasks, outputs_per_task):
  171. try:
  172. task.future.set_result(tuple(task_outputs))
  173. except InvalidStateError as e:
  174. logger.debug(f"Failed to send task result due to an exception: {e}")
  175. @property
  176. def empty(self):
  177. return not self.batch_receiver.poll()
  178. def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
  179. """receive next batch of numpy arrays"""
  180. if not self.batch_receiver.poll(timeout):
  181. raise TimeoutError()
  182. batch_index, batch_inputs = self.batch_receiver.recv()
  183. batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
  184. return batch_index, batch_inputs
  185. def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
  186. """send results for a processed batch, previously loaded through load_batch_to_runtime"""
  187. batch_outputs = [
  188. tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
  189. for tensor in batch_outputs
  190. ]
  191. self.outputs_sender.send((batch_index, batch_outputs))
  192. def get_task_size(self, task: Task) -> int:
  193. """compute task processing complexity (used for batching); defaults to batch size"""
  194. return len(task.args[0]) if task.args else 1