|
@@ -195,21 +195,35 @@ class TaskPool(TaskPoolBase):
|
|
|
|
|
|
while True:
|
|
|
logger.debug(f"{self.name} waiting for results from runtime")
|
|
|
- batch_index, batch_outputs = self.outputs_receiver.recv()
|
|
|
- logger.debug(f"{self.name}, batch {batch_index}: got results")
|
|
|
-
|
|
|
- # split batch into partitions for individual tasks
|
|
|
+ batch_index, batch_outputs_or_exception = self.outputs_receiver.recv()
|
|
|
batch_tasks = pending_batches.pop(batch_index)
|
|
|
- task_sizes = [self.get_task_size(task) for task in batch_tasks]
|
|
|
- outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
|
|
|
- logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
|
|
|
|
|
|
- # dispatch results to futures
|
|
|
- for task, task_outputs in zip(batch_tasks, outputs_per_task):
|
|
|
- try:
|
|
|
- task.future.set_result(tuple(task_outputs))
|
|
|
- except InvalidStateError as e:
|
|
|
- logger.debug(f"Failed to send task result due to an exception: {e}")
|
|
|
+ if isinstance(batch_outputs_or_exception, BaseException):
|
|
|
+ logger.debug(f"{self.name}, batch {batch_index}: got exception, propagating to handlers")
|
|
|
+ exception = batch_outputs_or_exception
|
|
|
+ for task in batch_tasks:
|
|
|
+ try:
|
|
|
+ task.future.set_exception(exception)
|
|
|
+ except InvalidStateError as e:
|
|
|
+ logger.debug(f"Failed to send runtime error to a task: {e}")
|
|
|
+
|
|
|
+ else:
|
|
|
+ logger.debug(f"{self.name}, batch {batch_index}: got results")
|
|
|
+ batch_outputs = batch_outputs_or_exception
|
|
|
+
|
|
|
+ # split batch into partitions for individual tasks
|
|
|
+ task_sizes = [self.get_task_size(task) for task in batch_tasks]
|
|
|
+ outputs_per_task = zip(
|
|
|
+ *(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs)
|
|
|
+ )
|
|
|
+ logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
|
|
|
+
|
|
|
+ # dispatch results to futures
|
|
|
+ for task, task_outputs in zip(batch_tasks, outputs_per_task):
|
|
|
+ try:
|
|
|
+ task.future.set_result(tuple(task_outputs))
|
|
|
+ except InvalidStateError as e:
|
|
|
+ logger.debug(f"Failed to send task result due to an exception: {e}")
|
|
|
|
|
|
@property
|
|
|
def empty(self):
|
|
@@ -232,6 +246,9 @@ class TaskPool(TaskPoolBase):
|
|
|
]
|
|
|
self.outputs_sender.send((batch_index, batch_outputs))
|
|
|
|
|
|
+ def send_exception_from_runtime(self, batch_index: int, exception: BaseException):
|
|
|
+ self.outputs_sender.send((batch_index, exception))
|
|
|
+
|
|
|
def get_task_size(self, task: Task) -> int:
|
|
|
"""compute task processing complexity (used for batching); defaults to batch size"""
|
|
|
return len(task.args[0]) if task.args else 1
|