convert_8bit.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334
  1. import bitsandbytes as bnb
  2. import torch
  3. def replace_8bit_linear(model, threshold=6.0):
  4. """
  5. A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
  6. library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
  7. 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
  8. version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
  9. bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
  10. The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
  11. be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
  12. CPU/GPU memory is required to run this function.
  13. Parameters:
  14. model (`torch.nn.Module`):
  15. Input model or `torch.nn.Module` as the function is run recursively.
  16. threshold (`float`, *optional*):
  17. `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
  18. `6.0` as described by the paper.
  19. """
  20. for n, module in model.named_children():
  21. if len(list(module.children())) > 0:
  22. replace_8bit_linear(module, threshold)
  23. if isinstance(module, torch.nn.Linear) and n != "lm_head":
  24. model._modules[n] = bnb.nn.Linear8bitLt(
  25. module.in_features,
  26. module.out_features,
  27. module.bias is not None,
  28. has_fp16_weights=False,
  29. threshold=threshold,
  30. ).to(model._modules[n].weight.device)
  31. return model