Ver Fonte

Remove no-op process in PrioritizedTaskPool (#484)

Please revert this if you ever need to make `PrioritizedTaskPool` a process again.
Alexander Borzunov há 2 anos atrás
pai
commit
459933f846
1 ficheiros alterados com 14 adições e 29 exclusões
  1. 14 29
      src/petals/server/task_pool.py

+ 14 - 29
src/petals/server/task_pool.py

@@ -9,7 +9,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
 
 import torch
 from hivemind import get_logger
-from hivemind.moe.server.task_pool import TaskPoolBase
 from hivemind.utils.mpfuture import ALL_STATES, MPFuture
 
 logger = get_logger(__name__)
@@ -27,7 +26,7 @@ class Task:
         return self.future._uid
 
 
-class PrioritizedTaskPool(TaskPoolBase):
+class PrioritizedTaskPool(threading.Thread):
     """
     Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
     returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
@@ -57,52 +56,41 @@ class PrioritizedTaskPool(TaskPoolBase):
         daemon=True,
         start=False,
     ):
-        super().__init__(process_func, daemon=daemon, name=name)
+        super().__init__(daemon=daemon, name=name)
+        self.process_func = process_func
+        # the lower the priority is, the more urgent it is to process this pool
+        self._priority = mp.Value(ctypes.c_double, 1.0)
+
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.device = device
 
         self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
         self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
 
-        self._prioritizer_thread = threading.Thread(
-            name=self.name + "_prioritizer",
-            target=self._prioritize_tasks,
-            args=[self.submitted_tasks, self._ordered_tasks],
-            daemon=True,
-        )
         self._dispatched_tasks = {}
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
         self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
         self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
 
-        self._stop = mp.Event()
         if start:
             self.start()
 
-    @staticmethod
-    def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
+    def run(self):
         """Read tasks from incoming queue and put them into a local priority queue"""
         while True:
-            task = submitted_tasks.get()
+            task = self.submitted_tasks.get()
             if task is None:
                 logger.debug("Shutting down prioritizer thread")
                 break
 
-            ordered_tasks.put(task, block=True)
-
-    def start(self):
-        assert not self.is_alive() and not self._prioritizer_thread.is_alive()
-        self._prioritizer_thread.start()
-        super().start()
+            self._ordered_tasks.put(task, block=True)
 
-    def shutdown(self, timeout: float = 3):
-        self.submitted_tasks.put(None)  # Shuts down self._prioritizer_thread
-        self._stop.set()
+    def terminate(self):
+        """An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
+        self.shutdown()
 
-        self.join(timeout)
-        if self.is_alive():
-            logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
-            self.terminate()
+    def shutdown(self):
+        self.submitted_tasks.put(None)  # Shuts down self.run()
 
     def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
         """Add task to this pool's queue, return Future for its output"""
@@ -163,9 +151,6 @@ class PrioritizedTaskPool(TaskPoolBase):
         else:
             task.future.set_exception(exception)
 
-    def run(self, *args, **kwargs):
-        self._stop.wait()
-
     @property
     def empty(self):
         return not self.batch_receiver.poll()