瀏覽代碼

default to fifo

justheuristic 3 年之前
父節點
當前提交
64a2a24911
共有 1 個文件被更改,包括 55 次插入2 次删除
  1. 55 2
      src/server/backend.py

+ 55 - 2
src/server/backend.py

@@ -1,11 +1,12 @@
 """Code for serving bloom blocks via hivemind-server"""
 """Code for serving bloom blocks via hivemind-server"""
+import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import os
 import os
 import threading
 import threading
 from concurrent.futures import Future
 from concurrent.futures import Future
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from queue import Empty, PriorityQueue
 from queue import Empty, PriorityQueue
-from typing import Optional, Sequence, Tuple
+from typing import Optional, Sequence, Tuple, Dict, Any, List
 
 
 import torch
 import torch
 from hivemind import use_hivemind_log_handler
 from hivemind import use_hivemind_log_handler
@@ -34,10 +35,23 @@ class PrioritizedTaskPool(TaskPool):
 
 
         self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
         self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
         self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
         self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
+        self.undispatched_task_priorities = mp.SimpleQueue()
+        self._timestamp = mp.Value(ctypes.c_double, 1.0)
+
+    @property
+    def priority(self):
+        return (-self._priority.value, -self._timestamp.value)
+
+    @priority.setter
+    def priority(self, priority_tuple: Sequence[float]):
+        assert len(priority_tuple) == 2, "pool priority must be a tuple of (priority, time_submitted)"
+        self._priority.value, self._timestamp.value = map(float, priority_tuple)
 
 
     def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
     def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
         f = super().submit_task(*args)
         f = super().submit_task(*args)
         self.priority_queue.put(priority)
         self.priority_queue.put(priority)
+        self.undispatched_task_priorities.put(priority)
+        # TODO use a single queue here
         return f
         return f
 
 
     def _priortize_tasks(self):
     def _priortize_tasks(self):
@@ -69,20 +83,59 @@ class PrioritizedTaskPool(TaskPool):
             output_thread.join()
             output_thread.join()
             priority_thread.join()
             priority_thread.join()
 
 
+    def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
+        """Infinite loop: aggregate tasks into batches and send them to runtime"""
+
+        prev_num_tasks = 0  # number of tasks currently in shared buffer
+        batch_index = max(pending_batches.keys(), default=0)
+        batch_iterator = self.iterate_minibatches(*args, **kwargs)
+
+        while True:
+            # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
+            # assumes that tasks are processed in the same order as they are created
+            for skip_i in range(prev_num_tasks):
+                dispatched_task_timestamp = self.undispatched_task_timestamps.get()
+                dispatched_task_priority = self.undispatched_task_priorities.get()
+                if skip_i == prev_num_tasks - 1:
+                    self.priority = (dispatched_task_priority, dispatched_task_timestamp)
+
+            logger.debug(f"{self.name} getting next batch")
+            batch_tasks = next(batch_iterator)
+            # save batch futures, _output_loop will deliver on them later
+            pending_batches[batch_index] = batch_tasks
+
+            logger.debug(f"{self.name}, batch  {batch_index}: aggregating inputs")
+            # find or create shared arrays for current batch size
+            batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
+            batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
+
+            logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
+            self.batch_sender.send((batch_index, batch_inputs))
+            logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
+            prev_num_tasks = len(batch_tasks)
+            batch_index += 1
+
+
     # TODO: this is a copy-paste of the original method, except that we use different queue
     # TODO: this is a copy-paste of the original method, except that we use different queue
     def iterate_minibatches(self, *args, **kwargs):
     def iterate_minibatches(self, *args, **kwargs):
         """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
         """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
+        print('IN iterate_minibatches')
         while True:
         while True:
             try:
             try:
                 logger.debug(f"{self.name} getting next task")
                 logger.debug(f"{self.name} getting next task")
                 task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
                 task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
+                print('IN iterate_minibatches - 1')
             except Empty:
             except Empty:
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
+                print('IN iterate_minibatches - 2')
                 continue
                 continue
 
 
+            print('IN iterate_minibatches - 3')
             try:
             try:
                 if task.task.future.set_running_or_notify_cancel():
                 if task.task.future.set_running_or_notify_cancel():
-                    yield [task]
+                    print('IN iterate_minibatches - 4')
+                    yield [task.task]
+                    print('IN iterate_minibatches - 5')
             except InvalidStateError as e:
             except InvalidStateError as e:
                 logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")
                 logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")