task_pool.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """
  2. Task pool is responsible for receiving tasks and grouping them together for processing (but not processing itself)
  3. """
  4. import ctypes
  5. import multiprocessing as mp
  6. import multiprocessing.context
  7. import os
  8. import threading
  9. import time
  10. import uuid
  11. from collections import namedtuple
  12. from concurrent.futures import Future
  13. from queue import Empty
  14. from typing import List, Tuple, Dict, Any, Generator
  15. import torch
  16. from hivemind.utils import MPFuture, get_logger
  17. logger = get_logger(__name__)
  18. Task = namedtuple("Task", ("future", "args"))
  19. class TaskPoolBase(mp.context.ForkProcess):
  20. """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """
  21. def __init__(self, process_func: callable, daemon=True):
  22. super().__init__(daemon=daemon)
  23. self.process_func = process_func
  24. self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool
  25. def run(self):
  26. raise NotImplementedError()
  27. def submit_task(self, *args: torch.Tensor) -> Future:
  28. raise NotImplementedError()
  29. def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
  30. raise NotImplementedError()
  31. @property
  32. def priority(self):
  33. return self._priority.value
  34. @priority.setter
  35. def priority(self, value):
  36. self._priority.value = float(value)
  37. @property
  38. def empty(self):
  39. raise NotImplementedError()
  40. class TaskPool(TaskPoolBase):
  41. """
  42. Request aggregator that accepts processing requests, groups them into batches, waits for Runtime
  43. to process these batches and dispatches results back to request sources. Operates as a background process.
  44. :param process_func: function to be applied to every formed batch; called by Runtime
  45. Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
  46. :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
  47. :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
  48. :param timeout: wait for a subsequent task for at most this many seconds
  49. :param pool_size: store at most this many unprocessed tasks in a queue
  50. :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime
  51. :param uid: pool identifier used for shared array allocation
  52. :param start: if True, start automatically at the end of __init__
  53. """
  54. def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1,
  55. timeout=None, pool_size=None, prefetch_batches=1, uid=None, daemon=True, start=False):
  56. super().__init__(process_func, daemon=daemon)
  57. self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout
  58. self.uid = uid or uuid.uuid4()
  59. self.prefetch_batches = prefetch_batches
  60. # interaction with ConnectionHandlers
  61. self.tasks = mp.Queue(maxsize=pool_size or 0)
  62. self.undispatched_task_timestamps = mp.SimpleQueue()
  63. # interaction with Runtime
  64. self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) # send/recv arrays that contain batch inputs
  65. self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False) # send/recv arrays that contain outputs
  66. if start:
  67. self.start()
  68. def submit_task(self, *args: torch.Tensor) -> Future:
  69. """ Add task to this pool's queue, return Future for its output """
  70. future1, future2 = MPFuture.make_pair()
  71. task = Task(future1, args)
  72. if self.get_task_size(task) > self.max_batch_size:
  73. exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
  74. future2.set_exception(exc)
  75. else:
  76. self.tasks.put(task)
  77. self.undispatched_task_timestamps.put(time.time())
  78. return future2
  79. def iterate_minibatches(self, *args, **kwargs):
  80. """ Form minibatches by grouping one or more tasks together up to self.max_batch_size """
  81. batch = []
  82. total_size = 0
  83. while True:
  84. if total_size >= self.min_batch_size and self.tasks.empty():
  85. yield batch
  86. batch = []
  87. total_size = 0
  88. try:
  89. logger.debug(f"{self.uid} getting next task")
  90. task = self.tasks.get(timeout=self.timeout)
  91. except Empty:
  92. logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
  93. continue
  94. task_size = self.get_task_size(task)
  95. if total_size + task_size > self.max_batch_size:
  96. yield batch
  97. batch = []
  98. total_size = 0
  99. if task.future.set_running_or_notify_cancel():
  100. batch.append(task)
  101. total_size += task_size
  102. def run(self, *args, **kwargs):
  103. torch.set_num_threads(1)
  104. logger.info(f'{self.uid} starting, pid={os.getpid()}')
  105. pending_batches = {} # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
  106. output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
  107. name=f'{self.uid}_output')
  108. try:
  109. output_thread.start()
  110. self._pool_input_loop(pending_batches, *args, **kwargs)
  111. except BaseException as e:
  112. # terminate output loop
  113. self.outputs_sender.send(e)
  114. output_thread.join()
  115. raise e
  116. def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
  117. """ Infinite loop: aggregate tasks into batches and send them to runtime """
  118. try:
  119. prev_num_tasks = 0 # number of tasks currently in shared buffer
  120. batch_index = max(pending_batches.keys(), default=0)
  121. batch_iterator = self.iterate_minibatches(*args, **kwargs)
  122. while True:
  123. # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
  124. # assumes that tasks are processed in the same order as they are created
  125. for skip_i in range(prev_num_tasks):
  126. finished_task_timestamp = self.undispatched_task_timestamps.get() # earlier timestamp = higher priority
  127. if skip_i == prev_num_tasks - 1:
  128. self.priority = finished_task_timestamp
  129. logger.debug(f"{self.uid} getting next batch")
  130. batch_tasks = next(batch_iterator)
  131. # save batch futures, _output_loop will deliver on them later
  132. pending_batches[batch_index] = batch_tasks
  133. logger.debug(f"{self.uid}, batch {batch_index}: aggregating inputs")
  134. # find or create shared arrays for current batch size
  135. batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in
  136. range(len(batch_tasks[0].args))]
  137. batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
  138. logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
  139. self.batch_sender.send((batch_index, batch_inputs))
  140. logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
  141. prev_num_tasks = len(batch_tasks)
  142. batch_index += 1
  143. except KeyboardInterrupt:
  144. logger.debug('Caught KeyboardInterrupt, shutting down')
  145. def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
  146. """ Infinite loop: receive results from runtime and dispatch them to task Futures """
  147. try:
  148. while True:
  149. logger.debug(f"{self.uid} waiting for results from runtime")
  150. payload = self.outputs_receiver.recv()
  151. if isinstance(payload, BaseException):
  152. raise payload
  153. else:
  154. batch_index, batch_outputs = payload
  155. logger.debug(f"{self.uid}, batch {batch_index}: got results")
  156. # split batch into partitions for individual tasks
  157. batch_tasks = pending_batches.pop(batch_index)
  158. task_sizes = [self.get_task_size(task) for task in batch_tasks]
  159. outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
  160. logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
  161. # dispatch results to futures
  162. for task, task_outputs in zip(batch_tasks, outputs_per_task):
  163. task.future.set_result(tuple(task_outputs))
  164. except KeyboardInterrupt:
  165. logger.debug(f"Caught KeyboardInterrupt, shutting down")
  166. @property
  167. def empty(self):
  168. return not self.batch_receiver.poll()
  169. def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
  170. """ receive next batch of numpy arrays """
  171. if not self.batch_receiver.poll(timeout):
  172. raise TimeoutError()
  173. batch_index, batch_inputs = self.batch_receiver.recv()
  174. batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
  175. return batch_index, batch_inputs
  176. def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
  177. """ send results for a processed batch, previously loaded through load_batch_to_runtime """
  178. batch_outputs = [tensor.to(device='cpu').share_memory_().detach().requires_grad_(tensor.requires_grad)
  179. for tensor in batch_outputs]
  180. self.outputs_sender.send((batch_index, batch_outputs))
  181. def get_task_size(self, task: Task) -> int:
  182. """ compute task processing complexity (used for batching); defaults to batch size """
  183. return len(task.args[0]) if task.args else 1