|
@@ -13,6 +13,9 @@ NUM_BYTES_FLOAT32 = 4
|
|
|
NUM_BYTES_FLOAT16 = 2
|
|
|
NUM_BITS_QUANTILE_COMPRESSION = 8
|
|
|
NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
|
|
|
+UNIFORM_BUCKETS_STD_RANGE = 6
|
|
|
+FP16_MAX = 65_504
|
|
|
+UINT8_RANGE = 256
|
|
|
|
|
|
warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
|
|
|
|
|
@@ -21,10 +24,15 @@ def _quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Te
|
|
|
n_bins = 2 ** n_bits
|
|
|
borders = torch.as_tensor(_quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
|
|
|
quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
|
|
|
- bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten(), tensor.flatten())
|
|
|
+ lookup = average_buckets(tensor, quant_weight, n_bins)
|
|
|
+ return quant_weight, lookup
|
|
|
+
|
|
|
+
|
|
|
+def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
|
|
|
+ bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
|
|
|
bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
|
|
|
lookup = bin_sums / bin_counts
|
|
|
- return quant_weight, lookup
|
|
|
+ return lookup
|
|
|
|
|
|
|
|
|
def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
|
|
@@ -57,6 +65,16 @@ def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
|
|
|
return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
|
|
|
|
|
|
|
|
|
+def _uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
|
|
|
+ offset = UINT8_RANGE // 2
|
|
|
+ shift = tensor.mean()
|
|
|
+ scale = range_in_sigmas * tensor.std() / UINT8_RANGE
|
|
|
+
|
|
|
+ quant_weight = torch.quantize_per_tensor(tensor - shift, scale, offset, torch.quint8).int_repr()
|
|
|
+ lookup = average_buckets(tensor, quant_weight, UINT8_RANGE)
|
|
|
+ return quant_weight, lookup
|
|
|
+
|
|
|
+
|
|
|
def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE,
|
|
|
allow_inplace=False) -> runtime_pb2.Tensor:
|
|
|
assert tensor.device == torch.device('cpu')
|
|
@@ -101,10 +119,13 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
|
|
|
size=array.shape,
|
|
|
dtype=array.dtype.name,
|
|
|
requires_grad=tensor.requires_grad)
|
|
|
- elif compression_type == CompressionType.QUANTILE_8BIT:
|
|
|
+ elif compression_type in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
|
|
|
assert tensor.dtype == torch.float32
|
|
|
|
|
|
- quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
|
|
|
+ if compression_type == CompressionType.QUANTILE_8BIT:
|
|
|
+ quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
|
|
|
+ elif compression_type == CompressionType.UNIFORM_8BIT:
|
|
|
+ quantized, lookup = _uint8_uniform_buckets_encode(tensor.detach(), UNIFORM_BUCKETS_STD_RANGE)
|
|
|
data = b''.join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
|
|
|
|
|
|
proto = runtime_pb2.Tensor(
|
|
@@ -149,9 +170,13 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
|
|
|
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
|
|
|
tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
|
|
|
|
|
|
- elif serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
|
|
|
- lookup = serialized_tensor.buffer[:NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32]
|
|
|
- quantized = serialized_tensor.buffer[NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32:]
|
|
|
+ elif serialized_tensor.compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
|
|
|
+ if serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
|
|
|
+ lookup_size = NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32
|
|
|
+ else:
|
|
|
+ lookup_size = UINT8_RANGE * NUM_BYTES_FLOAT32
|
|
|
+ lookup = serialized_tensor.buffer[:lookup_size]
|
|
|
+ quantized = serialized_tensor.buffer[lookup_size:]
|
|
|
lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32))
|
|
|
quantized = np.frombuffer(quantized, dtype=np.uint8)
|
|
|
quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
|