소스 검색

modify linear8bitlt to support pre-turing architectures

justheuristic 2 년 전
부모
커밋
8ddde143b5
2개의 변경된 파일198개의 추가작업 그리고 3개의 파일을 삭제
  1. 162 3
      src/petals/utils/linear8bitlt_patch.py
  2. 36 0
      tests/test_linear8bitlt.py

+ 162 - 3
src/petals/utils/linear8bitlt_patch.py

@@ -9,11 +9,12 @@ Based on: https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu#
 Exact match tests: see $REPO/tests/test_linear8bitlt.py
 """
 import dataclasses
+import warnings
 from typing import Optional, Tuple
 
 import bitsandbytes.functional as F
 import torch
-from bitsandbytes.autograd._functions import MatMul8bitLt, MatmulLtState
+from bitsandbytes.autograd._functions import MatMul8bitLt, MatmulLtState, GlobalOutlierPooler, prod
 from bitsandbytes.nn import Linear8bitLt
 
 
@@ -88,7 +89,7 @@ class CustomLinear8bitLt(Linear8bitLt):
 
         out = custom_matmul8bitlt(x, self.weight, bias=self.bias, state=self.state)
         if not self.state.has_fp16_weights:
-            if self.state.CB is not None:
+            if self.state.CB is not None and self.state.CxB is not None:
                 # we converted 8-bit row major to turing/ampere format in the first inference pass
                 # we no longer need the row-major weight
                 del self.state.CB
@@ -99,6 +100,7 @@ class CustomLinear8bitLt(Linear8bitLt):
 @dataclasses.dataclass(init=True)
 class CustomMatmulLtState(MatmulLtState):
     tile_indices: Optional[torch.Tensor] = None
+    force_no_igemmlt: bool = False
 
     def get_tile_size(self):
         assert self.formatB in (
@@ -123,9 +125,166 @@ def custom_matmul8bitlt(
 
 
 class CustomMatMul8bitLt(MatMul8bitLt):
-    # forward is the same as in inference-only CxB
+    # forward is the same, but we added the fallback for pre-turing GPUs
     # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
 
+    @staticmethod
+    def forward(ctx, A, B, out=None, bias=None, state=CustomMatmulLtState):
+        using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
+        # default to pytorch behavior if inputs are empty
+        ctx.is_empty = False
+        if prod(A.shape) == 0:
+            ctx.is_empty = True
+            ctx.A = A
+            ctx.B = B
+            ctx.bias = bias
+            if A.shape[-1] == B.shape[0]:
+                return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
+            else:
+                return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
+
+        # 1. Quantize A
+        # 2. Quantize B
+        # 3. Matmul
+        # 4. Mixed-precision decomposition matmul
+        # 5. Save state
+        formatB = state.formatB
+        input_shape = A.shape
+        if state.outlier_pool is None:
+            state.outlier_pool = GlobalOutlierPooler.get_instance()
+
+        # Cast A to fp16
+        if A.dtype != torch.float16:
+            warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
+
+        # 1. Quantize A
+        if len(A.shape) == 3:
+            A = A.view(-1, A.shape[-1]).contiguous()
+        CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
+            A.to(torch.float16), threshold=state.threshold
+        )
+
+        if state.threshold > 0.0 and coo_tensorA is not None:
+            if state.has_fp16_weights:
+                idx = torch.unique(coo_tensorA.colidx).long()
+                CA[:, idx] = 0
+                CAt[:, idx] = 0
+                subA = A[:, idx]
+                state.subB = B[:, idx].t().contiguous()
+                state.idx = idx
+            else:
+                if state.CxB is None and using_igemmlt:
+                    # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
+                    # we also need to convert it to the turing/ampere format
+                    state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+        else:
+            if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
+                state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+            subA = None
+
+        # 2. Quantize B
+        if state.has_fp16_weights:
+            has_grad = True if (getattr(B, "grad", None) is not None) else False
+            is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
+            if is_transposed:
+                B = B.contiguous()
+
+            if (state.is_training and not has_grad) or state.CxB is None:
+                state.reset_grads()
+                (
+                    CB,
+                    state.CBt,
+                    state.SCB,
+                    state.SCBt,
+                    coo_tensorB,
+                ) = F.double_quant(B.to(torch.float16))
+                if using_igemmlt:
+                    state.CxB, state.SB = F.transform(CB, to_order=formatB)
+                else:
+                    state.CB = CB
+        else:
+            has_grad = False
+
+        if coo_tensorA is not None and not state.has_fp16_weights:
+            # extract outliers
+
+            outlier_idx = torch.unique(coo_tensorA.colidx)
+            state.idx = outlier_idx
+            # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+            # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+            #    # do not use pool for 2nd FFN layer
+            #    state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
+            # else:
+            #    state.idx = outlier_idx
+            if state.CxB is not None:
+                outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
+            else:
+                outliers = state.CB[:, state.idx.long()].clone()
+
+            state.subB = (
+                (outliers * state.SCB.view(-1, 1) / 127.0)
+                .t()
+                .contiguous()
+                .to(A.dtype)
+            )
+            CA[:, state.idx.long()] = 0
+            CAt[:, state.idx.long()] = 0
+            subA = A[:, state.idx.long()]
+
+        shapeB = state.SB[0] if state.SB else B.shape
+
+        if len(input_shape) == 3:
+            output_shape = (input_shape[0], input_shape[1], shapeB[0])
+        else:
+            output_shape = (input_shape[0], shapeB[0])
+
+        # 3. Matmul
+        if using_igemmlt:
+            C32A, SA = F.transform(CA, "col32")
+            out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
+            if bias is None or bias.dtype == torch.float16:
+                output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
+                output = output.to(A.dtype)
+            else:  # apply bias separately
+                output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
+                output = output.to(A.dtype).add_(bias)
+
+        else:
+            A_wo_outliers = A.clone()
+            if state.idx is not None:
+                A_wo_outliers[:, state.idx.long()] = 0
+            output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
+            output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
+            if bias is not None:
+                output = output.add_(bias)
+
+        # we apply the fused bias here
+
+
+        # 4. Mixed-precision decomposition matmul
+        if coo_tensorA is not None and subA is not None:
+            output += torch.matmul(subA, state.subB)
+
+        # 5. Save state
+        ctx.state = state
+
+        ctx.formatB = formatB
+        ctx.grad_shape = input_shape
+        ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
+
+        if any(ctx.needs_input_grad[:2]):
+            ctx.tensors = (CAt, subA)
+            ctx.tensor_states = (SCAt, state.idx)
+        else:
+            ctx.tensors = [None, None]
+            ctx.tensor_states = (None, None)
+            ctx.save_for_backward(None, None)
+
+
+        clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
+        return clone_func(output.view(output_shape))
+
+
     @staticmethod
     def backward(ctx, grad_output):
         if ctx.is_empty:

+ 36 - 0
tests/test_linear8bitlt.py

@@ -71,3 +71,39 @@ def test_linear_exact_match():
     assert not linear_custom.state.has_fp16_weights
     assert linear_custom.state.CB is None
     assert linear_custom.state.CxB is not None
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
+def test_linear_no_igemmlt():
+    linear = torch.nn.Linear(1024, 3072)
+    x = torch.randn(3, 1024, dtype=torch.half)
+    linear_custom = CustomLinear8bitLt(
+        linear.in_features,
+        linear.out_features,
+        linear.bias is not None,
+        has_fp16_weights=False,
+        threshold=6.0,
+    )
+    linear_custom.state.force_no_igemmlt = True
+
+    linear_custom.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
+        linear.weight.dtype
+    )
+    linear_custom.bias = linear.bias
+    linear_custom.cuda()
+    linear.half().cuda()
+
+    x_ref = x.clone().cuda().requires_grad_(True)
+    x_ours = x.clone().cuda().requires_grad_(True)
+    fx_ref = linear(x_ref).float()
+    grad_proj = torch.randn_like(fx_ref)
+    (fx_ref * grad_proj).mean().backward()
+
+    fx_ours = linear_custom(x_ours).float()
+    (fx_ours * grad_proj).mean().backward()
+    assert torch.allclose(fx_ref, fx_ours, atol=0.02)
+    assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
+    assert not linear_custom.state.has_fp16_weights
+    assert linear_custom.state.CB is not None
+    assert linear_custom.state.CxB is None
+