task_pool.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. """
  2. Task pool is responsible for receiving tasks and grouping them together for processing (but not processing itself)
  3. """
  4. import ctypes
  5. import multiprocessing as mp
  6. import os
  7. import threading
  8. import time
  9. import uuid
  10. from collections import namedtuple
  11. from concurrent.futures import Future
  12. from queue import Empty
  13. from typing import List, Tuple, Dict, Any
  14. import torch
  15. from ..utils import SharedFuture
  16. Task = namedtuple("Task", ("future", "args"))
  17. class TaskPoolBase(mp.Process):
  18. """ A pool that accepts tasks and forms batches for parallel processing, interacts with TesseractRuntime """
  19. def __init__(self, process_func: callable):
  20. super().__init__()
  21. self.process_func = process_func
  22. self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool
  23. def run(self):
  24. raise NotImplementedError()
  25. def submit_task(self, *args: torch.Tensor) -> Future:
  26. raise NotImplementedError()
  27. def form_batch(self, *args, **kwargs) -> List[Task]:
  28. raise NotImplementedError()
  29. def iterate_minibatches(self, *args, **kwargs):
  30. while True:
  31. yield self.form_batch(*args, **kwargs)
  32. @property
  33. def priority(self):
  34. return self._priority.value
  35. @priority.setter
  36. def priority(self, value):
  37. self._priority.value = float(value)
  38. @property
  39. def empty(self):
  40. raise NotImplementedError()
  41. class TaskPool(TaskPoolBase):
  42. def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
  43. timeout=None, pool_size=None, prefetch_batches=1, uid=None, start=False):
  44. """
  45. Naive implementation of task pool that forms batch from earliest submitted tasks
  46. :param process_func: function to be applied to every formed batch; called by TesseractRuntime
  47. Note: process_func should accept only *args Tensors and return a list of output Tensors
  48. :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
  49. :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
  50. :param timeout: wait for a subsequent task for at most this many seconds
  51. :param pool_size: store at most this many unprocessed tasks in a queue
  52. :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
  53. :param uid: pool identifier used for shared array allocation
  54. :param start: if True, start automatically at the end of __init__
  55. """
  56. super().__init__(process_func)
  57. self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
  58. self.uid = uid or uuid.uuid4()
  59. self.prefetch_batches = prefetch_batches
  60. # interaction with ConnectionHandlers
  61. self.tasks = mp.Queue(maxsize=pool_size or 0)
  62. self.undispatched_task_timestamps = mp.SimpleQueue()
  63. # interaction with TesseractRuntime
  64. self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) # send/recv arrays that contain batch inputs
  65. self.batch_received = mp.Event() # runtime can notify pool that it can send next batch
  66. self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False) # send/recv arrays that contain outputs
  67. if start:
  68. self.start()
  69. def submit_task(self, *args: torch.Tensor) -> Future:
  70. future1, future2 = SharedFuture.make_pair()
  71. self.tasks.put(Task(future1, args))
  72. self.undispatched_task_timestamps.put(time.time())
  73. return future2
  74. def form_batch(self) -> List[Task]:
  75. batch_tasks = []
  76. total_size = 0
  77. while total_size < self.max_batch_size:
  78. if total_size >= self.min_batch_size and self.tasks.empty():
  79. break # timeout reached, returning incomplete batch
  80. try:
  81. task = self.tasks.get(timeout=self.timeout)
  82. except Empty:
  83. exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
  84. for task in batch_tasks:
  85. task.future.set_exception(exc)
  86. raise exc
  87. if task.future.set_running_or_notify_cancel():
  88. batch_tasks.append(task)
  89. total_size += self.get_task_size(task)
  90. return batch_tasks
  91. def run(self, *args, **kwargs):
  92. print(f'Starting pool, pid={os.getpid()}')
  93. pending_batches = {} # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
  94. output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
  95. name=f'{self.uid}-pool_output_loop')
  96. try:
  97. output_thread.start()
  98. self._pool_input_loop(pending_batches, *args, **kwargs)
  99. except BaseException as e:
  100. # terminate output loop
  101. self.outputs_sender.send(e)
  102. output_thread.join()
  103. raise e
  104. def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
  105. """ Infinite loop: aggregate tasks into batches and send them to runtime """
  106. prev_num_tasks = 0 # number of tasks currently in shared buffer
  107. batch_index = max(pending_batches.keys(), default=0)
  108. batch_iterator = self.iterate_minibatches(*args, **kwargs)
  109. self.batch_received.set() # initial state: no batches/outputs pending
  110. while True:
  111. self.batch_received.wait() # wait for runtime to receive (copy) previous batch
  112. # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
  113. # assumes that tasks are processed in the same order as they are created
  114. for skip_i in range(prev_num_tasks):
  115. finished_task_timestamp = self.undispatched_task_timestamps.get() # earlier timestamp = higher priority
  116. if skip_i == prev_num_tasks - 1:
  117. self.priority = finished_task_timestamp
  118. batch_tasks = next(batch_iterator)
  119. # save batch futures, _output_loop will deliver on them later
  120. pending_batches[batch_index] = batch_tasks
  121. # find or create shared arrays for current batch size
  122. batch_inputs = [
  123. torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
  124. for i in range(len(batch_tasks[0].args))
  125. ]
  126. self.batch_received.clear() # sending next batch...
  127. self.batch_sender.send((batch_index, batch_inputs))
  128. prev_num_tasks = len(batch_tasks)
  129. batch_index += 1
  130. def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
  131. """ Infinite loop: receive results from runtime and dispatch them to task Futures """
  132. while True:
  133. payload = self.outputs_receiver.recv()
  134. if isinstance(payload, BaseException):
  135. raise payload
  136. else:
  137. batch_index, batch_outputs = payload
  138. # split batch into partitions for individual tasks
  139. batch_tasks = pending_batches.pop(batch_index)
  140. task_sizes = [self.get_task_size(task) for task in batch_tasks]
  141. outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
  142. # dispatch results to futures
  143. for task, task_outputs in zip(batch_tasks, outputs_per_task):
  144. task.future.set_result(tuple(task_outputs))
  145. @property
  146. def empty(self):
  147. return not self.batch_receiver.poll()
  148. def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
  149. """ receive next batch of numpy arrays """
  150. if not self.batch_receiver.poll(timeout):
  151. raise TimeoutError()
  152. batch_index, batch_inputs = self.batch_receiver.recv()
  153. self.batch_received.set() # pool can now prepare next batch
  154. batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
  155. return batch_index, batch_inputs
  156. def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
  157. """ send results for a processed batch, previously loaded through load_batch_to_runtime """
  158. batch_outputs = [tensor.to(device='cpu').share_memory_() for tensor in batch_outputs]
  159. self.outputs_sender.send((batch_index, batch_outputs))
  160. def get_task_size(self, task: Task) -> int:
  161. """ compute task processing complexity (used for batching); defaults to batch size """
  162. return len(task.args[0]) if task.args else 1