quantization.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import importlib.util
  2. import math
  3. import os
  4. import warnings
  5. from abc import ABC, abstractmethod
  6. from concurrent.futures import ThreadPoolExecutor
  7. from typing import Tuple
  8. import numpy as np
  9. import torch
  10. if importlib.util.find_spec("bitsandbytes") is not None:
  11. warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
  12. from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
  13. from hivemind.compression.base import CompressionBase, CompressionInfo
  14. from hivemind.proto import runtime_pb2
  15. EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
  16. class Quantization(CompressionBase, ABC):
  17. codebook_dtype, indices_dtype = np.float32, np.uint8
  18. @abstractmethod
  19. def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
  20. """Convert tensor into a pair of (indices, codebook)"""
  21. ...
  22. def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
  23. quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
  24. return runtime_pb2.Tensor(
  25. compression=self.compression_type,
  26. buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
  27. size=tensor.shape,
  28. dtype=tensor.numpy().dtype.name,
  29. requires_grad=tensor.requires_grad,
  30. )
  31. def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  32. codebook_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
  33. codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
  34. quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
  35. quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
  36. codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
  37. return codebook[quantized]
  38. def estimate_compression_ratio(self, info: CompressionInfo) -> float:
  39. return self.n_bits / torch.finfo(info.descriptor.dtype).bits
  40. @property
  41. def n_bits(self):
  42. return self.indices_dtype(1).itemsize * 8
  43. @property
  44. def n_bins(self):
  45. return 2**self.n_bits
  46. class Uniform8BitQuantization(Quantization):
  47. RANGE_IN_SIGMAS: int = 6
  48. compression_type = runtime_pb2.UNIFORM_8BIT
  49. def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
  50. offset = self.n_bins // 2
  51. shift = tensor.mean()
  52. centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
  53. std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
  54. scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
  55. quantized = torch.quantize_per_tensor(centered_tensor, scale, offset, torch.quint8).int_repr()
  56. lookup = average_buckets(tensor, quantized, self.n_bins)
  57. return np.asarray(quantized, dtype=self.indices_dtype), np.asarray(lookup, dtype=self.codebook_dtype)
  58. class Quantile8BitQuantization(Quantization):
  59. compression_type = runtime_pb2.QUANTILE_8BIT
  60. def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
  61. tensor = tensor.detach().float()
  62. borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), self.n_bins + 1)[1:-1])
  63. quantized = torch.clamp_(torch.bucketize(tensor, borders), 0, self.n_bins - 1)
  64. codebook = average_buckets(tensor, quantized, self.n_bins)
  65. return quantized.numpy().astype(np.uint8), codebook.numpy()
  66. def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
  67. """Return the average value in each bucket"""
  68. bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
  69. bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
  70. lookup = bin_sums / bin_counts
  71. return lookup
  72. def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
  73. """Adjust chunk_size to minimize imbalance between chunk sizes"""
  74. if min_chunk_size >= num_elements:
  75. return min_chunk_size
  76. leftover_elements = num_elements % min_chunk_size
  77. num_chunks = num_elements // min_chunk_size
  78. return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
  79. def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10**5) -> np.ndarray:
  80. """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
  81. if not array.data.c_contiguous and array.data.f_contiguous:
  82. array = array.T
  83. array = np.ascontiguousarray(array.reshape(-1))
  84. quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
  85. chunk_size = get_chunk_size(len(array), min_chunk_size)
  86. num_chunks = (len(array) - 1) // chunk_size + 1
  87. partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
  88. jobs = []
  89. for i in range(num_chunks):
  90. chunk = slice(chunk_size * i, chunk_size * (i + 1))
  91. jobs.append(EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
  92. for job in jobs:
  93. job.result()
  94. return np.quantile(partition_quantiles, quantiles)
  95. BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
  96. Please install it with `pip install bitsandbytes`
  97. or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
  98. class BlockwiseQuantization(Quantization):
  99. compression_type = runtime_pb2.BLOCKWISE_8BIT
  100. codebook_dtype, indices_dtype = np.float32, np.uint8
  101. def quantize(
  102. self, tensor: torch.Tensor, allow_inplace: bool = False
  103. ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
  104. try:
  105. quantized, (absmax, codebook) = quantize_blockwise(tensor)
  106. except NameError:
  107. raise ImportError(BNB_MISSING_MESSAGE)
  108. return quantized.numpy(), (absmax.numpy(), codebook.numpy())
  109. def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
  110. quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
  111. serialized_data = (
  112. np.int64(len(absmax)).tobytes(),
  113. np.int64(len(codebook)).tobytes(),
  114. absmax.tobytes(),
  115. codebook.tobytes(),
  116. quantized.tobytes(),
  117. )
  118. return runtime_pb2.Tensor(
  119. buffer=b"".join(serialized_data),
  120. size=tensor.shape,
  121. requires_grad=tensor.requires_grad,
  122. dtype=tensor.numpy().dtype.name,
  123. compression=self.compression_type,
  124. )
  125. def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  126. absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
  127. codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
  128. absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
  129. codebook = np.frombuffer(
  130. serialized_tensor.buffer, offset=16 + absmax.nbytes, count=codebook_size, dtype=self.codebook_dtype
  131. )
  132. quantized = np.frombuffer(
  133. serialized_tensor.buffer, offset=16 + absmax.nbytes + codebook.nbytes, dtype=self.indices_dtype
  134. )
  135. absmax = torch.as_tensor(absmax)
  136. codebook = torch.as_tensor(codebook)
  137. quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
  138. try:
  139. return dequantize_blockwise(quantized, (absmax, codebook))
  140. except NameError:
  141. raise ImportError(BNB_MISSING_MESSAGE)