task_pool.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. """
  2. Task pool is responsible for receiving tasks and grouping them together for processing (but not processing itself)
  3. """
  4. import ctypes
  5. import multiprocessing as mp
  6. import os
  7. import threading
  8. import time
  9. import uuid
  10. from collections import namedtuple
  11. from concurrent.futures import Future
  12. from queue import Empty
  13. from typing import List, Tuple, Dict, Any
  14. import torch
  15. from ..utils import SharedFuture
  16. Task = namedtuple("Task", ("future", "args"))
  17. class TaskPoolBase(mp.Process):
  18. """ A pool that accepts tasks and forms batches for parallel processing, interacts with TesseractRuntime """
  19. def __init__(self, process_func: callable):
  20. super().__init__()
  21. self.process_func = process_func
  22. self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool
  23. def run(self):
  24. raise NotImplementedError()
  25. def submit_task(self, *args: torch.Tensor) -> Future:
  26. raise NotImplementedError()
  27. def form_batch(self, *args, **kwargs) -> List[Task]:
  28. raise NotImplementedError()
  29. def iterate_minibatches(self, *args, **kwargs):
  30. while True:
  31. yield self.form_batch(*args, **kwargs)
  32. @property
  33. def priority(self):
  34. return self._priority.value
  35. @priority.setter
  36. def priority(self, value):
  37. self._priority.value = float(value)
  38. @property
  39. def empty(self):
  40. raise NotImplementedError()
  41. class TaskPool(TaskPoolBase):
  42. """
  43. Request aggregator that accepts processing requests, groups them into batches, waits for TesseractRuntime
  44. to process these batches and dispatches results back to request sources. Operates as a background process.
  45. :param process_func: function to be applied to every formed batch; called by TesseractRuntime
  46. Note that process_func should accept only \*args Tensors and return a flat tuple of Tensors
  47. :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
  48. :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
  49. :param timeout: wait for a subsequent task for at most this many seconds
  50. :param pool_size: store at most this many unprocessed tasks in a queue
  51. :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
  52. :param uid: pool identifier used for shared array allocation
  53. :param start: if True, start automatically at the end of __init__
  54. """
  55. def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
  56. timeout=None, pool_size=None, prefetch_batches=1, uid=None, start=False):
  57. super().__init__(process_func)
  58. self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
  59. self.uid = uid or uuid.uuid4()
  60. self.prefetch_batches = prefetch_batches
  61. # interaction with ConnectionHandlers
  62. self.tasks = mp.Queue(maxsize=pool_size or 0)
  63. self.undispatched_task_timestamps = mp.SimpleQueue()
  64. # interaction with TesseractRuntime
  65. self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) # send/recv arrays that contain batch inputs
  66. self.batch_received = mp.Event() # runtime can notify pool that it can send next batch
  67. self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False) # send/recv arrays that contain outputs
  68. if start:
  69. self.start()
  70. def submit_task(self, *args: torch.Tensor) -> Future:
  71. """ Add task to this pool's queue, return Future for its output """
  72. future1, future2 = SharedFuture.make_pair()
  73. self.tasks.put(Task(future1, args))
  74. self.undispatched_task_timestamps.put(time.time())
  75. return future2
  76. def form_batch(self) -> List[Task]:
  77. batch_tasks = []
  78. total_size = 0
  79. while total_size < self.max_batch_size:
  80. if total_size >= self.min_batch_size and self.tasks.empty():
  81. break # timeout reached, returning incomplete batch
  82. try:
  83. task = self.tasks.get(timeout=self.timeout)
  84. except Empty:
  85. exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
  86. for task in batch_tasks:
  87. task.future.set_exception(exc)
  88. raise exc
  89. if task.future.set_running_or_notify_cancel():
  90. batch_tasks.append(task)
  91. total_size += self.get_task_size(task)
  92. return batch_tasks
  93. def run(self, *args, **kwargs):
  94. print(f'Starting pool, pid={os.getpid()}')
  95. pending_batches = {} # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
  96. output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
  97. name=f'{self.uid}-pool_output_loop')
  98. try:
  99. output_thread.start()
  100. self._pool_input_loop(pending_batches, *args, **kwargs)
  101. except BaseException as e:
  102. # terminate output loop
  103. self.outputs_sender.send(e)
  104. output_thread.join()
  105. raise e
  106. def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
  107. """ Infinite loop: aggregate tasks into batches and send them to runtime """
  108. prev_num_tasks = 0 # number of tasks currently in shared buffer
  109. batch_index = max(pending_batches.keys(), default=0)
  110. batch_iterator = self.iterate_minibatches(*args, **kwargs)
  111. self.batch_received.set() # initial state: no batches/outputs pending
  112. while True:
  113. self.batch_received.wait() # wait for runtime to receive (copy) previous batch
  114. # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
  115. # assumes that tasks are processed in the same order as they are created
  116. for skip_i in range(prev_num_tasks):
  117. finished_task_timestamp = self.undispatched_task_timestamps.get() # earlier timestamp = higher priority
  118. if skip_i == prev_num_tasks - 1:
  119. self.priority = finished_task_timestamp
  120. batch_tasks = next(batch_iterator)
  121. # save batch futures, _output_loop will deliver on them later
  122. pending_batches[batch_index] = batch_tasks
  123. # find or create shared arrays for current batch size
  124. batch_inputs = [
  125. torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
  126. for i in range(len(batch_tasks[0].args))
  127. ]
  128. self.batch_received.clear() # sending next batch...
  129. self.batch_sender.send((batch_index, batch_inputs))
  130. prev_num_tasks = len(batch_tasks)
  131. batch_index += 1
  132. def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
  133. """ Infinite loop: receive results from runtime and dispatch them to task Futures """
  134. while True:
  135. payload = self.outputs_receiver.recv()
  136. if isinstance(payload, BaseException):
  137. raise payload
  138. else:
  139. batch_index, batch_outputs = payload
  140. # split batch into partitions for individual tasks
  141. batch_tasks = pending_batches.pop(batch_index)
  142. task_sizes = [self.get_task_size(task) for task in batch_tasks]
  143. outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
  144. # dispatch results to futures
  145. for task, task_outputs in zip(batch_tasks, outputs_per_task):
  146. task.future.set_result(tuple(task_outputs))
  147. @property
  148. def empty(self):
  149. return not self.batch_receiver.poll()
  150. def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
  151. """ receive next batch of numpy arrays """
  152. if not self.batch_receiver.poll(timeout):
  153. raise TimeoutError()
  154. batch_index, batch_inputs = self.batch_receiver.recv()
  155. self.batch_received.set() # pool can now prepare next batch
  156. batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
  157. return batch_index, batch_inputs
  158. def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
  159. """ send results for a processed batch, previously loaded through load_batch_to_runtime """
  160. batch_outputs = [tensor.to(device='cpu').share_memory_() for tensor in batch_outputs]
  161. self.outputs_sender.send((batch_index, batch_outputs))
  162. def get_task_size(self, task: Task) -> int:
  163. """ compute task processing complexity (used for batching); defaults to batch size """
  164. return len(task.args[0]) if task.args else 1