Pārlūkot izejas kodu

Split compression strategies into separate classes (#366)

* implement compression strategies as subclasses
   * [x] NONE
   * [x] FLOAT16
   * [x] MEANSTD_16BIT
   * [x] QUANTILE_8BIT
   * [x] UNIFORM_8BIT
* [x] update load balancing in AllReduceRunner to account for compression ratio
* [x] pass tensor info into DecentralizedAverager, use defaults with indices
* [x] pass correct tensor info in TrainingAverager
* [x] pass CompressionInfos into TensorPartContainer
* [x] chunk -> part
* update all tests
  * [x] test_util_modules
  * [x] test_allreduce
  * [x] test_averaging
  * [x] test_training
  * [x] test_moe/test_expert_backend?
* implement adaptive compression strategies
  * [x] size-adaptive
  * [x] role-adaptive
  * [x] manual, for a list of tensors
* [x] separate compression strategy for load_state_from_peers
* [x] remove compression_type in ExpertBackend in favour of Compression
* [x] co-author @mponty and @Vsevolod-pl for attribution

__Testing:__
* [x] extract compression tests into a separate file
* [x] tests for adaptive strategies
* [x] test load_state_from_peers with compression
* [x] ensure that examples/albert works with mixed compression strategies
* [x] ensure that compression ratio is correctly accounted for in DecentralizedAverager
* [x] test load_state_from_peers compression

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: mponty <heapnhash@gmail.com>
Co-authored-by: Vsevolod-pl <vsevolod-pl@yandex.ru>
Co-authored-by: Michael Diskin <yhn112@users.noreply.github.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 4 gadi atpakaļ
vecāks
revīzija
1f54faf82f

+ 1 - 1
benchmarks/benchmark_tensor_compression.py

@@ -3,8 +3,8 @@ import time
 
 
 import torch
 import torch
 
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 0 - 3
examples/albert/arguments.py

@@ -93,9 +93,6 @@ class CollaborativeOptimizerArguments:
         default=100.0,
         default=100.0,
         metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
         metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
     )
     )
-    compression: str = field(
-        default="FLOAT16", metadata={"help": "Use this compression when averaging parameters/gradients"}
-    )
 
 
 
 
 @dataclass
 @dataclass

+ 1 - 2
examples/albert/run_trainer.py

@@ -18,7 +18,6 @@ from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 from transformers.trainer_utils import is_main_process
 
 
 import hivemind
 import hivemind
-from hivemind.utils.compression import CompressionType
 
 
 import utils
 import utils
 from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
@@ -262,7 +261,7 @@ def main():
         dht=dht,
         dht=dht,
         scheduler=scheduler,
         scheduler=scheduler,
         prefix=collaboration_args.experiment_prefix,
         prefix=collaboration_args.experiment_prefix,
-        compression_type=CompressionType.Value(collaboration_args.compression),
+        compression=hivemind.Float16Compression(),
         batch_size_per_step=total_batch_size_per_step,
         batch_size_per_step=total_batch_size_per_step,
         bandwidth=collaboration_args.bandwidth,
         bandwidth=collaboration_args.bandwidth,
         target_batch_size=adjusted_target_batch_size,
         target_batch_size=adjusted_target_batch_size,

+ 1 - 2
examples/albert/run_training_monitor.py

@@ -13,7 +13,6 @@ from torch_optimizer import Lamb
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 
 import hivemind
 import hivemind
-from hivemind.utils.compression import CompressionType
 
 
 import utils
 import utils
 from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
@@ -101,7 +100,7 @@ class CheckpointHandler:
             opt=opt,
             opt=opt,
             dht=dht,
             dht=dht,
             prefix=experiment_prefix,
             prefix=experiment_prefix,
-            compression_type=CompressionType.Value(collab_optimizer_args.compression),
+            compression_type=hivemind.Float16Compression(),
             bandwidth=collab_optimizer_args.bandwidth,
             bandwidth=collab_optimizer_args.bandwidth,
             target_batch_size=adjusted_target_batch_size,
             target_batch_size=adjusted_target_batch_size,
             client_mode=collab_optimizer_args.client_mode,
             client_mode=collab_optimizer_args.client_mode,

+ 1 - 0
hivemind/__init__.py

