|
@@ -1,5 +1,5 @@
|
|
-import torch
|
|
|
|
import bitsandbytes as bnb
|
|
import bitsandbytes as bnb
|
|
|
|
+import torch
|
|
|
|
|
|
|
|
|
|
def replace_8bit_linear(model, threshold=6.0):
|
|
def replace_8bit_linear(model, threshold=6.0):
|
|
@@ -31,4 +31,4 @@ def replace_8bit_linear(model, threshold=6.0):
|
|
has_fp16_weights=False,
|
|
has_fp16_weights=False,
|
|
threshold=threshold,
|
|
threshold=threshold,
|
|
).to(model._modules[n].weight.device)
|
|
).to(model._modules[n].weight.device)
|
|
- return model
|
|
|
|
|
|
+ return model
|