Explorar o código

update tests to actually use new code

justheuristic %!s(int64=2) %!d(string=hai) anos
pai
achega
14eb6bba08
Modificáronse 2 ficheiros con 15 adicións e 5 borrados
  1. 6 1
      src/petals/utils/linear8bitlt_patch.py
  2. 9 4
      tests/test_linear8bitlt.py

+ 6 - 1
src/petals/utils/linear8bitlt_patch.py

@@ -70,7 +70,12 @@ class CustomLinear8bitLt(Linear8bitLt):
     def __init__(self, *args, memory_efficient_backward: bool = False, **kwargs):
     def __init__(self, *args, memory_efficient_backward: bool = False, **kwargs):
         assert not memory_efficient_backward, "memory_efficient_backward is no longer used"
         assert not memory_efficient_backward, "memory_efficient_backward is no longer used"
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
-        self.state = CustomMatmulLtState(**dataclasses.asdict(self.state))
+        old_state, self.state = self.state, CustomMatmulLtState()
+        self.state.threshold = old_state.threshold
+        self.state.has_fp16_weights = old_state.has_fp16_weights
+        self.state.memory_efficient_backward = old_state.memory_efficient_backward
+        if old_state.threshold > 0.0 and not old_state.has_fp16_weights:
+            self.state.use_pool = True
 
 
     def forward(self, x: torch.Tensor):
     def forward(self, x: torch.Tensor):
         self.state.is_training = self.training
         self.state.is_training = self.training

+ 9 - 4
tests/test_linear8bitlt.py

@@ -39,9 +39,10 @@ def test_linear_exact_match():
         threshold=6.0,
         threshold=6.0,
         memory_efficient_backward=True,
         memory_efficient_backward=True,
     )
     )
-    linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to(
+    linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
         linear.weight.dtype
         linear.weight.dtype
     )
     )
+    linear8bitlt.bias = linear.bias
     linear8bitlt.cuda()
     linear8bitlt.cuda()
 
 
     linear_custom = CustomLinear8bitLt(
     linear_custom = CustomLinear8bitLt(
@@ -51,10 +52,11 @@ def test_linear_exact_match():
         has_fp16_weights=False,
         has_fp16_weights=False,
         threshold=6.0,
         threshold=6.0,
     )
     )
-    linear_custom.weight = bnb.nn.Int8Params(linear.weight.data, requires_grad=False, has_fp16_weights=False).to(
+    linear_custom.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
         linear.weight.dtype
         linear.weight.dtype
     )
     )
-    linear8bitlt.cuda()
+    linear_custom.bias = linear.bias
+    linear_custom.cuda()
 
 
     x_ref = x.clone().cuda().requires_grad_(True)
     x_ref = x.clone().cuda().requires_grad_(True)
     x_ours = x.clone().cuda().requires_grad_(True)
     x_ours = x.clone().cuda().requires_grad_(True)
@@ -62,7 +64,10 @@ def test_linear_exact_match():
     grad_proj = torch.randn_like(fx_ref)
     grad_proj = torch.randn_like(fx_ref)
     (fx_ref * grad_proj).mean().backward()
     (fx_ref * grad_proj).mean().backward()
 
 
-    fx_ours = linear8bitlt(x_ours).float()
+    fx_ours = linear_custom(x_ours).float()
     (fx_ours * grad_proj).mean().backward()
     (fx_ours * grad_proj).mean().backward()
     assert torch.equal(fx_ref, fx_ours)
     assert torch.equal(fx_ref, fx_ours)
     assert torch.allclose(x_ref.grad, x_ours.grad)
     assert torch.allclose(x_ref.grad, x_ours.grad)
+    assert not linear_custom.state.has_fp16_weights
+    assert linear_custom.state.CB is None
+    assert linear_custom.state.CxB is not None