@@ -1,4 +1,5 @@
 from hivemind.averaging import DecentralizedAverager, TrainingAverager
 from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe import (
 from hivemind.moe import (
     ExpertBackend,
     ExpertBackend,

+ 1 - 1
hivemind/averaging/allreduce.py

@@ -5,11 +5,11 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 import torch
 import torch
 
 
 from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
 from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
 from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
 from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 
 # flavour types
 # flavour types
 GroupID = bytes
 GroupID = bytes

+ 35 - 15
hivemind/averaging/averager.py

@@ -9,7 +9,6 @@ import multiprocessing as mp
 import os
 import os
 import threading
 import threading
 import weakref
 import weakref
-from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
 from dataclasses import asdict
 from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 
 
@@ -21,12 +20,18 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
+from hivemind.compression import (
+    CompressionBase,
+    CompressionInfo,
+    NoCompression,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.proto import averaging_pb2, runtime_pb2
+from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
 from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -51,7 +56,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
     :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
     :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
-    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param compression: optionally compress tensors with this compression algorithm before running all-reduce
+    :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
+    :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
@@ -102,7 +109,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         averaging_alpha: float = 1.0,
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
         allreduce_timeout: Optional[float] = None,
-        compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+        compression: CompressionBase = NoCompression(),
+        state_compression: CompressionBase = NoCompression(),
+        tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         bandwidth: Optional[float] = None,
         bandwidth: Optional[float] = None,
         min_vector_size: int = 0,
         min_vector_size: int = 0,
         auxiliary: bool = False,
         auxiliary: bool = False,
@@ -158,7 +167,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             request_timeout=request_timeout,
             request_timeout=request_timeout,
         )
         )
         self.allreduce_kwargs = dict(
         self.allreduce_kwargs = dict(
-            compression_type=compression_type, part_size_bytes=part_size_bytes, min_vector_size=min_vector_size
+            compression=compression,
+            part_size_bytes=part_size_bytes,
+            min_vector_size=min_vector_size,
         )
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -169,6 +180,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if allow_state_sharing is None:
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
         self.allow_state_sharing = allow_state_sharing
+        self.state_compression = state_compression
+        self.tensor_infos = tensor_infos
 
 
         self._ready = MPFuture()
         self._ready = MPFuture()
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
@@ -363,7 +376,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
                     future.set_result(
                     future.set_result(
                         await asyncio.wait_for(
                         await asyncio.wait_for(
-                            self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout
+                            self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
+                            timeout=self._allreduce_timeout,
                         )
                         )
                     )
                     )
                     # averaging is finished, loop will now exit
                     # averaging is finished, loop will now exit
@@ -529,24 +543,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         """
         if not self.allow_state_sharing:
         if not self.allow_state_sharing:
             return  # deny request and direct peer to the next prospective averager
             return  # deny request and direct peer to the next prospective averager
-        metadata, tensors = await self._get_current_state_from_host_process()
+        metadata, tensors, infos = await self._get_current_state_from_host_process()
+        if infos is None:
+            infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)]
+        assert len(tensors) == len(infos)
 
 
-        for tensor in tensors:
-            for part in split_for_streaming(serialize_torch_tensor(tensor)):
+        for tensor, info in zip(tensors, infos):
+            for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)):
                 if metadata is not None:
                 if metadata is not None:
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     metadata = None
                     metadata = None
                 else:
                 else:
                     yield averaging_pb2.DownloadData(tensor_part=part)
                     yield averaging_pb2.DownloadData(tensor_part=part)
 
 
-    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
+    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]:
         """
         """
         Get current state and send it to a peer. executed in the host process. Meant to be overriden.
         Get current state and send it to a peer. executed in the host process. Meant to be overriden.
         :returns: a tuple of (small metadata, sequence of torch tensors)
         :returns: a tuple of (small metadata, sequence of torch tensors)
         :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
         :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
         """
         """
         with self.get_tensors() as tensors:
         with self.get_tensors() as tensors:
-            return dict(group_key=self.get_group_bits()), tensors
+            return dict(group_key=self.get_group_bits()), tensors, self.tensor_infos
 
 
     async def _get_current_state_from_host_process(self):
     async def _get_current_state_from_host_process(self):
         """Executed in the averager process inside rpc_download_state"""
         """Executed in the averager process inside rpc_download_state"""
@@ -618,7 +635,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
         finally:
         finally:
             if not future.done():
             if not future.done():
-                logger.warning("Averager could not load state from peers: none of the requests succeeded.")
                 future.set_result(None)
                 future.set_result(None)
 
 
     def get_group_bits(self, wait: bool = True):
     def get_group_bits(self, wait: bool = True):
@@ -681,7 +697,11 @@ def _background_thread_fetch_current_state(
             get_current_state = get_current_state_ref()
             get_current_state = get_current_state_ref()
             if get_current_state is None:
             if get_current_state is None:
                 break
                 break
-            state_metadata, state_tensors = get_current_state()
+            state = get_current_state()
+            assert 0 < len(state) <= 3
+            if len(state) != 3:
+                state = tuple(state + (None,) * (3 - len(state)))
+            state_metadata, state_tensors, tensor_infos = state
             del get_current_state
             del get_current_state
 
 
             state_metadata = serializer.dumps(state_metadata)
             state_metadata = serializer.dumps(state_metadata)
@@ -689,7 +709,7 @@ def _background_thread_fetch_current_state(
                 tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
                 tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
             )
             )
             # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
             # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
-            future.set_result((state_metadata, state_tensors))
+            future.set_result((state_metadata, state_tensors, tensor_infos))
         except BaseException as e:
         except BaseException as e:
             future.set_exception(e)
             future.set_exception(e)
             logger.warning(e)
             logger.warning(e)

+ 22 - 16
hivemind/averaging/partition.py

@@ -3,14 +3,14 @@ Auxiliary data structures for AllReduceRunner
 """
 """
 import asyncio
 import asyncio
 from collections import deque
 from collections import deque
-from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar, Union
+from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-from hivemind.proto.runtime_pb2 import CompressionType, Tensor
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
+from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor
 from hivemind.utils.asyncio import amap_in_executor
-from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 DEFAULT_PART_SIZE_BYTES = 2 ** 16
 DEFAULT_PART_SIZE_BYTES = 2 ** 16
@@ -22,8 +22,9 @@ class TensorPartContainer:
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
     :param tensors: local tensors to be split and aggregated
     :param tensors: local tensors to be split and aggregated
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
-    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param compression: optionally compress tensors with this compression algorithm before sending them to peers
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
+    :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     """
     """
 
 
@@ -31,16 +32,19 @@ class TensorPartContainer:
         self,
         self,
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
         peer_fractions: Sequence[float],
-        compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE,
+        compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
+        tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         prefetch: int = 5,
         prefetch: int = 5,
     ):
     ):
-        if not isinstance(compression_type, Sequence):
-            compression_type = [compression_type] * len(tensors)
-        assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
+        if tensor_infos is None:
+            tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
+        assert len(tensor_infos) == len(tensors), "compression types do not match the number of tensors"
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
-        self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
+        self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
         self.total_size = sum(tensor.numel() for tensor in tensors)
         self.total_size = sum(tensor.numel() for tensor in tensors)
+        self.prefetch = prefetch
+
         self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
         self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
@@ -56,11 +60,13 @@ class TensorPartContainer:
         pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
         pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
         pivots[-1] = self.total_size
         pivots[-1] = self.total_size
 
 
-        for tensor, tensor_compression in zip(self.local_tensors, compression_type):
-            part_size_values = int(part_size_bytes / get_nbytes_per_value(tensor.dtype, tensor_compression))
+        for tensor, info in zip(self.local_tensors, self.tensor_infos):
+            bytes_per_value = tensor.element_size() * compression.estimate_compression_ratio(info)
+            part_size_values = int(part_size_bytes / bytes_per_value)
             tensor_parts = tensor.detach().view(-1).split(part_size_values)
             tensor_parts = tensor.detach().view(-1).split(part_size_values)
             self.num_parts_by_tensor.append(len(tensor_parts))
             self.num_parts_by_tensor.append(len(tensor_parts))
-            for part in tensor_parts:
+            for part_index, part in enumerate(tensor_parts):
+                part_info = info.get_part(part_index, part_size_values)
                 if current_length + len(part) > pivots[current_peer_index]:
                 if current_length + len(part) > pivots[current_peer_index]:
                     # switch to next peer; if a part lands between parts of two or
                     # switch to next peer; if a part lands between parts of two or
                     # more peers, assign that part to the peer with highest intersection
                     # more peers, assign that part to the peer with highest intersection
@@ -71,9 +77,9 @@ class TensorPartContainer:
                         current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
                         current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
                         peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
                         peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
                     assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
                     assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
-                    self._input_parts_by_peer[assigned_peer_index].append((part, tensor_compression))
+                    self._input_parts_by_peer[assigned_peer_index].append((part, part_info))
                 else:
                 else:
-                    self._input_parts_by_peer[current_peer_index].append((part, tensor_compression))
+                    self._input_parts_by_peer[current_peer_index].append((part, part_info))
                 current_length += len(part)
                 current_length += len(part)
 
 
         assert current_length == self.total_size
         assert current_length == self.total_size
@@ -89,7 +95,7 @@ class TensorPartContainer:
         return input_parts
         return input_parts
 
 
     @torch.no_grad()
     @torch.no_grad()
-    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
+    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[runtime_pb2.Tensor]:
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
         self._inputs_consumed_by_peer[peer_index] = True
@@ -99,7 +105,7 @@ class TensorPartContainer:
                 yield self._input_parts_by_peer[peer_index].popleft()
                 yield self._input_parts_by_peer[peer_index].popleft()
 
 
         async for serialized_part in amap_in_executor(
         async for serialized_part in amap_in_executor(
-            lambda x_and_compr: serialize_torch_tensor(*x_and_compr), _aiterate_parts(), max_prefetch=self.prefetch
+            lambda x_and_info: self.compression.compress(*x_and_info), _aiterate_parts(), max_prefetch=self.prefetch
         ):
         ):
             yield serialized_part
             yield serialized_part
 
 

+ 47 - 12
hivemind/averaging/training.py

@@ -8,6 +8,7 @@ from typing import Dict, Iterator, Optional, Sequence
 import torch
 import torch
 
 
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging import DecentralizedAverager
+from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.utils import get_logger, nested_flatten, nested_pack
 from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -41,23 +42,28 @@ class TrainingAverager(DecentralizedAverager):
         average_gradients: bool,
         average_gradients: bool,
         average_opt_statistics: Sequence[str] = (),
         average_opt_statistics: Sequence[str] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
+        parameter_names: Optional[Sequence[str]] = None,
         initialize_optimizer: bool = True,
         initialize_optimizer: bool = True,
         **kwargs
         **kwargs
     ):
     ):
+        if initialize_optimizer:
+            initialize_optimizer_state(opt)  # note: this will run one optimizer step!
+        if parameter_names is None:
+            parameter_names = tuple(i for group in opt.param_groups for i in range(len(group["params"])))
 
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
         self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
+        self.parameter_names = parameter_names
         self.step_executor = ThreadPoolExecutor(max_workers=1)
         self.step_executor = ThreadPoolExecutor(max_workers=1)
         self.lock_averager_step = Lock()
         self.lock_averager_step = Lock()
         self.pending_updates_done = Event()
         self.pending_updates_done = Event()
         self.pending_updates_done.set()
         self.pending_updates_done.set()
-        if initialize_optimizer:
-            initialize_optimizer_state(opt)  # note: this will run one optimizer step!
 
 
         with torch.no_grad():
         with torch.no_grad():
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
-        super().__init__(averaged_tensors=averaged_tensors, **kwargs)
+
+        super().__init__(averaged_tensors=averaged_tensors, tensor_infos=list(self.tensor_infos()), **kwargs)
 
 
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
         """
         """
@@ -119,13 +125,8 @@ class TrainingAverager(DecentralizedAverager):
             self.local_step += 1
             self.local_step += 1
             return gathered
             return gathered
 
 
-    def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:
-        """
-        Iterate local trainer's tensors that should be averaged with peers
-
-        :param replace_none: if True and average_gradients is True, None grads will be replaced with a zero tensors
-          Otherwise, such gradients will be skipped. (this may cause inconsistencies with averaged_tensors)
-        """
+    def local_tensors(self) -> Iterator[torch.Tensor]:
+        """Iterate local trainer's tensors that should be averaged with peers"""
         if self.average_parameters:
         if self.average_parameters:
             for param_group in self.opt.param_groups:
             for param_group in self.opt.param_groups:
                 yield from param_group["params"]
                 yield from param_group["params"]
@@ -134,7 +135,7 @@ class TrainingAverager(DecentralizedAverager):
                 for param in param_group["params"]:
                 for param in param_group["params"]:
                     if param.grad is not None:
                     if param.grad is not None:
                         yield param.grad
                         yield param.grad
-                    elif replace_none:
+                    else:
                         yield torch.zeros_like(param)
                         yield torch.zeros_like(param)
         for stats in self.opt_statistics:
         for stats in self.opt_statistics:
             for param_group in self.opt.param_groups:
             for param_group in self.opt.param_groups:
@@ -142,6 +143,26 @@ class TrainingAverager(DecentralizedAverager):
                     yield self.opt.state[param][stats]
                     yield self.opt.state[param][stats]
         yield from iter(self.extra_tensors)
         yield from iter(self.extra_tensors)
 
 
+    def tensor_infos(self):
+        """Get CompressionInfo for each tensor, accounting for its role and specification"""
+        params = tuple(param for param_group in self.opt.param_groups for param in param_group["params"])
+        assert len(params) == len(self.parameter_names)
+        if self.average_parameters:
+            for param, key in zip(params, self.parameter_names):
+                yield CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+        if self.average_gradients:
+            for param, key in zip(params, self.parameter_names):
+                if param.grad is not None:
+                    grad = param.grad if param.grad is not None else torch.zeros_like(param)
+                    yield CompressionInfo.from_tensor(grad, key=key, role=TensorRole.GRADIENT)
+        for stats in self.opt_statistics:
+            for param, key in zip(params, self.parameter_names):
+                yield CompressionInfo.from_tensor(
+                    self.opt.state[param][stats], key=(key, stats), role=TensorRole.OPTIMIZER
+                )
+        for i, extra_tensor in enumerate(self.extra_tensors):
+            yield CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+
     def get_current_state(self):
     def get_current_state(self):
         """
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
@@ -151,11 +172,25 @@ class TrainingAverager(DecentralizedAverager):
             optimized_parameters = tuple(
             optimized_parameters = tuple(
                 param.detach().cpu() for param_group in self.opt.param_groups for param in param_group["params"]
                 param.detach().cpu() for param_group in self.opt.param_groups for param in param_group["params"]
             )
             )
+            parameter_infos = [
+                CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+                for param, key in zip(optimized_parameters, self.parameter_names)
+            ]
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            extra_infos = [
+                CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+                for i, extra_tensor in enumerate(extra_tensors)
+            ]
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
+            optimizer_infos = [
+                CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
+                for i, opt_tensor in enumerate(optimizer_tensors)
+            ]
 
 
         metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
         metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
-        return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
+        return metadata, all_tensors, all_tensor_infos
 
 
     def load_state_from_peers(self, **kwargs):
     def load_state_from_peers(self, **kwargs):
         """
         """

+ 52 - 0
hivemind/compression/__init__.py

@@ -0,0 +1,52 @@
+"""
+Compression strategies that reduce the network communication in .averaging, .optim and .moe
+"""
+
+import warnings
+from typing import Dict, Optional
+
+import torch
+
+from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
+from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
+from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
+from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.proto import runtime_pb2
+
+warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
+
+
+BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+    NONE=NoCompression(),
+    FLOAT16=Float16Compression(),
+    MEANSTD_16BIT=ScaledFloat16Compression(),
+    QUANTILE_8BIT=Quantile8BitQuantization(),
+    UNIFORM_8BIT=Uniform8BitQuantization(),
+)
+
+for key in runtime_pb2.CompressionType.keys():
+    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer."
+    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+    assert (
+        runtime_pb2.CompressionType.Name(actual_compression_type) == key
+    ), f"Compression strategy for {key} has inconsistent type"
+
+
+def serialize_torch_tensor(
+    tensor: torch.Tensor,
+    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+    info: Optional[CompressionInfo] = None,
+    allow_inplace: bool = False,
+    **kwargs,
+) -> runtime_pb2.Tensor:
+    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
+    assert tensor.device == torch.device("cpu")
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
+    return compression.compress(tensor, info, allow_inplace)
+
+
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    """Restore a pytorch tensor from a protobuf message"""
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)

+ 67 - 0
hivemind/compression/adaptive.py

@@ -0,0 +1,67 @@
+from abc import ABC, abstractmethod
+from typing import Mapping, Sequence, Union
+
+import torch
+
+import hivemind
+from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole
+from hivemind.proto import runtime_pb2
+
+
+class AdaptiveCompressionBase(CompressionBase, ABC):
+    @abstractmethod
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        ...
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return self.choose_compression(info).estimate_compression_ratio(info)
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace)
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        return hivemind.compression.deserialize_torch_tensor(serialized_tensor)
+
+
+class SizeAdaptiveCompression(AdaptiveCompressionBase):
+    """Apply compression strategy 1 if tensor has more than :threshold: elements and strategy 2 otherwise"""
+
+    def __init__(self, threshold: int, less: CompressionBase, greater_equal: CompressionBase):
+        self.threshold, self.less, self.greater_equal = threshold, less, greater_equal
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.greater_equal if info.descriptor.numel() >= self.threshold else self.less
+
+
+class RoleAdaptiveCompression(AdaptiveCompressionBase):
+    """Compress a tensor based on its role in training. Any non-specified compressions will use the "default" option"""
+
+    def __init__(
+        self,
+        *,
+        activation: CompressionBase = None,
+        parameter: CompressionBase = None,
+        gradient: CompressionBase = None,
+        optimizer: CompressionBase = None,
+        default: CompressionBase = NoCompression()
+    ):
+        self.role_compressions = {
+            TensorRole.ACTIVATION: activation or default,
+            TensorRole.PARAMETER: parameter or default,
+            TensorRole.GRADIENT: gradient or default,
+            TensorRole.OPTIMIZER: optimizer or default,
+            TensorRole.UNSPECIFIED: default,
+        }
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.role_compressions[info.role]
+
+
+class PerTensorCompression(AdaptiveCompressionBase):
+    """Manually specify the compression strategy depending on tensor key"""
+
+    def __init__(self, tensor_compressions: Union[Sequence[CompressionBase], Mapping[Key, CompressionBase]]):
+        self.tensor_compressions = tensor_compressions
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.tensor_compressions[info.key]

+ 89 - 0
hivemind/compression/base.py

@@ -0,0 +1,89 @@
+import dataclasses
+from abc import ABC, abstractmethod
+from enum import Enum, auto
+from typing import Any, Optional
+
+import numpy as np
+import torch
+
+from hivemind.proto import runtime_pb2
+from hivemind.utils.tensor_descr import TensorDescriptor
+
+Key = Any
+
+
+class TensorRole(Enum):
+    ACTIVATION = auto()
+    PARAMETER = auto()
+    GRADIENT = auto()
+    OPTIMIZER = auto()
+    UNSPECIFIED = auto()
+
+
+@dataclasses.dataclass(frozen=True)
+class CompressionInfo:
+    """Auxiliary data structure that contains information about the tensor that determines how it is compressed"""
+
+    key: Key  # name or index of the tensor from named parameters, optimizer state dict or i/o structure
+    descriptor: TensorDescriptor  # data structure that defines shape, dtype, layout and device information
+    role: TensorRole = TensorRole.UNSPECIFIED  # which role does the tensor play with respect to the model
+    part_index: int = 0  # if tensor is sliced into parts, this represents the index within one tensor
+    part_size: Optional[int] = None  # if tensor is sliced into parts, this is the _maximum_ number of values per part
+
+    @classmethod
+    def from_tensor(cls, tensor: torch.Tensor, key: Key = None, descriptor: TensorDescriptor = None, **kwargs):
+        return cls(key, descriptor or TensorDescriptor.from_tensor(tensor), **kwargs)
+
+    def get_part(self, part_index: int, part_size: Optional[int]):
+        return CompressionInfo(self.key, self.descriptor, self.role, part_index=part_index, part_size=part_size)
+
+
+class CompressionBase(ABC):
+    """A base class that applies compression algorithm to a pytorch tensor"""
+
+    compression_type: runtime_pb2.CompressionType
+
+    @abstractmethod
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        """
+        Applies compression algorithm to a tensor based on their meta-parameters
+
+        :param tensor: a pytorch tensor to compress; depending on the applicaiton, it is a full tensor or a part
+        :param info: meta-information about the tensor; if partitioning is used, this still describes the full tensor
+        :param allow_inplace: if True, compression can (but doesn't have to) to modify tensor in-place for efficiency
+        :returns: a protobuf message that encodes the tensor
+        """
+        ...
+
+    @abstractmethod
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        """Create a pytorch tensor from the serialized outputs of .compress"""
+        ...
+
+    @abstractmethod
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        """Estimate the compression ratio without doing the actual compression; lower ratio = better compression"""
+        ...
+
+
+class NoCompression(CompressionBase):
+    """A dummy compression strategy that preserves the original tensor as is."""
+
+    compression_type = runtime_pb2.CompressionType.NONE
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        array = tensor.numpy()
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=array.tobytes(),
+            size=array.shape,
+            dtype=array.dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
+        return torch.as_tensor(array).reshape(tuple(serialized_tensor.size))
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return 1.0

+ 92 - 0
hivemind/compression/floating.py

@@ -0,0 +1,92 @@
+import math
+
+import numpy as np
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo
+from hivemind.proto import runtime_pb2
+
+
+class Float16Compression(CompressionBase):
+    compression_type = runtime_pb2.CompressionType.FLOAT16
+    FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        dtype_name = tensor.numpy().dtype.name
+        tensor = tensor.detach().cpu().float()
+        tensor = tensor if allow_inplace else tensor.clone()
+        tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=tensor.numpy().tobytes(),
+            size=tensor.shape,
+            dtype=dtype_name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        original_dtype = np.dtype(serialized_tensor.dtype)
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
+        return torch.as_tensor(np.asarray(array, dtype=original_dtype)).reshape(tuple(serialized_tensor.size))
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return 16.0 / get_num_bits(info.descriptor.dtype)
+
+
+class ScaledFloat16Compression(Float16Compression):
+    """A compression strategy that applies mean-std scaling over last axis before casting to float16"""
+
+    compression_type = runtime_pb2.CompressionType.MEANSTD_16BIT
+    FP32_BYTES = torch.finfo(torch.float32).bits // 8
+    FP32_EPS = torch.finfo(torch.float32).eps
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        dtype_name = tensor.numpy().dtype.name
+        tensor = tensor.detach().cpu().float()
+        tensor = tensor if allow_inplace else tensor.clone()
+        means = torch.mean(tensor, dim=-1, keepdim=True)
+        tensor.sub_(means)
+        stds = tensor.norm(dim=-1, keepdim=True) / math.sqrt(tensor.shape[-1])
+        stds.clamp_min_(self.FP32_EPS)
+        tensor.div_(stds)
+        tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
+
+        data = b"".join((tensor.numpy().tobytes(), means.float().numpy().tobytes(), stds.float().numpy().tobytes()))
+
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype=dtype_name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        stats_shape = list(serialized_tensor.size)
+        stats_shape[-1] = 1
+        stats_count = np.prod(stats_shape)
+        means_offset = len(serialized_tensor.buffer) - 2 * stats_count * self.FP32_BYTES
+        stds_offset = len(serialized_tensor.buffer) - stats_count * self.FP32_BYTES
+
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16, count=np.prod(serialized_tensor.size))
+        means = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=means_offset, count=stats_count)
+        stds = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=stds_offset, count=stats_count)
+
+        means = torch.as_tensor(means).reshape(stats_shape)
+        stds = torch.as_tensor(stds).reshape(stats_shape)
+        tensor = torch.as_tensor(np.asarray(array, dtype=serialized_tensor.dtype)).reshape(
+            list(serialized_tensor.size)
+        )
+        return tensor.mul_(stds).add_(means)
+
+
+def get_num_bits(dtype: torch.dtype) -> int:
+    if dtype == torch.bool:
+        return 8  # see https://github.com/pytorch/pytorch/issues/41571
+    elif dtype.is_floating_point:
+        return torch.finfo(dtype).bits
+    else:
+        try:
+            return torch.iinfo(dtype).bits
+        except TypeError:
+            raise TypeError(f"Could not infer size for tensor type {dtype}")

+ 114 - 0
hivemind/compression/quantization.py

@@ -0,0 +1,114 @@
+import math
+import os
+from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Tuple
+
+import numpy as np
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo
+from hivemind.proto import runtime_pb2
+
+EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
+
+
+class Quantization(CompressionBase, ABC):
+    codebook_dtype, indices_dtype = np.float32, np.uint8
+
+    @abstractmethod
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        """Convert tensor into a pair of (indices, codebook)"""
+        ...
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
+            size=tensor.shape,
+            dtype=tensor.numpy().dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        codebook_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
+        codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
+        quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
+        quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
+        codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
+        return codebook[quantized]
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return self.n_bits / torch.finfo(info.descriptor.dtype).bits
+
+    @property
+    def n_bits(self):
+        return self.indices_dtype(1).itemsize * 8
+
+    @property
+    def n_bins(self):
+        return 2 ** self.n_bits
+
+
+class Uniform8BitQuantization(Quantization):
+    RANGE_IN_SIGMAS: int = 6
+    compression_type = runtime_pb2.UNIFORM_8BIT
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        offset = self.n_bins // 2
+        shift = tensor.mean()
+        centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
+        std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
+        scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
+        quantized = torch.quantize_per_tensor(centered_tensor, scale, offset, torch.quint8).int_repr()
+        lookup = average_buckets(tensor, quantized, self.n_bins)
+        return np.asarray(quantized, dtype=self.indices_dtype), np.asarray(lookup, dtype=self.codebook_dtype)
+
+
+class Quantile8BitQuantization(Quantization):
+    compression_type = runtime_pb2.QUANTILE_8BIT
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        tensor = tensor.detach().float()
+        borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), self.n_bins + 1)[1:-1])
+        quantized = torch.clamp_(torch.bucketize(tensor, borders), 0, self.n_bins - 1)
+        codebook = average_buckets(tensor, quantized, self.n_bins)
+        return quantized.numpy().astype(np.uint8), codebook.numpy()
+
+
+def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
+    """Return the average value in each bucket"""
+    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
+    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
+    lookup = bin_sums / bin_counts
+    return lookup
+
+
+def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
+    """Adjust chunk_size to minimize imbalance between chunk sizes"""
+    if min_chunk_size >= num_elements:
+        return min_chunk_size
+    leftover_elements = num_elements % min_chunk_size
+    num_chunks = num_elements // min_chunk_size
+    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
+
+
+def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
+    """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
+    if not array.data.c_contiguous and array.data.f_contiguous:
+        array = array.T
+    array = np.ascontiguousarray(array.reshape(-1))
+    quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
+    chunk_size = get_chunk_size(len(array), min_chunk_size)
+    num_chunks = (len(array) - 1) // chunk_size + 1
+    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
+
+    jobs = []
+    for i in range(num_chunks):
+        chunk = slice(chunk_size * i, chunk_size * (i + 1))
+        jobs.append(EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
+
+    for job in jobs:
+        job.result()
+    return np.quantile(partition_quantiles, quantiles)

+ 0 - 1
hivemind/dht/__init__.py

@@ -17,7 +17,6 @@ from __future__ import annotations
 import asyncio
 import asyncio
 import multiprocessing as mp
 import multiprocessing as mp
 import os
 import os
-from concurrent.futures import ThreadPoolExecutor
 from functools import partial
 from functools import partial
 from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
 from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
 
 

+ 1 - 1
hivemind/moe/client/expert.py

@@ -5,9 +5,9 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
 from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import ChannelCache
 from hivemind.utils.grpc import ChannelCache
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert

+ 1 - 1
hivemind/moe/client/moe.py

@@ -10,12 +10,12 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
 import hivemind
 import hivemind
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_flatten, nested_map, nested_pack
 from hivemind.utils import nested_flatten, nested_map, nested_pack
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 1 - 1
hivemind/moe/server/connection_handler.py

@@ -6,11 +6,11 @@ from typing import Dict
 import grpc
 import grpc
 import torch
 import torch
 
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import Endpoint, get_logger, nested_flatten
 from hivemind.utils import Endpoint, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 0 - 1
hivemind/utils/__init__.py

@@ -1,5 +1,4 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.asyncio import *
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger

+ 0 - 209
hivemind/utils/compression.py

@@ -1,209 +0,0 @@
-import os
-import warnings
-from concurrent.futures import ThreadPoolExecutor
-from typing import Optional, Sequence, Tuple
-
-import numpy as np
-import torch
-
-from hivemind.proto import runtime_pb2
-from hivemind.proto.runtime_pb2 import CompressionType
-
-FP32_EPS = 1e-06
-NUM_BYTES_FLOAT32 = 4
-NUM_BYTES_FLOAT16 = 2
-NUM_BITS_QUANTILE_COMPRESSION = 8
-NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
-UNIFORM_BUCKETS_STD_RANGE = 6
-FP16_MAX = 65_504
-UINT8_RANGE = 256
-
-COMPRESSION_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTILE_COMPRESSION_THREADS", 128)))
-
-warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
-
-
-def _quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
-    n_bins = 2 ** n_bits
-    borders = torch.as_tensor(_quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
-    quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
-    lookup = average_buckets(tensor, quant_weight, n_bins)
-    return quant_weight, lookup
-
-
-def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
-    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
-    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
-    lookup = bin_sums / bin_counts
-    return lookup
-
-
-def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
-    """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
-    if not array.data.c_contiguous and array.data.f_contiguous:
-        array = array.T
-    array = np.ascontiguousarray(array.reshape(-1))
-    quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
-    chunk_size = _get_chunk_size(len(array), min_chunk_size)
-    num_chunks = (len(array) - 1) // chunk_size + 1
-    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
-
-    jobs = []
-    for i in range(num_chunks):
-        chunk = slice(chunk_size * i, chunk_size * (i + 1))
-        jobs.append(COMPRESSION_EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
-
-    for job in jobs:
-        job.result()
-    return np.quantile(partition_quantiles, quantiles)
-
-
-def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
-    """Adjust chunk_size to minimize imbalance between chunk sizes"""
-    if min_chunk_size >= num_elements:
-        return min_chunk_size
-    leftover_elements = num_elements % min_chunk_size
-    num_chunks = num_elements // min_chunk_size
-    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
-
-
-def _uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
-    offset = UINT8_RANGE // 2
-    shift = tensor.mean()
-    scale = range_in_sigmas * tensor.std() / UINT8_RANGE
-
-    quant_weight = torch.quantize_per_tensor(tensor - shift, scale, offset, torch.quint8).int_repr()
-    lookup = average_buckets(tensor, quant_weight, UINT8_RANGE)
-    return quant_weight, lookup
-
-
-def serialize_torch_tensor(
-    tensor: torch.Tensor, compression_type=CompressionType.NONE, allow_inplace=False
-) -> runtime_pb2.Tensor:
-    assert tensor.device == torch.device("cpu")
-    if compression_type == CompressionType.MEANSTD_16BIT:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        means = torch.mean(tensor, dim=-1, keepdim=True)
-        tensor.sub_(means)
-
-        stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
-        stds.clamp_min_(FP32_EPS)
-        tensor.div_(stds)
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = b"".join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="compressed_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type == CompressionType.FLOAT16:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = tensor.numpy().tobytes()
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="clamped_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type == CompressionType.NONE:
-        array = tensor.numpy()
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=array.tobytes(),
-            size=array.shape,
-            dtype=array.dtype.name,
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        assert tensor.dtype == torch.float32
-
-        if compression_type == CompressionType.QUANTILE_8BIT:
-            quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
-        elif compression_type == CompressionType.UNIFORM_8BIT:
-            quantized, lookup = _uint8_uniform_buckets_encode(tensor.detach(), UNIFORM_BUCKETS_STD_RANGE)
-        data = b"".join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="compressed_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    else:
-        raise ValueError(f"Unknown compression type: {compression_type}")
-
-    return proto
-
-
-def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype] = None):
-    """Helper conversion function that handles edge case with scalar deserialization"""
-    if size:
-        return torch.as_tensor(array, dtype=dtype).view(*size)
-    else:
-        return torch.as_tensor(array, dtype=dtype)
-
-
-def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-    if serialized_tensor.compression == CompressionType.NONE:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
-        tensor = construct_torch_tensor(array, serialized_tensor.size)
-
-    elif serialized_tensor.compression == CompressionType.MEANSTD_16BIT:
-        stats_size = list(serialized_tensor.size)
-        stats_size[-1] = 1
-        stats_count = np.prod(stats_size)
-
-        means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count : -NUM_BYTES_FLOAT32 * stats_count]
-        stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count :]
-        means = construct_torch_tensor(np.frombuffer(means, dtype=np.float32), stats_size)
-        stds = construct_torch_tensor(np.frombuffer(stds, dtype=np.float32), stats_size)
-
-        array = np.frombuffer(serialized_tensor.buffer[: -8 * stats_count], dtype=np.float16)
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
-
-    elif serialized_tensor.compression == CompressionType.FLOAT16:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
-
-    elif serialized_tensor.compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        if serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
-            lookup_size = NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32
-        else:
-            lookup_size = UINT8_RANGE * NUM_BYTES_FLOAT32
-        lookup = serialized_tensor.buffer[:lookup_size]
-        quantized = serialized_tensor.buffer[lookup_size:]
-        lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32))
-        quantized = np.frombuffer(quantized, dtype=np.uint8)
-        quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
-        tensor = lookup[quantized]
-
-    else:
-        raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
-
-    tensor.requires_grad_(serialized_tensor.requires_grad)
-    return tensor
-
-
-def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
-    """returns the number of bytes per value for a given tensor (excluding metadata)"""
-    if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        return 1
-    elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):
-        return 2
-    elif compression == CompressionType.NONE:
-        return torch.finfo(dtype).bits // 8
-    else:
-        raise NotImplementedError(f"Unknown compression type: {CompressionType.Name(compression)}")

+ 11 - 4
hivemind/utils/tensor_descr.py

@@ -1,6 +1,10 @@
+from __future__ import annotations
+
 import warnings
 import warnings
 from dataclasses import asdict, dataclass
 from dataclasses import asdict, dataclass
+from typing import Tuple
 
 
+import numpy as np
 import torch
 import torch
 
 
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
@@ -29,11 +33,14 @@ class TensorDescriptor(DescriptorBase):
     compression: CompressionType = CompressionType.NONE
     compression: CompressionType = CompressionType.NONE
 
 
     @property
     @property
-    def shape(self):
+    def shape(self) -> Tuple[int, ...]:
         return self.size
         return self.size
 
 
+    def numel(self) -> int:
+        return int(np.prod(self.size))
+
     @classmethod
     @classmethod
-    def from_tensor(cls, tensor: torch.Tensor):
+    def from_tensor(cls, tensor: torch.Tensor) -> TensorDescriptor:
         return cls(
         return cls(
             tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
             tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
         )
         )
@@ -55,7 +62,7 @@ class BatchTensorDescriptor(TensorDescriptor):
         super().__init__((None, *instance_size), **kwargs)
         super().__init__((None, *instance_size), **kwargs)
 
 
     @classmethod
     @classmethod
-    def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
+    def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE) -> BatchTensorDescriptor:
         return cls(
         return cls(
             *tensor.shape[1:],
             *tensor.shape[1:],
             dtype=tensor.dtype,
             dtype=tensor.dtype,
@@ -66,7 +73,7 @@ class BatchTensorDescriptor(TensorDescriptor):
             compression=compression if tensor.is_floating_point() else CompressionType.NONE
             compression=compression if tensor.is_floating_point() else CompressionType.NONE
         )
         )
 
 
-    def make_empty(self, *batch_size, **kwargs):
+    def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
         return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
 
 

+ 3 - 3
tests/test_allreduce.py

@@ -6,12 +6,12 @@ from typing import Sequence
 import pytest
 import pytest
 import torch
 import torch
 
 
-from hivemind import aenumerate
+from hivemind import Quantile8BitQuantization, aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
+from hivemind.compression import deserialize_torch_tensor
 from hivemind.p2p import P2P, StubBase
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import deserialize_torch_tensor
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -83,7 +83,7 @@ async def test_partitioning_asynchronous():
     tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096), torch.randn(4096, 1024), torch.randn(30_000, 1024)]
     tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096), torch.randn(4096, 1024), torch.randn(30_000, 1024)]
     peer_fractions = [0.4, 0.3, 0.2, 0.1]
     peer_fractions = [0.4, 0.3, 0.2, 0.1]
 
 
-    partition = TensorPartContainer(tensors, peer_fractions, compression_type=CompressionType.QUANTILE_8BIT)
+    partition = TensorPartContainer(tensors, peer_fractions, compression=Quantile8BitQuantization())
     read_started, read_finished = asyncio.Event(), asyncio.Event()
     read_started, read_finished = asyncio.Event(), asyncio.Event()
 
 
     async def write_tensors():
     async def write_tensors():

+ 0 - 57
tests/test_averaging.py

@@ -1,5 +1,4 @@
 import random
 import random
-import time
 
 
 import numpy as np
 import numpy as np
 import pytest
 import pytest
@@ -12,7 +11,6 @@ from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.partition import AllreduceException
 from hivemind.averaging.partition import AllreduceException
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
-from hivemind.proto.runtime_pb2 import CompressionType
 
 
 from test_utils.dht_swarms import launch_dht_instances
 from test_utils.dht_swarms import launch_dht_instances
 
 
@@ -169,61 +167,6 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
         process.shutdown()
         process.shutdown()
 
 
 
 
-@pytest.mark.forked
-def test_allreduce_compression():
-    """this test ensures that compression works correctly when multiple tensors have different compression types"""
-
-    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
-    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
-    results = {}
-
-    FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
-
-    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
-        dht_instances = launch_dht_instances(2)
-        averager1 = hivemind.averaging.DecentralizedAverager(
-            [x.clone() for x in tensors1],
-            dht=dht_instances[0],
-            compression_type=compression_type_pair,
-            client_mode=True,
-            target_group_size=2,
-            prefix="mygroup",
-            start=True,
-        )
-        averager2 = hivemind.averaging.DecentralizedAverager(
-            [x.clone() for x in tensors2],
-            dht=dht_instances[1],
-            compression_type=compression_type_pair,
-            target_group_size=2,
-            prefix="mygroup",
-            start=True,
-        )
-
-        for future in averager1.step(wait=False), averager2.step(wait=False):
-            future.result()
-
-        with averager1.get_tensors() as averaged_tensors:
-            results[compression_type_pair] = averaged_tensors
-
-        for instance in [averager1, averager2] + dht_instances:
-            instance.shutdown()
-
-    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
-    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
-    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
-    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
-
-    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
-    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
-    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
-    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
-
-    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
-    for i in range(2):
-        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
-        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
-
-
 def compute_mean_std(averagers, unbiased=True):
 def compute_mean_std(averagers, unbiased=True):
     results = []
     results = []
     for averager in averagers:
     for averager in averagers:

+ 213 - 0
tests/test_compression.py

@@ -0,0 +1,213 @@
+import multiprocessing as mp
+from ctypes import c_int32
+
+import pytest
+import torch
+import torch.nn as nn
+
+import hivemind
+from hivemind.compression import (
+    CompressionBase,
+    CompressionInfo,
+    Float16Compression,
+    NoCompression,
+    PerTensorCompression,
+    RoleAdaptiveCompression,
+    SizeAdaptiveCompression,
+    Uniform8BitQuantization,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)
+from hivemind.compression.adaptive import AdaptiveCompressionBase
+from hivemind.proto.runtime_pb2 import CompressionType
+
+from test_utils.dht_swarms import launch_dht_instances
+
+
+@pytest.mark.forked
+def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
+    torch.manual_seed(0)
+    X = torch.randn(*size)
+    assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
+    assert error.square().mean() < alpha
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
+    assert error.square().mean() < alpha
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
+    assert error.square().mean() < beta
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
+    assert error.square().mean() < beta
+
+    zeros = torch.zeros(5, 5)
+    for compression_type in CompressionType.values():
+        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
+
+
+@pytest.mark.forked
+def test_serialize_tensor():
+    def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
+        serialized_tensor = serialize_torch_tensor(tensor, compression)
+        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+        restored = hivemind.combine_from_streaming(chunks)
+        assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
+
+    tensor = torch.randn(512, 12288)
+    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
+        _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
+
+    _check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
+    _check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
+    _check(torch.tensor(1.0), CompressionType.NONE)
+    _check(torch.tensor(1.0), CompressionType.FLOAT16)
+
+
+@pytest.mark.forked
+def test_allreduce_compression():
+    """this test ensures that compression works correctly when multiple tensors have different compression types"""
+
+    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
+    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
+    results = {}
+
+    FLOAT16, UINT8 = Float16Compression(), Uniform8BitQuantization()
+
+    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        dht_instances = launch_dht_instances(2)
+        averager1 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors1],
+            dht=dht_instances[0],
+            compression=PerTensorCompression(compression_type_pair),
+            client_mode=True,
+            target_group_size=2,
+            prefix="mygroup",
+            start=True,
+        )
+        averager2 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors2],
+            dht=dht_instances[1],
+            compression=PerTensorCompression(compression_type_pair),
+            target_group_size=2,
+            prefix="mygroup",
+            start=True,
+        )
+
+        for future in averager1.step(wait=False), averager2.step(wait=False):
+            future.result()
+
+        with averager1.get_tensors() as averaged_tensors:
+            results[compression_type_pair] = averaged_tensors
+
+        for instance in [averager1, averager2] + dht_instances:
+            instance.shutdown()
+
+    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
+    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
+    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
+    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
+
+    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
+    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
+    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
+    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
+
+    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
+    for i in range(2):
+        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
+        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
+
+
+class TrackedCompression(AdaptiveCompressionBase):
+    def __init__(self, compression: CompressionBase):
+        self.compression = compression
+        self.mp_counter, self.mp_part_size = mp.Value(c_int32, 0), mp.Value(c_int32, 0)
+        super().__init__()
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.compression
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False):
+        self.mp_counter.value += 1
+        if info.part_size is not None:
+            self.mp_part_size.value = max(self.mp_part_size.value, info.part_size)
+        return self.compression.compress(tensor, info=info, allow_inplace=allow_inplace)
+
+
+def make_params():
+    return [
+        nn.Parameter(x)
+        for x in (
+            torch.randn([]),
+            torch.randn(1),
+            torch.randn(100),
+            torch.randn(1_000),
+            torch.randn(5_000),
+            torch.randn(10_000),
+        )
+    ]
+
+
+@pytest.mark.forked
+def test_adaptive_compression():
+    UINT8 = TrackedCompression(Uniform8BitQuantization())
+    FLOAT16 = TrackedCompression(Float16Compression())
+    FLOAT32 = TrackedCompression(NoCompression())
+    STATE_FP16 = TrackedCompression(Float16Compression())
+    STATE_FP32 = TrackedCompression(NoCompression())
+
+    averaging_compression_adaptive = RoleAdaptiveCompression(
+        parameter=FLOAT16,
+        gradient=SizeAdaptiveCompression(threshold=1_000, less=FLOAT16, greater_equal=UINT8),
+        optimizer=FLOAT32,
+        default=FLOAT32,
+    )
+
+    state_compression_adaptive = SizeAdaptiveCompression(
+        threshold=500,
+        less=STATE_FP32,
+        greater_equal=STATE_FP16,
+    )
+
+    averager1 = hivemind.TrainingAverager(
+        opt=torch.optim.Adam(make_params()),
+        average_parameters=True,
+        average_gradients=True,
+        average_opt_statistics=("exp_avg",),
+        compression=averaging_compression_adaptive,
+        state_compression=state_compression_adaptive,
+        prefix="test_avgr",
+        target_group_size=2,
+        part_size_bytes=5_000,
+        start=True,
+        dht=hivemind.DHT(start=True),
+    )
+
+    averager2 = hivemind.TrainingAverager(
+        opt=torch.optim.Adam(make_params()),
+        average_parameters=True,
+        average_gradients=True,
+        average_opt_statistics=("exp_avg",),
+        compression=averaging_compression_adaptive,
+        state_compression=state_compression_adaptive,
+        prefix="test_avgr",
+        target_group_size=2,
+        part_size_bytes=5_000,
+        start=True,
+        dht=hivemind.DHT(initial_peers=averager1.dht.get_visible_maddrs(), start=True),
+    )
+
+    futures = [averager1.step(wait=False), averager2.step(wait=False)]
+
+    for future in futures:
+        future.result()
+
+    assert UINT8.mp_counter.value == 4  # half gradients: 3 tensors, 1 is split
+    assert UINT8.mp_part_size.value == 5_000  # single byte tensors
+    assert FLOAT16.mp_counter.value == 13  # parameters and half gradients
+    assert FLOAT16.mp_part_size.value == 2_500  # two-byte tensors
+    assert FLOAT32.mp_counter.value == 16  # statistics
+    assert FLOAT32.mp_part_size.value == 1250  # four-byte tensors
+
+    averager1.load_state_from_peers()
+    assert STATE_FP16.mp_counter.value == STATE_FP32.mp_counter.value == 9
+    assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0  # not partitioned

+ 1 - 51
tests/test_util_modules.py

@@ -9,6 +9,7 @@ import pytest
 import torch
 import torch
 
 
 import hivemind
 import hivemind
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
@@ -25,7 +26,6 @@ from hivemind.utils.asyncio import (
     azip,
     azip,
     cancel_and_wait,
     cancel_and_wait,
 )
 )
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 from hivemind.utils.mpfuture import InvalidStateError
 
 
 
 
@@ -323,24 +323,6 @@ def test_many_futures():
     p.join()
     p.join()
 
 
 
 
-def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
-    torch.manual_seed(0)
-    X = torch.randn(*size)
-    assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
-    assert error.square().mean() < alpha
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
-    assert error.square().mean() < alpha
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
-    assert error.square().mean() < beta
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
-    assert error.square().mean() < beta
-
-    zeros = torch.zeros(5, 5)
-    for compression_type in CompressionType.values():
-        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
-
-
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_channel_cache():
 async def test_channel_cache():
@@ -385,38 +367,6 @@ async def test_channel_cache():
             assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
             assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
 
 
 
 
-def test_serialize_tensor():
-    tensor = torch.randn(512, 12288)
-
-    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
-    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
-        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
-        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-        restored = hivemind.combine_from_streaming(chunks)
-        assert torch.allclose(deserialize_torch_tensor(restored), tensor)
-
-    chunk_size = 30 * 1024
-    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16)
-    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
-    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-    restored = hivemind.combine_from_streaming(chunks)
-    assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
-
-    tensor = torch.randint(0, 100, (512, 1, 1))
-    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
-    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
-    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-    restored = hivemind.combine_from_streaming(chunks)
-    assert torch.allclose(deserialize_torch_tensor(restored), tensor)
-
-    scalar = torch.tensor(1.0)
-    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE)
-    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
-
-    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16)
-    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
-
-
 def test_serialize_tuple():
 def test_serialize_tuple():
     test_pairs = (
     test_pairs = (
         ((1, 2, 3), [1, 2, 3]),
         ((1, 2, 3), [1, 2, 3]),