Explorar o código

Enforce max_batch_size

Max Ryabinin %!s(int64=5) %!d(string=hai) anos
pai
achega
aa5377aa5a
Modificáronse 1 ficheiros con 14 adicións e 3 borrados
  1. 14 3
      hivemind/runtime/task_pool.py

+ 14 - 3
hivemind/runtime/task_pool.py

@@ -92,8 +92,13 @@ class TaskPool(TaskPoolBase):
     def submit_task(self, *args: torch.Tensor) -> Future:
     def submit_task(self, *args: torch.Tensor) -> Future:
         """ Add task to this pool's queue, return Future for its output """
         """ Add task to this pool's queue, return Future for its output """
         future1, future2 = SharedFuture.make_pair()
         future1, future2 = SharedFuture.make_pair()
-        self.tasks.put(Task(future1, args))
-        self.undispatched_task_timestamps.put(time.time())
+        task = Task(future1, args)
+        if self.get_task_size(task) > self.max_batch_size:
+            exc = ValueError("Task size greater than max_batch_size, it will never be finished")
+            future2.set_exception(exc)
+        else:
+            self.tasks.put(task)
+            self.undispatched_task_timestamps.put(time.time())
         return future2
         return future2
 
 
     def form_batch(self) -> List[Task]:
     def form_batch(self) -> List[Task]:
@@ -112,9 +117,15 @@ class TaskPool(TaskPoolBase):
                     task.future.set_exception(exc)
                     task.future.set_exception(exc)
                 raise exc
                 raise exc
 
 
+            task_size = self.get_task_size(task)
+
+            if total_size + task_size > self.max_batch_size:  # adding now will exceed max_batch_size, put it back
+                self.tasks.put(task)
+                self.undispatched_task_timestamps.put(self.undispatched_task_timestamps.get())
+
             if task.future.set_running_or_notify_cancel():
             if task.future.set_running_or_notify_cancel():
                 batch_tasks.append(task)
                 batch_tasks.append(task)
-                total_size += self.get_task_size(task)
+                total_size += task_size
 
 
         return batch_tasks
         return batch_tasks