瀏覽代碼

Don't use NF4 default

Aleksandr Borzunov 2 年之前
父節點
當前提交
753f8df594
共有 1 個文件被更改,包括 3 次插入1 次删除
  1. 3 1
      src/petals/server/server.py

+ 3 - 1
src/petals/server/server.py

@@ -171,9 +171,11 @@ class Server:
 
 
         if quant_type is None:
         if quant_type is None:
             if device.type == "cuda":
             if device.type == "cuda":
-                quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
+                quant_type = QuantType.INT8
             else:
             else:
                 quant_type = QuantType.NONE
                 quant_type = QuantType.NONE
+        elif quant_type == QuantType.NF4:
+            raise RuntimeError("4-bit quantization (NF4) is not supported on AMD GPUs!")
         self.quant_type = quant_type
         self.quant_type = quant_type
         logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
         logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")