Browse Source

Use 4-bit for llama by default, use bitsandbytes 0.40.0.post3 (#340)

NF4 inference with bitsandbytes 0.40.0.post3 is ~2x faster than int8 inference, though training is still ~3x slower, see:

- [bitsandbytes 0.40.0 Release notes](https://github.com/TimDettmers/bitsandbytes/releases/tag/0.40.0)
- [RPS benchmarks](https://github.com/bigscience-workshop/petals/pull/333#issuecomment-1614040385)

We've decided to use NF4 by default for LLaMA.
Alexander Borzunov 2 years ago
parent
commit
fa095f6461
2 changed files with 5 additions and 2 deletions
  1. 1 1
      setup.cfg
  2. 4 1
      src/petals/server/server.py

+ 1 - 1
setup.cfg

@@ -32,7 +32,7 @@ packages = find:
 python_requires = >=3.7
 install_requires =
     torch>=1.12
-    bitsandbytes==0.39.1
+    bitsandbytes==0.40.0.post3
     accelerate>=0.16.0,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3

+ 4 - 1
src/petals/server/server.py

@@ -166,7 +166,10 @@ class Server:
             check_device_balance(self.tensor_parallel_devices)
 
         if quant_type is None:
-            quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE
+            if device.type == "cuda":
+                quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
+            else:
+                quant_type = QuantType.NONE
         self.quant_type = quant_type
         logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")