5
0

test_linear8bitlt.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import bitsandbytes as bnb
  2. import pytest
  3. import torch
  4. from bitsandbytes import functional as F
  5. from petals.utils.linear8bitlt_patch import CustomLinear8bitLt, get_inverse_transform_indices, undo_layout
  6. @pytest.mark.skipif(
  7. not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
  8. reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
  9. )
  10. def test_layout_exact_match():
  11. x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
  12. for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
  13. transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
  14. tile_indices = get_inverse_transform_indices(transform, tile_size)
  15. cxb = transform(x)
  16. torch.cuda.synchronize()
  17. restored_x = undo_layout(cxb, tile_indices)
  18. torch.cuda.synchronize()
  19. assert restored_x.is_contiguous()
  20. assert torch.all(torch.eq(restored_x, x))
  21. @pytest.mark.skipif(
  22. not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
  23. reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
  24. )
  25. def test_linear_exact_match():
  26. linear = torch.nn.Linear(1024, 3072)
  27. x = torch.randn(3, 1024, dtype=torch.half)
  28. linear8bitlt = bnb.nn.Linear8bitLt(
  29. linear.in_features,
  30. linear.out_features,
  31. linear.bias is not None,
  32. has_fp16_weights=False,
  33. threshold=6.0,
  34. memory_efficient_backward=True,
  35. )
  36. linear8bitlt.weight = bnb.nn.Int8Params(linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False).to(
  37. linear.weight.dtype
  38. )
  39. linear8bitlt.bias = linear.bias
  40. linear8bitlt.cuda()
  41. linear_custom = CustomLinear8bitLt(
  42. linear.in_features,
  43. linear.out_features,
  44. linear.bias is not None,
  45. has_fp16_weights=False,
  46. threshold=6.0,
  47. )
  48. linear_custom.weight = bnb.nn.Int8Params(
  49. linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
  50. ).to(linear.weight.dtype)
  51. linear_custom.bias = linear.bias
  52. linear_custom.cuda()
  53. x_ref = x.clone().cuda().requires_grad_(True)
  54. x_ours = x.clone().cuda().requires_grad_(True)
  55. fx_ref = linear8bitlt(x_ref).float()
  56. grad_proj = torch.randn_like(fx_ref)
  57. (fx_ref * grad_proj).mean().backward()
  58. fx_ours = linear_custom(x_ours).float()
  59. (fx_ours * grad_proj).mean().backward()
  60. assert torch.equal(fx_ref, fx_ours)
  61. assert torch.allclose(x_ref.grad, x_ours.grad)
  62. assert not linear_custom.state.has_fp16_weights
  63. assert linear_custom.state.CB is None
  64. assert linear_custom.state.CxB is not None
  65. @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
  66. def test_linear_no_igemmlt():
  67. linear = torch.nn.Linear(1024, 3072)
  68. x = torch.randn(3, 1024, dtype=torch.half)
  69. linear_custom = CustomLinear8bitLt(
  70. linear.in_features,
  71. linear.out_features,
  72. linear.bias is not None,
  73. has_fp16_weights=False,
  74. threshold=6.0,
  75. )
  76. linear_custom.state.force_no_igemmlt = True
  77. linear_custom.weight = bnb.nn.Int8Params(
  78. linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
  79. ).to(linear.weight.dtype)
  80. linear_custom.bias = linear.bias
  81. linear_custom.cuda()
  82. linear.half().cuda()
  83. x_ref = x.clone().cuda().requires_grad_(True)
  84. x_ours = x.clone().cuda().requires_grad_(True)
  85. fx_ref = linear(x_ref).float()
  86. grad_proj = torch.randn_like(fx_ref)
  87. (fx_ref * grad_proj).mean().backward()
  88. fx_ours = linear_custom(x_ours).float()
  89. (fx_ours * grad_proj).mean().backward()
  90. assert torch.allclose(fx_ref, fx_ours, atol=0.02)
  91. assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
  92. assert not linear_custom.state.has_fp16_weights
  93. assert linear_custom.state.CB is not None
  94. assert linear_custom.state.CxB is None