|
@@ -1,6 +1,7 @@
|
|
|
"""
|
|
|
Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
|
|
|
"""
|
|
|
+
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import os
|
|
@@ -10,11 +11,12 @@ from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type,
|
|
|
import grpc
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
-from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
|
|
|
+from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
from hivemind.utils.networking import Endpoint
|
|
|
+from hivemind.utils.threading import run_in_background
|
|
|
from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
@@ -30,6 +32,11 @@ GRPC_KEEPALIVE_OPTIONS = (
|
|
|
('grpc.http2.min_ping_interval_without_data_ms', 10 * 1000),
|
|
|
)
|
|
|
|
|
|
+NUM_BYTES_FLOAT32 = 4
|
|
|
+NUM_BYTES_FLOAT16 = 2
|
|
|
+NUM_BITS_QUANTILE_COMPRESSION = 8
|
|
|
+NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
|
|
|
+
|
|
|
|
|
|
class ChannelInfo(NamedTuple):
|
|
|
target: Endpoint
|
|
@@ -160,6 +167,46 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
|
|
|
raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
|
|
|
|
|
|
|
|
|
+def quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ n_bins = 2 ** n_bits
|
|
|
+ borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
|
|
|
+ quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
|
|
|
+ bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten(), tensor.flatten())
|
|
|
+ bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
|
|
|
+ lookup = bin_sums / bin_counts
|
|
|
+ return quant_weight, lookup
|
|
|
+
|
|
|
+
|
|
|
+def quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
|
|
|
+ """ Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel. """
|
|
|
+ if not array.data.c_contiguous and array.data.f_contiguous:
|
|
|
+ array = array.T
|
|
|
+ array = np.ascontiguousarray(array.reshape(-1))
|
|
|
+ quantiles = np.linspace(0., 1., num=n_quantiles, dtype=array.dtype)
|
|
|
+ chunk_size = get_chunk_size(len(array), min_chunk_size)
|
|
|
+ num_chunks = (len(array) - 1) // chunk_size + 1
|
|
|
+ partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
|
|
|
+
|
|
|
+ jobs = []
|
|
|
+ for i in range(num_chunks):
|
|
|
+ chunk = slice(chunk_size * i, chunk_size * (i + 1))
|
|
|
+ jobs.append(run_in_background(
|
|
|
+ np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
|
|
|
+
|
|
|
+ for job in jobs:
|
|
|
+ job.result()
|
|
|
+ return np.quantile(partition_quantiles, quantiles)
|
|
|
+
|
|
|
+
|
|
|
+def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
|
|
|
+ """ Adjust chunk_size to minimize imbalance between chunk sizes """
|
|
|
+ if min_chunk_size >= num_elements:
|
|
|
+ return min_chunk_size
|
|
|
+ leftover_elements = num_elements % min_chunk_size
|
|
|
+ num_chunks = num_elements // min_chunk_size
|
|
|
+ return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
|
|
|
+
|
|
|
+
|
|
|
FP16_MAX = 65_504
|
|
|
|
|
|
|
|
@@ -207,6 +254,18 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
|
|
|
size=array.shape,
|
|
|
dtype=array.dtype.name,
|
|
|
requires_grad=tensor.requires_grad)
|
|
|
+ elif compression_type == CompressionType.QUANTILE_8BIT:
|
|
|
+ assert tensor.dtype == torch.float32
|
|
|
+
|
|
|
+ quantized, lookup = quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
|
|
|
+ data = b''.join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
|
|
|
+
|
|
|
+ proto = runtime_pb2.Tensor(
|
|
|
+ compression=compression_type,
|
|
|
+ buffer=data,
|
|
|
+ size=tensor.shape,
|
|
|
+ dtype='compressed_float32',
|
|
|
+ requires_grad=tensor.requires_grad)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown compression type: {compression_type}")
|
|
|
|
|
@@ -230,15 +289,23 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
|
|
|
stats_size = list(serialized_tensor.size)
|
|
|
stats_size[-1] = 1
|
|
|
stats_count = np.prod(stats_size)
|
|
|
- means = serialized_tensor.buffer[-8 * stats_count:-4 * stats_count]
|
|
|
- stds = serialized_tensor.buffer[-4 * stats_count:]
|
|
|
+ means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count: -NUM_BYTES_FLOAT32 * stats_count]
|
|
|
+ stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count:]
|
|
|
means = torch.as_tensor(np.frombuffer(means, dtype=np.float32).copy()).view(*stats_size)
|
|
|
stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32).copy()).view(*stats_size)
|
|
|
+
|
|
|
array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16).copy()
|
|
|
tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
|
|
|
elif serialized_tensor.compression == CompressionType.FLOAT16:
|
|
|
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy()
|
|
|
tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
|
|
|
+ elif serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
|
|
|
+ lookup = serialized_tensor.buffer[:NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32]
|
|
|
+ quantized = serialized_tensor.buffer[NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32:]
|
|
|
+ lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32).copy())
|
|
|
+ quantized = np.frombuffer(quantized, dtype=np.uint8).copy()
|
|
|
+ quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
|
|
|
+ tensor = lookup[quantized]
|
|
|
else:
|
|
|
raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
|
|
|
|