|
@@ -36,22 +36,10 @@ class PrioritizedTaskPool(TaskPool):
|
|
|
|
|
|
self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
|
|
|
self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
|
|
|
- self.undispatched_task_priorities = mp.SimpleQueue()
|
|
|
- self._timestamp = mp.Value(ctypes.c_double, 1.0)
|
|
|
-
|
|
|
- @property
|
|
|
- def priority(self):
|
|
|
- return (-self._priority.value, -self._timestamp.value)
|
|
|
-
|
|
|
- @priority.setter
|
|
|
- def priority(self, priority_tuple: Sequence[float]):
|
|
|
- assert len(priority_tuple) == 2, "pool priority must be a tuple of (priority, time_submitted)"
|
|
|
- self._priority.value, self._timestamp.value = map(float, priority_tuple)
|
|
|
|
|
|
def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
|
|
|
f = super().submit_task(*args)
|
|
|
self.priority_queue.put(priority)
|
|
|
- self.undispatched_task_priorities.put(priority)
|
|
|
# TODO use a single queue here
|
|
|
return f
|
|
|
|
|
@@ -84,39 +72,6 @@ class PrioritizedTaskPool(TaskPool):
|
|
|
output_thread.join()
|
|
|
priority_thread.join()
|
|
|
|
|
|
- def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
|
|
|
- """Infinite loop: aggregate tasks into batches and send them to runtime"""
|
|
|
-
|
|
|
- prev_num_tasks = 0 # number of tasks currently in shared buffer
|
|
|
- batch_index = max(pending_batches.keys(), default=0)
|
|
|
- batch_iterator = self.iterate_minibatches(*args, **kwargs)
|
|
|
-
|
|
|
- while True:
|
|
|
- # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
|
|
|
- # assumes that tasks are processed in the same order as they are created
|
|
|
- for skip_i in range(prev_num_tasks):
|
|
|
- dispatched_task_timestamp = self.undispatched_task_timestamps.get()
|
|
|
- dispatched_task_priority = self.undispatched_task_priorities.get()
|
|
|
- if skip_i == prev_num_tasks - 1:
|
|
|
- self.priority = (dispatched_task_priority, dispatched_task_timestamp)
|
|
|
-
|
|
|
- logger.debug(f"{self.name} getting next batch")
|
|
|
- batch_tasks = next(batch_iterator)
|
|
|
- # save batch futures, _output_loop will deliver on them later
|
|
|
- pending_batches[batch_index] = batch_tasks
|
|
|
-
|
|
|
- logger.debug(f"{self.name}, batch {batch_index}: aggregating inputs")
|
|
|
- # find or create shared arrays for current batch size
|
|
|
- batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
|
|
|
- batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
|
|
|
-
|
|
|
- logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
|
|
|
- self.batch_sender.send((batch_index, batch_inputs))
|
|
|
- logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
|
|
|
- prev_num_tasks = len(batch_tasks)
|
|
|
- batch_index += 1
|
|
|
-
|
|
|
-
|
|
|
# TODO: this is a copy-paste of the original method, except that we use different queue
|
|
|
def iterate_minibatches(self, *args, **kwargs):
|
|
|
"""Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
|