Browse Source

create pools with proper max batch size

justheuristic 2 years ago
parent
commit
6a41bbfaa5
2 changed files with 8 additions and 3 deletions
  1. 4 3
      src/server/backend.py
  2. 4 0
      src/server/server.py

+ 4 - 3
src/server/backend.py

@@ -27,11 +27,12 @@ class TransformerBackend(ModuleBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
+        max_batch_size = self.forward_pool.max_batch_size
         self.inference_pool = PrioritizedTaskPool(
-            self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
+            self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
         )
-        self.forward_pool = PrioritizedTaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
-        self.backward_pool = PrioritizedTaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
+        self.forward_pool = PrioritizedTaskPool(self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward")
+        self.backward_pool = PrioritizedTaskPool(self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward")
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
         self.inference_schema = (
             (

+ 4 - 0
src/server/server.py

@@ -118,6 +118,8 @@ class Server(threading.Thread):
         custom_module_path=None,
         update_period: float = 30,
         expiration: Optional[float] = None,
+        prefetch_batches: int = 1,
+        sender_threads: int = 1,
         max_block_selection_delay: float = 1,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
@@ -236,6 +238,8 @@ class Server(threading.Thread):
             stats_report_interval=stats_report_interval,
             update_period=update_period,
             expiration=expiration,
+            prefetch_batches=prefetch_batches,
+            sender_threads=sender_threads,
             start=start,
         )