|
@@ -9,7 +9,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
from hivemind import get_logger
|
|
|
-from hivemind.moe.server.task_pool import TaskPoolBase
|
|
|
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
@@ -27,7 +26,7 @@ class Task:
|
|
|
return self.future._uid
|
|
|
|
|
|
|
|
|
-class PrioritizedTaskPool(TaskPoolBase):
|
|
|
+class PrioritizedTaskPool(threading.Thread):
|
|
|
"""
|
|
|
Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
|
|
|
returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
|
|
@@ -57,52 +56,41 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|
|
daemon=True,
|
|
|
start=False,
|
|
|
):
|
|
|
- super().__init__(process_func, daemon=daemon, name=name)
|
|
|
+ super().__init__(daemon=daemon, name=name)
|
|
|
+ self.process_func = process_func
|
|
|
+ # the lower the priority is, the more urgent it is to process this pool
|
|
|
+ self._priority = mp.Value(ctypes.c_double, 1.0)
|
|
|
+
|
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
self.device = device
|
|
|
|
|
|
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
|
|
|
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
|
|
|
|
|
|
- self._prioritizer_thread = threading.Thread(
|
|
|
- name=self.name + "_prioritizer",
|
|
|
- target=self._prioritize_tasks,
|
|
|
- args=[self.submitted_tasks, self._ordered_tasks],
|
|
|
- daemon=True,
|
|
|
- )
|
|
|
self._dispatched_tasks = {}
|
|
|
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
|
|
|
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
|
|
|
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
|
|
|
|
|
|
- self._stop = mp.Event()
|
|
|
if start:
|
|
|
self.start()
|
|
|
|
|
|
- @staticmethod
|
|
|
- def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
|
|
|
+ def run(self):
|
|
|
"""Read tasks from incoming queue and put them into a local priority queue"""
|
|
|
while True:
|
|
|
- task = submitted_tasks.get()
|
|
|
+ task = self.submitted_tasks.get()
|
|
|
if task is None:
|
|
|
logger.debug("Shutting down prioritizer thread")
|
|
|
break
|
|
|
|
|
|
- ordered_tasks.put(task, block=True)
|
|
|
-
|
|
|
- def start(self):
|
|
|
- assert not self.is_alive() and not self._prioritizer_thread.is_alive()
|
|
|
- self._prioritizer_thread.start()
|
|
|
- super().start()
|
|
|
+ self._ordered_tasks.put(task, block=True)
|
|
|
|
|
|
- def shutdown(self, timeout: float = 3):
|
|
|
- self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread
|
|
|
- self._stop.set()
|
|
|
+ def terminate(self):
|
|
|
+ """An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
|
|
|
+ self.shutdown()
|
|
|
|
|
|
- self.join(timeout)
|
|
|
- if self.is_alive():
|
|
|
- logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
|
|
|
- self.terminate()
|
|
|
+ def shutdown(self):
|
|
|
+ self.submitted_tasks.put(None) # Shuts down self.run()
|
|
|
|
|
|
def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
|
|
|
"""Add task to this pool's queue, return Future for its output"""
|
|
@@ -163,9 +151,6 @@ class PrioritizedTaskPool(TaskPoolBase):
|
|
|
else:
|
|
|
task.future.set_exception(exception)
|
|
|
|
|
|
- def run(self, *args, **kwargs):
|
|
|
- self._stop.wait()
|
|
|
-
|
|
|
@property
|
|
|
def empty(self):
|
|
|
return not self.batch_receiver.poll()
|