Jelajahi Sumber

Move compression-related code to hivemind.utils.compression (#213)

* Move compression-related code to hivemind.utils.compression

* Remove copies during deserialization, silence warning
Max Ryabinin 4 tahun lalu
induk
melakukan
916c3db52d

+ 2 - 2
hivemind/client/averaging/__init__.py

@@ -25,8 +25,8 @@ from hivemind.client.averaging.load_balancing import load_balance_peers
 from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
-    serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
+from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
 from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase

+ 2 - 1
hivemind/client/averaging/allreduce.py

@@ -5,7 +5,8 @@ import grpc
 import torch
 
 from hivemind.utils import Endpoint, get_logger, ChannelCache, anext
-from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
+from hivemind.utils import split_for_streaming, combine_from_streaming
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.proto import averaging_pb2_grpc, runtime_pb2, averaging_pb2
 
 # flavour types

+ 2 - 1
hivemind/client/expert.py

@@ -7,7 +7,8 @@ from torch.autograd.function import once_differentiable
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
-from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, ChannelCache
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.grpc import ChannelCache
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 

+ 4 - 4
hivemind/client/moe.py

@@ -5,17 +5,17 @@ from queue import Queue, Empty
 from typing import Tuple, List, Optional, Dict, Any
 
 import grpc
-
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 import hivemind
-from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
-from hivemind.server.expert_uid import UID_DELIMITER
 from hivemind.client.beam_search import MoEBeamSearcher
+from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import nested_pack, nested_flatten, serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.server.expert_uid import UID_DELIMITER
+from hivemind.utils import nested_pack, nested_flatten
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 2 - 6
hivemind/hivemind_cli/run_server.py

@@ -51,8 +51,7 @@ def main():
     parser.add_argument('--increase_file_limit', action='store_true',
                         help='On *nix, this will increase the max number of processes '
                              'a server can spawn before hitting "Too many open files"; Use at your own risk.')
-    parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
-                        'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
+    parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression for gRPC')
     parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
     parser.add_argument('--stats_report_interval', type=int, required=False,
                         help='Interval between two reports of batch processing performance statistics')
@@ -74,10 +73,7 @@ def main():
         increase_file_limit()
 
     compression_type = args.pop("compression")
-    if compression_type == "MEANSTD":
-        compression = CompressionType.MEANSTD_LAST_AXIS_FLOAT16
-    else:
-        compression = getattr(CompressionType, compression_type)
+    compression = getattr(CompressionType, compression_type)
 
     server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression)
 

+ 1 - 1
hivemind/proto/runtime.proto

@@ -28,7 +28,7 @@ message ExpertResponse {
 
 enum CompressionType{
   NONE = 0;
-  MEANSTD_LAST_AXIS_FLOAT16 = 1;
+  MEANSTD_16BIT = 1;
   FLOAT16 = 2;
   QUANTILE_8BIT = 3;
 }

+ 3 - 2
hivemind/server/connection_handler.py

@@ -8,9 +8,10 @@ import torch
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_backend import ExpertBackend
-from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, nested_flatten
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
+from hivemind.utils import get_logger, Endpoint, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 logger = get_logger(__name__)
 

+ 7 - 6
hivemind/utils/__init__.py

@@ -1,10 +1,11 @@
-from hivemind.utils.networking import *
+from hivemind.utils.asyncio import *
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.grpc import *
+from hivemind.utils.logging import get_logger
+from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
-from hivemind.utils.tensor_descr import *
+from hivemind.utils.networking import *
 from hivemind.utils.serializer import *
-from hivemind.utils.mpfuture import *
+from hivemind.utils.tensor_descr import *
 from hivemind.utils.threading import *
-from hivemind.utils.grpc import *
 from hivemind.utils.timed_storage import *
-from hivemind.utils.logging import get_logger
-from hivemind.utils.asyncio import *

+ 164 - 0
hivemind/utils/compression.py

@@ -0,0 +1,164 @@
+from typing import Tuple, Sequence, Optional
+
+import numpy as np
+import torch
+import warnings
+
+from hivemind.proto import runtime_pb2
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.threading import run_in_background
+
+FP16_MAX = 65_504
+NUM_BYTES_FLOAT32 = 4
+NUM_BYTES_FLOAT16 = 2
+NUM_BITS_QUANTILE_COMPRESSION = 8
+NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
+
+warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
+
+
+def _quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
+    n_bins = 2 ** n_bits
+    borders = torch.as_tensor(_quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
+    quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
+    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten(), tensor.flatten())
+    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
+    lookup = bin_sums / bin_counts
+    return quant_weight, lookup
+
+
+def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
+    """ Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel. """
+    if not array.data.c_contiguous and array.data.f_contiguous:
+        array = array.T
+    array = np.ascontiguousarray(array.reshape(-1))
+    quantiles = np.linspace(0., 1., num=n_quantiles, dtype=array.dtype)
+    chunk_size = _get_chunk_size(len(array), min_chunk_size)
+    num_chunks = (len(array) - 1) // chunk_size + 1
+    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
+
+    jobs = []
+    for i in range(num_chunks):
+        chunk = slice(chunk_size * i, chunk_size * (i + 1))
+        jobs.append(run_in_background(
+            np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
+
+    for job in jobs:
+        job.result()
+    return np.quantile(partition_quantiles, quantiles)
+
+
+def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
+    """ Adjust chunk_size to minimize imbalance between chunk sizes """
+    if min_chunk_size >= num_elements:
+        return min_chunk_size
+    leftover_elements = num_elements % min_chunk_size
+    num_chunks = num_elements // min_chunk_size
+    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
+
+
+def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE,
+                           allow_inplace=False) -> runtime_pb2.Tensor:
+    assert tensor.device == torch.device('cpu')
+    if compression_type == CompressionType.MEANSTD_16BIT:
+        assert tensor.dtype == torch.float32
+
+        tensor = tensor if allow_inplace else tensor.clone()
+        means = torch.mean(tensor, dim=-1, keepdim=True)
+        tensor.sub_(means)
+
+        stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
+        tensor.div_(stds)
+        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
+
+        data = b''.join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
+
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype='compressed_float32',
+            requires_grad=tensor.requires_grad)
+    elif compression_type == CompressionType.FLOAT16:
+        assert tensor.dtype == torch.float32
+
+        tensor = tensor if allow_inplace else tensor.clone()
+        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
+
+        data = tensor.numpy().tobytes()
+
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype='clamped_float32',
+            requires_grad=tensor.requires_grad)
+    elif compression_type == CompressionType.NONE:
+        array = tensor.numpy()
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=array.tobytes(),
+            size=array.shape,
+            dtype=array.dtype.name,
+            requires_grad=tensor.requires_grad)
+    elif compression_type == CompressionType.QUANTILE_8BIT:
+        assert tensor.dtype == torch.float32
+
+        quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
+        data = b''.join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
+
+        proto = runtime_pb2.Tensor(
+            compression=compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype='compressed_float32',
+            requires_grad=tensor.requires_grad)
+    else:
+        raise ValueError(f"Unknown compression type: {compression_type}")
+
+    return proto
+
+
+def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype] = None):
+    """ Helper conversion function that handles edge case with scalar deserialization """
+    if size:
+        return torch.as_tensor(array, dtype=dtype).view(*size)
+    else:
+        return torch.as_tensor(array, dtype=dtype)
+
+
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    if serialized_tensor.compression == CompressionType.NONE:
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
+        tensor = construct_torch_tensor(array, serialized_tensor.size)
+
+    elif serialized_tensor.compression == CompressionType.MEANSTD_16BIT:
+        stats_size = list(serialized_tensor.size)
+        stats_size[-1] = 1
+        stats_count = np.prod(stats_size)
+
+        means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count: -NUM_BYTES_FLOAT32 * stats_count]
+        stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count:]
+        means = construct_torch_tensor(np.frombuffer(means, dtype=np.float32), stats_size)
+        stds = construct_torch_tensor(np.frombuffer(stds, dtype=np.float32), stats_size)
+
+        array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16)
+        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
+
+    elif serialized_tensor.compression == CompressionType.FLOAT16:
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
+        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
+
+    elif serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
+        lookup = serialized_tensor.buffer[:NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32]
+        quantized = serialized_tensor.buffer[NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32:]
+        lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32))
+        quantized = np.frombuffer(quantized, dtype=np.uint8)
+        quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
+        tensor = lookup[quantized]
+
+    else:
+        raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
+
+    tensor.requires_grad_(serialized_tensor.requires_grad)
+    return tensor

+ 1 - 156
hivemind/utils/grpc.py

