Browse Source

Fix Linear8bitlt state config, update tests (#112)

* fix state initializer
* update tests to actually use new code
* keep bias during quantization
justheuristic 2 years ago
parent
commit
01838f9a99

+ 1 - 2
src/petals/utils/convert_8bit.py

@@ -1,5 +1,3 @@
-import os
-
 import bitsandbytes as bnb
 import torch
 
@@ -37,4 +35,5 @@ def replace_8bit_linear(model, threshold=6.0):
             model._modules[n].weight = bnb.nn.Int8Params(
                 module.weight.data, requires_grad=False, has_fp16_weights=False
             ).to(module.weight.dtype)
+            model._modules[n].bias = module.bias
     return model

+ 6 - 1
src/petals/utils/linear8bitlt_patch.py

@@ -70,7 +70,12 @@ class CustomLinear8bitLt(Linear8bitLt):
     def __init__(self, *args, memory_efficient_backward: bool = False, **kwargs):
         assert not memory_efficient_backward, "memory_efficient_backward is no longer used"
         super().__init__(*args, **kwargs)
-        self.state = CustomMatmulLtState(**dataclasses.asdict(self.state))
+        old_state, self.state = self.state, CustomMatmulLtState()
+        self.state.threshold = old_state.threshold
+        self.state.has_fp16_weights = old_state.has_fp16_weights
+        self.state.memory_efficient_backward = old_state.memory_efficient_backward
+        if old_state.threshold > 0.0 and not old_state.has_fp16_weights:
+            self.state.use_pool = True
 
     def forward(self, x: torch.Tensor):
         self.state.is_training = self.training

+ 11 - 6
tests/test_linear8bitlt.py

@@ -39,9 +39,10 @@ def test_linear_exact_match():
         threshold=6.0,
         memory_efficient_backward=True,
     )
-    linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to(
+    linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
         linear.weight.dtype
     )
+    linear8bitlt.bias = linear.bias
     linear8bitlt.cuda()
 
     linear_custom = CustomLinear8bitLt(
@@ -51,10 +52,11 @@ def test_linear_exact_match():
         has_fp16_weights=False,
         threshold=6.0,
     )
-    linear_custom.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to(
-        linear.weight.dtype
-    )
-    linear8bitlt.cuda()
+    linear_custom.weight = bnb.nn.Int8Params(
+        linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
+    ).to(linear.weight.dtype)
+    linear_custom.bias = linear.bias
+    linear_custom.cuda()
 
     x_ref = x.clone().cuda().requires_grad_(True)
     x_ours = x.clone().cuda().requires_grad_(True)
@@ -62,7 +64,10 @@ def test_linear_exact_match():
     grad_proj = torch.randn_like(fx_ref)
     (fx_ref * grad_proj).mean().backward()
 
-    fx_ours = linear8bitlt(x_ours).float()
+    fx_ours = linear_custom(x_ours).float()
     (fx_ours * grad_proj).mean().backward()
     assert torch.equal(fx_ref, fx_ours)
     assert torch.allclose(x_ref.grad, x_ours.grad)
+    assert not linear_custom.state.has_fp16_weights
+    assert linear_custom.state.CB is None
+    assert linear_custom.state.CxB is not None