瀏覽代碼

Update convert_8bit.py

justheuristic 2 年之前
父節點
當前提交
0e5e93af7c
共有 1 個文件被更改,包括 5 次插入0 次删除
  1. 5 0
      src/utils/convert_8bit.py

+ 5 - 0
src/utils/convert_8bit.py

@@ -1,7 +1,11 @@
 import bitsandbytes as bnb
+import os
 import torch
 
 
+PETALS_8BIT_BACKWARD = bool(int(os.environ.get("PETALS_8BIT_BACKWARD", 0)))
+
+
 def replace_8bit_linear(model, threshold=6.0):
     """
     A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
@@ -29,6 +33,7 @@ def replace_8bit_linear(model, threshold=6.0):
                 module.bias is not None,
                 has_fp16_weights=False,
                 threshold=threshold,
+                memory_efficient_backward=PETALS_8BIT_BACKWARD,
             )
             model._modules[n].weight = bnb.nn.Int8Params(
                 module.weight.data, requires_grad=False, has_fp16_weights=False