|
@@ -23,7 +23,7 @@ class TransformerBackend(ModuleBackend):
|
|
for name, buf in self.module.named_buffers():
|
|
for name, buf in self.module.named_buffers():
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
|
|
|
|
- self.inference_pool = TaskPool(self.inference_step, max_batch_size=4096, name=f"{self.name}_inference")
|
|
|
|
|
|
+ self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
|
|
|
|
|
|
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
with torch.inference_mode():
|
|
with torch.inference_mode():
|