|
@@ -27,11 +27,12 @@ class TransformerBackend(ModuleBackend):
|
|
|
for name, buf in self.module.named_buffers():
|
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
|
|
|
|
+ max_batch_size = self.forward_pool.max_batch_size
|
|
|
self.inference_pool = PrioritizedTaskPool(
|
|
|
- self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
|
|
|
+ self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
|
|
|
)
|
|
|
- self.forward_pool = PrioritizedTaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
|
|
|
- self.backward_pool = PrioritizedTaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
|
|
|
+ self.forward_pool = PrioritizedTaskPool(self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward")
|
|
|
+ self.backward_pool = PrioritizedTaskPool(self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward")
|
|
|
self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
|
|
|
self.inference_schema = (
|
|
|
(
|