justheuristic пре 2 година
родитељ
комит
b93739df9c
1 измењених фајлова са 3 додато и 3 уклоњено
  1. 3 3
      tests/test_linear8bitlt.py

+ 3 - 3
tests/test_linear8bitlt.py

@@ -52,9 +52,9 @@ def test_linear_exact_match():
         has_fp16_weights=False,
         threshold=6.0,
     )
-    linear_custom.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
-        linear.weight.dtype
-    )
+    linear_custom.weight = bnb.nn.Int8Params(
+        linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
+    ).to(linear.weight.dtype)
     linear_custom.bias = linear.bias
     linear_custom.cuda()