Browse Source

Require TaskPoolBase to implement load_batch_to_runtime (#506)

The TaskPoolBase interface currently requires iterate_minibatches to be implemented. However, this method is not called by anything except TaskPool (internally). Runtime actually calls load_batch_to_runtime. This PR changes the interface to reflect that.

While we're at it, i've also changed prefetch generator so that it actually does not prefetch batches when prefetch_batches = 0. Previously, 0 would silently mean "unlimited",


Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 2 years ago
parent
commit
41587e44d5
2 changed files with 8 additions and 8 deletions
  1. 6 6
      hivemind/moe/server/runtime.py
  2. 2 2
      hivemind/moe/server/task_pool.py

+ 6 - 6
hivemind/moe/server/runtime.py

@@ -79,9 +79,11 @@ class Runtime(threading.Thread):
                     self.stats_reporter.start()
                 logger.info("Started")
 
-                for pool, batch_index, batch in BackgroundGenerator(
-                    self.iterate_minibatches_from_pools(), self.prefetch_batches
-                ):
+                batch_iterator = self.iterate_minibatches_from_pools()
+                if self.prefetch_batches > 0:
+                    batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
+
+                for pool, batch_index, batch in batch_iterator:
                     logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
 
                     start = time()
@@ -127,9 +129,7 @@ class Runtime(threading.Thread):
         self.shutdown_trigger.set()
 
     def iterate_minibatches_from_pools(self, timeout=None):
-        """
-        Chooses pool according to priority, then copies exposed batch and frees the buffer
-        """
+        """Iteratively select non-empty pool with highest priority and loads a batch from that pool"""
         with DefaultSelector() as selector:
             for pool in self.pools:
                 selector.register(pool.batch_receiver, EVENT_READ, pool)

+ 2 - 2
hivemind/moe/server/task_pool.py

@@ -38,7 +38,7 @@ class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
         pass
 
     @abstractmethod
-    def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
+    def load_batch_to_runtime(self) -> Tuple[Any, List[torch.Tensor]]:
         pass
 
     @property
@@ -230,7 +230,7 @@ class TaskPool(TaskPoolBase):
         return not self.batch_receiver.poll()
 
     def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]:
-        """receive next batch of numpy arrays"""
+        """receive next batch of tensors"""
         if not self.batch_receiver.poll(timeout):
             raise TimeoutError()