Procházet zdrojové kódy

Split compression strategies into separate classes (#366)

* implement compression strategies as subclasses
   * [x] NONE
   * [x] FLOAT16
   * [x] MEANSTD_16BIT
   * [x] QUANTILE_8BIT
   * [x] UNIFORM_8BIT
* [x] update load balancing in AllReduceRunner to account for compression ratio
* [x] pass tensor info into DecentralizedAverager, use defaults with indices
* [x] pass correct tensor info in TrainingAverager
* [x] pass CompressionInfos into TensorPartContainer
* [x] chunk -> part
* update all tests
  * [x] test_util_modules
  * [x] test_allreduce
  * [x] test_averaging
  * [x] test_training
  * [x] test_moe/test_expert_backend?
* implement adaptive compression strategies
  * [x] size-adaptive
  * [x] role-adaptive
  * [x] manual, for a list of tensors
* [x] separate compression strategy for load_state_from_peers
* [x] remove compression_type in ExpertBackend in favour of Compression
* [x] co-author @mponty and @Vsevolod-pl for attribution

__Testing:__
* [x] extract compression tests into a separate file
* [x] tests for adaptive strategies
* [x] test load_state_from_peers with compression
* [x] ensure that examples/albert works with mixed compression strategies
* [x] ensure that compression ratio is correctly accounted for in DecentralizedAverager
* [x] test load_state_from_peers compression

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: mponty <heapnhash@gmail.com>
Co-authored-by: Vsevolod-pl <vsevolod-pl@yandex.ru>
Co-authored-by: Michael Diskin <yhn112@users.noreply.github.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic před 4 roky
rodič
revize
1f54faf82f

+ 1 - 1
benchmarks/benchmark_tensor_compression.py

@@ -3,8 +3,8 @@ import time
 
 import torch
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 0 - 3
examples/albert/arguments.py

@@ -93,9 +93,6 @@ class CollaborativeOptimizerArguments:
         default=100.0,
         metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
     )
-    compression: str = field(
-        default="FLOAT16", metadata={"help": "Use this compression when averaging parameters/gradients"}
-    )
 
 
 @dataclass

+ 1 - 2
examples/albert/run_trainer.py

@@ -18,7 +18,6 @@ from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 
 import hivemind
-from hivemind.utils.compression import CompressionType
 
 import utils
 from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
@@ -262,7 +261,7 @@ def main():
         dht=dht,
         scheduler=scheduler,
         prefix=collaboration_args.experiment_prefix,
-        compression_type=CompressionType.Value(collaboration_args.compression),
+        compression=hivemind.Float16Compression(),
         batch_size_per_step=total_batch_size_per_step,
         bandwidth=collaboration_args.bandwidth,
         target_batch_size=adjusted_target_batch_size,

+ 1 - 2
examples/albert/run_training_monitor.py

@@ -13,7 +13,6 @@ from torch_optimizer import Lamb
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 import hivemind
-from hivemind.utils.compression import CompressionType
 
 import utils
 from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
@@ -101,7 +100,7 @@ class CheckpointHandler:
             opt=opt,
             dht=dht,
             prefix=experiment_prefix,
-            compression_type=CompressionType.Value(collab_optimizer_args.compression),
+            compression_type=hivemind.Float16Compression(),
             bandwidth=collab_optimizer_args.bandwidth,
             target_batch_size=adjusted_target_batch_size,
             client_mode=collab_optimizer_args.client_mode,

+ 1 - 0
hivemind/__init__.py

@@ -1,4 +1,5 @@
 from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.moe import (
     ExpertBackend,

+ 1 - 1
hivemind/averaging/allreduce.py

@@ -5,11 +5,11 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 import torch
 
 from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
 from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 # flavour types
 GroupID = bytes

+ 35 - 15
hivemind/averaging/averager.py

@@ -9,7 +9,6 @@ import multiprocessing as mp
 import os
 import threading
 import weakref
-from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
 from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 
@@ -21,12 +20,18 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
+from hivemind.compression import (
+    CompressionBase,
+    CompressionInfo,
+    NoCompression,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.proto import averaging_pb2, runtime_pb2
+from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -51,7 +56,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
     :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
-    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param compression: optionally compress tensors with this compression algorithm before running all-reduce
+    :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
+    :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
@@ -102,7 +109,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
-        compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+        compression: CompressionBase = NoCompression(),
+        state_compression: CompressionBase = NoCompression(),
+        tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         bandwidth: Optional[float] = None,
         min_vector_size: int = 0,
         auxiliary: bool = False,
@@ -158,7 +167,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             request_timeout=request_timeout,
         )
         self.allreduce_kwargs = dict(
-            compression_type=compression_type, part_size_bytes=part_size_bytes, min_vector_size=min_vector_size
+            compression=compression,
+            part_size_bytes=part_size_bytes,
+            min_vector_size=min_vector_size,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -169,6 +180,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
+        self.state_compression = state_compression
+        self.tensor_infos = tensor_infos
 
         self._ready = MPFuture()
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
@@ -363,7 +376,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                     future.set_result(
                         await asyncio.wait_for(
-                            self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout
+                            self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
+                            timeout=self._allreduce_timeout,
                         )
                     )
                     # averaging is finished, loop will now exit
@@ -529,24 +543,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         if not self.allow_state_sharing:
             return  # deny request and direct peer to the next prospective averager
-        metadata, tensors = await self._get_current_state_from_host_process()
+        metadata, tensors, infos = await self._get_current_state_from_host_process()
+        if infos is None:
+            infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)]
+        assert len(tensors) == len(infos)
 
-        for tensor in tensors:
-            for part in split_for_streaming(serialize_torch_tensor(tensor)):
+        for tensor, info in zip(tensors, infos):
+            for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)):
                 if metadata is not None:
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     metadata = None
                 else:
                     yield averaging_pb2.DownloadData(tensor_part=part)
 
