|
@@ -92,8 +92,13 @@ class TaskPool(TaskPoolBase):
|
|
|
def submit_task(self, *args: torch.Tensor) -> Future:
|
|
|
""" Add task to this pool's queue, return Future for its output """
|
|
|
future1, future2 = SharedFuture.make_pair()
|
|
|
- self.tasks.put(Task(future1, args))
|
|
|
- self.undispatched_task_timestamps.put(time.time())
|
|
|
+ task = Task(future1, args)
|
|
|
+ if self.get_task_size(task) > self.max_batch_size:
|
|
|
+ exc = ValueError("Task size greater than max_batch_size, it will never be finished")
|
|
|
+ future2.set_exception(exc)
|
|
|
+ else:
|
|
|
+ self.tasks.put(task)
|
|
|
+ self.undispatched_task_timestamps.put(time.time())
|
|
|
return future2
|
|
|
|
|
|
def form_batch(self) -> List[Task]:
|
|
@@ -112,9 +117,15 @@ class TaskPool(TaskPoolBase):
|
|
|
task.future.set_exception(exc)
|
|
|
raise exc
|
|
|
|
|
|
+ task_size = self.get_task_size(task)
|
|
|
+
|
|
|
+ if total_size + task_size > self.max_batch_size: # adding now will exceed max_batch_size, put it back
|
|
|
+ self.tasks.put(task)
|
|
|
+ self.undispatched_task_timestamps.put(self.undispatched_task_timestamps.get())
|
|
|
+
|
|
|
if task.future.set_running_or_notify_cancel():
|
|
|
batch_tasks.append(task)
|
|
|
- total_size += self.get_task_size(task)
|
|
|
+ total_size += task_size
|
|
|
|
|
|
return batch_tasks
|
|
|
|