|
@@ -15,30 +15,6 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-class InferenceTaskPool(TaskPool):
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
- super().__init__(*args, **kwargs)
|
|
|
-
|
|
|
- assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
|
|
|
-
|
|
|
- def iterate_minibatches(self, *args, **kwargs):
|
|
|
- """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
|
|
|
-
|
|
|
- while True:
|
|
|
- try:
|
|
|
- logger.debug(f"{self.name} getting next task")
|
|
|
- task = self.tasks.get(timeout=self.timeout)
|
|
|
- except Empty:
|
|
|
- logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
|
|
|
- continue
|
|
|
-
|
|
|
- try:
|
|
|
- if task.future.set_running_or_notify_cancel():
|
|
|
- yield [task]
|
|
|
- except InvalidStateError as e:
|
|
|
- logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
|
|
|
-
|
|
|
-
|
|
|
class InferenceTaskPool(TaskPool):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|