|
@@ -15,7 +15,6 @@ from hivemind.utils import InvalidStateError, get_logger
|
|
|
|
|
|
from src.bloom.from_pretrained import BloomBlock
|
|
|
from src.server.cache import MemoryCache
|
|
|
-from src.server.task_broker import DustBrokerBase, SimpleBroker
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -23,28 +22,30 @@ logger = get_logger(__file__)
|
|
|
|
|
|
@dataclass(order=True)
|
|
|
class PrioritizedTask:
|
|
|
- value: int
|
|
|
+ priority: float
|
|
|
task: Task = field(compare=False)
|
|
|
|
|
|
|
|
|
class PrioritizedTaskPool(TaskPool):
|
|
|
- def __init__(self, *args, broker: DustBrokerBase = SimpleBroker(), **kwargs):
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
- self.broker = broker
|
|
|
- self.dust_queue = mp.Queue(maxsize=self.tasks.maxsize)
|
|
|
- self.priority_queue = PriorityQueue(maxsize=self.tasks.maxsize)
|
|
|
|
|
|
- def submit_task(self, *args: torch.Tensor, dust: float = 0.0) -> Future:
|
|
|
+ assert self.min_batch_size == 1, "PriorityTaskPool supports no batching"
|
|
|
+
|
|
|
+ self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
|
|
|
+ self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
|
|
|
+
|
|
|
+ def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
|
|
|
f = super().submit_task(*args)
|
|
|
- self.dust_queue.put(dust)
|
|
|
+ self.priority_queue.put(priority)
|
|
|
return f
|
|
|
|
|
|
def _priortize_tasks(self):
|
|
|
"""Infinite loop prioritizing incoming tasks"""
|
|
|
while True:
|
|
|
task = self.tasks.get(block=True)
|
|
|
- dust = self.dust_queue.get(block=True)
|
|
|
- self.priority_queue.put(PrioritizedTask(-self.broker(task, dust), task), block=True)
|
|
|
+ priority = self.priority_queue.get(block=True)
|
|
|
+ self.prioritized_task_queue.put(PrioritizedTask(priority, task), block=True)
|
|
|
|
|
|
def run(self, *args, **kwargs):
|
|
|
torch.set_num_threads(1)
|
|
@@ -71,58 +72,19 @@ class PrioritizedTaskPool(TaskPool):
|
|
|
# TODO: this is a copy-paste of the original method, except that we use different queue
|
|
|
def iterate_minibatches(self, *args, **kwargs):
|
|
|
"""Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
|
|
|
- batch = []
|
|
|
- total_size = 0
|
|
|
-
|
|
|
- while True:
|
|
|
- if total_size >= self.min_batch_size and self.priority_queue.empty():
|
|
|
- yield batch
|
|
|
- batch = []
|
|
|
- total_size = 0
|
|
|
- try:
|
|
|
- logger.debug(f"{self.name} getting next task")
|
|
|
- task = self.priority_queue.get(timeout=self.timeout)
|
|
|
- except Empty:
|
|
|
- logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
|
|
|
- continue
|
|
|
-
|
|
|
- task_size = self.get_task_size(task)
|
|
|
-
|
|
|
- if total_size + task_size > self.max_batch_size:
|
|
|
- yield batch
|
|
|
- batch = []
|
|
|
- total_size = 0
|
|
|
-
|
|
|
- try:
|
|
|
- if task.future.set_running_or_notify_cancel():
|
|
|
- batch.append(task)
|
|
|
- total_size += task_size
|
|
|
- except InvalidStateError as e:
|
|
|
- logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
|
|
|
-
|
|
|
-
|
|
|
-class InferenceTaskPool(TaskPool):
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
- super().__init__(*args, **kwargs)
|
|
|
-
|
|
|
- assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
|
|
|
-
|
|
|
- def iterate_minibatches(self, *args, **kwargs):
|
|
|
- """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
|
|
|
-
|
|
|
while True:
|
|
|
try:
|
|
|
logger.debug(f"{self.name} getting next task")
|
|
|
- task = self.tasks.get(timeout=self.timeout)
|
|
|
+ task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
|
|
|
except Empty:
|
|
|
logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
- if task.future.set_running_or_notify_cancel():
|
|
|
+ if task.task.future.set_running_or_notify_cancel():
|
|
|
yield [task]
|
|
|
except InvalidStateError as e:
|
|
|
- logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
|
|
|
+ logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")
|
|
|
|
|
|
|
|
|
class TransformerBackend(ModuleBackend):
|
|
@@ -137,9 +99,11 @@ class TransformerBackend(ModuleBackend):
|
|
|
for name, buf in self.module.named_buffers():
|
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
|
|
|
|
- self.inference_pool = InferenceTaskPool(
|
|
|
+ self.inference_pool = PrioritizedTaskPool(
|
|
|
self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
|
|
|
)
|
|
|
+ self.forward_pool = PrioritizedTaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
|
|
|
+ self.backward_pool = PrioritizedTaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
|
|
|
self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
|
|
|
|
|
|
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|