浏览代码

Temporary fix: check that batch has min_batch_size elements before getting

Max Ryabinin 5 年之前
父节点
当前提交
3fa788e54b
共有 1 个文件被更改,包括 5 次插入1 次删除
  1. 5 1
      hivemind/runtime/task_pool.py

+ 5 - 1
hivemind/runtime/task_pool.py

@@ -102,6 +102,10 @@ class TaskPool(TaskPoolBase):
         total_size = 0
 
         while True:
+            if total_size >= self.min_batch_size and self.tasks.empty():
+                yield batch
+                batch = []
+                total_size = 0
             try:
                 task = self.tasks.get(timeout=self.timeout)
             except Empty:
@@ -112,7 +116,7 @@ class TaskPool(TaskPoolBase):
 
             task_size = self.get_task_size(task)
 
-            if total_size + task_size > self.max_batch_size or total_size >= self.min_batch_size and self.tasks.empty():
+            if total_size + task_size > self.max_batch_size:
                 yield batch
                 batch = []
                 total_size = 0