|
@@ -120,8 +120,8 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
|
|
return np.quantile(partition_quantiles, quantiles)
|
|
return np.quantile(partition_quantiles, quantiles)
|
|
|
|
|
|
|
|
|
|
-BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
|
|
|
|
-Please install it with `pip install bitsandbytes`
|
|
|
|
|
|
+BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
|
|
|
|
+Please install it with `pip install bitsandbytes`
|
|
or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
|
|
or using the instruction from https://github.com/TimDettmers/bitsandbytes."""
|
|
|
|
|
|
|
|
|
|
@@ -139,7 +139,12 @@ class BlockwiseQuantization(Quantization):
|
|
return quantized.numpy(), (absmax.numpy(), codebook.numpy())
|
|
return quantized.numpy(), (absmax.numpy(), codebook.numpy())
|
|
|
|
|
|
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
|
|
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
|
|
- quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
|
|
|
|
|
|
+ tensor = tensor.detach()
|
|
|
|
+ dtype_name = str(tensor.dtype).lstrip("torch.")
|
|
|
|
+ if tensor.dtype == torch.bfloat16:
|
|
|
|
+ tensor = tensor.to(torch.float32)
|
|
|
|
+
|
|
|
|
+ quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace)
|
|
|
|
|
|
serialized_data = (
|
|
serialized_data = (
|
|
np.int64(len(absmax)).tobytes(),
|
|
np.int64(len(absmax)).tobytes(),
|
|
@@ -153,7 +158,7 @@ class BlockwiseQuantization(Quantization):
|
|
buffer=b"".join(serialized_data),
|
|
buffer=b"".join(serialized_data),
|
|
size=tensor.shape,
|
|
size=tensor.shape,
|
|
requires_grad=tensor.requires_grad,
|
|
requires_grad=tensor.requires_grad,
|
|
- dtype=tensor.numpy().dtype.name,
|
|
|
|
|
|
+ dtype=dtype_name,
|
|
compression=self.compression_type,
|
|
compression=self.compression_type,
|
|
)
|
|
)
|
|
|
|
|
|
@@ -172,6 +177,8 @@ class BlockwiseQuantization(Quantization):
|
|
codebook = torch.as_tensor(codebook)
|
|
codebook = torch.as_tensor(codebook)
|
|
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
|
|
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
|
|
try:
|
|
try:
|
|
- return dequantize_blockwise(quantized, (absmax, codebook))
|
|
|
|
|
|
+ result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor
|
|
except NameError:
|
|
except NameError:
|
|
raise ImportError(BNB_MISSING_MESSAGE)
|
|
raise ImportError(BNB_MISSING_MESSAGE)
|
|
|
|
+ result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
|
|
|
|
+ return result
|