|
@@ -79,9 +79,11 @@ class Runtime(threading.Thread):
|
|
|
self.stats_reporter.start()
|
|
|
logger.info("Started")
|
|
|
|
|
|
- for pool, batch_index, batch in BackgroundGenerator(
|
|
|
- self.iterate_minibatches_from_pools(), self.prefetch_batches
|
|
|
- ):
|
|
|
+ batch_iterator = self.iterate_minibatches_from_pools()
|
|
|
+ if self.prefetch_batches > 0:
|
|
|
+ batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
|
|
|
+
|
|
|
+ for pool, batch_index, batch in batch_iterator:
|
|
|
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
|
|
|
|
|
|
start = time()
|
|
@@ -127,9 +129,7 @@ class Runtime(threading.Thread):
|
|
|
self.shutdown_trigger.set()
|
|
|
|
|
|
def iterate_minibatches_from_pools(self, timeout=None):
|
|
|
- """
|
|
|
- Chooses pool according to priority, then copies exposed batch and frees the buffer
|
|
|
- """
|
|
|
+ """Iteratively select non-empty pool with highest priority and loads a batch from that pool"""
|
|
|
with DefaultSelector() as selector:
|
|
|
for pool in self.pools:
|
|
|
selector.register(pool.batch_receiver, EVENT_READ, pool)
|