|
@@ -1,11 +1,12 @@
|
|
|
"""Code for serving bloom blocks via hivemind-server"""
|
|
|
+import ctypes
|
|
|
import multiprocessing as mp
|
|
|
import os
|
|
|
import threading
|
|
|
from concurrent.futures import Future
|
|
|
from dataclasses import dataclass, field
|
|
|
from queue import Empty, PriorityQueue
|
|
|
-from typing import Optional, Sequence, Tuple
|
|
|
+from typing import Optional, Sequence, Tuple, Dict, Any, List
|
|
|
|
|
|
import torch
|
|
|
from hivemind import use_hivemind_log_handler
|
|
@@ -34,10 +35,23 @@ class PrioritizedTaskPool(TaskPool):
|
|
|
|
|
|
self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
|
|
|
self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
|
|
|
+ self.undispatched_task_priorities = mp.SimpleQueue()
|
|
|
+ self._timestamp = mp.Value(ctypes.c_double, 1.0)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def priority(self):
|
|
|
+ return (-self._priority.value, -self._timestamp.value)
|
|
|
+
|
|
|
+ @priority.setter
|
|
|
+ def priority(self, priority_tuple: Sequence[float]):
|
|
|
+ assert len(priority_tuple) == 2, "pool priority must be a tuple of (priority, time_submitted)"
|
|
|
+ self._priority.value, self._timestamp.value = map(float, priority_tuple)
|
|
|
|
|
|
def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
|
|
|
f = super().submit_task(*args)
|
|
|
self.priority_queue.put(priority)
|
|
|
+ self.undispatched_task_priorities.put(priority)
|
|
|
+ # TODO use a single queue here
|
|
|
return f
|
|
|
|
|
|
def _priortize_tasks(self):
|
|
@@ -69,20 +83,59 @@ class PrioritizedTaskPool(TaskPool):
|
|
|
output_thread.join()
|
|
|
priority_thread.join()
|
|
|
|
|
|
+ def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
|
|
|
+ """Infinite loop: aggregate tasks into batches and send them to runtime"""
|
|
|
+
|
|
|
+ prev_num_tasks = 0 # number of tasks currently in shared buffer
|
|
|
+ batch_index = max(pending_batches.keys(), default=0)
|
|
|
+ batch_iterator = self.iterate_minibatches(*args, **kwargs)
|
|
|
+
|
|
|
+ while True:
|
|
|
+ # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
|
|
|
+ # assumes that tasks are processed in the same order as they are created
|
|
|
+ for skip_i in range(prev_num_tasks):
|
|
|
+ dispatched_task_timestamp = self.undispatched_task_timestamps.get()
|
|
|
+ dispatched_task_priority = self.undispatched_task_priorities.get()
|
|
|
+ if skip_i == prev_num_tasks - 1:
|
|
|
+ self.priority = (dispatched_task_priority, dispatched_task_timestamp)
|
|
|
+
|
|
|
+ logger.debug(f"{self.name} getting next batch")
|
|
|
+ batch_tasks = next(batch_iterator)
|
|
|
+ # save batch futures, _output_loop will deliver on them later
|
|
|
+ pending_batches[batch_index] = batch_tasks
|
|
|
+
|
|
|
+ logger.debug(f"{self.name}, batch {batch_index}: aggregating inputs")
|
|
|
+ # find or create shared arrays for current batch size
|
|
|
+ batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
|
|
|
+ batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
|
|
|
+
|
|
|
+ logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
|
|
|
+ self.batch_sender.send((batch_index, batch_inputs))
|
|
|
+ logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
|
|
|
+ prev_num_tasks = len(batch_tasks)
|
|
|
+ batch_index += 1
|
|
|
+
|
|
|
+
|
|
|
# TODO: this is a copy-paste of the original method, except that we use different queue
|
|
|
def iterate_minibatches(self, *args, **kwargs):
|
|
|
"""Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
|
|
|
+ print('IN iterate_minibatches')
|
|
|
while True:
|
|
|
try:
|
|
|
logger.debug(f"{self.name} getting next task")
|
|
|
task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
|
|
|
+ print('IN iterate_minibatches - 1')
|
|
|
except Empty:
|
|
|
logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
|
|
|
+ print('IN iterate_minibatches - 2')
|
|
|
continue
|
|
|
|
|
|
+ print('IN iterate_minibatches - 3')
|
|
|
try:
|
|
|
if task.task.future.set_running_or_notify_cancel():
|
|
|
- yield [task]
|
|
|
+ print('IN iterate_minibatches - 4')
|
|
|
+ yield [task.task]
|
|
|
+ print('IN iterate_minibatches - 5')
|
|
|
except InvalidStateError as e:
|
|
|
logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")
|
|
|
|