task_pool.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. """
  2. Task pool is responsible for receiving tasks and grouping them together for processing (but not processing itself)
  3. """
  4. import ctypes
  5. import multiprocessing as mp
  6. import os
  7. import threading
  8. import time
  9. import uuid
  10. from abc import ABCMeta, abstractmethod
  11. from collections import namedtuple
  12. from concurrent.futures import Future
  13. from queue import Empty
  14. from typing import List, Tuple, Dict, Any, Generator
  15. import torch
  16. from hivemind.utils import MPFuture, get_logger, FutureStateError
  17. logger = get_logger(__name__)
  18. Task = namedtuple("Task", ("future", "args"))
  19. class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
  20. """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
  21. def __init__(self, process_func: callable, daemon=True):
  22. super().__init__(daemon=daemon)
  23. self.process_func = process_func
  24. self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool
  25. @abstractmethod
  26. def run(self):
  27. pass
  28. @abstractmethod
  29. def submit_task(self, *args: torch.Tensor) -> Future:
  30. pass
  31. @abstractmethod
  32. def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
  33. pass
  34. @property
  35. def priority(self):
  36. return self._priority.value
  37. @priority.setter
  38. def priority(self, value):
  39. self._priority.value = float(value)
  40. @property
  41. @abstractmethod
  42. def empty(self):
  43. pass
  44. class TaskPool(TaskPoolBase):
  45. """
  46. Request aggregator that accepts processing requests, groups them into batches, waits for Runtime
  47. to process these batches and dispatches results back to request sources. Operates as a background process.
  48. :param process_func: function to be applied to every formed batch; called by Runtime
  49. Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
  50. :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
  51. :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
  52. :param timeout: wait for a subsequent task for at most this many seconds
  53. :param pool_size: store at most this many unprocessed tasks in a queue
  54. :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
  55. :param uid: pool identifier used for shared array allocation
  56. :param start: if True, start automatically at the end of __init__
  57. """
  58. def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
  59. timeout=None, pool_size=None, prefetch_batches=1, uid=None, daemon=True, start=False):
  60. super().__init__(process_func, daemon=daemon)
  61. self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
  62. self.uid = uid or uuid.uuid4()
  63. self.prefetch_batches = prefetch_batches
  64. # interaction with ConnectionHandlers
  65. self.tasks = mp.Queue(maxsize=pool_size or 0)
  66. self.undispatched_task_timestamps = mp.SimpleQueue()
  67. # interaction with Runtime
  68. self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) # send/recv arrays that contain batch inputs
  69. self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False) # send/recv arrays that contain outputs
  70. if start:
  71. self.start()
  72. def submit_task(self, *args: torch.Tensor) -> Future:
  73. """ Add task to this pool's queue, return Future for its output """
  74. future1, future2 = MPFuture.make_pair()
  75. task = Task(future1, args)
  76. if self.get_task_size(task) > self.max_batch_size:
  77. exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
  78. future2.set_exception(exc)
  79. else:
  80. self.tasks.put(task)
  81. self.undispatched_task_timestamps.put(time.time())
  82. return future2
  83. def iterate_minibatches(self, *args, **kwargs):
  84. """ Form minibatches by grouping one or more tasks together up to self.max_batch_size """
  85. batch = []
  86. total_size = 0
  87. while True:
  88. if total_size >= self.min_batch_size and self.tasks.empty():
  89. yield batch
  90. batch = []
  91. total_size = 0
  92. try:
  93. logger.debug(f"{self.uid} getting next task")
  94. task = self.tasks.get(timeout=self.timeout)
  95. except Empty:
  96. logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
  97. continue
  98. task_size = self.get_task_size(task)
  99. if total_size + task_size > self.max_batch_size:
  100. yield batch
  101. batch = []
  102. total_size = 0
  103. try:
  104. if task.future.set_running_or_notify_cancel():
  105. batch.append(task)
  106. total_size += task_size
  107. except FutureStateError as e:
  108. logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
  109. def run(self, *args, **kwargs):
  110. torch.set_num_threads(1)
  111. logger.info(f'{self.uid} starting, pid={os.getpid()}')
  112. pending_batches = {} # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
  113. output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
  114. name=f'{self.uid}_output')
  115. try:
  116. output_thread.start()
  117. self._pool_input_loop(pending_batches, *args, **kwargs)
  118. except BaseException as e:
  119. # terminate output loop
  120. self.outputs_sender.send(e)
  121. output_thread.join()
  122. raise e
  123. def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
  124. """ Infinite loop: aggregate tasks into batches and send them to runtime """
  125. try:
  126. prev_num_tasks = 0 # number of tasks currently in shared buffer
  127. batch_index = max(pending_batches.keys(), default=0)
  128. batch_iterator = self.iterate_minibatches(*args, **kwargs)
  129. while True:
  130. # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
  131. # assumes that tasks are processed in the same order as they are created
  132. for skip_i in range(prev_num_tasks):
  133. finished_task_timestamp = self.undispatched_task_timestamps.get() # earlier timestamp = higher priority
  134. if skip_i == prev_num_tasks - 1:
  135. self.priority = finished_task_timestamp
  136. logger.debug(f"{self.uid} getting next batch")
  137. batch_tasks = next(batch_iterator)
  138. # save batch futures, _output_loop will deliver on them later
  139. pending_batches[batch_index] = batch_tasks
  140. logger.debug(f"{self.uid}, batch {batch_index}: aggregating inputs")
  141. # find or create shared arrays for current batch size
  142. batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
  143. range(len(batch_tasks[0].args))]
  144. batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
  145. logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
  146. self.batch_sender.send((batch_index, batch_inputs))
  147. logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
  148. prev_num_tasks = len(batch_tasks)
  149. batch_index += 1
  150. except KeyboardInterrupt:
  151. logger.debug('Caught KeyboardInterrupt, shutting down')
  152. def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
  153. """ Infinite loop: receive results from runtime and dispatch them to task Futures """
  154. try:
  155. while True:
  156. logger.debug(f"{self.uid} waiting for results from runtime")
  157. payload = self.outputs_receiver.recv()
  158. if isinstance(payload, BaseException):
  159. raise payload
  160. else:
  161. batch_index, batch_outputs = payload
  162. logger.debug(f"{self.uid}, batch {batch_index}: got results")
  163. # split batch into partitions for individual tasks
  164. batch_tasks = pending_batches.pop(batch_index)
  165. task_sizes = [self.get_task_size(task) for task in batch_tasks]
  166. outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
  167. logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
  168. # dispatch results to futures
  169. for task, task_outputs in zip(batch_tasks, outputs_per_task):
  170. try:
  171. task.future.set_result(tuple(task_outputs))
  172. except FutureStateError as e:
  173. logger.debug(f"Failed to send task result due to an exception: {e}")
  174. except KeyboardInterrupt:
  175. logger.debug(f"Caught KeyboardInterrupt, shutting down")
  176. @property
  177. def empty(self):
  178. return not self.batch_receiver.poll()
  179. def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
  180. """ receive next batch of numpy arrays """
  181. if not self.batch_receiver.poll(timeout):
  182. raise TimeoutError()
  183. batch_index, batch_inputs = self.batch_receiver.recv()
  184. batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
  185. return batch_index, batch_inputs
  186. def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
  187. """ send results for a processed batch, previously loaded through load_batch_to_runtime """
  188. batch_outputs = [tensor.to(device='cpu').share_memory_().detach().requires_grad_(tensor.requires_grad)
  189. for tensor in batch_outputs]
  190. self.outputs_sender.send((batch_index, batch_outputs))
  191. def get_task_size(self, task: Task) -> int:
  192. """ compute task processing complexity (used for batching); defaults to batch size """
  193. return len(task.args[0]) if task.args else 1