Jelajahi Sumber

Merge pull request #26 from learning-at-home/form_batch_fix

Enforce max_batch_size
justheuristic 5 tahun lalu
induk
melakukan
b092c4c00c
1 mengubah file dengan 25 tambahan dan 18 penghapusan
  1. 25 18
      hivemind/runtime/task_pool.py

+ 25 - 18
hivemind/runtime/task_pool.py

@@ -10,7 +10,7 @@ import uuid
 from collections import namedtuple
 from concurrent.futures import Future
 from queue import Empty
-from typing import List, Tuple, Dict, Any
+from typing import List, Tuple, Dict, Any, Generator
 
 import torch
 
@@ -33,13 +33,9 @@ class TaskPoolBase(mp.Process):
     def submit_task(self, *args: torch.Tensor) -> Future:
         raise NotImplementedError()
 
-    def form_batch(self, *args, **kwargs) -> List[Task]:
+    def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
         raise NotImplementedError()
 
-    def iterate_minibatches(self, *args, **kwargs):
-        while True:
-            yield self.form_batch(*args, **kwargs)
-
     @property
     def priority(self):
         return self._priority.value
@@ -92,31 +88,42 @@ class TaskPool(TaskPoolBase):
     def submit_task(self, *args: torch.Tensor) -> Future:
         """ Add task to this pool's queue, return Future for its output """
         future1, future2 = SharedFuture.make_pair()
-        self.tasks.put(Task(future1, args))
-        self.undispatched_task_timestamps.put(time.time())
+        task = Task(future1, args)
+        if self.get_task_size(task) > self.max_batch_size:
+            exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it will never be finished")
+            future2.set_exception(exc)
+        else:
+            self.tasks.put(task)
+            self.undispatched_task_timestamps.put(time.time())
         return future2
 
-    def form_batch(self) -> List[Task]:
-        batch_tasks = []
+    def iterate_minibatches(self, *args, **kwargs):
+        batch = []
         total_size = 0
 
-        while total_size < self.max_batch_size:
+        while True:
             if total_size >= self.min_batch_size and self.tasks.empty():
-                break  # timeout reached, returning incomplete batch
-
+                yield batch
+                batch = []
+                total_size = 0
             try:
                 task = self.tasks.get(timeout=self.timeout)
             except Empty:
                 exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
-                for task in batch_tasks:
+                for task in batch:
                     task.future.set_exception(exc)
                 raise exc
 
-            if task.future.set_running_or_notify_cancel():
-                batch_tasks.append(task)
-                total_size += self.get_task_size(task)
+            task_size = self.get_task_size(task)
 
-        return batch_tasks
+            if total_size + task_size > self.max_batch_size:
+                yield batch
+                batch = []
+                total_size = 0
+
+            if task.future.set_running_or_notify_cancel():
+                batch.append(task)
+                total_size += task_size
 
     def run(self, *args, **kwargs):
         print(f'Starting pool, pid={os.getpid()}')