compression.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import os
  2. import warnings
  3. from concurrent.futures import ThreadPoolExecutor
  4. from typing import Optional, Sequence, Tuple
  5. import numpy as np
  6. import torch
  7. from hivemind.proto import runtime_pb2
  8. from hivemind.proto.runtime_pb2 import CompressionType
  9. FP32_EPS = 1e-06
  10. NUM_BYTES_FLOAT32 = 4
  11. NUM_BYTES_FLOAT16 = 2
  12. NUM_BITS_QUANTILE_COMPRESSION = 8
  13. NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
  14. UNIFORM_BUCKETS_STD_RANGE = 6
  15. FP16_MAX = 65_504
  16. UINT8_RANGE = 256
  17. COMPRESSION_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTILE_COMPRESSION_THREADS", 128)))
  18. warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
  19. def _quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
  20. n_bins = 2 ** n_bits
  21. borders = torch.as_tensor(_quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
  22. quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
  23. lookup = average_buckets(tensor, quant_weight, n_bins)
  24. return quant_weight, lookup
  25. def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
  26. bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
  27. bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
  28. lookup = bin_sums / bin_counts
  29. return lookup
  30. def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
  31. """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
  32. if not array.data.c_contiguous and array.data.f_contiguous:
  33. array = array.T
  34. array = np.ascontiguousarray(array.reshape(-1))
  35. quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
  36. chunk_size = _get_chunk_size(len(array), min_chunk_size)
  37. num_chunks = (len(array) - 1) // chunk_size + 1
  38. partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
  39. jobs = []
  40. for i in range(num_chunks):
  41. chunk = slice(chunk_size * i, chunk_size * (i + 1))
  42. jobs.append(COMPRESSION_EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
  43. for job in jobs:
  44. job.result()
  45. return np.quantile(partition_quantiles, quantiles)
  46. def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
  47. """Adjust chunk_size to minimize imbalance between chunk sizes"""
  48. if min_chunk_size >= num_elements:
  49. return min_chunk_size
  50. leftover_elements = num_elements % min_chunk_size
  51. num_chunks = num_elements // min_chunk_size
  52. return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
  53. def _uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
  54. offset = UINT8_RANGE // 2
  55. shift = tensor.mean()
  56. scale = range_in_sigmas * tensor.std() / UINT8_RANGE
  57. quant_weight = torch.quantize_per_tensor(tensor - shift, scale, offset, torch.quint8).int_repr()
  58. lookup = average_buckets(tensor, quant_weight, UINT8_RANGE)
  59. return quant_weight, lookup
  60. def serialize_torch_tensor(
  61. tensor: torch.Tensor, compression_type=CompressionType.NONE, allow_inplace=False
  62. ) -> runtime_pb2.Tensor:
  63. assert tensor.device == torch.device("cpu")
  64. if compression_type == CompressionType.MEANSTD_16BIT:
  65. assert tensor.dtype == torch.float32
  66. tensor = tensor if allow_inplace else tensor.clone()
  67. means = torch.mean(tensor, dim=-1, keepdim=True)
  68. tensor.sub_(means)
  69. stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
  70. stds.clamp_min_(FP32_EPS)
  71. tensor.div_(stds)
  72. tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
  73. data = b"".join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
  74. proto = runtime_pb2.Tensor(
  75. compression=compression_type,
  76. buffer=data,
  77. size=tensor.shape,
  78. dtype="compressed_float32",
  79. requires_grad=tensor.requires_grad,
  80. )
  81. elif compression_type == CompressionType.FLOAT16:
  82. assert tensor.dtype == torch.float32
  83. tensor = tensor if allow_inplace else tensor.clone()
  84. tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
  85. data = tensor.numpy().tobytes()
  86. proto = runtime_pb2.Tensor(
  87. compression=compression_type,
  88. buffer=data,
  89. size=tensor.shape,
  90. dtype="clamped_float32",
  91. requires_grad=tensor.requires_grad,
  92. )
  93. elif compression_type == CompressionType.NONE:
  94. array = tensor.numpy()
  95. proto = runtime_pb2.Tensor(
  96. compression=compression_type,
  97. buffer=array.tobytes(),
  98. size=array.shape,
  99. dtype=array.dtype.name,
  100. requires_grad=tensor.requires_grad,
  101. )
  102. elif compression_type in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
  103. assert tensor.dtype == torch.float32
  104. if compression_type == CompressionType.QUANTILE_8BIT:
  105. quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
  106. elif compression_type == CompressionType.UNIFORM_8BIT:
  107. quantized, lookup = _uint8_uniform_buckets_encode(tensor.detach(), UNIFORM_BUCKETS_STD_RANGE)
  108. data = b"".join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
  109. proto = runtime_pb2.Tensor(
  110. compression=compression_type,
  111. buffer=data,
  112. size=tensor.shape,
  113. dtype="compressed_float32",
  114. requires_grad=tensor.requires_grad,
  115. )
  116. else:
  117. raise ValueError(f"Unknown compression type: {compression_type}")
  118. return proto
  119. def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype] = None):
  120. """Helper conversion function that handles edge case with scalar deserialization"""
  121. if size:
  122. return torch.as_tensor(array, dtype=dtype).view(*size)
  123. else:
  124. return torch.as_tensor(array, dtype=dtype)
  125. def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  126. if serialized_tensor.compression == CompressionType.NONE:
  127. array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
  128. tensor = construct_torch_tensor(array, serialized_tensor.size)
  129. elif serialized_tensor.compression == CompressionType.MEANSTD_16BIT:
  130. stats_size = list(serialized_tensor.size)
  131. stats_size[-1] = 1
  132. stats_count = np.prod(stats_size)
  133. means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count : -NUM_BYTES_FLOAT32 * stats_count]
  134. stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count :]
  135. means = construct_torch_tensor(np.frombuffer(means, dtype=np.float32), stats_size)
  136. stds = construct_torch_tensor(np.frombuffer(stds, dtype=np.float32), stats_size)
  137. array = np.frombuffer(serialized_tensor.buffer[: -8 * stats_count], dtype=np.float16)
  138. tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
  139. elif serialized_tensor.compression == CompressionType.FLOAT16:
  140. array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
  141. tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
  142. elif serialized_tensor.compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
  143. if serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
  144. lookup_size = NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32
  145. else:
  146. lookup_size = UINT8_RANGE * NUM_BYTES_FLOAT32
  147. lookup = serialized_tensor.buffer[:lookup_size]
  148. quantized = serialized_tensor.buffer[lookup_size:]
  149. lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32))
  150. quantized = np.frombuffer(quantized, dtype=np.uint8)
  151. quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
  152. tensor = lookup[quantized]
  153. else:
  154. raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
  155. tensor.requires_grad_(serialized_tensor.requires_grad)
  156. return tensor
  157. def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
  158. """returns the number of bytes per value for a given tensor (excluding metadata)"""
  159. if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
  160. return 1
  161. elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):
  162. return 2
  163. elif compression == CompressionType.NONE:
  164. return torch.finfo(dtype).bits // 8
  165. else:
  166. raise NotImplementedError(f"Unknown compression type: {CompressionType.Name(compression)}")