Ver código fonte

Enforce max_batch_size

Max Ryabinin 5 anos atrás
pai
commit
aa5377aa5a
1 arquivos alterados com 14 adições e 3 exclusões
  1. 14 3
      hivemind/runtime/task_pool.py

+ 14 - 3
hivemind/runtime/task_pool.py

@@ -92,8 +92,13 @@ class TaskPool(TaskPoolBase):
     def submit_task(self, *args: torch.Tensor) -> Future:
         """ Add task to this pool's queue, return Future for its output """
         future1, future2 = SharedFuture.make_pair()
-        self.tasks.put(Task(future1, args))
-        self.undispatched_task_timestamps.put(time.time())
+        task = Task(future1, args)
+        if self.get_task_size(task) > self.max_batch_size:
+            exc = ValueError("Task size greater than max_batch_size, it will never be finished")
+            future2.set_exception(exc)
+        else:
+            self.tasks.put(task)
+            self.undispatched_task_timestamps.put(time.time())
         return future2
 
     def form_batch(self) -> List[Task]:
@@ -112,9 +117,15 @@ class TaskPool(TaskPoolBase):
                     task.future.set_exception(exc)
                 raise exc
 
+            task_size = self.get_task_size(task)
+
+            if total_size + task_size > self.max_batch_size:  # adding now will exceed max_batch_size, put it back
+                self.tasks.put(task)
+                self.undispatched_task_timestamps.put(self.undispatched_task_timestamps.get())
+
             if task.future.set_running_or_notify_cancel():
                 batch_tasks.append(task)
-                total_size += self.get_task_size(task)
+                total_size += task_size
 
         return batch_tasks