|
@@ -10,7 +10,7 @@ import uuid
|
|
|
from collections import namedtuple
|
|
|
from concurrent.futures import Future
|
|
|
from queue import Empty
|
|
|
-from typing import List, Tuple, Dict, Any
|
|
|
+from typing import List, Tuple, Dict, Any, Generator
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -33,13 +33,9 @@ class TaskPoolBase(mp.Process):
|
|
|
def submit_task(self, *args: torch.Tensor) -> Future:
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
- def form_batch(self, *args, **kwargs) -> List[Task]:
|
|
|
+ def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
- def iterate_minibatches(self, *args, **kwargs):
|
|
|
- while True:
|
|
|
- yield self.form_batch(*args, **kwargs)
|
|
|
-
|
|
|
@property
|
|
|
def priority(self):
|
|
|
return self._priority.value
|
|
@@ -92,31 +88,42 @@ class TaskPool(TaskPoolBase):
|
|
|
def submit_task(self, *args: torch.Tensor) -> Future:
|
|
|
""" Add task to this pool's queue, return Future for its output """
|
|
|
future1, future2 = SharedFuture.make_pair()
|
|
|
- self.tasks.put(Task(future1, args))
|
|
|
- self.undispatched_task_timestamps.put(time.time())
|
|
|
+ task = Task(future1, args)
|
|
|
+ if self.get_task_size(task) > self.max_batch_size:
|
|
|
+ exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it will never be finished")
|
|
|
+ future2.set_exception(exc)
|
|
|
+ else:
|
|
|
+ self.tasks.put(task)
|
|
|
+ self.undispatched_task_timestamps.put(time.time())
|
|
|
return future2
|
|
|
|
|
|
- def form_batch(self) -> List[Task]:
|
|
|
- batch_tasks = []
|
|
|
+ def iterate_minibatches(self, *args, **kwargs):
|
|
|
+ batch = []
|
|
|
total_size = 0
|
|
|
|
|
|
- while total_size < self.max_batch_size:
|
|
|
+ while True:
|
|
|
if total_size >= self.min_batch_size and self.tasks.empty():
|
|
|
- break # timeout reached, returning incomplete batch
|
|
|
-
|
|
|
+ yield batch
|
|
|
+ batch = []
|
|
|
+ total_size = 0
|
|
|
try:
|
|
|
task = self.tasks.get(timeout=self.timeout)
|
|
|
except Empty:
|
|
|
exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
|
|
|
- for task in batch_tasks:
|
|
|
+ for task in batch:
|
|
|
task.future.set_exception(exc)
|
|
|
raise exc
|
|
|
|
|
|
- if task.future.set_running_or_notify_cancel():
|
|
|
- batch_tasks.append(task)
|
|
|
- total_size += self.get_task_size(task)
|
|
|
+ task_size = self.get_task_size(task)
|
|
|
|
|
|
- return batch_tasks
|
|
|
+ if total_size + task_size > self.max_batch_size:
|
|
|
+ yield batch
|
|
|
+ batch = []
|
|
|
+ total_size = 0
|
|
|
+
|
|
|
+ if task.future.set_running_or_notify_cancel():
|
|
|
+ batch.append(task)
|
|
|
+ total_size += task_size
|
|
|
|
|
|
def run(self, *args, **kwargs):
|
|
|
print(f'Starting pool, pid={os.getpid()}')
|