|
@@ -31,8 +31,12 @@ class TransformerBackend(ModuleBackend):
|
|
|
self.inference_pool = PrioritizedTaskPool(
|
|
|
self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
|
|
|
)
|
|
|
- self.forward_pool = PrioritizedTaskPool(self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward")
|
|
|
- self.backward_pool = PrioritizedTaskPool(self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward")
|
|
|
+ self.forward_pool = PrioritizedTaskPool(
|
|
|
+ self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
|
|
|
+ )
|
|
|
+ self.backward_pool = PrioritizedTaskPool(
|
|
|
+ self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
|
|
|
+ )
|
|
|
self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
|
|
|
self.inference_schema = (
|
|
|
(
|