justheuristic 2 ani în urmă
părinte
comite
b93739df9c
1 a modificat fișierele cu 3 adăugiri și 3 ștergeri
  1. 3 3
      tests/test_linear8bitlt.py

+ 3 - 3
tests/test_linear8bitlt.py

@@ -52,9 +52,9 @@ def test_linear_exact_match():
         has_fp16_weights=False,
         threshold=6.0,
     )
-    linear_custom.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
-        linear.weight.dtype
-    )
+    linear_custom.weight = bnb.nn.Int8Params(
+        linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
+    ).to(linear.weight.dtype)
     linear_custom.bias = linear.bias
     linear_custom.cuda()