|
@@ -10,7 +10,7 @@ import uuid
|
|
from collections import namedtuple
|
|
from collections import namedtuple
|
|
from concurrent.futures import Future
|
|
from concurrent.futures import Future
|
|
from queue import Empty
|
|
from queue import Empty
|
|
-from typing import List, Tuple, Dict, Any
|
|
|
|
|
|
+from typing import List, Tuple, Dict, Any, Generator
|
|
|
|
|
|
import torch
|
|
import torch
|
|
|
|
|
|
@@ -33,13 +33,9 @@ class TaskPoolBase(mp.Process):
|
|
def submit_task(self, *args: torch.Tensor) -> Future:
|
|
def submit_task(self, *args: torch.Tensor) -> Future:
|
|
raise NotImplementedError()
|
|
raise NotImplementedError()
|
|
|
|
|
|
- def form_batch(self, *args, **kwargs) -> List[Task]:
|
|
|
|
|
|
+ def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
|
|
raise NotImplementedError()
|
|
raise NotImplementedError()
|
|
|
|
|
|
- def iterate_minibatches(self, *args, **kwargs):
|
|
|
|
- while True:
|
|
|
|
- yield self.form_batch(*args, **kwargs)
|
|
|
|
-
|
|
|
|
@property
|
|
@property
|
|
def priority(self):
|
|
def priority(self):
|
|
return self._priority.value
|
|
return self._priority.value
|
|
@@ -101,34 +97,30 @@ class TaskPool(TaskPoolBase):
|
|
self.undispatched_task_timestamps.put(time.time())
|
|
self.undispatched_task_timestamps.put(time.time())
|
|
return future2
|
|
return future2
|
|
|
|
|
|
- def form_batch(self) -> List[Task]:
|
|
|
|
- batch_tasks = []
|
|
|
|
|
|
+ def iterate_minibatches(self, *args, **kwargs):
|
|
|
|
+ batch = []
|
|
total_size = 0
|
|
total_size = 0
|
|
|
|
|
|
- while total_size < self.max_batch_size:
|
|
|
|
- if total_size >= self.min_batch_size and self.tasks.empty():
|
|
|
|
- break # timeout reached, returning incomplete batch
|
|
|
|
-
|
|
|
|
|
|
+ while True:
|
|
try:
|
|
try:
|
|
task = self.tasks.get(timeout=self.timeout)
|
|
task = self.tasks.get(timeout=self.timeout)
|
|
except Empty:
|
|
except Empty:
|
|
exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
|
|
exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
|
|
- for task in batch_tasks:
|
|
|
|
|
|
+ for task in batch:
|
|
task.future.set_exception(exc)
|
|
task.future.set_exception(exc)
|
|
raise exc
|
|
raise exc
|
|
|
|
|
|
task_size = self.get_task_size(task)
|
|
task_size = self.get_task_size(task)
|
|
|
|
|
|
- if total_size + task_size > self.max_batch_size: # adding now will exceed max_batch_size, put it back
|
|
|
|
- self.tasks.put(task)
|
|
|
|
- self.undispatched_task_timestamps.put(self.undispatched_task_timestamps.get())
|
|
|
|
|
|
+ if total_size + task_size > self.max_batch_size or total_size >= self.min_batch_size and self.tasks.empty():
|
|
|
|
+ yield batch
|
|
|
|
+ batch = []
|
|
|
|
+ total_size = 0
|
|
|
|
|
|
if task.future.set_running_or_notify_cancel():
|
|
if task.future.set_running_or_notify_cancel():
|
|
- batch_tasks.append(task)
|
|
|
|
|
|
+ batch.append(task)
|
|
total_size += task_size
|
|
total_size += task_size
|
|
|
|
|
|
- return batch_tasks
|
|
|
|
-
|
|
|
|
def run(self, *args, **kwargs):
|
|
def run(self, *args, **kwargs):
|
|
print(f'Starting pool, pid={os.getpid()}')
|
|
print(f'Starting pool, pid={os.getpid()}')
|
|
pending_batches = {} # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
|
|
pending_batches = {} # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
|