@@ -6,17 +6,13 @@ from __future__ import annotations
 
 import os
 import threading
-from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable, Sequence
+from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable
 
 import grpc
-import numpy as np
-import torch
 
-from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.networking import Endpoint
-from hivemind.utils.threading import run_in_background
 from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
 
 logger = get_logger(__name__)
@@ -32,11 +28,6 @@ GRPC_KEEPALIVE_OPTIONS = (
     ('grpc.http2.min_ping_interval_without_data_ms', 10 * 1000),
 )
 
-NUM_BYTES_FLOAT32 = 4
-NUM_BYTES_FLOAT16 = 2
-NUM_BITS_QUANTILE_COMPRESSION = 8
-NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
-
 
 class ChannelInfo(NamedTuple):
     target: Endpoint
@@ -167,152 +158,6 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
 
-def quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
-    n_bins = 2 ** n_bits
-    borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
-    quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
-    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten(), tensor.flatten())
-    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
-    lookup = bin_sums / bin_counts
-    return quant_weight, lookup
-
-
-def quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
-    """ Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel. """
-    if not array.data.c_contiguous and array.data.f_contiguous:
-        array = array.T
-    array = np.ascontiguousarray(array.reshape(-1))
-    quantiles = np.linspace(0., 1., num=n_quantiles, dtype=array.dtype)
-    chunk_size = get_chunk_size(len(array), min_chunk_size)
-    num_chunks = (len(array) - 1) // chunk_size + 1
-    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
-
-    jobs = []
-    for i in range(num_chunks):
-        chunk = slice(chunk_size * i, chunk_size * (i + 1))
-        jobs.append(run_in_background(
-            np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
-
-    for job in jobs:
-        job.result()
-    return np.quantile(partition_quantiles, quantiles)
-
-
-def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
-    """ Adjust chunk_size to minimize imbalance between chunk sizes """
-    if min_chunk_size >= num_elements:
-        return min_chunk_size
-    leftover_elements = num_elements % min_chunk_size
-    num_chunks = num_elements // min_chunk_size
-    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
-
-
-FP16_MAX = 65_504
-
-
-def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE,
-                           allow_inplace=False) -> runtime_pb2.Tensor:
-    assert tensor.device == torch.device('cpu')
-    if compression_type == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        means = torch.mean(tensor, dim=-1, keepdim=True)
-        tensor.sub_(means)
-
-        stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
-        tensor.div_(stds)
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = b''.join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype='compressed_float32',
-            requires_grad=tensor.requires_grad)
-    elif compression_type == CompressionType.FLOAT16:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = tensor.numpy().tobytes()
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype='clamped_float32',
-            requires_grad=tensor.requires_grad)
-    elif compression_type == CompressionType.NONE:
-        array = tensor.numpy()
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=array.tobytes(),
-            size=array.shape,
-            dtype=array.dtype.name,
-            requires_grad=tensor.requires_grad)
-    elif compression_type == CompressionType.QUANTILE_8BIT:
-        assert tensor.dtype == torch.float32
-
-        quantized, lookup = quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
-        data = b''.join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype='compressed_float32',
-            requires_grad=tensor.requires_grad)
-    else:
-        raise ValueError(f"Unknown compression type: {compression_type}")
-
-    return proto
-
-
-def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype]=None):
-    """ Helper conversion function that handles edge case with scalar deserialization """
-    if size:
-        return torch.as_tensor(array, dtype=dtype).view(*size)
-    else:
-        return torch.as_tensor(array, dtype=dtype)
-
-
-def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-    # TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
-    if serialized_tensor.compression == CompressionType.NONE:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)).copy()
-        tensor = construct_torch_tensor(array, serialized_tensor.size)
-    elif serialized_tensor.compression == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
-        stats_size = list(serialized_tensor.size)
-        stats_size[-1] = 1
-        stats_count = np.prod(stats_size)
-        means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count: -NUM_BYTES_FLOAT32 * stats_count]
-        stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count:]
-        means = torch.as_tensor(np.frombuffer(means, dtype=np.float32).copy()).view(*stats_size)
-        stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32).copy()).view(*stats_size)
-
-        array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16).copy()
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
-    elif serialized_tensor.compression == CompressionType.FLOAT16:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy()
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
-    elif serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
-        lookup = serialized_tensor.buffer[:NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32]
-        quantized = serialized_tensor.buffer[NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32:]
-        lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32).copy())
-        quantized = np.frombuffer(quantized, dtype=np.uint8).copy()
-        quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
-        tensor = lookup[quantized]
-    else:
-        raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
-
-    tensor.requires_grad_(serialized_tensor.requires_grad)
-    return tensor
-
-
 def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int) -> Iterator[runtime_pb2.Tensor]:
     """ Split serialized_tensor into multiple chunks for gRPC streaming """
     buffer = memoryview(serialized_tensor.buffer)

