|
@@ -166,7 +166,10 @@ class Server:
|
|
check_device_balance(self.tensor_parallel_devices)
|
|
check_device_balance(self.tensor_parallel_devices)
|
|
|
|
|
|
if quant_type is None:
|
|
if quant_type is None:
|
|
- quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE
|
|
|
|
|
|
+ if device.type == "cuda":
|
|
|
|
+ quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
|
|
|
|
+ else:
|
|
|
|
+ quant_type = QuantType.NONE
|
|
self.quant_type = quant_type
|
|
self.quant_type = quant_type
|
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
|
|
|
|