-    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
+    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]:
         """
         Get current state and send it to a peer. executed in the host process. Meant to be overriden.
         :returns: a tuple of (small metadata, sequence of torch tensors)
         :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
         """
         with self.get_tensors() as tensors:
-            return dict(group_key=self.get_group_bits()), tensors
+            return dict(group_key=self.get_group_bits()), tensors, self.tensor_infos
 
     async def _get_current_state_from_host_process(self):
         """Executed in the averager process inside rpc_download_state"""
@@ -618,7 +635,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         finally:
             if not future.done():
-                logger.warning("Averager could not load state from peers: none of the requests succeeded.")
                 future.set_result(None)
 
     def get_group_bits(self, wait: bool = True):
@@ -681,7 +697,11 @@ def _background_thread_fetch_current_state(
             get_current_state = get_current_state_ref()
             if get_current_state is None:
                 break
-            state_metadata, state_tensors = get_current_state()
+            state = get_current_state()
+            assert 0 < len(state) <= 3
+            if len(state) != 3:
+                state = tuple(state + (None,) * (3 - len(state)))
+            state_metadata, state_tensors, tensor_infos = state
             del get_current_state
 
             state_metadata = serializer.dumps(state_metadata)
@@ -689,7 +709,7 @@ def _background_thread_fetch_current_state(
                 tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
             )
             # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
-            future.set_result((state_metadata, state_tensors))
+            future.set_result((state_metadata, state_tensors, tensor_infos))
         except BaseException as e:
             future.set_exception(e)
             logger.warning(e)

+ 22 - 16
hivemind/averaging/partition.py

@@ -3,14 +3,14 @@ Auxiliary data structures for AllReduceRunner
 """
 import asyncio
 from collections import deque
-from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar, Union
+from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar
 
 import numpy as np
 import torch
 
-from hivemind.proto.runtime_pb2 import CompressionType, Tensor
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
+from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor
-from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor
 
 T = TypeVar("T")
 DEFAULT_PART_SIZE_BYTES = 2 ** 16
@@ -22,8 +22,9 @@ class TensorPartContainer:
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
     :param tensors: local tensors to be split and aggregated
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
-    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param compression: optionally compress tensors with this compression algorithm before sending them to peers
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
+    :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     """
 
@@ -31,16 +32,19 @@ class TensorPartContainer:
         self,
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
-        compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE,
+        compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
+        tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         prefetch: int = 5,
     ):
-        if not isinstance(compression_type, Sequence):
-            compression_type = [compression_type] * len(tensors)
-        assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
+        if tensor_infos is None:
+            tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
+        assert len(tensor_infos) == len(tensors), "compression types do not match the number of tensors"
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
-        self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
+        self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
         self.total_size = sum(tensor.numel() for tensor in tensors)
+        self.prefetch = prefetch
+
         self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
@@ -56,11 +60,13 @@ class TensorPartContainer:
         pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
         pivots[-1] = self.total_size
 
-        for tensor, tensor_compression in zip(self.local_tensors, compression_type):
-            part_size_values = int(part_size_bytes / get_nbytes_per_value(tensor.dtype, tensor_compression))
+        for tensor, info in zip(self.local_tensors, self.tensor_infos):
+            bytes_per_value = tensor.element_size() * compression.estimate_compression_ratio(info)
+            part_size_values = int(part_size_bytes / bytes_per_value)
             tensor_parts = tensor.detach().view(-1).split(part_size_values)
             self.num_parts_by_tensor.append(len(tensor_parts))
-            for part in tensor_parts:
+            for part_index, part in enumerate(tensor_parts):
+                part_info = info.get_part(part_index, part_size_values)
                 if current_length + len(part) > pivots[current_peer_index]:
                     # switch to next peer; if a part lands between parts of two or
                     # more peers, assign that part to the peer with highest intersection
@@ -71,9 +77,9 @@ class TensorPartContainer:
                         current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
                         peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
                     assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
-                    self._input_parts_by_peer[assigned_peer_index].append((part, tensor_compression))
+                    self._input_parts_by_peer[assigned_peer_index].append((part, part_info))
                 else:
-                    self._input_parts_by_peer[current_peer_index].append((part, tensor_compression))
+                    self._input_parts_by_peer[current_peer_index].append((part, part_info))
                 current_length += len(part)
 
         assert current_length == self.total_size
@@ -89,7 +95,7 @@ class TensorPartContainer:
         return input_parts
 
     @torch.no_grad()
-    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
+    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[runtime_pb2.Tensor]:
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
@@ -99,7 +105,7 @@ class TensorPartContainer:
                 yield self._input_parts_by_peer[peer_index].popleft()
 
         async for serialized_part in amap_in_executor(
-            lambda x_and_compr: serialize_torch_tensor(*x_and_compr), _aiterate_parts(), max_prefetch=self.prefetch
+            lambda x_and_info: self.compression.compress(*x_and_info), _aiterate_parts(), max_prefetch=self.prefetch
         ):
             yield serialized_part
 

+ 47 - 12
hivemind/averaging/training.py

@@ -8,6 +8,7 @@ from typing import Dict, Iterator, Optional, Sequence
 import torch
 
 from hivemind.averaging import DecentralizedAverager
+from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
@@ -41,23 +42,28 @@ class TrainingAverager(DecentralizedAverager):
         average_gradients: bool,
         average_opt_statistics: Sequence[str] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
+        parameter_names: Optional[Sequence[str]] = None,
         initialize_optimizer: bool = True,
         **kwargs
     ):
+        if initialize_optimizer:
+            initialize_optimizer_state(opt)  # note: this will run one optimizer step!
+        if parameter_names is None:
+            parameter_names = tuple(i for group in opt.param_groups for i in range(len(group["params"])))
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
+        self.parameter_names = parameter_names
         self.step_executor = ThreadPoolExecutor(max_workers=1)
         self.lock_averager_step = Lock()
         self.pending_updates_done = Event()
         self.pending_updates_done.set()
