Pārlūkot izejas kodu

Get rid of form_batch

Max Ryabinin 5 gadi atpakaļ
vecāks
revīzija
9f73e00241
1 mainītis faili ar 11 papildinājumiem un 19 dzēšanām
  1. 11 19
      hivemind/runtime/task_pool.py

+ 11 - 19
hivemind/runtime/task_pool.py

@@ -10,7 +10,7 @@ import uuid
 from collections import namedtuple
 from concurrent.futures import Future
 from queue import Empty
-from typing import List, Tuple, Dict, Any
+from typing import List, Tuple, Dict, Any, Generator
 
 import torch
 
@@ -33,13 +33,9 @@ class TaskPoolBase(mp.Process):
     def submit_task(self, *args: torch.Tensor) -> Future:
         raise NotImplementedError()
 
-    def form_batch(self, *args, **kwargs) -> List[Task]:
+    def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
         raise NotImplementedError()
 
-    def iterate_minibatches(self, *args, **kwargs):
-        while True:
-            yield self.form_batch(*args, **kwargs)
-
     @property
     def priority(self):
         return self._priority.value
@@ -101,34 +97,30 @@ class TaskPool(TaskPoolBase):
             self.undispatched_task_timestamps.put(time.time())
         return future2
 
-    def form_batch(self) -> List[Task]:
-        batch_tasks = []
+    def iterate_minibatches(self, *args, **kwargs):
+        batch = []
         total_size = 0
 
-        while total_size < self.max_batch_size:
-            if total_size >= self.min_batch_size and self.tasks.empty():
-                break  # timeout reached, returning incomplete batch
-
+        while True:
             try:
                 task = self.tasks.get(timeout=self.timeout)
             except Empty:
                 exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
-                for task in batch_tasks:
+                for task in batch:
                     task.future.set_exception(exc)
                 raise exc
 
             task_size = self.get_task_size(task)
 
-            if total_size + task_size > self.max_batch_size:  # adding now will exceed max_batch_size, put it back
-                self.tasks.put(task)
-                self.undispatched_task_timestamps.put(self.undispatched_task_timestamps.get())
+            if total_size + task_size > self.max_batch_size or total_size >= self.min_batch_size and self.tasks.empty():
+                yield batch
+                batch = []
+                total_size = 0
 
             if task.future.set_running_or_notify_cancel():
-                batch_tasks.append(task)
+                batch.append(task)
                 total_size += task_size
 
-        return batch_tasks
-
     def run(self, *args, **kwargs):
         print(f'Starting pool, pid={os.getpid()}')
         pending_batches = {}  # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime