justheuristic 2 年之前
父節點
當前提交
fb0aa13054
共有 1 個文件被更改,包括 0 次插入45 次删除
  1. 0 45
      src/server/backend.py

+ 0 - 45
src/server/backend.py

@@ -36,22 +36,10 @@ class PrioritizedTaskPool(TaskPool):
 
         self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
         self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
-        self.undispatched_task_priorities = mp.SimpleQueue()
-        self._timestamp = mp.Value(ctypes.c_double, 1.0)
-
-    @property
-    def priority(self):
-        return (-self._priority.value, -self._timestamp.value)
-
-    @priority.setter
-    def priority(self, priority_tuple: Sequence[float]):
-        assert len(priority_tuple) == 2, "pool priority must be a tuple of (priority, time_submitted)"
-        self._priority.value, self._timestamp.value = map(float, priority_tuple)
 
     def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
         f = super().submit_task(*args)
         self.priority_queue.put(priority)
-        self.undispatched_task_priorities.put(priority)
         # TODO use a single queue here
         return f
 
@@ -84,39 +72,6 @@ class PrioritizedTaskPool(TaskPool):
             output_thread.join()
             priority_thread.join()
 
-    def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
-        """Infinite loop: aggregate tasks into batches and send them to runtime"""
-
-        prev_num_tasks = 0  # number of tasks currently in shared buffer
-        batch_index = max(pending_batches.keys(), default=0)
-        batch_iterator = self.iterate_minibatches(*args, **kwargs)
-
-        while True:
-            # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
-            # assumes that tasks are processed in the same order as they are created
-            for skip_i in range(prev_num_tasks):
-                dispatched_task_timestamp = self.undispatched_task_timestamps.get()
-                dispatched_task_priority = self.undispatched_task_priorities.get()
-                if skip_i == prev_num_tasks - 1:
-                    self.priority = (dispatched_task_priority, dispatched_task_timestamp)
-
-            logger.debug(f"{self.name} getting next batch")
-            batch_tasks = next(batch_iterator)
-            # save batch futures, _output_loop will deliver on them later
-            pending_batches[batch_index] = batch_tasks
-
-            logger.debug(f"{self.name}, batch  {batch_index}: aggregating inputs")
-            # find or create shared arrays for current batch size
-            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
-            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
-
-            logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
-            self.batch_sender.send((batch_index, batch_inputs))
-            logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
-            prev_num_tasks = len(batch_tasks)
-            batch_index += 1
-
-
     # TODO: this is a copy-paste of the original method, except that we use different queue
     def iterate_minibatches(self, *args, **kwargs):
         """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""