-        if initialize_optimizer:
-            initialize_optimizer_state(opt)  # note: this will run one optimizer step!
 
         with torch.no_grad():
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
-        super().__init__(averaged_tensors=averaged_tensors, **kwargs)
+
+        super().__init__(averaged_tensors=averaged_tensors, tensor_infos=list(self.tensor_infos()), **kwargs)
 
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
         """
@@ -119,13 +125,8 @@ class TrainingAverager(DecentralizedAverager):
             self.local_step += 1
             return gathered
 
-    def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:
-        """
-        Iterate local trainer's tensors that should be averaged with peers
-
-        :param replace_none: if True and average_gradients is True, None grads will be replaced with a zero tensors
-          Otherwise, such gradients will be skipped. (this may cause inconsistencies with averaged_tensors)
-        """
+    def local_tensors(self) -> Iterator[torch.Tensor]:
+        """Iterate local trainer's tensors that should be averaged with peers"""
         if self.average_parameters:
             for param_group in self.opt.param_groups:
                 yield from param_group["params"]
@@ -134,7 +135,7 @@ class TrainingAverager(DecentralizedAverager):
                 for param in param_group["params"]:
                     if param.grad is not None:
                         yield param.grad
-                    elif replace_none:
+                    else:
                         yield torch.zeros_like(param)
         for stats in self.opt_statistics:
             for param_group in self.opt.param_groups:
@@ -142,6 +143,26 @@ class TrainingAverager(DecentralizedAverager):
                     yield self.opt.state[param][stats]
         yield from iter(self.extra_tensors)
 
+    def tensor_infos(self):
+        """Get CompressionInfo for each tensor, accounting for its role and specification"""
+        params = tuple(param for param_group in self.opt.param_groups for param in param_group["params"])
+        assert len(params) == len(self.parameter_names)
+        if self.average_parameters:
+            for param, key in zip(params, self.parameter_names):
+                yield CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+        if self.average_gradients:
+            for param, key in zip(params, self.parameter_names):
+                if param.grad is not None:
+                    grad = param.grad if param.grad is not None else torch.zeros_like(param)
+                    yield CompressionInfo.from_tensor(grad, key=key, role=TensorRole.GRADIENT)
+        for stats in self.opt_statistics:
+            for param, key in zip(params, self.parameter_names):
+                yield CompressionInfo.from_tensor(
+                    self.opt.state[param][stats], key=(key, stats), role=TensorRole.OPTIMIZER
+                )
+        for i, extra_tensor in enumerate(self.extra_tensors):
+            yield CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+
     def get_current_state(self):
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
@@ -151,11 +172,25 @@ class TrainingAverager(DecentralizedAverager):
             optimized_parameters = tuple(
                 param.detach().cpu() for param_group in self.opt.param_groups for param in param_group["params"]
             )
+            parameter_infos = [
+                CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+                for param, key in zip(optimized_parameters, self.parameter_names)
+            ]
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            extra_infos = [
+                CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+                for i, extra_tensor in enumerate(extra_tensors)
+            ]
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
+            optimizer_infos = [
+                CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
+                for i, opt_tensor in enumerate(optimizer_tensors)
+            ]
 
         metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
