浏览代码

Remove no-op process in PrioritizedTaskPool (#484)

Please revert this if you ever need to make `PrioritizedTaskPool` a process again.
Alexander Borzunov 2 年之前
父节点
当前提交
459933f846
共有 1 个文件被更改,包括 14 次插入29 次删除
  1. 14 29
      src/petals/server/task_pool.py

+ 14 - 29
src/petals/server/task_pool.py

@@ -9,7 +9,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
 from hivemind import get_logger
 from hivemind import get_logger
-from hivemind.moe.server.task_pool import TaskPoolBase
 from hivemind.utils.mpfuture import ALL_STATES, MPFuture
 from hivemind.utils.mpfuture import ALL_STATES, MPFuture
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -27,7 +26,7 @@ class Task:
         return self.future._uid
         return self.future._uid
 
 
 
 
-class PrioritizedTaskPool(TaskPoolBase):
+class PrioritizedTaskPool(threading.Thread):
     """
     """
     Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
     Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
     returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
     returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
@@ -57,52 +56,41 @@ class PrioritizedTaskPool(TaskPoolBase):
         daemon=True,
         daemon=True,
         start=False,
         start=False,
     ):
     ):
-        super().__init__(process_func, daemon=daemon, name=name)
+        super().__init__(daemon=daemon, name=name)
+        self.process_func = process_func
+        # the lower the priority is, the more urgent it is to process this pool
+        self._priority = mp.Value(ctypes.c_double, 1.0)
+
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.device = device
         self.device = device
 
 
         self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
         self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
         self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
         self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
 
 
-        self._prioritizer_thread = threading.Thread(
-            name=self.name + "_prioritizer",
-            target=self._prioritize_tasks,
-            args=[self.submitted_tasks, self._ordered_tasks],
-            daemon=True,
-        )
         self._dispatched_tasks = {}
         self._dispatched_tasks = {}
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
         self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
         self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
         self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
         self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
 
 
-        self._stop = mp.Event()
         if start:
         if start:
             self.start()
             self.start()
 
 
-    @staticmethod
-    def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
+    def run(self):
         """Read tasks from incoming queue and put them into a local priority queue"""
         """Read tasks from incoming queue and put them into a local priority queue"""
         while True:
         while True:
-            task = submitted_tasks.get()
+            task = self.submitted_tasks.get()
             if task is None:
             if task is None:
                 logger.debug("Shutting down prioritizer thread")
                 logger.debug("Shutting down prioritizer thread")
                 break
                 break
 
 
-            ordered_tasks.put(task, block=True)
-
-    def start(self):
-        assert not self.is_alive() and not self._prioritizer_thread.is_alive()
-        self._prioritizer_thread.start()
-        super().start()
+            self._ordered_tasks.put(task, block=True)
 
 
-    def shutdown(self, timeout: float = 3):
-        self.submitted_tasks.put(None)  # Shuts down self._prioritizer_thread
-        self._stop.set()
+    def terminate(self):
+        """An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
+        self.shutdown()
 
 
-        self.join(timeout)
-        if self.is_alive():
-            logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
-            self.terminate()
+    def shutdown(self):
+        self.submitted_tasks.put(None)  # Shuts down self.run()
 
 
     def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
     def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
         """Add task to this pool's queue, return Future for its output"""
@@ -163,9 +151,6 @@ class PrioritizedTaskPool(TaskPoolBase):
         else:
         else:
             task.future.set_exception(exception)
             task.future.set_exception(exception)
 
 
-    def run(self, *args, **kwargs):
-        self._stop.wait()
-
     @property
     @property
     def empty(self):
     def empty(self):
         return not self.batch_receiver.poll()
         return not self.batch_receiver.poll()