Ver Fonte

Add uniform compression (#202)

* Add uniform compression to 8 bit
* Change lookup computation of `uint8_uniform_buckets_encode` for more stable training with int8 compression
refactor `quantile_encode_approx`
* Add UNIFORM_8BIT case to `test_tensor_compression`
* Fix possible bug with size of lookup in `deserialize_torch_tensor`
mponty há 4 anos atrás
pai
commit
0080028e25

+ 1 - 0
hivemind/proto/runtime.proto

@@ -31,6 +31,7 @@ enum CompressionType{
   MEANSTD_16BIT = 1;
   FLOAT16 = 2;
   QUANTILE_8BIT = 3;
+  UNIFORM_8BIT = 4;
 }
 
 message Tensor {

+ 32 - 7
hivemind/utils/compression.py

@@ -13,6 +13,9 @@ NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT16 = 2
 NUM_BITS_QUANTILE_COMPRESSION = 8
 NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
+UNIFORM_BUCKETS_STD_RANGE = 6
+FP16_MAX = 65_504
+UINT8_RANGE = 256
 
 warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
 
@@ -21,10 +24,15 @@ def _quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Te
     n_bins = 2 ** n_bits
     borders = torch.as_tensor(_quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
     quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
-    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten(), tensor.flatten())
+    lookup = average_buckets(tensor, quant_weight, n_bins)
+    return quant_weight, lookup
+
+
+def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
+    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
     bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
     lookup = bin_sums / bin_counts
-    return quant_weight, lookup
+    return lookup
 
 
 def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
@@ -57,6 +65,16 @@ def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
     return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
 
 
+def _uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
+    offset = UINT8_RANGE // 2
+    shift = tensor.mean()
+    scale = range_in_sigmas * tensor.std() / UINT8_RANGE
+
+    quant_weight = torch.quantize_per_tensor(tensor - shift, scale, offset, torch.quint8).int_repr()
+    lookup = average_buckets(tensor, quant_weight, UINT8_RANGE)
+    return quant_weight, lookup
+
+
 def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE,
                            allow_inplace=False) -> runtime_pb2.Tensor:
     assert tensor.device == torch.device('cpu')
@@ -101,10 +119,13 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp
             size=array.shape,
             dtype=array.dtype.name,
             requires_grad=tensor.requires_grad)
-    elif compression_type == CompressionType.QUANTILE_8BIT:
+    elif compression_type in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
         assert tensor.dtype == torch.float32
 
-        quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
+        if compression_type == CompressionType.QUANTILE_8BIT:
+            quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
+        elif compression_type == CompressionType.UNIFORM_8BIT:
+            quantized, lookup = _uint8_uniform_buckets_encode(tensor.detach(), UNIFORM_BUCKETS_STD_RANGE)
         data = b''.join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
 
         proto = runtime_pb2.Tensor(
@@ -149,9 +170,13 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten
         array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
         tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
 
-    elif serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
-        lookup = serialized_tensor.buffer[:NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32]
-        quantized = serialized_tensor.buffer[NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32:]
+    elif serialized_tensor.compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
+        if serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
+            lookup_size = NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32
+        else:
+            lookup_size = UINT8_RANGE * NUM_BYTES_FLOAT32
+        lookup = serialized_tensor.buffer[:lookup_size]
+        quantized = serialized_tensor.buffer[lookup_size:]
         lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32))
         quantized = np.frombuffer(quantized, dtype=np.uint8)
         quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)

+ 2 - 1
tests/test_util_modules.py

@@ -135,7 +135,8 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     assert error.square().mean() < alpha
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
     assert error.square().mean() < beta
-
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
+    assert error.square().mean() < beta
 
 @pytest.mark.forked
 @pytest.mark.asyncio