-        return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
+        return metadata, all_tensors, all_tensor_infos
 
     def load_state_from_peers(self, **kwargs):
         """

+ 52 - 0
hivemind/compression/__init__.py

@@ -0,0 +1,52 @@
+"""
+Compression strategies that reduce the network communication in .averaging, .optim and .moe
+"""
+
+import warnings
+from typing import Dict, Optional
+
+import torch
+
+from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
+from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
+from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
+from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.proto import runtime_pb2
+
+warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
+
+
+BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+    NONE=NoCompression(),
+    FLOAT16=Float16Compression(),
+    MEANSTD_16BIT=ScaledFloat16Compression(),
+    QUANTILE_8BIT=Quantile8BitQuantization(),
+    UNIFORM_8BIT=Uniform8BitQuantization(),
+)
+
+for key in runtime_pb2.CompressionType.keys():
+    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer."
+    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+    assert (
+        runtime_pb2.CompressionType.Name(actual_compression_type) == key
+    ), f"Compression strategy for {key} has inconsistent type"
+
+
+def serialize_torch_tensor(
+    tensor: torch.Tensor,
+    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+    info: Optional[CompressionInfo] = None,
+    allow_inplace: bool = False,
+    **kwargs,
+) -> runtime_pb2.Tensor:
+    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
+    assert tensor.device == torch.device("cpu")
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
+    return compression.compress(tensor, info, allow_inplace)
+
+
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    """Restore a pytorch tensor from a protobuf message"""
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)

+ 67 - 0
hivemind/compression/adaptive.py

@@ -0,0 +1,67 @@
+from abc import ABC, abstractmethod
+from typing import Mapping, Sequence, Union
+
+import torch
+
+import hivemind
+from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole
+from hivemind.proto import runtime_pb2
+
+
+class AdaptiveCompressionBase(CompressionBase, ABC):
+    @abstractmethod
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        ...
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return self.choose_compression(info).estimate_compression_ratio(info)
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace)
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        return hivemind.compression.deserialize_torch_tensor(serialized_tensor)
+
+
+class SizeAdaptiveCompression(AdaptiveCompressionBase):
+    """Apply compression strategy 1 if tensor has more than :threshold: elements and strategy 2 otherwise"""
+
+    def __init__(self, threshold: int, less: CompressionBase, greater_equal: CompressionBase):
+        self.threshold, self.less, self.greater_equal = threshold, less, greater_equal
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.greater_equal if info.descriptor.numel() >= self.threshold else self.less
+
+
+class RoleAdaptiveCompression(AdaptiveCompressionBase):
+    """Compress a tensor based on its role in training. Any non-specified compressions will use the "default" option"""
+
+    def __init__(
+        self,
+        *,
+        activation: CompressionBase = None,
+        parameter: CompressionBase = None,
+        gradient: CompressionBase = None,
+        optimizer: CompressionBase = None,
+        default: CompressionBase = NoCompression()
+    ):
+        self.role_compressions = {
+            TensorRole.ACTIVATION: activation or default,
+            TensorRole.PARAMETER: parameter or default,
+            TensorRole.GRADIENT: gradient or default,
+            TensorRole.OPTIMIZER: optimizer or default,
+            TensorRole.UNSPECIFIED: default,
+        }
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.role_compressions[info.role]
+
+
+class PerTensorCompression(AdaptiveCompressionBase):
+    """Manually specify the compression strategy depending on tensor key"""
+
+    def __init__(self, tensor_compressions: Union[Sequence[CompressionBase], Mapping[Key, CompressionBase]]):
+        self.tensor_compressions = tensor_compressions
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.tensor_compressions[info.key]

+ 89 - 0
hivemind/compression/base.py

@@ -0,0 +1,89 @@
+import dataclasses
+from abc import ABC, abstractmethod
+from enum import Enum, auto
+from typing import Any, Optional
+
+import numpy as np
+import torch
+
+from hivemind.proto import runtime_pb2
+from hivemind.utils.tensor_descr import TensorDescriptor
+
+Key = Any
+
+
+class TensorRole(Enum):
+    ACTIVATION = auto()
+    PARAMETER = auto()
+    GRADIENT = auto()
+    OPTIMIZER = auto()
+    UNSPECIFIED = auto()
+
+
+@dataclasses.dataclass(frozen=True)
+class CompressionInfo:
+    """Auxiliary data structure that contains information about the tensor that determines how it is compressed"""
+
+    key: Key  # name or index of the tensor from named parameters, optimizer state dict or i/o structure
+    descriptor: TensorDescriptor  # data structure that defines shape, dtype, layout and device information
+    role: TensorRole = TensorRole.UNSPECIFIED  # which role does the tensor play with respect to the model
+    part_index: int = 0  # if tensor is sliced into parts, this represents the index within one tensor
+    part_size: Optional[int] = None  # if tensor is sliced into parts, this is the _maximum_ number of values per part
+
+    @classmethod
+    def from_tensor(cls, tensor: torch.Tensor, key: Key = None, descriptor: TensorDescriptor = None, **kwargs):
+        return cls(key, descriptor or TensorDescriptor.from_tensor(tensor), **kwargs)
+
+    def get_part(self, part_index: int, part_size: Optional[int]):
+        return CompressionInfo(self.key, self.descriptor, self.role, part_index=part_index, part_size=part_size)
+
+
+class CompressionBase(ABC):
+    """A base class that applies compression algorithm to a pytorch tensor"""
+
+    compression_type: runtime_pb2.CompressionType
+
+    @abstractmethod
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        """
+        Applies compression algorithm to a tensor based on their meta-parameters
+
+        :param tensor: a pytorch tensor to compress; depending on the applicaiton, it is a full tensor or a part
+        :param info: meta-information about the tensor; if partitioning is used, this still describes the full tensor
+        :param allow_inplace: if True, compression can (but doesn't have to) to modify tensor in-place for efficiency
+        :returns: a protobuf message that encodes the tensor
+        """
+        ...
+
+    @abstractmethod
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        """Create a pytorch tensor from the serialized outputs of .compress"""
+        ...
+
+    @abstractmethod
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        """Estimate the compression ratio without doing the actual compression; lower ratio = better compression"""
+        ...
+
+
+class NoCompression(CompressionBase):
+    """A dummy compression strategy that preserves the original tensor as is."""
+
+    compression_type = runtime_pb2.CompressionType.NONE
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        array = tensor.numpy()
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=array.tobytes(),
+            size=array.shape,
+            dtype=array.dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
+        return torch.as_tensor(array).reshape(tuple(serialized_tensor.size))
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return 1.0

+ 92 - 0
hivemind/compression/floating.py

@@ -0,0 +1,92 @@
+import math
+
+import numpy as np
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo
+from hivemind.proto import runtime_pb2
+
+
+class Float16Compression(CompressionBase):
+    compression_type = runtime_pb2.CompressionType.FLOAT16
+    FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        dtype_name = tensor.numpy().dtype.name
+        tensor = tensor.detach().cpu().float()
+        tensor = tensor if allow_inplace else tensor.clone()
+        tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=tensor.numpy().tobytes(),
+            size=tensor.shape,
+            dtype=dtype_name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        original_dtype = np.dtype(serialized_tensor.dtype)
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
+        return torch.as_tensor(np.asarray(array, dtype=original_dtype)).reshape(tuple(serialized_tensor.size))
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return 16.0 / get_num_bits(info.descriptor.dtype)
+
+
+class ScaledFloat16Compression(Float16Compression):
+    """A compression strategy that applies mean-std scaling over last axis before casting to float16"""
+
+    compression_type = runtime_pb2.CompressionType.MEANSTD_16BIT
+    FP32_BYTES = torch.finfo(torch.float32).bits // 8
+    FP32_EPS = torch.finfo(torch.float32).eps
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        dtype_name = tensor.numpy().dtype.name
+        tensor = tensor.detach().cpu().float()
+        tensor = tensor if allow_inplace else tensor.clone()
+        means = torch.mean(tensor, dim=-1, keepdim=True)
+        tensor.sub_(means)
+        stds = tensor.norm(dim=-1, keepdim=True) / math.sqrt(tensor.shape[-1])
+        stds.clamp_min_(self.FP32_EPS)
+        tensor.div_(stds)
+        tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
+
+        data = b"".join((tensor.numpy().tobytes(), means.float().numpy().tobytes(), stds.float().numpy().tobytes()))
+
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype=dtype_name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        stats_shape = list(serialized_tensor.size)
+        stats_shape[-1] = 1
+        stats_count = np.prod(stats_shape)
+        means_offset = len(serialized_tensor.buffer) - 2 * stats_count * self.FP32_BYTES
+        stds_offset = len(serialized_tensor.buffer) - stats_count * self.FP32_BYTES
+
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16, count=np.prod(serialized_tensor.size))
+        means = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=means_offset, count=stats_count)
+        stds = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=stds_offset, count=stats_count)
+
+        means = torch.as_tensor(means).reshape(stats_shape)
+        stds = torch.as_tensor(stds).reshape(stats_shape)
+        tensor = torch.as_tensor(np.asarray(array, dtype=serialized_tensor.dtype)).reshape(
+            list(serialized_tensor.size)
+        )
+        return tensor.mul_(stds).add_(means)
+
+
+def get_num_bits(dtype: torch.dtype) -> int:
+    if dtype == torch.bool:
+        return 8  # see https://github.com/pytorch/pytorch/issues/41571
+    elif dtype.is_floating_point:
+        return torch.finfo(dtype).bits
+    else:
+        try:
+            return torch.iinfo(dtype).bits
+        except TypeError:
+            raise TypeError(f"Could not infer size for tensor type {dtype}")

+ 114 - 0
hivemind/compression/quantization.py

@@ -0,0 +1,114 @@
+import math
+import os
+from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Tuple
+
+import numpy as np
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo
+from hivemind.proto import runtime_pb2
+
+EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
+
+
+class Quantization(CompressionBase, ABC):
+    codebook_dtype, indices_dtype = np.float32, np.uint8
+
+    @abstractmethod
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        """Convert tensor into a pair of (indices, codebook)"""
+        ...
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
+            size=tensor.shape,
+            dtype=tensor.numpy().dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        codebook_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
+        codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
+        quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
+        quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
+        codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
+        return codebook[quantized]
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return self.n_bits / torch.finfo(info.descriptor.dtype).bits
+
+    @property
+    def n_bits(self):
+        return self.indices_dtype(1).itemsize * 8
+
+    @property
+    def n_bins(self):
+        return 2 ** self.n_bits
+
+
+class Uniform8BitQuantization(Quantization):
+    RANGE_IN_SIGMAS: int = 6
+    compression_type = runtime_pb2.UNIFORM_8BIT
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        offset = self.n_bins // 2
+        shift = tensor.mean()
+        centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
+        std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
+        scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
+        quantized = torch.quantize_per_tensor(centered_tensor, scale, offset, torch.quint8).int_repr()
+        lookup = average_buckets(tensor, quantized, self.n_bins)
+        return np.asarray(quantized, dtype=self.indices_dtype), np.asarray(lookup, dtype=self.codebook_dtype)
+
+
+class Quantile8BitQuantization(Quantization):
+    compression_type = runtime_pb2.QUANTILE_8BIT
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        tensor = tensor.detach().float()
+        borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), self.n_bins + 1)[1:-1])
+        quantized = torch.clamp_(torch.bucketize(tensor, borders), 0, self.n_bins - 1)
+        codebook = average_buckets(tensor, quantized, self.n_bins)
+        return quantized.numpy().astype(np.uint8), codebook.numpy()
+
+
+def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
+    """Return the average value in each bucket"""
+    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
+    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
+    lookup = bin_sums / bin_counts
+    return lookup
+
+
+def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
+    """Adjust chunk_size to minimize imbalance between chunk sizes"""
+    if min_chunk_size >= num_elements:
+        return min_chunk_size
+    leftover_elements = num_elements % min_chunk_size
+    num_chunks = num_elements // min_chunk_size
+    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
+
+
+def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
+    """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
+    if not array.data.c_contiguous and array.data.f_contiguous:
+        array = array.T
+    array = np.ascontiguousarray(array.reshape(-1))
+    quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
+    chunk_size = get_chunk_size(len(array), min_chunk_size)
+    num_chunks = (len(array) - 1) // chunk_size + 1
+    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
+
+    jobs = []
+    for i in range(num_chunks):
+        chunk = slice(chunk_size * i, chunk_size * (i + 1))
+        jobs.append(EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
+
+    for job in jobs:
+        job.result()
+    return np.quantile(partition_quantiles, quantiles)

+ 0 - 1
hivemind/dht/__init__.py

@@ -17,7 +17,6 @@ from __future__ import annotations
 import asyncio
 import multiprocessing as mp
 import os
-from concurrent.futures import ThreadPoolExecutor
 from functools import partial
 from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
 

+ 1 - 1
hivemind/moe/client/expert.py

@@ -5,9 +5,9 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import ChannelCache
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert

+ 1 - 1
hivemind/moe/client/moe.py

@@ -10,12 +10,12 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 import hivemind
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_flatten, nested_map, nested_pack
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 1 - 1
hivemind/moe/server/connection_handler.py

@@ -6,11 +6,11 @@ from typing import Dict
 import grpc
 import torch
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import Endpoint, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 logger = get_logger(__name__)

+ 0 - 1
hivemind/utils/__init__.py

@@ -1,5 +1,4 @@
 from hivemind.utils.asyncio import *
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger

+ 0 - 209
hivemind/utils/compression.py

@@ -1,209 +0,0 @@
-import os
-import warnings
-from concurrent.futures import ThreadPoolExecutor
-from typing import Optional, Sequence, Tuple
-
-import numpy as np
-import torch
-
-from hivemind.proto import runtime_pb2
-from hivemind.proto.runtime_pb2 import CompressionType
-
-FP32_EPS = 1e-06
-NUM_BYTES_FLOAT32 = 4
-NUM_BYTES_FLOAT16 = 2
-NUM_BITS_QUANTILE_COMPRESSION = 8
-NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
-UNIFORM_BUCKETS_STD_RANGE = 6
-FP16_MAX = 65_504
-UINT8_RANGE = 256
-
-COMPRESSION_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTILE_COMPRESSION_THREADS", 128)))
-
-warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
-
-
-def _quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
-    n_bins = 2 ** n_bits
-    borders = torch.as_tensor(_quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
-    quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
-    lookup = average_buckets(tensor, quant_weight, n_bins)
-    return quant_weight, lookup
-
-
-def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
-    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
-    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
-    lookup = bin_sums / bin_counts
-    return lookup
-
-
-def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
-    """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
-    if not array.data.c_contiguous and array.data.f_contiguous:
-        array = array.T
-    array = np.ascontiguousarray(array.reshape(-1))
-    quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
-    chunk_size = _get_chunk_size(len(array), min_chunk_size)
-    num_chunks = (len(array) - 1) // chunk_size + 1
-    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
-
-    jobs = []
-    for i in range(num_chunks):
-        chunk = slice(chunk_size * i, chunk_size * (i + 1))
-        jobs.append(COMPRESSION_EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
-
-    for job in jobs:
-        job.result()
-    return np.quantile(partition_quantiles, quantiles)
-
-
-def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
-    """Adjust chunk_size to minimize imbalance between chunk sizes"""
-    if min_chunk_size >= num_elements:
-        return min_chunk_size
-    leftover_elements = num_elements % min_chunk_size
-    num_chunks = num_elements // min_chunk_size
-    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
-
-
-def _uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
-    offset = UINT8_RANGE // 2
-    shift = tensor.mean()
-    scale = range_in_sigmas * tensor.std() / UINT8_RANGE
-
-    quant_weight = torch.quantize_per_tensor(tensor - shift, scale, offset, torch.quint8).int_repr()
-    lookup = average_buckets(tensor, quant_weight, UINT8_RANGE)
-    return quant_weight, lookup
-
-
-def serialize_torch_tensor(
-    tensor: torch.Tensor, compression_type=CompressionType.NONE, allow_inplace=False
-) -> runtime_pb2.Tensor:
-    assert tensor.device == torch.device("cpu")
-    if compression_type == CompressionType.MEANSTD_16BIT:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        means = torch.mean(tensor, dim=-1, keepdim=True)
-        tensor.sub_(means)
-
-        stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
-        stds.clamp_min_(FP32_EPS)
-        tensor.div_(stds)
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = b"".join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="compressed_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type == CompressionType.FLOAT16:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = tensor.numpy().tobytes()
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="clamped_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type == CompressionType.NONE:
-        array = tensor.numpy()
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=array.tobytes(),
-            size=array.shape,
-            dtype=array.dtype.name,
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        assert tensor.dtype == torch.float32
-
-        if compression_type == CompressionType.QUANTILE_8BIT:
-            quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
-        elif compression_type == CompressionType.UNIFORM_8BIT:
-            quantized, lookup = _uint8_uniform_buckets_encode(tensor.detach(), UNIFORM_BUCKETS_STD_RANGE)
-        data = b"".join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="compressed_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    else:
-        raise ValueError(f"Unknown compression type: {compression_type}")
-
-    return proto
-
-
-def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype] = None):
-    """Helper conversion function that handles edge case with scalar deserialization"""
-    if size:
-        return torch.as_tensor(array, dtype=dtype).view(*size)
-    else:
-        return torch.as_tensor(array, dtype=dtype)
-
-
-def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-    if serialized_tensor.compression == CompressionType.NONE:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
-        tensor = construct_torch_tensor(array, serialized_tensor.size)
-
-    elif serialized_tensor.compression == CompressionType.MEANSTD_16BIT:
-        stats_size = list(serialized_tensor.size)
-        stats_size[-1] = 1
-        stats_count = np.prod(stats_size)
-
-        means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count : -NUM_BYTES_FLOAT32 * stats_count]
-        stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count :]
-        means = construct_torch_tensor(np.frombuffer(means, dtype=np.float32), stats_size)
-        stds = construct_torch_tensor(np.frombuffer(stds, dtype=np.float32), stats_size)
-
-        array = np.frombuffer(serialized_tensor.buffer[: -8 * stats_count], dtype=np.float16)
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
-
-    elif serialized_tensor.compression == CompressionType.FLOAT16:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
-
-    elif serialized_tensor.compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        if serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
-            lookup_size = NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32
-        else:
-            lookup_size = UINT8_RANGE * NUM_BYTES_FLOAT32
-        lookup = serialized_tensor.buffer[:lookup_size]
-        quantized = serialized_tensor.buffer[lookup_size:]
-        lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32))
-        quantized = np.frombuffer(quantized, dtype=np.uint8)
-        quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
-        tensor = lookup[quantized]
-
-    else:
-        raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
-
-    tensor.requires_grad_(serialized_tensor.requires_grad)
-    return tensor
-
-
-def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
-    """returns the number of bytes per value for a given tensor (excluding metadata)"""
-    if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        return 1
-    elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):
-        return 2
-    elif compression == CompressionType.NONE:
-        return torch.finfo(dtype).bits // 8
-    else:
-        raise NotImplementedError(f"Unknown compression type: {CompressionType.Name(compression)}")

+ 11 - 4
hivemind/utils/tensor_descr.py

@@ -1,6 +1,10 @@
+from __future__ import annotations
+
 import warnings
 from dataclasses import asdict, dataclass
+from typing import Tuple
 
+import numpy as np
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
@@ -29,11 +33,14 @@ class TensorDescriptor(DescriptorBase):
     compression: CompressionType = CompressionType.NONE
 
     @property
-    def shape(self):
+    def shape(self) -> Tuple[int, ...]:
         return self.size
 
+    def numel(self) -> int:
+        return int(np.prod(self.size))
+
     @classmethod
-    def from_tensor(cls, tensor: torch.Tensor):
+    def from_tensor(cls, tensor: torch.Tensor) -> TensorDescriptor:
         return cls(
             tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
         )
@@ -55,7 +62,7 @@ class BatchTensorDescriptor(TensorDescriptor):
         super().__init__((None, *instance_size), **kwargs)
 
     @classmethod
-    def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
+    def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE) -> BatchTensorDescriptor:
         return cls(
             *tensor.shape[1:],
             dtype=tensor.dtype,
@@ -66,7 +73,7 @@ class BatchTensorDescriptor(TensorDescriptor):
             compression=compression if tensor.is_floating_point() else CompressionType.NONE
         )
 
-    def make_empty(self, *batch_size, **kwargs):
+    def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
 

+ 3 - 3
tests/test_allreduce.py

@@ -6,12 +6,12 @@ from typing import Sequence
 import pytest
 import torch
 
-from hivemind import aenumerate
+from hivemind import Quantile8BitQuantization, aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
+from hivemind.compression import deserialize_torch_tensor
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import deserialize_torch_tensor
 
 
 @pytest.mark.forked
@@ -83,7 +83,7 @@ async def test_partitioning_asynchronous():
     tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096), torch.randn(4096, 1024), torch.randn(30_000, 1024)]
     peer_fractions = [0.4, 0.3, 0.2, 0.1]
 
-    partition = TensorPartContainer(tensors, peer_fractions, compression_type=CompressionType.QUANTILE_8BIT)
+    partition = TensorPartContainer(tensors, peer_fractions, compression=Quantile8BitQuantization())
     read_started, read_finished = asyncio.Event(), asyncio.Event()
 
     async def write_tensors():

+ 0 - 57
tests/test_averaging.py

@@ -1,5 +1,4 @@
 import random
-import time
 
 import numpy as np
 import pytest
@@ -12,7 +11,6 @@ from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.partition import AllreduceException
 from hivemind.p2p import PeerID
-from hivemind.proto.runtime_pb2 import CompressionType
 
 from test_utils.dht_swarms import launch_dht_instances
 
@@ -169,61 +167,6 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
         process.shutdown()
 
 
-@pytest.mark.forked
-def test_allreduce_compression():
-    """this test ensures that compression works correctly when multiple tensors have different compression types"""
-
-    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
-    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
-    results = {}
-
-    FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
-
-    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
-        dht_instances = launch_dht_instances(2)
-        averager1 = hivemind.averaging.DecentralizedAverager(
-            [x.clone() for x in tensors1],
-            dht=dht_instances[0],
-            compression_type=compression_type_pair,
-            client_mode=True,
-            target_group_size=2,
-            prefix="mygroup",
-            start=True,
-        )
-        averager2 = hivemind.averaging.DecentralizedAverager(
-            [x.clone() for x in tensors2],
-            dht=dht_instances[1],
-            compression_type=compression_type_pair,
-            target_group_size=2,
-            prefix="mygroup",
-            start=True,
-        )
-
-        for future in averager1.step(wait=False), averager2.step(wait=False):
-            future.result()
-
-        with averager1.get_tensors() as averaged_tensors:
-            results[compression_type_pair] = averaged_tensors
-
-        for instance in [averager1, averager2] + dht_instances:
-            instance.shutdown()
-
-    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
-    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
-    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
-    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
-
-    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
-    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
-    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
-    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
-
-    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
-    for i in range(2):
-        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
-        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
-
-
 def compute_mean_std(averagers, unbiased=True):
     results = []
     for averager in averagers:

+ 213 - 0
tests/test_compression.py

@@ -0,0 +1,213 @@
+import multiprocessing as mp
+from ctypes import c_int32
+
+import pytest
+import torch
+import torch.nn as nn
+
+import hivemind
+from hivemind.compression import (
+    CompressionBase,
+    CompressionInfo,
+    Float16Compression,
+    NoCompression,
+    PerTensorCompression,
+    RoleAdaptiveCompression,
+    SizeAdaptiveCompression,
+    Uniform8BitQuantization,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)
+from hivemind.compression.adaptive import AdaptiveCompressionBase
+from hivemind.proto.runtime_pb2 import CompressionType
+
+from test_utils.dht_swarms import launch_dht_instances
+
+
+@pytest.mark.forked
+def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
+    torch.manual_seed(0)
+    X = torch.randn(*size)
+    assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
+    assert error.square().mean() < alpha
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
+    assert error.square().mean() < alpha
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
+    assert error.square().mean() < beta
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
+    assert error.square().mean() < beta
+
+    zeros = torch.zeros(5, 5)
+    for compression_type in CompressionType.values():
+        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
+
+
+@pytest.mark.forked
+def test_serialize_tensor():
+    def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
+        serialized_tensor = serialize_torch_tensor(tensor, compression)
+        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+        restored = hivemind.combine_from_streaming(chunks)
+        assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
+
+    tensor = torch.randn(512, 12288)
+    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
+        _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
+
+    _check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
+    _check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
+    _check(torch.tensor(1.0), CompressionType.NONE)
+    _check(torch.tensor(1.0), CompressionType.FLOAT16)
+
+
+@pytest.mark.forked
+def test_allreduce_compression():
+    """this test ensures that compression works correctly when multiple tensors have different compression types"""
+
+    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
+    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
+    results = {}
+
+    FLOAT16, UINT8 = Float16Compression(), Uniform8BitQuantization()
+
+    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        dht_instances = launch_dht_instances(2)
+        averager1 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors1],
+            dht=dht_instances[0],
+            compression=PerTensorCompression(compression_type_pair),
+            client_mode=True,
+            target_group_size=2,
+            prefix="mygroup",
+            start=True,
+        )
+        averager2 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors2],
+            dht=dht_instances[1],
+            compression=PerTensorCompression(compression_type_pair),
+            target_group_size=2,
+            prefix="mygroup",
+            start=True,
+        )
+
+        for future in averager1.step(wait=False), averager2.step(wait=False):
+            future.result()
+
+        with averager1.get_tensors() as averaged_tensors:
+            results[compression_type_pair] = averaged_tensors
+
+        for instance in [averager1, averager2] + dht_instances:
+            instance.shutdown()
+
+    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
+    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
+    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
+    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
+
+    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
+    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
+    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
+    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
+
+    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
+    for i in range(2):
+        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
+        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
+
+
+class TrackedCompression(AdaptiveCompressionBase):
+    def __init__(self, compression: CompressionBase):
+        self.compression = compression
+        self.mp_counter, self.mp_part_size = mp.Value(c_int32, 0), mp.Value(c_int32, 0)
+        super().__init__()
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.compression
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False):
+        self.mp_counter.value += 1
+        if info.part_size is not None:
+            self.mp_part_size.value = max(self.mp_part_size.value, info.part_size)
+        return self.compression.compress(tensor, info=info, allow_inplace=allow_inplace)
+
+
+def make_params():
+    return [
+        nn.Parameter(x)
+        for x in (
+            torch.randn([]),
+            torch.randn(1),
+            torch.randn(100),
+            torch.randn(1_000),
+            torch.randn(5_000),
+            torch.randn(10_000),
+        )
+    ]
+
+
+@pytest.mark.forked
+def test_adaptive_compression():
+    UINT8 = TrackedCompression(Uniform8BitQuantization())
+    FLOAT16 = TrackedCompression(Float16Compression())
+    FLOAT32 = TrackedCompression(NoCompression())
+    STATE_FP16 = TrackedCompression(Float16Compression())
+    STATE_FP32 = TrackedCompression(NoCompression())
+
+    averaging_compression_adaptive = RoleAdaptiveCompression(
+        parameter=FLOAT16,
+        gradient=SizeAdaptiveCompression(threshold=1_000, less=FLOAT16, greater_equal=UINT8),
+        optimizer=FLOAT32,
+        default=FLOAT32,
+    )
+
+    state_compression_adaptive = SizeAdaptiveCompression(
+        threshold=500,
+        less=STATE_FP32,
+        greater_equal=STATE_FP16,
+    )
+
+    averager1 = hivemind.TrainingAverager(
+        opt=torch.optim.Adam(make_params()),
+        average_parameters=True,
+        average_gradients=True,
+        average_opt_statistics=("exp_avg",),
+        compression=averaging_compression_adaptive,
+        state_compression=state_compression_adaptive,
+        prefix="test_avgr",
+        target_group_size=2,
+        part_size_bytes=5_000,
+        start=True,
+        dht=hivemind.DHT(start=True),
+    )
+
+    averager2 = hivemind.TrainingAverager(
+        opt=torch.optim.Adam(make_params()),
+        average_parameters=True,
+        average_gradients=True,
+        average_opt_statistics=("exp_avg",),
+        compression=averaging_compression_adaptive,
+        state_compression=state_compression_adaptive,
+        prefix="test_avgr",
+        target_group_size=2,
+        part_size_bytes=5_000,
+        start=True,
+        dht=hivemind.DHT(initial_peers=averager1.dht.get_visible_maddrs(), start=True),
+    )
+
+    futures = [averager1.step(wait=False), averager2.step(wait=False)]
+
+    for future in futures:
+        future.result()
+
+    assert UINT8.mp_counter.value == 4  # half gradients: 3 tensors, 1 is split
+    assert UINT8.mp_part_size.value == 5_000  # single byte tensors
+    assert FLOAT16.mp_counter.value == 13  # parameters and half gradients
+    assert FLOAT16.mp_part_size.value == 2_500  # two-byte tensors
+    assert FLOAT32.mp_counter.value == 16  # statistics
+    assert FLOAT32.mp_part_size.value == 1250  # four-byte tensors
+
+    averager1.load_state_from_peers()
+    assert STATE_FP16.mp_counter.value == STATE_FP32.mp_counter.value == 9
+    assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0  # not partitioned

+ 1 - 51
tests/test_util_modules.py

@@ -9,6 +9,7 @@ import pytest
 import torch
 
 import hivemind
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
@@ -25,7 +26,6 @@ from hivemind.utils.asyncio import (
     azip,
     cancel_and_wait,
 )
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 
 
@@ -323,24 +323,6 @@ def test_many_futures():
     p.join()
 
 
-def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
-    torch.manual_seed(0)
-    X = torch.randn(*size)
-    assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
-    assert error.square().mean() < alpha
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
-    assert error.square().mean() < alpha
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
-    assert error.square().mean() < beta
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
-    assert error.square().mean() < beta
-
-    zeros = torch.zeros(5, 5)
-    for compression_type in CompressionType.values():
-        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
-
-
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_channel_cache():
@@ -385,38 +367,6 @@ async def test_channel_cache():
             assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
 
 
-def test_serialize_tensor():
-    tensor = torch.randn(512, 12288)
-
-    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
-    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
-        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
-        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-        restored = hivemind.combine_from_streaming(chunks)
-        assert torch.allclose(deserialize_torch_tensor(restored), tensor)
-
-    chunk_size = 30 * 1024
-    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16)
-    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
-    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-    restored = hivemind.combine_from_streaming(chunks)
-    assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
-
-    tensor = torch.randint(0, 100, (512, 1, 1))
-    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
-    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
-    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-    restored = hivemind.combine_from_streaming(chunks)
-    assert torch.allclose(deserialize_torch_tensor(restored), tensor)
-
-    scalar = torch.tensor(1.0)
-    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE)
-    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
-
-    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16)
-    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
-
-
 def test_serialize_tuple():
     test_pairs = (
         ((1, 2, 3), [1, 2, 3]),