Просмотр исходного кода

Add quantile compression (#182)

* Implemented quantile compression

* Implemented test for quantile compression

* Named most of magic constants

* Renamed test_vector_compression to test_tensor_compression (because functions are serialize_tensor and deserialize_tensor)

* Implemented benchmark for different compression types

Co-authored-by: justheuristic <justheuristic@gmail.com>
Vsevolod-pl 4 лет назад
Родитель
Сommit
b9c02ac191

+ 1 - 0
hivemind/proto/runtime.proto

@@ -30,6 +30,7 @@ enum CompressionType{
   NONE = 0;
   MEANSTD_LAST_AXIS_FLOAT16 = 1;
   FLOAT16 = 2;
+  QUANTILE_8BIT = 3;
 }
 
 message Tensor {

+ 70 - 3
hivemind/utils/grpc.py

@@ -1,6 +1,7 @@
 """
 Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
 """
+
 from __future__ import annotations
 
 import os
@@ -10,11 +11,12 @@ from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type,
 import grpc
 import numpy as np
 import torch
-from hivemind.proto.runtime_pb2 import CompressionType
 
+from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.networking import Endpoint
+from hivemind.utils.threading import run_in_background
 from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
 
 logger = get_logger(__name__)
@@ -30,6 +32,11 @@ GRPC_KEEPALIVE_OPTIONS = (
     ('grpc.http2.min_ping_interval_without_data_ms', 10 * 1000),
 )
 
+NUM_BYTES_FLOAT32 = 4
+NUM_BYTES_FLOAT16 = 2
+NUM_BITS_QUANTILE_COMPRESSION = 8
+NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
+
 
 class ChannelInfo(NamedTuple):
     target: Endpoint
@@ -160,6 +167,46 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
 
+def quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
+    n_bins = 2 ** n_bits
+    borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
+    quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
+    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten(), tensor.flatten())
+    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
+    lookup = bin_sums / bin_counts
+    return quant_weight, lookup
+
+
+def quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
+    """ Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel. """
+    if not array.data.c_contiguous and array.data.f_contiguous:
+        array = array.T
+    array = np.ascontiguousarray(array.reshape(-1))
+    quantiles = np.linspace(0., 1., num=n_quantiles, dtype=array.dtype)
+    chunk_size = get_chunk_size(len(array), min_chunk_size)
+    num_chunks = (len(array) - 1) // chunk_size + 1
+    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
+
+    jobs = []
+    for i in range(num_chunks):
+        chunk = slice(chunk_size * i, chunk_size * (i + 1))
+        jobs.append(run_in_background(
+            np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
+
+    for job in jobs:
+        job.result()
+    return np.quantile(partition_quantiles, quantiles)
+
+
+def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
+    """ Adjust chunk_size to minimize imbalance between chunk sizes """
+    if min_chunk_size >= num_elements:
+        return min_chunk_size
+    leftover_elements = num_elements % min_chunk_size
+    num_chunks = num_elements // min_chunk_size
+    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
+
+
 FP16_MAX = 65_504
 
 
@@ -207,6 +254,18 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             size=array.shape,
             dtype=array.dtype.name,
             requires_grad=tensor.requires_grad)
+    elif compression_type == CompressionType.QUANTILE_8BIT:
+        assert tensor.dtype == torch.float32
+
+        quantized, lookup = quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
+        data = b''.join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
+
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype='compressed_float32',
+            requires_grad=tensor.requires_grad)
     else:
         raise ValueError(f"Unknown compression type: {compression_type}")
 
@@ -230,15 +289,23 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
         stats_size = list(serialized_tensor.size)
         stats_size[-1] = 1
         stats_count = np.prod(stats_size)
-        means = serialized_tensor.buffer[-8 * stats_count:-4 * stats_count]
-        stds = serialized_tensor.buffer[-4 * stats_count:]
+        means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count: -NUM_BYTES_FLOAT32 * stats_count]
+        stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count:]
         means = torch.as_tensor(np.frombuffer(means, dtype=np.float32).copy()).view(*stats_size)
         stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32).copy()).view(*stats_size)
+
         array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16).copy()
         tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
     elif serialized_tensor.compression == CompressionType.FLOAT16:
         array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy()
         tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
+    elif serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
+        lookup = serialized_tensor.buffer[:NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32]
+        quantized = serialized_tensor.buffer[NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32:]
+        lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32).copy())
+        quantized = np.frombuffer(quantized, dtype=np.uint8).copy()
+        quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
+        tensor = lookup[quantized]
     else:
         raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
 

+ 31 - 0
tests/benchmark_tensor_compression.py

@@ -0,0 +1,31 @@
+import time
+import argparse
+import torch
+
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
+
+
+def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
+    t = time.time()
+    deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
+    return time.time() - t
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--size', type=int, default=10000000, required=False)
+    parser.add_argument('--seed', type=int, default=7348, required=False)
+    parser.add_argument('--num_iters', type=int, default=30, required=False)
+
+    args = parser.parse_args()
+
+    torch.manual_seed(args.seed)
+    X = torch.randn(args.size)
+
+    for name, compression_type in CompressionType.items():
+        tm = 0
+        for i in range(args.num_iters):
+            tm += benchmark_compression(X, compression_type)
+        tm /= args.num_iters
+        print(f"Compression type: {name}, time: {tm}")

+ 6 - 5
tests/test_util_modules.py

@@ -8,6 +8,8 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import MSGPackSerializer
 from concurrent.futures import CancelledError
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
 
 
 def test_mpfuture_result():
@@ -121,17 +123,16 @@ async def test_await_mpfuture():
             await future
 
 
-
-def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
+def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     torch.manual_seed(0)
-    from hivemind.proto.runtime_pb2 import CompressionType
-    from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
     X = torch.randn(*size)
     assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16))-X
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X
     assert error.square().mean() < alpha
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
     assert error.square().mean() < alpha
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
+    assert error.square().mean() < beta
 
 
 @pytest.mark.forked