justheuristic 2 lat temu
rodzic
commit
6a07bea20d
1 zmienionych plików z 6 dodań i 2 usunięć
  1. 6 2
      src/server/backend.py

+ 6 - 2
src/server/backend.py

@@ -31,8 +31,12 @@ class TransformerBackend(ModuleBackend):
         self.inference_pool = PrioritizedTaskPool(
             self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
         )
-        self.forward_pool = PrioritizedTaskPool(self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward")
-        self.backward_pool = PrioritizedTaskPool(self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward")
+        self.forward_pool = PrioritizedTaskPool(
+            self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
+        )
+        self.backward_pool = PrioritizedTaskPool(
+            self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
+        )
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
         self.inference_schema = (
             (