|
@@ -52,9 +52,9 @@ def test_linear_exact_match():
|
|
|
has_fp16_weights=False,
|
|
|
threshold=6.0,
|
|
|
)
|
|
|
- linear_custom.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
|
|
|
- linear.weight.dtype
|
|
|
- )
|
|
|
+ linear_custom.weight = bnb.nn.Int8Params(
|
|
|
+ linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
|
|
+ ).to(linear.weight.dtype)
|
|
|
linear_custom.bias = linear.bias
|
|
|
linear_custom.cuda()
|
|
|
|