quantization.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import math
  2. import os
  3. import warnings
  4. from abc import ABC, abstractmethod
  5. from concurrent.futures import ThreadPoolExecutor
  6. from typing import Tuple
  7. import numpy as np
  8. import torch
  9. from hivemind.compression.base import CompressionBase, CompressionInfo
  10. from hivemind.proto import runtime_pb2
  11. warnings.filterwarnings("ignore", module="bitsandbytes", category=UserWarning)
  12. EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
  13. class Quantization(CompressionBase, ABC):
  14. codebook_dtype, indices_dtype = np.float32, np.uint8
  15. @abstractmethod
  16. def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
  17. """Convert tensor into a pair of (indices, codebook)"""
  18. ...
  19. def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
  20. quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
  21. return runtime_pb2.Tensor(
  22. compression=self.compression_type,
  23. buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
  24. size=tensor.shape,
  25. dtype=tensor.numpy().dtype.name,
  26. requires_grad=tensor.requires_grad,
  27. )
  28. def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  29. codebook_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
  30. codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
  31. quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
  32. quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
  33. codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
  34. return codebook[quantized]
  35. def estimate_compression_ratio(self, info: CompressionInfo) -> float:
  36. return self.n_bits / torch.finfo(info.descriptor.dtype).bits
  37. @property
  38. def n_bits(self):
  39. return self.indices_dtype(1).itemsize * 8
  40. @property
  41. def n_bins(self):
  42. return 2**self.n_bits
  43. class Uniform8BitQuantization(Quantization):
  44. RANGE_IN_SIGMAS: int = 6
  45. compression_type = runtime_pb2.UNIFORM_8BIT
  46. def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
  47. offset = self.n_bins // 2
  48. shift = tensor.mean()
  49. centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
  50. std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
  51. scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
  52. quantized = torch.quantize_per_tensor(centered_tensor, scale, offset, torch.quint8).int_repr()
  53. lookup = average_buckets(tensor, quantized, self.n_bins)
  54. return np.asarray(quantized, dtype=self.indices_dtype), np.asarray(lookup, dtype=self.codebook_dtype)
  55. class Quantile8BitQuantization(Quantization):
  56. compression_type = runtime_pb2.QUANTILE_8BIT
  57. def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
  58. tensor = tensor.detach().float()
  59. borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), self.n_bins + 1)[1:-1])
  60. quantized = torch.clamp_(torch.bucketize(tensor, borders), 0, self.n_bins - 1)
  61. codebook = average_buckets(tensor, quantized, self.n_bins)
  62. return quantized.numpy().astype(np.uint8), codebook.numpy()
  63. def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
  64. """Return the average value in each bucket"""
  65. bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
  66. bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
  67. lookup = bin_sums / bin_counts
  68. return lookup
  69. def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
  70. """Adjust chunk_size to minimize imbalance between chunk sizes"""
  71. if min_chunk_size >= num_elements:
  72. return min_chunk_size
  73. leftover_elements = num_elements % min_chunk_size
  74. num_chunks = num_elements // min_chunk_size
  75. return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
  76. def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10**5) -> np.ndarray:
  77. """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
  78. if not array.data.c_contiguous and array.data.f_contiguous:
  79. array = array.T
  80. array = np.ascontiguousarray(array.reshape(-1))
  81. quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
  82. chunk_size = get_chunk_size(len(array), min_chunk_size)
  83. num_chunks = (len(array) - 1) // chunk_size + 1
  84. partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
  85. jobs = []
  86. for i in range(num_chunks):
  87. chunk = slice(chunk_size * i, chunk_size * (i + 1))
  88. jobs.append(EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
  89. for job in jobs:
  90. job.result()
  91. return np.quantile(partition_quantiles, quantiles)
  92. BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
  93. Please install it with `pip install bitsandbytes`
  94. or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
  95. class BlockwiseQuantization(Quantization):
  96. compression_type = runtime_pb2.BLOCKWISE_8BIT
  97. codebook_dtype, indices_dtype = np.float32, np.uint8
  98. def quantize(
  99. self, tensor: torch.Tensor, allow_inplace: bool = False
  100. ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
  101. try:
  102. # This runs actual import only on the 1st call, copies references after that
  103. from bitsandbytes.functional import quantize_blockwise
  104. except ImportError:
  105. raise ImportError(BNB_MISSING_MESSAGE)
  106. quantized, (absmax, codebook) = quantize_blockwise(tensor)
  107. return quantized.numpy(), (absmax.numpy(), codebook.numpy())
  108. def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
  109. tensor = tensor.detach()
  110. dtype_name = str(tensor.dtype).lstrip("torch.")
  111. if tensor.dtype == torch.bfloat16:
  112. tensor = tensor.to(torch.float32)
  113. quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace)
  114. serialized_data = (
  115. np.int64(len(absmax)).tobytes(),
  116. np.int64(len(codebook)).tobytes(),
  117. absmax.tobytes(),
  118. codebook.tobytes(),
  119. quantized.tobytes(),
  120. )
  121. return runtime_pb2.Tensor(
  122. buffer=b"".join(serialized_data),
  123. size=tensor.shape,
  124. requires_grad=tensor.requires_grad,
  125. dtype=dtype_name,
  126. compression=self.compression_type,
  127. )
  128. def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  129. try:
  130. from bitsandbytes.functional import dequantize_blockwise
  131. except ImportError:
  132. raise ImportError(BNB_MISSING_MESSAGE)
  133. absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
  134. codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
  135. absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
  136. codebook = np.frombuffer(
  137. serialized_tensor.buffer, offset=16 + absmax.nbytes, count=codebook_size, dtype=self.codebook_dtype
  138. )
  139. quantized = np.frombuffer(
  140. serialized_tensor.buffer, offset=16 + absmax.nbytes + codebook.nbytes, dtype=self.indices_dtype
  141. )
  142. absmax = torch.as_tensor(absmax)
  143. codebook = torch.as_tensor(codebook)
  144. quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
  145. result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor
  146. result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
  147. return result