justheuristic 2 năm trước cách đây
mục cha
commit
b93739df9c
1 tập tin đã thay đổi với 3 bổ sung3 xóa
  1. 3 3
      tests/test_linear8bitlt.py

+ 3 - 3
tests/test_linear8bitlt.py

@@ -52,9 +52,9 @@ def test_linear_exact_match():
         has_fp16_weights=False,
         threshold=6.0,
     )
-    linear_custom.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
-        linear.weight.dtype
-    )
+    linear_custom.weight = bnb.nn.Int8Params(
+        linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
+    ).to(linear.weight.dtype)
     linear_custom.bias = linear.bias
     linear_custom.cuda()