Browse Source

Change runtime.py to choose tasks with lowest (instead of highest) priority (#505)

Currently, the priority is set to the timestamp of the earliest undispatched task.
Choosing earliest tasks will reduce the maximum waiting time when queue is nonempty

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Pavel Samygin <44449246+greenfatguy@users.noreply.github.com>
(cherry picked from commit 6395e89f9882419908ef54cdb1148d3da50f65fd)
justheuristic 2 năm trước cách đây
mục cha
commit
2ba000b328
2 tập tin đã thay đổi với 4 bổ sung4 xóa
  1. 2 2
      hivemind/moe/server/runtime.py
  2. 2 2
      hivemind/moe/server/task_pool.py

+ 2 - 2
hivemind/moe/server/runtime.py

@@ -143,8 +143,8 @@ class Runtime(threading.Thread):
                 if self.SHUTDOWN_TRIGGER in ready_objects:
                     break  # someone asked us to shutdown, break from the loop
 
-                logger.debug("Choosing the pool with highest priority")
-                pool = max(ready_objects, key=lambda pool: pool.priority)
+                logger.debug("Choosing the pool with first priority")
+                pool = min(ready_objects, key=lambda pool: pool.priority)
 
                 logger.debug(f"Loading batch from {pool.name}")
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)

+ 2 - 2
hivemind/moe/server/task_pool.py

@@ -27,7 +27,7 @@ class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
     def __init__(self, process_func: callable, daemon=True, **kwargs):
         super().__init__(daemon=daemon, **kwargs)
         self.process_func = process_func
-        self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
+        self._priority = mp.Value(ctypes.c_double, 1.0)  # lower priority = the more urgent to process this pool
 
     @abstractmethod
     def run(self):
@@ -170,7 +170,7 @@ class TaskPool(TaskPoolBase):
             for skip_i in range(prev_num_tasks):
                 finished_task_timestamp = (
                     self.undispatched_task_timestamps.get()
-                )  # earlier timestamp = higher priority
+                )  # earlier timestamp = smaller (better) priority, earlier processing
                 if skip_i == prev_num_tasks - 1:
                     self.priority = finished_task_timestamp