@@ -37,4 +37,5 @@ def replace_8bit_linear(model, threshold=6.0):
model._modules[n].weight = bnb.nn.Int8Params(
module.weight.data, requires_grad=False, has_fp16_weights=False
).to(module.weight.dtype)
+ model._modules[n].bias = module.bias
return model