task_pool.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """
  2. Task pool is responsible for receiving tasks and grouping them together for processing (but not processing itself)
  3. """
  4. import ctypes
  5. import multiprocessing as mp
  6. import os
  7. import threading
  8. import time
  9. import uuid
  10. from collections import namedtuple
  11. from concurrent.futures import Future
  12. from queue import Empty
  13. from typing import List, Tuple, Dict, Any
  14. import torch
  15. from ..utils import SharedFuture
  16. Task = namedtuple("Task", ("future", "args"))
  17. class TaskPoolBase(mp.Process):
  18. """ A pool that accepts tasks and forms batches for parallel processing, interacts with TesseractRuntime """
  19. def __init__(self, process_func: callable):
  20. super().__init__()
  21. self.process_func = process_func
  22. self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool
  23. def run(self):
  24. raise NotImplementedError()
  25. def submit_task(self, *args: torch.Tensor) -> Future:
  26. raise NotImplementedError()
  27. def form_batch(self, *args, **kwargs) -> List[Task]:
  28. raise NotImplementedError()
  29. def iterate_minibatches(self, *args, **kwargs):
  30. while True:
  31. yield self.form_batch(*args, **kwargs)
  32. @property
  33. def priority(self):
  34. return self._priority.value
  35. @priority.setter
  36. def priority(self, value):
  37. self._priority.value = float(value)
  38. @property
  39. def empty(self):
  40. raise NotImplementedError()
  41. class TaskPool(TaskPoolBase):
  42. """
  43. Request aggregator that accepts processing requests, groups them into batches, waits for TesseractRuntime
  44. to process these batches and dispatches results back to request sources. Operates as a background process.
  45. :param process_func: function to be applied to every formed batch; called by TesseractRuntime
  46. Note: process_func should accept only *args Tensors and return a list of output Tensors
  47. :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
  48. :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
  49. :param timeout: wait for a subsequent task for at most this many seconds
  50. :param pool_size: store at most this many unprocessed tasks in a queue
  51. :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
  52. :param uid: pool identifier used for shared array allocation
  53. :param start: if True, start automatically at the end of __init__
  54. """
  55. def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
  56. timeout=None, pool_size=None, prefetch_batches=1, uid=None, start=False):
  57. super().__init__(process_func)
  58. self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
  59. self.uid = uid or uuid.uuid4()
  60. self.prefetch_batches = prefetch_batches
  61. # interaction with ConnectionHandlers
  62. self.tasks = mp.Queue(maxsize=pool_size or 0)
  63. self.undispatched_task_timestamps = mp.SimpleQueue()
  64. # interaction with TesseractRuntime
  65. self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) # send/recv arrays that contain batch inputs
  66. self.batch_received = mp.Event() # runtime can notify pool that it can send next batch
  67. self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False) # send/recv arrays that contain outputs
  68. if start:
  69. self.start()
  70. def submit_task(self, *args: torch.Tensor) -> Future:
  71. future1, future2 = SharedFuture.make_pair()
  72. self.tasks.put(Task(future1, args))
  73. self.undispatched_task_timestamps.put(time.time())
  74. return future2
  75. def form_batch(self) -> List[Task]:
  76. batch_tasks = []
  77. total_size = 0
  78. while total_size < self.max_batch_size:
  79. if total_size >= self.min_batch_size and self.tasks.empty():
  80. break # timeout reached, returning incomplete batch
  81. try:
  82. task = self.tasks.get(timeout=self.timeout)
  83. except Empty:
  84. exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
  85. for task in batch_tasks:
  86. task.future.set_exception(exc)
  87. raise exc
  88. if task.future.set_running_or_notify_cancel():
  89. batch_tasks.append(task)
  90. total_size += self.get_task_size(task)
  91. return batch_tasks
  92. def run(self, *args, **kwargs):
  93. print(f'Starting pool, pid={os.getpid()}')
  94. pending_batches = {} # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
  95. output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
  96. name=f'{self.uid}-pool_output_loop')
  97. try:
  98. output_thread.start()
  99. self._pool_input_loop(pending_batches, *args, **kwargs)
  100. except BaseException as e:
  101. # terminate output loop
  102. self.outputs_sender.send(e)
  103. output_thread.join()
  104. raise e
  105. def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
  106. """ Infinite loop: aggregate tasks into batches and send them to runtime """
  107. prev_num_tasks = 0 # number of tasks currently in shared buffer
  108. batch_index = max(pending_batches.keys(), default=0)
  109. batch_iterator = self.iterate_minibatches(*args, **kwargs)
  110. self.batch_received.set() # initial state: no batches/outputs pending
  111. while True:
  112. self.batch_received.wait() # wait for runtime to receive (copy) previous batch
  113. # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
  114. # assumes that tasks are processed in the same order as they are created
  115. for skip_i in range(prev_num_tasks):
  116. finished_task_timestamp = self.undispatched_task_timestamps.get() # earlier timestamp = higher priority
  117. if skip_i == prev_num_tasks - 1:
  118. self.priority = finished_task_timestamp
  119. batch_tasks = next(batch_iterator)
  120. # save batch futures, _output_loop will deliver on them later
  121. pending_batches[batch_index] = batch_tasks
  122. # find or create shared arrays for current batch size
  123. batch_inputs = [
  124. torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
  125. for i in range(len(batch_tasks[0].args))
  126. ]
  127. self.batch_received.clear() # sending next batch...
  128. self.batch_sender.send((batch_index, batch_inputs))
  129. prev_num_tasks = len(batch_tasks)
  130. batch_index += 1
  131. def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
  132. """ Infinite loop: receive results from runtime and dispatch them to task Futures """
  133. while True:
  134. payload = self.outputs_receiver.recv()
  135. if isinstance(payload, BaseException):
  136. raise payload
  137. else:
  138. batch_index, batch_outputs = payload
  139. # split batch into partitions for individual tasks
  140. batch_tasks = pending_batches.pop(batch_index)
  141. task_sizes = [self.get_task_size(task) for task in batch_tasks]
  142. outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
  143. # dispatch results to futures
  144. for task, task_outputs in zip(batch_tasks, outputs_per_task):
  145. task.future.set_result(tuple(task_outputs))
  146. @property
  147. def empty(self):
  148. return not self.batch_receiver.poll()
  149. def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
  150. """ receive next batch of numpy arrays """
  151. if not self.batch_receiver.poll(timeout):
  152. raise TimeoutError()
  153. batch_index, batch_inputs = self.batch_receiver.recv()
  154. self.batch_received.set() # pool can now prepare next batch
  155. batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
  156. return batch_index, batch_inputs
  157. def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
  158. """ send results for a processed batch, previously loaded through load_batch_to_runtime """
  159. batch_outputs = [tensor.to(device='cpu').share_memory_() for tensor in batch_outputs]
  160. self.outputs_sender.send((batch_index, batch_outputs))
  161. def get_task_size(self, task: Task) -> int:
  162. """ compute task processing complexity (used for batching); defaults to batch size """
  163. return len(task.args[0]) if task.args else 1