@@ -1,5 +1,5 @@
-import torch
import bitsandbytes as bnb
+import torch
def replace_8bit_linear(model, threshold=6.0):
@@ -31,4 +31,4 @@ def replace_8bit_linear(model, threshold=6.0):
has_fp16_weights=False,
threshold=threshold,
).to(model._modules[n].weight.device)
- return model
+ return model