+ 1 - 1
hivemind/utils/threading.py

@@ -1,7 +1,7 @@
 import os
 from concurrent.futures import Future, ThreadPoolExecutor
 
-from hivemind.utils import get_logger
+from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
 

+ 1 - 1
tests/benchmark_tensor_compression.py

@@ -3,7 +3,7 @@ import argparse
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 
 
 def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:

+ 18 - 18
tests/test_util_modules.py

@@ -9,7 +9,8 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 import hivemind
-from hivemind.utils import MSGPackSerializer, serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import MSGPackSerializer
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.mpfuture import FutureStateError
 
 
@@ -128,7 +129,7 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     torch.manual_seed(0)
     X = torch.randn(*size)
     assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
     assert error.square().mean() < alpha
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
     assert error.square().mean() < alpha
@@ -176,33 +177,33 @@ async def test_channel_cache():
 def test_serialize_tensor():
     tensor = torch.randn(512, 12288)
 
-    serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.NONE)
+    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
     for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
         chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
         assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
         restored = hivemind.combine_from_streaming(chunks)
-        assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
+        assert torch.allclose(deserialize_torch_tensor(restored), tensor)
 
     chunk_size = 30 * 1024
-    serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.FLOAT16)
+    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16)
     chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
     assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
     restored = hivemind.combine_from_streaming(chunks)
-    assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
+    assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
 
     tensor = torch.randint(0, 100, (512, 1, 1))
-    serialized_tensor = hivemind.serialize_torch_tensor(tensor, hivemind.CompressionType.NONE)
+    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
     chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
     assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
     restored = hivemind.combine_from_streaming(chunks)
-    assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
+    assert torch.allclose(deserialize_torch_tensor(restored), tensor)
 
     scalar = torch.tensor(1.)
-    serialized_scalar = hivemind.serialize_torch_tensor(scalar, hivemind.CompressionType.NONE)
-    assert torch.allclose(hivemind.deserialize_torch_tensor(serialized_scalar), scalar)
+    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE)
+    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
 
-    serialized_scalar = hivemind.serialize_torch_tensor(scalar, hivemind.CompressionType.FLOAT16)
-    assert torch.allclose(hivemind.deserialize_torch_tensor(serialized_scalar), scalar)
+    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16)
+    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
 
 
 def test_serialize_tuple():
@@ -221,7 +222,7 @@ def test_serialize_tuple():
 
 def test_split_parts():
     tensor = torch.randn(910, 512)
-    serialized_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, allow_inplace=False)
+    serialized_tensor_part = serialize_torch_tensor(tensor, allow_inplace=False)
     chunks1 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 16384))
     assert len(chunks1) == int(np.ceil(tensor.numel() * tensor.element_size() / 16384))
 
@@ -231,8 +232,7 @@ def test_split_parts():
     chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10 ** 9))
     assert len(chunks3) == 1
 
-    compressed_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, hivemind.CompressionType.FLOAT16,
-                                                                   allow_inplace=False)
+    compressed_tensor_part = serialize_torch_tensor(tensor, CompressionType.FLOAT16, allow_inplace=False)
     chunks4 = list(hivemind.utils.split_for_streaming(compressed_tensor_part, 16384))
     assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384))
 
@@ -241,16 +241,16 @@ def test_split_parts():
     combined3 = hivemind.utils.combine_from_streaming(chunks3)
     combined4 = hivemind.utils.combine_from_streaming(chunks4)
     for combined in combined1, combined2, combined3:
-        assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8)
+        assert torch.allclose(tensor, deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8)
 
-    assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3)
+    assert torch.allclose(tensor, deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3)
 
     combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5])
     combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1])
     combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1])
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
         with pytest.raises(RuntimeError):
-            hivemind.deserialize_torch_tensor(combined)
+            deserialize_torch_tensor(combined)
             # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceProtocol