Browse Source

Temporary fix: check that batch has min_batch_size elements before getting

Max Ryabinin 5 years ago
parent
commit
3fa788e54b
1 changed files with 5 additions and 1 deletions
  1. 5 1
      hivemind/runtime/task_pool.py

+ 5 - 1
hivemind/runtime/task_pool.py

@@ -102,6 +102,10 @@ class TaskPool(TaskPoolBase):
         total_size = 0
         total_size = 0
 
 
         while True:
         while True:
+            if total_size >= self.min_batch_size and self.tasks.empty():
+                yield batch
+                batch = []
+                total_size = 0
             try:
             try:
                 task = self.tasks.get(timeout=self.timeout)
                 task = self.tasks.get(timeout=self.timeout)
             except Empty:
             except Empty:
@@ -112,7 +116,7 @@ class TaskPool(TaskPoolBase):
 
 
             task_size = self.get_task_size(task)
             task_size = self.get_task_size(task)
 
 
-            if total_size + task_size > self.max_batch_size or total_size >= self.min_batch_size and self.tasks.empty():
+            if total_size + task_size > self.max_batch_size:
                 yield batch
                 yield batch
                 batch = []
                 batch = []
                 total_size = 0
                 total_size = 0