Browse Source

default to fifo

justheuristic 3 năm trước cách đây
mục cha
commit
64a2a24911
1 tập tin đã thay đổi với 55 bổ sung2 xóa
  1. 55 2
      src/server/backend.py

+ 55 - 2
src/server/backend.py

@@ -1,11 +1,12 @@
 """Code for serving bloom blocks via hivemind-server"""
+import ctypes
 import multiprocessing as mp
 import os
 import threading
 from concurrent.futures import Future
 from dataclasses import dataclass, field
 from queue import Empty, PriorityQueue
-from typing import Optional, Sequence, Tuple
+from typing import Optional, Sequence, Tuple, Dict, Any, List
 
 import torch
 from hivemind import use_hivemind_log_handler
@@ -34,10 +35,23 @@ class PrioritizedTaskPool(TaskPool):
 
         self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
         self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
+        self.undispatched_task_priorities = mp.SimpleQueue()
+        self._timestamp = mp.Value(ctypes.c_double, 1.0)
+
+    @property
+    def priority(self):
+        return (-self._priority.value, -self._timestamp.value)
+
+    @priority.setter
+    def priority(self, priority_tuple: Sequence[float]):
+        assert len(priority_tuple) == 2, "pool priority must be a tuple of (priority, time_submitted)"
+        self._priority.value, self._timestamp.value = map(float, priority_tuple)
 
     def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
         f = super().submit_task(*args)
         self.priority_queue.put(priority)
+        self.undispatched_task_priorities.put(priority)
+        # TODO use a single queue here
         return f
 
     def _priortize_tasks(self):
@@ -69,20 +83,59 @@ class PrioritizedTaskPool(TaskPool):
             output_thread.join()
             priority_thread.join()
 
+    def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
+        """Infinite loop: aggregate tasks into batches and send them to runtime"""
+
+        prev_num_tasks = 0  # number of tasks currently in shared buffer
+        batch_index = max(pending_batches.keys(), default=0)
+        batch_iterator = self.iterate_minibatches(*args, **kwargs)
+
+        while True:
+            # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
+            # assumes that tasks are processed in the same order as they are created
+            for skip_i in range(prev_num_tasks):
+                dispatched_task_timestamp = self.undispatched_task_timestamps.get()
+                dispatched_task_priority = self.undispatched_task_priorities.get()
+                if skip_i == prev_num_tasks - 1:
+                    self.priority = (dispatched_task_priority, dispatched_task_timestamp)
+
+            logger.debug(f"{self.name} getting next batch")
+            batch_tasks = next(batch_iterator)
+            # save batch futures, _output_loop will deliver on them later
+            pending_batches[batch_index] = batch_tasks
+
+            logger.debug(f"{self.name}, batch  {batch_index}: aggregating inputs")
+            # find or create shared arrays for current batch size
+            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
+            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
+
+            logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
+            self.batch_sender.send((batch_index, batch_inputs))
+            logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
+            prev_num_tasks = len(batch_tasks)
+            batch_index += 1
+
+
     # TODO: this is a copy-paste of the original method, except that we use different queue
     def iterate_minibatches(self, *args, **kwargs):
         """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
+        print('IN iterate_minibatches')
         while True:
             try:
                 logger.debug(f"{self.name} getting next task")
                 task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
+                print('IN iterate_minibatches - 1')
             except Empty:
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
+                print('IN iterate_minibatches - 2')
                 continue
 
+            print('IN iterate_minibatches - 3')
             try:
                 if task.task.future.set_running_or_notify_cancel():
-                    yield [task]
+                    print('IN iterate_minibatches - 4')
+                    yield [task.task]
+                    print('IN iterate_minibatches - 5')
             except InvalidStateError as e:
                 logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")