|
@@ -286,7 +286,6 @@ class ModuleContainer(threading.Thread):
|
|
|
)
|
|
|
|
|
|
if load_in_8bit:
|
|
|
- dtype = block.input_layernorm.weight.dtype
|
|
|
block = replace_8bit_linear(block)
|
|
|
|
|
|
block = block.to(device)
|
|
@@ -300,13 +299,13 @@ class ModuleContainer(threading.Thread):
|
|
|
backend_dtype=None if torch_dtype == "auto" else torch_dtype,
|
|
|
args_schema=(
|
|
|
BatchTensorDescriptor(
|
|
|
- 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
|
|
+ 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
|
|
),
|
|
|
),
|
|
|
kwargs_schema={},
|
|
|
outputs_schema=(
|
|
|
BatchTensorDescriptor(
|
|
|
- 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
|
|
+ 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
|
|
),
|
|
|
),
|
|
|
min_batch_size=min_batch_size,
|