dbaranchuk 3 năm trước cách đây
mục cha
commit
f7555a3e3d
1 tập tin đã thay đổi với 1 bổ sung3 xóa
  1. 1 3
      src/utils/convert_8bit.py

+ 1 - 3
src/utils/convert_8bit.py

@@ -31,8 +31,6 @@ def replace_8bit_linear(model, threshold=6.0):
                 threshold=threshold,
             )
             model._modules[n].weight = bnb.nn.Int8Params(
-                module.weight.data, 
-                requires_grad=False,
-                has_fp16_weights=False
+                module.weight.data, requires_grad=False, has_fp16_weights=False
             ).to(module.weight.dtype)
     return model