瀏覽代碼

Merge branch 'master' of github.com:learning-at-home/hivemind into server-p2p

Denis Mazur 4 年之前
父節點
當前提交
6328ffe814

+ 1 - 1
benchmarks/benchmark_averaging.py

@@ -80,7 +80,7 @@ def benchmark_averaging(
             with lock_stats:
                 successful_steps += int(success)
                 total_steps += 1
-            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
+            logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}")
         logger.info(f"Averager {index}: done.")
 
     threads = []

+ 1 - 1
benchmarks/benchmark_tensor_compression.py

@@ -3,8 +3,8 @@ import time
 
 import torch
 
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 0 - 3
examples/albert/arguments.py

@@ -93,9 +93,6 @@ class CollaborativeOptimizerArguments:
         default=100.0,
         metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
     )
-    compression: str = field(
-        default="FLOAT16", metadata={"help": "Use this compression when averaging parameters/gradients"}
-    )
 
 
 @dataclass

+ 12 - 24
examples/albert/run_trainer.py

@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 
-import logging
 import os
 import pickle
 from dataclasses import asdict
@@ -18,33 +17,22 @@ from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 
 import hivemind
-from hivemind.utils.compression import CompressionType
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
 from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 
-logger = logging.getLogger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
 
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 
-def setup_logging(training_args):
-    logging.basicConfig(
-        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
-        datefmt="%m/%d/%Y %H:%M:%S",
-        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
-    )
 
-    # Log on each process the small summary:
-    logger.warning(
-        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
-        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
-    )
-    # Set the verbosity to info of the Transformers logger (on main process only):
-    if is_main_process(training_args.local_rank):
+def setup_transformers_logging(process_rank: int):
+    if is_main_process(process_rank):
         transformers.utils.logging.set_verbosity_info()
-        transformers.utils.logging.enable_default_handler()
-        transformers.utils.logging.enable_explicit_format()
-    logger.info("Training/evaluation parameters %s", training_args)
+        transformers.utils.logging.disable_default_handler()
+        transformers.utils.logging.enable_propagation()
 
 
 def get_model(training_args, config, tokenizer):
@@ -150,7 +138,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
                     loss=self.loss,
                     mini_steps=self.steps,
                 )
-                logger.info(f"Step {self.collaborative_optimizer.local_step}")
+                logger.info(f"Step #{self.collaborative_optimizer.local_step}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 logger.info(f"Performance: {samples_per_second} samples per second.")
                 if self.steps:
@@ -204,7 +192,6 @@ class NoOpScheduler(LRSchedulerBase):
             return self.optimizer.scheduler.print_lr(*args, **kwargs)
 
     def step(self):
-        logger.debug("Called NoOpScheduler.step")
         self._last_lr = self.get_lr()
 
     def state_dict(self):
@@ -222,7 +209,8 @@ def main():
     if len(collaboration_args.initial_peers) == 0:
         raise ValueError("Please specify at least one network endpoint in initial peers.")
 
-    setup_logging(training_args)
+    setup_transformers_logging(training_args.local_rank)
+    logger.info(f"Training/evaluation parameters:\n{training_args}")
 
     # Set seed before initializing model.
     set_seed(training_args.seed)
@@ -263,7 +251,7 @@ def main():
         dht=dht,
         scheduler=scheduler,
         prefix=collaboration_args.experiment_prefix,
-        compression_type=CompressionType.Value(collaboration_args.compression),
+        compression=hivemind.Float16Compression(),
         batch_size_per_step=total_batch_size_per_step,
         bandwidth=collaboration_args.bandwidth,
         target_batch_size=adjusted_target_batch_size,

+ 5 - 5
examples/albert/run_training_monitor.py

@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 
-import logging
 import time
 from dataclasses import asdict, dataclass, field
 from ipaddress import ip_address
@@ -13,12 +12,13 @@ from torch_optimizer import Lamb
 from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 import hivemind
-from hivemind.utils.compression import CompressionType
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import utils
 from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 
-logger = logging.getLogger(__name__)
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger()
 
 
 @dataclass
@@ -101,7 +101,7 @@ class CheckpointHandler:
             opt=opt,
             dht=dht,
             prefix=experiment_prefix,
-            compression_type=CompressionType.Value(collab_optimizer_args.compression),
+            compression_type=hivemind.Float16Compression(),
             bandwidth=collab_optimizer_args.bandwidth,
             target_batch_size=adjusted_target_batch_size,
             client_mode=collab_optimizer_args.client_mode,
@@ -140,7 +140,7 @@ class CheckpointHandler:
         self.model.push_to_hub(
             repo_name=self.repo_path,
             repo_url=self.repo_url,
-            commit_message=f"Step {current_step}, loss {current_loss:.3f}",
+            commit_message=f"Step #{current_step}, loss {current_loss:.3f}",
         )
         logger.info("Finished uploading to Model Hub")
 

+ 1 - 7
examples/albert/utils.py

@@ -7,7 +7,7 @@ from hivemind import choose_ip_address
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
-from hivemind.utils.logging import get_logger
+from hivemind.utils.logging import TextStyle, get_logger
 
 logger = get_logger(__name__)
 
@@ -30,12 +30,6 @@ def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase],
     return validators, signature_validator.local_public_key
 
 
-class TextStyle:
-    BOLD = "\033[1m"
-    BLUE = "\033[34m"
-    RESET = "\033[0m"
-
-
 def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
     if only_p2p:
         unique_addrs = {addr["p2p"] for addr in visible_maddrs}

+ 1 - 0
hivemind/__init__.py

@@ -1,4 +1,5 @@
 from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.moe import (
     ExpertBackend,

+ 11 - 7
hivemind/averaging/allreduce.py

@@ -5,11 +5,11 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 import torch
 
 from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
 
 # flavour types
 GroupID = bytes
@@ -153,13 +153,17 @@ class AllReduceRunner(ServicerBase):
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
-            loop = asyncio.get_event_loop()
             code = None
             stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
-            async for part_index, msg in aenumerate(stream):
+            async for part_index, (averaged_part_delta, msg) in aenumerate(
+                amap_in_executor(
+                    lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
+                    stream,
+                    max_prefetch=self.tensor_part_container.prefetch,
+                )
+            ):
                 if code is None:
                     code = msg.code
-                averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
 
             if code != averaging_pb2.AVERAGED_PART:
@@ -193,7 +197,7 @@ class AllReduceRunner(ServicerBase):
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
                 sender_index = self.sender_peer_ids.index(context.remote_id)
-                async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
+                async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
                     yield msg
 
             except Exception as e:
@@ -232,7 +236,7 @@ class AllReduceRunner(ServicerBase):
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
         # Coroutines are lazy, so we take the first item to start the couroutine's execution
-        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
+        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 95 - 60
hivemind/averaging/averager.py

@@ -9,7 +9,6 @@ import multiprocessing as mp
 import os
 import threading
 import weakref
-from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
 from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 
@@ -21,12 +20,18 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
+from hivemind.compression import (
+    CompressionBase,
+    CompressionInfo,
+    NoCompression,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.proto import averaging_pb2, runtime_pb2
+from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -51,7 +56,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
     :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
-    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param compression: optionally compress tensors with this compression algorithm before running all-reduce
+    :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
+    :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
@@ -102,7 +109,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
-        compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+        compression: CompressionBase = NoCompression(),
+        state_compression: CompressionBase = NoCompression(),
+        tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         bandwidth: Optional[float] = None,
         min_vector_size: int = 0,
         auxiliary: bool = False,
@@ -158,7 +167,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             request_timeout=request_timeout,
         )
         self.allreduce_kwargs = dict(
-            compression_type=compression_type, part_size_bytes=part_size_bytes, min_vector_size=min_vector_size
+            compression=compression,
+            part_size_bytes=part_size_bytes,
+            min_vector_size=min_vector_size,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -169,6 +180,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
+        self.state_compression = state_compression
+        self.tensor_infos = tensor_infos
 
         self._ready = MPFuture()
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
@@ -197,6 +210,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def peer_id(self) -> PeerID:
         return self.dht.peer_id
 
+    @property
+    def request_timeout(self):
+        return self._matchmaking.request_timeout
+
     def run(self):
         """
         Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
@@ -211,48 +228,56 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
-            async def _run():
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                self._p2p = await self.dht.replicate_p2p()
+                if not self.client_mode:
+                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
+                else:
+                    logger.debug(f"The averager is running in client mode.")
+
+                self._matchmaking = Matchmaking(
+                    self._p2p,
+                    self.schema_hash,
+                    self.dht,
+                    client_mode=self.client_mode,
+                    **self.matchmaking_kwargs,
+                )
+                if not self.client_mode:
+                    asyncio.create_task(self._declare_for_download_periodically())
+
+                self._pending_group_assembled = asyncio.Event()
+                self._pending_group_assembled.set()
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
                 try:
-                    self._p2p = await self.dht.replicate_p2p()
-                    if not self.client_mode:
-                        await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
-                    else:
-                        logger.debug(f"The averager is running in client mode.")
-
-                    self._matchmaking = Matchmaking(
-                        self._p2p,
-                        self.schema_hash,
-                        self.dht,
-                        client_mode=self.client_mode,
-                        **self.matchmaking_kwargs,
-                    )
-                    if not self.client_mode:
-                        asyncio.create_task(self._declare_for_download_periodically())
-
-                    self._pending_group_assembled = asyncio.Event()
-                    self._pending_group_assembled.set()
-                except Exception as e:
-                    # Loglevel is DEBUG since normally the exception is propagated to the caller
-                    logger.debug(e, exc_info=True)
-                    self._ready.set_exception(e)
-                    return
-                self._ready.set_result(None)
-
-                while True:
-                    try:
-                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    except (OSError, ConnectionError) as e:
-                        logger.exception(e)
-                        await asyncio.sleep(self._matchmaking.request_timeout)
-                        continue
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
-
-            loop.run_until_complete(_run())
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self.request_timeout)
+                except asyncio.TimeoutError:
+                    pass
+                if not self._inner_pipe.poll():
+                    continue
+                try:
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self.request_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
 
     def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
@@ -351,7 +376,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                     future.set_result(
                         await asyncio.wait_for(
-                            self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout
+                            self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
+                            timeout=self._allreduce_timeout,
                         )
                     )
                     # averaging is finished, loop will now exit
@@ -484,7 +510,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
 
-        async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
+        async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
 
     async def _declare_for_download_periodically(self):
@@ -517,24 +543,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         if not self.allow_state_sharing:
             return  # deny request and direct peer to the next prospective averager
-        metadata, tensors = await self._get_current_state_from_host_process()
+        metadata, tensors, infos = await self._get_current_state_from_host_process()
+        if infos is None:
+            infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)]
+        assert len(tensors) == len(infos)
 
-        for tensor in tensors:
-            for part in split_for_streaming(serialize_torch_tensor(tensor)):
+        for tensor, info in zip(tensors, infos):
+            for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)):
                 if metadata is not None:
                     yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
                     metadata = None
                 else:
                     yield averaging_pb2.DownloadData(tensor_part=part)
 
-    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
+    def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]:
         """
         Get current state and send it to a peer. executed in the host process. Meant to be overriden.
         :returns: a tuple of (small metadata, sequence of torch tensors)
         :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
         """
         with self.get_tensors() as tensors:
-            return dict(group_key=self.get_group_bits()), tensors
+            return dict(group_key=self.get_group_bits()), tensors, self.tensor_infos
 
     async def _get_current_state_from_host_process(self):
         """Executed in the averager process inside rpc_download_state"""
@@ -542,7 +571,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
         return await future
 
-    def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
+    def load_state_from_peers(
+        self, wait: bool = True, timeout: Optional[float] = None
+    ) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
         """
         Try to download the latest optimizer state one of the existing peer.
         :returns: on success, return a 2-tuple with (metadata, tensors), where
@@ -554,7 +585,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         future = MPFuture()
         self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
-        return future.result() if wait else future
+        return future.result(timeout=timeout) if wait else future
 
     async def _load_state_from_peers(self, future: MPFuture):
         try:
@@ -579,7 +610,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
-                        async for message in stream:
+
+                        async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:
@@ -603,7 +635,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         finally:
             if not future.done():
-                logger.warning("Averager could not load state from peers: all requests have failed.")
                 future.set_result(None)
 
     def get_group_bits(self, wait: bool = True):
@@ -666,7 +697,11 @@ def _background_thread_fetch_current_state(
             get_current_state = get_current_state_ref()
             if get_current_state is None:
                 break
-            state_metadata, state_tensors = get_current_state()
+            state = get_current_state()
+            assert 0 < len(state) <= 3
+            if len(state) != 3:
+                state = tuple(state + (None,) * (3 - len(state)))
+            state_metadata, state_tensors, tensor_infos = state
             del get_current_state
 
             state_metadata = serializer.dumps(state_metadata)
@@ -674,7 +709,7 @@ def _background_thread_fetch_current_state(
                 tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
             )
             # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
-            future.set_result((state_metadata, state_tensors))
+            future.set_result((state_metadata, state_tensors, tensor_infos))
         except BaseException as e:
             future.set_exception(e)
             logger.warning(e)

+ 24 - 18
hivemind/averaging/partition.py

@@ -3,17 +3,17 @@ Auxiliary data structures for AllReduceRunner
 """
 import asyncio
 from collections import deque
-from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar, Union
+from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar
 
 import numpy as np
 import torch
 
-from hivemind.proto.runtime_pb2 import CompressionType, Tensor
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
+from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor
-from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor
 
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 19
+DEFAULT_PART_SIZE_BYTES = 2 ** 16
 
 
 class TensorPartContainer:
@@ -22,8 +22,9 @@ class TensorPartContainer:
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
     :param tensors: local tensors to be split and aggregated
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
-    :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
+    :param compression: optionally compress tensors with this compression algorithm before sending them to peers
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
+    :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     """
 
@@ -31,16 +32,19 @@ class TensorPartContainer:
         self,
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
-        compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE,
+        compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
-        prefetch: int = 1,
+        tensor_infos: Optional[Sequence[CompressionInfo]] = None,
+        prefetch: int = 5,
     ):
-        if not isinstance(compression_type, Sequence):
-            compression_type = [compression_type] * len(tensors)
-        assert len(compression_type) == len(tensors), "compression types do not match the number of tensors"
+        if tensor_infos is None:
+            tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
+        assert len(tensor_infos) == len(tensors), "compression types do not match the number of tensors"
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
-        self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
+        self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
         self.total_size = sum(tensor.numel() for tensor in tensors)
+        self.prefetch = prefetch
+
         self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
@@ -56,11 +60,13 @@ class TensorPartContainer:
         pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
         pivots[-1] = self.total_size
 
-        for tensor, tensor_compression in zip(self.local_tensors, compression_type):
-            part_size_values = int(part_size_bytes / get_nbytes_per_value(tensor.dtype, tensor_compression))
+        for tensor, info in zip(self.local_tensors, self.tensor_infos):
+            bytes_per_value = tensor.element_size() * compression.estimate_compression_ratio(info)
+            part_size_values = int(part_size_bytes / bytes_per_value)
             tensor_parts = tensor.detach().view(-1).split(part_size_values)
             self.num_parts_by_tensor.append(len(tensor_parts))
-            for part in tensor_parts:
+            for part_index, part in enumerate(tensor_parts):
+                part_info = info.get_part(part_index, part_size_values)
                 if current_length + len(part) > pivots[current_peer_index]:
                     # switch to next peer; if a part lands between parts of two or
                     # more peers, assign that part to the peer with highest intersection
@@ -71,9 +77,9 @@ class TensorPartContainer:
                         current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
                         peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
                     assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
-                    self._input_parts_by_peer[assigned_peer_index].append((part, tensor_compression))
+                    self._input_parts_by_peer[assigned_peer_index].append((part, part_info))
                 else:
-                    self._input_parts_by_peer[current_peer_index].append((part, tensor_compression))
+                    self._input_parts_by_peer[current_peer_index].append((part, part_info))
                 current_length += len(part)
 
         assert current_length == self.total_size
@@ -89,7 +95,7 @@ class TensorPartContainer:
         return input_parts
 
     @torch.no_grad()
-    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[Tensor]:
+    async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[runtime_pb2.Tensor]:
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
@@ -99,7 +105,7 @@ class TensorPartContainer:
                 yield self._input_parts_by_peer[peer_index].popleft()
 
         async for serialized_part in amap_in_executor(
-            lambda x_and_compr: serialize_torch_tensor(*x_and_compr), _aiterate_parts(), max_prefetch=self.prefetch
+            lambda x_and_info: self.compression.compress(*x_and_info), _aiterate_parts(), max_prefetch=self.prefetch
         ):
             yield serialized_part
 

+ 47 - 12
hivemind/averaging/training.py

@@ -8,6 +8,7 @@ from typing import Dict, Iterator, Optional, Sequence
 import torch
 
 from hivemind.averaging import DecentralizedAverager
+from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
@@ -41,23 +42,28 @@ class TrainingAverager(DecentralizedAverager):
         average_gradients: bool,
         average_opt_statistics: Sequence[str] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
+        parameter_names: Optional[Sequence[str]] = None,
         initialize_optimizer: bool = True,
         **kwargs
     ):
+        if initialize_optimizer:
+            initialize_optimizer_state(opt)  # note: this will run one optimizer step!
+        if parameter_names is None:
+            parameter_names = tuple(i for group in opt.param_groups for i in range(len(group["params"])))
 
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
+        self.parameter_names = parameter_names
         self.step_executor = ThreadPoolExecutor(max_workers=1)
         self.lock_averager_step = Lock()
         self.pending_updates_done = Event()
         self.pending_updates_done.set()
-        if initialize_optimizer:
-            initialize_optimizer_state(opt)  # note: this will run one optimizer step!
 
         with torch.no_grad():
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
-        super().__init__(averaged_tensors=averaged_tensors, **kwargs)
+
+        super().__init__(averaged_tensors=averaged_tensors, tensor_infos=list(self.tensor_infos()), **kwargs)
 
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
         """
@@ -119,13 +125,8 @@ class TrainingAverager(DecentralizedAverager):
             self.local_step += 1
             return gathered
 
-    def local_tensors(self, replace_none: bool = True) -> Iterator[torch.Tensor]:
-        """
-        Iterate local trainer's tensors that should be averaged with peers
-
-        :param replace_none: if True and average_gradients is True, None grads will be replaced with a zero tensors
-          Otherwise, such gradients will be skipped. (this may cause inconsistencies with averaged_tensors)
-        """
+    def local_tensors(self) -> Iterator[torch.Tensor]:
+        """Iterate local trainer's tensors that should be averaged with peers"""
         if self.average_parameters:
             for param_group in self.opt.param_groups:
                 yield from param_group["params"]
@@ -134,7 +135,7 @@ class TrainingAverager(DecentralizedAverager):
                 for param in param_group["params"]:
                     if param.grad is not None:
                         yield param.grad
-                    elif replace_none:
+                    else:
                         yield torch.zeros_like(param)
         for stats in self.opt_statistics:
             for param_group in self.opt.param_groups:
@@ -142,6 +143,26 @@ class TrainingAverager(DecentralizedAverager):
                     yield self.opt.state[param][stats]
         yield from iter(self.extra_tensors)
 
+    def tensor_infos(self):
+        """Get CompressionInfo for each tensor, accounting for its role and specification"""
+        params = tuple(param for param_group in self.opt.param_groups for param in param_group["params"])
+        assert len(params) == len(self.parameter_names)
+        if self.average_parameters:
+            for param, key in zip(params, self.parameter_names):
+                yield CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+        if self.average_gradients:
+            for param, key in zip(params, self.parameter_names):
+                if param.grad is not None:
+                    grad = param.grad if param.grad is not None else torch.zeros_like(param)
+                    yield CompressionInfo.from_tensor(grad, key=key, role=TensorRole.GRADIENT)
+        for stats in self.opt_statistics:
+            for param, key in zip(params, self.parameter_names):
+                yield CompressionInfo.from_tensor(
+                    self.opt.state[param][stats], key=(key, stats), role=TensorRole.OPTIMIZER
+                )
+        for i, extra_tensor in enumerate(self.extra_tensors):
+            yield CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+
     def get_current_state(self):
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
@@ -151,11 +172,25 @@ class TrainingAverager(DecentralizedAverager):
             optimized_parameters = tuple(
                 param.detach().cpu() for param_group in self.opt.param_groups for param in param_group["params"]
             )
+            parameter_infos = [
+                CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+                for param, key in zip(optimized_parameters, self.parameter_names)
+            ]
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            extra_infos = [
+                CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+                for i, extra_tensor in enumerate(extra_tensors)
+            ]
             optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.opt)
+            optimizer_infos = [
+                CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
+                for i, opt_tensor in enumerate(optimizer_tensors)
+            ]
 
         metadata = dict(step=self.local_step, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata)
-        return metadata, list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
+        return metadata, all_tensors, all_tensor_infos
 
     def load_state_from_peers(self, **kwargs):
         """

+ 52 - 0
hivemind/compression/__init__.py

@@ -0,0 +1,52 @@
+"""
+Compression strategies that reduce the network communication in .averaging, .optim and .moe
+"""
+
+import warnings
+from typing import Dict, Optional
+
+import torch
+
+from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
+from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
+from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
+from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.proto import runtime_pb2
+
+warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
+
+
+BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
+    NONE=NoCompression(),
+    FLOAT16=Float16Compression(),
+    MEANSTD_16BIT=ScaledFloat16Compression(),
+    QUANTILE_8BIT=Quantile8BitQuantization(),
+    UNIFORM_8BIT=Uniform8BitQuantization(),
+)
+
+for key in runtime_pb2.CompressionType.keys():
+    assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer."
+    actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
+    assert (
+        runtime_pb2.CompressionType.Name(actual_compression_type) == key
+    ), f"Compression strategy for {key} has inconsistent type"
+
+
+def serialize_torch_tensor(
+    tensor: torch.Tensor,
+    compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
+    info: Optional[CompressionInfo] = None,
+    allow_inplace: bool = False,
+    **kwargs,
+) -> runtime_pb2.Tensor:
+    """Serialize a given tensor into a protobuf message using the specified compression strategy"""
+    assert tensor.device == torch.device("cpu")
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
+    info = info or CompressionInfo.from_tensor(tensor, **kwargs)
+    return compression.compress(tensor, info, allow_inplace)
+
+
+def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+    """Restore a pytorch tensor from a protobuf message"""
+    compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
+    return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)

+ 67 - 0
hivemind/compression/adaptive.py

@@ -0,0 +1,67 @@
+from abc import ABC, abstractmethod
+from typing import Mapping, Sequence, Union
+
+import torch
+
+import hivemind
+from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole
+from hivemind.proto import runtime_pb2
+
+
+class AdaptiveCompressionBase(CompressionBase, ABC):
+    @abstractmethod
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        ...
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return self.choose_compression(info).estimate_compression_ratio(info)
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace)
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        return hivemind.compression.deserialize_torch_tensor(serialized_tensor)
+
+
+class SizeAdaptiveCompression(AdaptiveCompressionBase):
+    """Apply compression strategy 1 if tensor has more than :threshold: elements and strategy 2 otherwise"""
+
+    def __init__(self, threshold: int, less: CompressionBase, greater_equal: CompressionBase):
+        self.threshold, self.less, self.greater_equal = threshold, less, greater_equal
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.greater_equal if info.descriptor.numel() >= self.threshold else self.less
+
+
+class RoleAdaptiveCompression(AdaptiveCompressionBase):
+    """Compress a tensor based on its role in training. Any non-specified compressions will use the "default" option"""
+
+    def __init__(
+        self,
+        *,
+        activation: CompressionBase = None,
+        parameter: CompressionBase = None,
+        gradient: CompressionBase = None,
+        optimizer: CompressionBase = None,
+        default: CompressionBase = NoCompression()
+    ):
+        self.role_compressions = {
+            TensorRole.ACTIVATION: activation or default,
+            TensorRole.PARAMETER: parameter or default,
+            TensorRole.GRADIENT: gradient or default,
+            TensorRole.OPTIMIZER: optimizer or default,
+            TensorRole.UNSPECIFIED: default,
+        }
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.role_compressions[info.role]
+
+
+class PerTensorCompression(AdaptiveCompressionBase):
+    """Manually specify the compression strategy depending on tensor key"""
+
+    def __init__(self, tensor_compressions: Union[Sequence[CompressionBase], Mapping[Key, CompressionBase]]):
+        self.tensor_compressions = tensor_compressions
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.tensor_compressions[info.key]

+ 89 - 0
hivemind/compression/base.py

@@ -0,0 +1,89 @@
+import dataclasses
+from abc import ABC, abstractmethod
+from enum import Enum, auto
+from typing import Any, Optional
+
+import numpy as np
+import torch
+
+from hivemind.proto import runtime_pb2
+from hivemind.utils.tensor_descr import TensorDescriptor
+
+Key = Any
+
+
+class TensorRole(Enum):
+    ACTIVATION = auto()
+    PARAMETER = auto()
+    GRADIENT = auto()
+    OPTIMIZER = auto()
+    UNSPECIFIED = auto()
+
+
+@dataclasses.dataclass(frozen=True)
+class CompressionInfo:
+    """Auxiliary data structure that contains information about the tensor that determines how it is compressed"""
+
+    key: Key  # name or index of the tensor from named parameters, optimizer state dict or i/o structure
+    descriptor: TensorDescriptor  # data structure that defines shape, dtype, layout and device information
+    role: TensorRole = TensorRole.UNSPECIFIED  # which role does the tensor play with respect to the model
+    part_index: int = 0  # if tensor is sliced into parts, this represents the index within one tensor
+    part_size: Optional[int] = None  # if tensor is sliced into parts, this is the _maximum_ number of values per part
+
+    @classmethod
+    def from_tensor(cls, tensor: torch.Tensor, key: Key = None, descriptor: TensorDescriptor = None, **kwargs):
+        return cls(key, descriptor or TensorDescriptor.from_tensor(tensor), **kwargs)
+
+    def get_part(self, part_index: int, part_size: Optional[int]):
+        return CompressionInfo(self.key, self.descriptor, self.role, part_index=part_index, part_size=part_size)
+
+
+class CompressionBase(ABC):
+    """A base class that applies compression algorithm to a pytorch tensor"""
+
+    compression_type: runtime_pb2.CompressionType
+
+    @abstractmethod
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        """
+        Applies compression algorithm to a tensor based on their meta-parameters
+
+        :param tensor: a pytorch tensor to compress; depending on the applicaiton, it is a full tensor or a part
+        :param info: meta-information about the tensor; if partitioning is used, this still describes the full tensor
+        :param allow_inplace: if True, compression can (but doesn't have to) to modify tensor in-place for efficiency
+        :returns: a protobuf message that encodes the tensor
+        """
+        ...
+
+    @abstractmethod
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        """Create a pytorch tensor from the serialized outputs of .compress"""
+        ...
+
+    @abstractmethod
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        """Estimate the compression ratio without doing the actual compression; lower ratio = better compression"""
+        ...
+
+
+class NoCompression(CompressionBase):
+    """A dummy compression strategy that preserves the original tensor as is."""
+
+    compression_type = runtime_pb2.CompressionType.NONE
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        array = tensor.numpy()
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=array.tobytes(),
+            size=array.shape,
+            dtype=array.dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
+        return torch.as_tensor(array).reshape(tuple(serialized_tensor.size))
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return 1.0

+ 92 - 0
hivemind/compression/floating.py

@@ -0,0 +1,92 @@
+import math
+
+import numpy as np
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo
+from hivemind.proto import runtime_pb2
+
+
+class Float16Compression(CompressionBase):
+    compression_type = runtime_pb2.CompressionType.FLOAT16
+    FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        dtype_name = tensor.numpy().dtype.name
+        tensor = tensor.detach().cpu().float()
+        tensor = tensor if allow_inplace else tensor.clone()
+        tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=tensor.numpy().tobytes(),
+            size=tensor.shape,
+            dtype=dtype_name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        original_dtype = np.dtype(serialized_tensor.dtype)
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
+        return torch.as_tensor(np.asarray(array, dtype=original_dtype)).reshape(tuple(serialized_tensor.size))
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return 16.0 / get_num_bits(info.descriptor.dtype)
+
+
+class ScaledFloat16Compression(Float16Compression):
+    """A compression strategy that applies mean-std scaling over last axis before casting to float16"""
+
+    compression_type = runtime_pb2.CompressionType.MEANSTD_16BIT
+    FP32_BYTES = torch.finfo(torch.float32).bits // 8
+    FP32_EPS = torch.finfo(torch.float32).eps
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        dtype_name = tensor.numpy().dtype.name
+        tensor = tensor.detach().cpu().float()
+        tensor = tensor if allow_inplace else tensor.clone()
+        means = torch.mean(tensor, dim=-1, keepdim=True)
+        tensor.sub_(means)
+        stds = tensor.norm(dim=-1, keepdim=True) / math.sqrt(tensor.shape[-1])
+        stds.clamp_min_(self.FP32_EPS)
+        tensor.div_(stds)
+        tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
+
+        data = b"".join((tensor.numpy().tobytes(), means.float().numpy().tobytes(), stds.float().numpy().tobytes()))
+
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=data,
+            size=tensor.shape,
+            dtype=dtype_name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        stats_shape = list(serialized_tensor.size)
+        stats_shape[-1] = 1
+        stats_count = np.prod(stats_shape)
+        means_offset = len(serialized_tensor.buffer) - 2 * stats_count * self.FP32_BYTES
+        stds_offset = len(serialized_tensor.buffer) - stats_count * self.FP32_BYTES
+
+        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16, count=np.prod(serialized_tensor.size))
+        means = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=means_offset, count=stats_count)
+        stds = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=stds_offset, count=stats_count)
+
+        means = torch.as_tensor(means).reshape(stats_shape)
+        stds = torch.as_tensor(stds).reshape(stats_shape)
+        tensor = torch.as_tensor(np.asarray(array, dtype=serialized_tensor.dtype)).reshape(
+            list(serialized_tensor.size)
+        )
+        return tensor.mul_(stds).add_(means)
+
+
+def get_num_bits(dtype: torch.dtype) -> int:
+    if dtype == torch.bool:
+        return 8  # see https://github.com/pytorch/pytorch/issues/41571
+    elif dtype.is_floating_point:
+        return torch.finfo(dtype).bits
+    else:
+        try:
+            return torch.iinfo(dtype).bits
+        except TypeError:
+            raise TypeError(f"Could not infer size for tensor type {dtype}")

+ 114 - 0
hivemind/compression/quantization.py

@@ -0,0 +1,114 @@
+import math
+import os
+from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Tuple
+
+import numpy as np
+import torch
+
+from hivemind.compression.base import CompressionBase, CompressionInfo
+from hivemind.proto import runtime_pb2
+
+EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))
+
+
+class Quantization(CompressionBase, ABC):
+    codebook_dtype, indices_dtype = np.float32, np.uint8
+
+    @abstractmethod
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        """Convert tensor into a pair of (indices, codebook)"""
+        ...
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
+            size=tensor.shape,
+            dtype=tensor.numpy().dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        codebook_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
+        codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
+        quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
+        quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
+        codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
+        return codebook[quantized]
+
+    def estimate_compression_ratio(self, info: CompressionInfo) -> float:
+        return self.n_bits / torch.finfo(info.descriptor.dtype).bits
+
+    @property
+    def n_bits(self):
+        return self.indices_dtype(1).itemsize * 8
+
+    @property
+    def n_bins(self):
+        return 2 ** self.n_bits
+
+
+class Uniform8BitQuantization(Quantization):
+    RANGE_IN_SIGMAS: int = 6
+    compression_type = runtime_pb2.UNIFORM_8BIT
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        offset = self.n_bins // 2
+        shift = tensor.mean()
+        centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
+        std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
+        scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
+        quantized = torch.quantize_per_tensor(centered_tensor, scale, offset, torch.quint8).int_repr()
+        lookup = average_buckets(tensor, quantized, self.n_bins)
+        return np.asarray(quantized, dtype=self.indices_dtype), np.asarray(lookup, dtype=self.codebook_dtype)
+
+
+class Quantile8BitQuantization(Quantization):
+    compression_type = runtime_pb2.QUANTILE_8BIT
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+        tensor = tensor.detach().float()
+        borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), self.n_bins + 1)[1:-1])
+        quantized = torch.clamp_(torch.bucketize(tensor, borders), 0, self.n_bins - 1)
+        codebook = average_buckets(tensor, quantized, self.n_bins)
+        return quantized.numpy().astype(np.uint8), codebook.numpy()
+
+
+def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
+    """Return the average value in each bucket"""
+    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
+    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
+    lookup = bin_sums / bin_counts
+    return lookup
+
+
+def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
+    """Adjust chunk_size to minimize imbalance between chunk sizes"""
+    if min_chunk_size >= num_elements:
+        return min_chunk_size
+    leftover_elements = num_elements % min_chunk_size
+    num_chunks = num_elements // min_chunk_size
+    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
+
+
+def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
+    """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
+    if not array.data.c_contiguous and array.data.f_contiguous:
+        array = array.T
+    array = np.ascontiguousarray(array.reshape(-1))
+    quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
+    chunk_size = get_chunk_size(len(array), min_chunk_size)
+    num_chunks = (len(array) - 1) // chunk_size + 1
+    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
+
+    jobs = []
+    for i in range(num_chunks):
+        chunk = slice(chunk_size * i, chunk_size * (i + 1))
+        jobs.append(EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
+
+    for job in jobs:
+        job.result()
+    return np.quantile(partition_quantiles, quantiles)

+ 43 - 38
hivemind/dht/__init__.py

@@ -17,7 +17,6 @@ from __future__ import annotations
 import asyncio
 import multiprocessing as mp
 import os
-from concurrent.futures import ThreadPoolExecutor
 from functools import partial
 from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
 
@@ -102,45 +101,51 @@ class DHT(mp.Process):
 
     def run(self) -> None:
         """Serve DHT forever. This function will not return until DHT node is shut down"""
-        loop = switch_to_uvloop()
-
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
-            async def _run():
+        loop = switch_to_uvloop()
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                if self._daemon_listen_maddr is not None:
+                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                else:
+                    replicated_p2p = None
+
+                self._node = await DHTNode.create(
+                    initial_peers=self.initial_peers,
+                    num_workers=self.num_workers,
+                    record_validator=self._record_validator,
+                    p2p=replicated_p2p,
+                    **self.kwargs,
+                )
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
+                try:
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
+                except asyncio.TimeoutError:
+                    pass
+                if not self._inner_pipe.poll():
+                    continue
                 try:
-                    if self._daemon_listen_maddr is not None:
-                        replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
-                    else:
-                        replicated_p2p = None
-
-                    self._node = await DHTNode.create(
-                        initial_peers=self.initial_peers,
-                        num_workers=self.num_workers,
-                        record_validator=self._record_validator,
-                        p2p=replicated_p2p,
-                        **self.kwargs,
-                    )
-                except Exception as e:
-                    # Loglevel is DEBUG since normally the exception is propagated to the caller
-                    logger.debug(e, exc_info=True)
-                    self._ready.set_exception(e)
-                    return
-                self._ready.set_result(None)
-
-                while True:
-                    try:
-                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    except (OSError, ConnectionError) as e:
-                        logger.exception(e)
-                        await asyncio.sleep(self._node.protocol.wait_timeout)
-                        continue
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
-
-            coro = _run()
-            loop.run_until_complete(coro)
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._node.protocol.wait_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
 
     def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """

+ 2 - 0
hivemind/dht/node.py

@@ -187,6 +187,8 @@ class DHTNode:
         if p2p is None:
             if not kwargs.get("use_ipfs"):
                 kwargs["initial_peers"] = initial_peers
+            if client_mode:
+                kwargs.setdefault("dht_mode", "client")
             p2p = await P2P.create(**kwargs)
             self._should_shutdown_p2p = True
         else:

+ 3 - 2
hivemind/moe/client/expert.py

@@ -11,8 +11,9 @@ from torch.autograd.function import once_differentiable
 import hivemind
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils import switch_to_uvloop
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils import nested_compare, nested_flatten, nested_pack
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 

+ 1 - 1
hivemind/moe/client/moe.py

@@ -10,12 +10,12 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 import hivemind
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.utils import nested_flatten, nested_map, nested_pack
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 1 - 1
hivemind/moe/server/connection_handler.py

@@ -6,12 +6,12 @@ from typing import AsyncIterator, Dict
 import torch
 
 from hivemind.dht import DHT
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.proto import runtime_pb2
 from hivemind.utils import MPFuture, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 logger = get_logger(__name__)
 

+ 23 - 12
hivemind/optim/collaborative.py

@@ -85,6 +85,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
     :param scheduler: if specified, use this scheduler to update optimizer learning rate
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
@@ -114,6 +115,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         performance_ema_alpha: float = 0.1,
         metadata_expiration: float = 60.0,
         averaging_timeout: Optional[float] = None,
+        load_state_timeout: float = 600.0,
         step_tolerance: int = 1,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
@@ -137,7 +139,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             default_refresh_period,
         )
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
-        self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
+        self.averaging_timeout = averaging_timeout
+        self.load_state_timeout = load_state_timeout
+        self.metadata_expiration = metadata_expiration
         self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
         self.client_mode, self.step_tolerance = client_mode, step_tolerance
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
@@ -149,7 +153,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
         self.last_step_time = None
 
-        self.collaboration_state = self.fetch_collaboration_state()
+        self.collaboration_state = self._fetch_state()
         self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
         self.lock_local_progress, self.should_report_progress = Lock(), Event()
         self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
@@ -185,7 +189,14 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
         with self.lock_collaboration_state:
-            self.averager.load_state_from_peers(**kwargs)
+            while True:
+                try:
+                    self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    break
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
+                    continue
+
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.reset_accumulated_grads_()
             self.update_scheduler()
@@ -226,8 +237,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self.fetch_collaboration_state()
+        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
+        self.collaboration_state = self._fetch_state()
         self.collaboration_state_updated.set()
 
         if not self.is_synchronized:
@@ -277,8 +288,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         if not self.collaboration_state.ready_for_step:
             return
 
-        logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self.fetch_collaboration_state()
+        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
+        self.collaboration_state = self._fetch_state()
         self.collaboration_state_updated.set()
 
         with self.lock_collaboration_state:
@@ -381,9 +392,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 continue  # if state was updated externally, reset timer
 
             with self.lock_collaboration_state:
-                self.collaboration_state = self.fetch_collaboration_state()
+                self.collaboration_state = self._fetch_state()
 
-    def fetch_collaboration_state(self) -> CollaborationState:
+    def _fetch_state(self) -> CollaborationState:
         """Read performance statistics reported by peers, estimate progress towards next batch"""
         response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         current_time = get_dht_time()
@@ -441,9 +452,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         )
         logger.log(
             self.status_loglevel,
-            f"Collaboration accumulated {total_samples_accumulated} samples from "
-            f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
-            f"(refresh in {time_to_next_fetch:.2f}s.)",
+            f"{self.prefix} accumulated {total_samples_accumulated} samples from "
+            f"{num_peers} peers for step #{global_optimizer_step}. "
+            f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
         )
         return CollaborationState(
             global_optimizer_step,

+ 64 - 11
hivemind/p2p/p2p_daemon.py

@@ -1,9 +1,12 @@
 import asyncio
+import json
+import logging
 import os
 import secrets
 from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
+from datetime import datetime
 from importlib.resources import path
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
@@ -12,11 +15,11 @@ from multiaddr import Multiaddr
 
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
-from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
-from hivemind.utils.asyncio import aiter, asingle
-from hivemind.utils.logging import get_logger
+from hivemind.utils.asyncio import as_aiter, asingle
+from hivemind.utils.logging import get_logger, golog_level_to_python, loglevel, python_level_to_golog
 
 logger = get_logger(__name__)
 
@@ -54,9 +57,9 @@ class P2P:
     END_OF_STREAM = RPCError()
 
     DHT_MODE_MAPPING = {
-        "dht": {"dht": 1},
-        "dht_server": {"dhtServer": 1},
-        "dht_client": {"dhtClient": 1},
+        "auto": {"dht": 1},
+        "server": {"dhtServer": 1},
+        "client": {"dhtClient": 1},
     }
     FORCE_REACHABILITY_MAPPING = {
         "public": {"forceReachabilityPublic": 1},
@@ -80,7 +83,7 @@ class P2P:
         announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
         auto_nat: bool = True,
         conn_manager: bool = True,
-        dht_mode: str = "dht_server",
+        dht_mode: str = "server",
         force_reachability: Optional[str] = None,
         host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
         identity_path: Optional[str] = None,
@@ -95,6 +98,7 @@ class P2P:
         use_relay: bool = True,
         use_relay_hop: bool = False,
         use_relay_discovery: bool = False,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
@@ -103,7 +107,9 @@ class P2P:
         :param announce_maddrs: Visible multiaddrs that the peer will announce
                                 for external connections from other p2p instances
         :param conn_manager: Enables the Connection Manager
-        :param dht_mode: DHT mode (dht_client/dht_server/dht)
+        :param dht_mode: libp2p DHT mode (auto/client/server).
+                         Defaults to "server" to make collaborations work in local networks.
+                         Details: https://pkg.go.dev/github.com/libp2p/go-libp2p-kad-dht#ModeOpt
         :param force_reachability: Force reachability mode (public/private)
         :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
         :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
@@ -163,11 +169,17 @@ class P2P:
             relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
             tls=tls,
+            persistentConnMaxMsgSize=persistent_conn_max_msg_size,
             **process_kwargs,
         )
 
+        env = os.environ.copy()
+        env.setdefault("GOLOG_LOG_LEVEL", python_level_to_golog(loglevel))
+        env["GOLOG_LOG_FMT"] = "json"
+
+        logger.debug(f"Launching {proc_args}")
         self._child = await asyncio.subprocess.create_subprocess_exec(
-            *proc_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
+            *proc_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, env=env
         )
         self._alive = True
 
@@ -179,7 +191,12 @@ class P2P:
             await self.shutdown()
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
-        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(
+            control_maddr=self._daemon_listen_maddr,
+            listen_maddr=self._client_listen_maddr,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
+
         await self._ping_daemon()
         return self
 
@@ -480,7 +497,7 @@ class P2P:
         input: Union[TInputProtobuf, TInputStream],
         output_protobuf_type: Type[Message],
     ) -> TOutputStream:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
+        requests = input if isinstance(input, AsyncIterableABC) else as_aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
 
     def _start_listening(self) -> None:
@@ -558,8 +575,44 @@ class P2P:
                 break
             last_line = line.rstrip().decode(errors="ignore")
 
+            self._log_p2pd_message(last_line)
             if last_line.startswith("Peer ID:"):
                 ready.set_result(None)
 
         if not ready.done():
             ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
+
+    @staticmethod
+    def _log_p2pd_message(line: str) -> None:
+        if '"logger"' not in line:  # User-friendly info from p2pd stdout
+            logger.debug(line, extra={"caller": "p2pd"})
+            return
+
+        try:
+            record = json.loads(line)
+            caller = record["caller"]
+
+            level = golog_level_to_python(record["level"])
+            if level <= logging.WARNING:
+                # Many Go loggers are excessively verbose (e.g. show warnings for unreachable peers),
+                # so we downgrade INFO and WARNING messages to DEBUG.
+                # The Go verbosity can still be controlled via the GOLOG_LOG_LEVEL env variable.
+                # Details: https://github.com/ipfs/go-log#golog_log_level
+                level = logging.DEBUG
+
+            message = record["msg"]
+            if "error" in record:
+                message += f": {record['error']}"
+
+            logger.log(
+                level,
+                message,
+                extra={
+                    "origin_created": datetime.strptime(record["ts"], "%Y-%m-%dT%H:%M:%S.%f%z").timestamp(),
+                    "caller": caller,
+                },
+            )
+        except Exception:
+            # Parsing errors are unlikely, but we don't want to lose these messages anyway
+            logger.warning(line, extra={"caller": "p2pd"})
+            logger.exception("Failed to parse go-log message:")

+ 26 - 6
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -26,6 +26,8 @@ SUPPORT_CONN_PROTOCOLS = (
 SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
 logger = get_logger(__name__)
 
+DEFAULT_MAX_MSG_SIZE = 4 * 1024 ** 2
+
 
 def parse_conn_protocol(maddr: Multiaddr) -> int:
     proto_codes = set(proto.code for proto in maddr.protocols())
@@ -84,10 +86,13 @@ class ControlClient:
         daemon_connector: DaemonConnector,
         listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
         *,
-        _initialized_with_create=False,
+        _initialized_with_create: bool = False,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
     ) -> None:
         assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
 
+        self.persistent_conn_max_msg_size = persistent_conn_max_msg_size
+
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
@@ -107,8 +112,14 @@ class ControlClient:
         daemon_connector: DaemonConnector,
         listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
         use_persistent_conn: bool = True,
+        persistent_conn_max_msg_size=2 << 22,
     ) -> "ControlClient":
-        control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
+        control = cls(
+            daemon_connector,
+            listen_maddr,
+            _initialized_with_create=True,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
 
         if use_persistent_conn:
             await control._ensure_persistent_conn()
@@ -207,12 +218,18 @@ class ControlClient:
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
 
-        await self._pending_messages.put(
-            p2pd_pb.PersistentConnectionRequest(
+        payload = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, unaryResponse=response)
+        if payload.ByteSize() <= self.persistent_conn_max_msg_size:
+            await self._pending_messages.put(payload)
+        else:
+            error_msg = p2pd_pb.PersistentConnectionRequest(
                 callId=call_id.bytes,
-                unaryResponse=response,
+                callUnaryResponse=p2pd_pb.CallUnaryResponse(
+                    error=b"response size exceeds message size limit",
+                ),
             )
-        )
+            await self._pending_messages.put(error_msg)
+
         self._handler_tasks.pop(call_id)
 
     async def _cancel_unary_call(self, call_id: UUID):
@@ -255,6 +272,9 @@ class ControlClient:
             callUnary=call_unary_req,
         )
 
+        if req.ByteSize() > self.persistent_conn_max_msg_size:
+            raise P2PDaemonError(f"Message size exceeds set limit {self.persistent_conn_max_msg_size}")
+
         try:
             self._pending_calls[call_id] = asyncio.Future()
             await self._pending_messages.put(req)

+ 19 - 3
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,7 +10,13 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 from multiaddr import Multiaddr
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler, TUnaryHandler
+from hivemind.p2p.p2p_daemon_bindings.control import (
+    DEFAULT_MAX_MSG_SIZE,
+    ControlClient,
+    DaemonConnector,
+    StreamHandler,
+    TUnaryHandler,
+)
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
@@ -22,11 +28,21 @@ class Client:
         self.control = None
 
     @classmethod
-    async def create(cls, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> "Client":
+    async def create(
+        cls,
+        control_maddr: Multiaddr = None,
+        listen_maddr: Multiaddr = None,
+        *,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
+    ) -> "Client":
         client = cls(_initialized_with_create=True)
 
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        client.control = await ControlClient.create(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(
+            daemon_connector=daemon_connector,
+            listen_maddr=listen_maddr,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
 
         return client
 

+ 0 - 1
hivemind/utils/__init__.py

@@ -1,5 +1,4 @@
 from hivemind.utils.asyncio import *
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger

+ 12 - 1
hivemind/utils/asyncio.py

@@ -28,7 +28,7 @@ async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
     return await aiter.__anext__()
 
 
-async def aiter(*args: T) -> AsyncIterator[T]:
+async def as_aiter(*args: T) -> AsyncIterator[T]:
     """create an asynchronous iterator from a sequence of values"""
     for arg in args:
         yield arg
@@ -127,3 +127,14 @@ async def amap_in_executor(
     finally:
         if not task.done():
             task.cancel()
+
+
+async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> AsyncIterator[T]:
+    """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
+    # based on https://stackoverflow.com/a/50245879
+    iterator = iterable.__aiter__()
+    while True:
+        try:
+            yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
+        except StopAsyncIteration:
+            break

+ 0 - 209
hivemind/utils/compression.py

@@ -1,209 +0,0 @@
-import os
-import warnings
-from concurrent.futures import ThreadPoolExecutor
-from typing import Optional, Sequence, Tuple
-
-import numpy as np
-import torch
-
-from hivemind.proto import runtime_pb2
-from hivemind.proto.runtime_pb2 import CompressionType
-
-FP32_EPS = 1e-06
-NUM_BYTES_FLOAT32 = 4
-NUM_BYTES_FLOAT16 = 2
-NUM_BITS_QUANTILE_COMPRESSION = 8
-NUM_COMPRESSION_QUANTILES = 2 ** NUM_BITS_QUANTILE_COMPRESSION
-UNIFORM_BUCKETS_STD_RANGE = 6
-FP16_MAX = 65_504
-UINT8_RANGE = 256
-
-COMPRESSION_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTILE_COMPRESSION_THREADS", 128)))
-
-warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
-
-
-def _quantile_encode_approx(tensor: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
-    n_bins = 2 ** n_bits
-    borders = torch.as_tensor(_quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
-    quant_weight = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
-    lookup = average_buckets(tensor, quant_weight, n_bins)
-    return quant_weight, lookup
-
-
-def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
-    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
-    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
-    lookup = bin_sums / bin_counts
-    return lookup
-
-
-def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
-    """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
-    if not array.data.c_contiguous and array.data.f_contiguous:
-        array = array.T
-    array = np.ascontiguousarray(array.reshape(-1))
-    quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
-    chunk_size = _get_chunk_size(len(array), min_chunk_size)
-    num_chunks = (len(array) - 1) // chunk_size + 1
-    partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)
-
-    jobs = []
-    for i in range(num_chunks):
-        chunk = slice(chunk_size * i, chunk_size * (i + 1))
-        jobs.append(COMPRESSION_EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
-
-    for job in jobs:
-        job.result()
-    return np.quantile(partition_quantiles, quantiles)
-
-
-def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
-    """Adjust chunk_size to minimize imbalance between chunk sizes"""
-    if min_chunk_size >= num_elements:
-        return min_chunk_size
-    leftover_elements = num_elements % min_chunk_size
-    num_chunks = num_elements // min_chunk_size
-    return min_chunk_size + (leftover_elements - 1) // num_chunks + 1
-
-
-def _uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
-    offset = UINT8_RANGE // 2
-    shift = tensor.mean()
-    scale = range_in_sigmas * tensor.std() / UINT8_RANGE
-
-    quant_weight = torch.quantize_per_tensor(tensor - shift, scale, offset, torch.quint8).int_repr()
-    lookup = average_buckets(tensor, quant_weight, UINT8_RANGE)
-    return quant_weight, lookup
-
-
-def serialize_torch_tensor(
-    tensor: torch.Tensor, compression_type=CompressionType.NONE, allow_inplace=False
-) -> runtime_pb2.Tensor:
-    assert tensor.device == torch.device("cpu")
-    if compression_type == CompressionType.MEANSTD_16BIT:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        means = torch.mean(tensor, dim=-1, keepdim=True)
-        tensor.sub_(means)
-
-        stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
-        stds.clamp_min_(FP32_EPS)
-        tensor.div_(stds)
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = b"".join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="compressed_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type == CompressionType.FLOAT16:
-        assert tensor.dtype == torch.float32
-
-        tensor = tensor if allow_inplace else tensor.clone()
-        tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
-
-        data = tensor.numpy().tobytes()
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="clamped_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type == CompressionType.NONE:
-        array = tensor.numpy()
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=array.tobytes(),
-            size=array.shape,
-            dtype=array.dtype.name,
-            requires_grad=tensor.requires_grad,
-        )
-    elif compression_type in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        assert tensor.dtype == torch.float32
-
-        if compression_type == CompressionType.QUANTILE_8BIT:
-            quantized, lookup = _quantile_encode_approx(tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION)
-        elif compression_type == CompressionType.UNIFORM_8BIT:
-            quantized, lookup = _uint8_uniform_buckets_encode(tensor.detach(), UNIFORM_BUCKETS_STD_RANGE)
-        data = b"".join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes()))
-
-        proto = runtime_pb2.Tensor(
-            compression=compression_type,
-            buffer=data,
-            size=tensor.shape,
-            dtype="compressed_float32",
-            requires_grad=tensor.requires_grad,
-        )
-    else:
-        raise ValueError(f"Unknown compression type: {compression_type}")
-
-    return proto
-
-
-def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype] = None):
-    """Helper conversion function that handles edge case with scalar deserialization"""
-    if size:
-        return torch.as_tensor(array, dtype=dtype).view(*size)
-    else:
-        return torch.as_tensor(array, dtype=dtype)
-
-
-def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
-    if serialized_tensor.compression == CompressionType.NONE:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
-        tensor = construct_torch_tensor(array, serialized_tensor.size)
-
-    elif serialized_tensor.compression == CompressionType.MEANSTD_16BIT:
-        stats_size = list(serialized_tensor.size)
-        stats_size[-1] = 1
-        stats_count = np.prod(stats_size)
-
-        means = serialized_tensor.buffer[-2 * NUM_BYTES_FLOAT32 * stats_count : -NUM_BYTES_FLOAT32 * stats_count]
-        stds = serialized_tensor.buffer[-NUM_BYTES_FLOAT32 * stats_count :]
-        means = construct_torch_tensor(np.frombuffer(means, dtype=np.float32), stats_size)
-        stds = construct_torch_tensor(np.frombuffer(stds, dtype=np.float32), stats_size)
-
-        array = np.frombuffer(serialized_tensor.buffer[: -8 * stats_count], dtype=np.float16)
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means)
-
-    elif serialized_tensor.compression == CompressionType.FLOAT16:
-        array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
-        tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32)
-
-    elif serialized_tensor.compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        if serialized_tensor.compression == CompressionType.QUANTILE_8BIT:
-            lookup_size = NUM_COMPRESSION_QUANTILES * NUM_BYTES_FLOAT32
-        else:
-            lookup_size = UINT8_RANGE * NUM_BYTES_FLOAT32
-        lookup = serialized_tensor.buffer[:lookup_size]
-        quantized = serialized_tensor.buffer[lookup_size:]
-        lookup = torch.as_tensor(np.frombuffer(lookup, dtype=np.float32))
-        quantized = np.frombuffer(quantized, dtype=np.uint8)
-        quantized = construct_torch_tensor(quantized, serialized_tensor.size, dtype=torch.int64)
-        tensor = lookup[quantized]
-
-    else:
-        raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
-
-    tensor.requires_grad_(serialized_tensor.requires_grad)
-    return tensor
-
-
-def get_nbytes_per_value(dtype: torch.dtype, compression: CompressionType) -> int:
-    """returns the number of bytes per value for a given tensor (excluding metadata)"""
-    if compression in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT):
-        return 1
-    elif compression in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT):
-        return 2
-    elif compression == CompressionType.NONE:
-        return torch.finfo(dtype).bits // 8
-    else:
-        raise NotImplementedError(f"Unknown compression type: {CompressionType.Name(compression)}")

+ 171 - 17
hivemind/utils/logging.py

@@ -1,22 +1,176 @@
 import logging
 import os
+import sys
+import threading
+from enum import Enum
+from typing import Optional, Union
 
+logging.addLevelName(logging.WARNING, "WARN")
 
-def get_logger(module_name: str) -> logging.Logger:
-    # trim package name
-    name_without_prefix = ".".join(module_name.split(".")[1:])
-    loglevel = os.getenv("LOGLEVEL", "INFO")
-
-    logging.addLevelName(logging.WARNING, "WARN")
-    formatter = logging.Formatter(
-        fmt="[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}",
-        style="{",
-        datefmt="%Y/%m/%d %H:%M:%S",
-    )
-    handler = logging.StreamHandler()
-    handler.setFormatter(formatter)
-    logger = logging.getLogger(name_without_prefix)
-    logger.setLevel(loglevel)
-    logger.addHandler(handler)
+loglevel = os.getenv("LOGLEVEL", "INFO")
+
+_env_colors = os.getenv("HIVEMIND_COLORS")
+if _env_colors is not None:
+    use_colors = _env_colors.lower() == "true"
+else:
+    use_colors = sys.stderr.isatty()
+
+
+class HandlerMode(Enum):
+    NOWHERE = 0
+    IN_HIVEMIND = 1
+    IN_ROOT_LOGGER = 2
+
+
+_init_lock = threading.RLock()
+_current_mode = HandlerMode.IN_HIVEMIND
+_default_handler = None
+
+
+class TextStyle:
+    """
+    ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
+    """
+
+    RESET = "\033[0m"
+    BOLD = "\033[1m"
+    RED = "\033[31m"
+    BLUE = "\033[34m"
+    PURPLE = "\033[35m"
+    ORANGE = "\033[38;5;208m"  # From 8-bit palette
+
+    if not use_colors:
+        # Set the constants above to empty strings
+        _codes = locals()
+        _codes.update({_name: "" for _name in list(_codes) if _name.isupper()})
+
+
+class CustomFormatter(logging.Formatter):
+    """
+    A formatter that allows a log time and caller info to be overridden via
+    ``logger.log(level, message, extra={"origin_created": ..., "caller": ...})``.
+    """
+
+    # Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
+    _LEVEL_TO_COLOR = {
+        logging.DEBUG: TextStyle.PURPLE,
+        logging.INFO: TextStyle.BLUE,
+        logging.WARNING: TextStyle.ORANGE,
+        logging.ERROR: TextStyle.RED,
+        logging.CRITICAL: TextStyle.RED,
+    }
+
+    def format(self, record: logging.LogRecord) -> str:
+        if hasattr(record, "origin_created"):
+            record.created = record.origin_created
+            record.msecs = (record.created - int(record.created)) * 1000
+
+        if not hasattr(record, "caller"):
+            record.caller = f"{record.name}.{record.funcName}:{record.lineno}"
+
+        # Aliases for the format argument
+        record.levelcolor = self._LEVEL_TO_COLOR[record.levelno]
+        record.bold = TextStyle.BOLD
+        record.reset = TextStyle.RESET
+
+        return super().format(record)
+
+
+def _initialize_if_necessary():
+    global _current_mode, _default_handler
+
+    with _init_lock:
+        if _default_handler is not None:
+            return
+
+        formatter = CustomFormatter(
+            fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
+            style="{",
+            datefmt="%b %d %H:%M:%S",
+        )
+        _default_handler = logging.StreamHandler()
+        _default_handler.setFormatter(formatter)
+
+        _enable_default_handler("hivemind")
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+    """
+    Same as ``logging.getLogger()`` but ensures that the default log handler is initialized.
+    """
+
+    _initialize_if_necessary()
+    return logging.getLogger(name)
+
+
+def _enable_default_handler(name: str) -> None:
+    logger = get_logger(name)
+    logger.addHandler(_default_handler)
     logger.propagate = False
-    return logger
+    logger.setLevel(loglevel)
+
+
+def _disable_default_handler(name: str) -> None:
+    logger = get_logger(name)
+    logger.removeHandler(_default_handler)
+    logger.propagate = True
+    logger.setLevel(logging.NOTSET)
+
+
+def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
+    """
+    Choose loggers where the default hivemind log handler is applied. Options for the ``where`` argument are:
+
+    * "in_hivemind" (default): Use the hivemind log handler in the loggers of the ``hivemind`` package.
+                               Don't propagate their messages to the root logger.
+    * "nowhere": Don't use the hivemind log handler anywhere.
+                 Propagate the ``hivemind`` messages to the root logger.
+    * "in_root_logger": Use the hivemind log handler in the root logger
+                        (that is, in all application loggers until they disable propagation to the root logger).
+                        Propagate the ``hivemind`` messages to the root logger.
+
+    The options may be defined as strings (case-insensitive) or values from the HandlerMode enum.
+    """
+
+    global _current_mode
+
+    if isinstance(where, str):
+        # We allow `where` to be a string, so a developer does not have to import the enum for one usage
+        where = HandlerMode[where.upper()]
+
+    if where == _current_mode:
+        return
+
+    if _current_mode == HandlerMode.IN_HIVEMIND:
+        _disable_default_handler("hivemind")
+    elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
+        _disable_default_handler(None)
+
+    _current_mode = where
+
+    if _current_mode == HandlerMode.IN_HIVEMIND:
+        _enable_default_handler("hivemind")
+    elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
+        _enable_default_handler(None)
+
+
+def golog_level_to_python(level: str) -> int:
+    level = level.upper()
+    if level in ["DPANIC", "PANIC", "FATAL"]:
+        return logging.CRITICAL
+
+    level = logging.getLevelName(level)
+    if not isinstance(level, int):
+        raise ValueError(f"Unknown go-log level: {level}")
+    return level
+
+
+def python_level_to_golog(level: str) -> str:
+    if not isinstance(level, str):
+        raise ValueError("`level` is expected to be a Python log level in the string form")
+
+    if level == "CRITICAL":
+        return "FATAL"
+    if level == "WARNING":
+        return "WARN"
+    return level

+ 4 - 2
hivemind/utils/mpfuture.py

@@ -180,8 +180,10 @@ class MPFuture(base.Future, Generic[ResultType]):
                     future = future_ref()
 
                 if future is None:
-                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
-                elif update_type == UpdateType.RESULT:
+                    # The MPFuture instance is already destroyed in this process
+                    # (the caller is not interested in the result)
+                    continue
+                if update_type == UpdateType.RESULT:
                     future.set_result(payload)
                 elif update_type == UpdateType.EXCEPTION:
                     future.set_exception(payload)

+ 11 - 4
hivemind/utils/tensor_descr.py

@@ -1,6 +1,10 @@
+from __future__ import annotations
+
 import warnings
 from dataclasses import asdict, dataclass
+from typing import Tuple
 
+import numpy as np
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
@@ -29,11 +33,14 @@ class TensorDescriptor(DescriptorBase):
     compression: CompressionType = CompressionType.NONE
 
     @property
-    def shape(self):
+    def shape(self) -> Tuple[int, ...]:
         return self.size
 
+    def numel(self) -> int:
+        return int(np.prod(self.size))
+
     @classmethod
-    def from_tensor(cls, tensor: torch.Tensor):
+    def from_tensor(cls, tensor: torch.Tensor) -> TensorDescriptor:
         return cls(
             tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
         )
@@ -55,7 +62,7 @@ class BatchTensorDescriptor(TensorDescriptor):
         super().__init__((None, *instance_size), **kwargs)
 
     @classmethod
-    def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
+    def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE) -> BatchTensorDescriptor:
         return cls(
             *tensor.shape[1:],
             dtype=tensor.dtype,
@@ -66,7 +73,7 @@ class BatchTensorDescriptor(TensorDescriptor):
             compression=compression if tensor.is_floating_point() else CompressionType.NONE
         )
 
-    def make_empty(self, *batch_size, **kwargs):
+    def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
 

+ 5 - 5
setup.py

@@ -14,9 +14,10 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.4"
-P2PD_CHECKSUM = "194dca06116fdd36bc4b681d18f3b9cb"
+P2PD_VERSION = "v0.3.6"
+P2PD_CHECKSUM = "627d0c3b475a29331fdfd1667e828f6d"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
+P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
 
 here = os.path.abspath(os.path.dirname(__file__))
 
@@ -85,11 +86,10 @@ def download_p2p_daemon():
     binary_path = os.path.join(install_path, "p2pd")
     if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
         print("Downloading Peer to Peer Daemon")
-        url = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
-        urllib.request.urlretrieve(url, binary_path)
+        urllib.request.urlretrieve(P2PD_BINARY_URL, binary_path)
         os.chmod(binary_path, 0o777)
         if md5(binary_path) != P2PD_CHECKSUM:
-            raise RuntimeError(f"Downloaded p2pd binary from {url} does not match with md5 checksum")
+            raise RuntimeError(f"Downloaded p2pd binary from {P2PD_BINARY_URL} does not match with md5 checksum")
 
 
 class BuildPy(build_py):

+ 3 - 3
tests/test_allreduce.py

@@ -6,12 +6,12 @@ from typing import Sequence
 import pytest
 import torch
 
-from hivemind import aenumerate
+from hivemind import Quantile8BitQuantization, aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
+from hivemind.compression import deserialize_torch_tensor
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import deserialize_torch_tensor
 
 
 @pytest.mark.forked
@@ -83,7 +83,7 @@ async def test_partitioning_asynchronous():
     tensors = [torch.randn(2048, 2048), torch.randn(1024, 4096), torch.randn(4096, 1024), torch.randn(30_000, 1024)]
     peer_fractions = [0.4, 0.3, 0.2, 0.1]
 
-    partition = TensorPartContainer(tensors, peer_fractions, compression_type=CompressionType.QUANTILE_8BIT)
+    partition = TensorPartContainer(tensors, peer_fractions, compression=Quantile8BitQuantization())
     read_started, read_finished = asyncio.Event(), asyncio.Event()
 
     async def write_tensors():

+ 0 - 57
tests/test_averaging.py

@@ -1,5 +1,4 @@
 import random
-import time
 
 import numpy as np
 import pytest
@@ -12,7 +11,6 @@ from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.partition import AllreduceException
 from hivemind.p2p import PeerID
-from hivemind.proto.runtime_pb2 import CompressionType
 
 from test_utils.dht_swarms import launch_dht_instances
 
@@ -169,61 +167,6 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
         process.shutdown()
 
 
-@pytest.mark.forked
-def test_allreduce_compression():
-    """this test ensures that compression works correctly when multiple tensors have different compression types"""
-
-    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
-    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
-    results = {}
-
-    FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
-
-    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
-        dht_instances = launch_dht_instances(2)
-        averager1 = hivemind.averaging.DecentralizedAverager(
-            [x.clone() for x in tensors1],
-            dht=dht_instances[0],
-            compression_type=compression_type_pair,
-            client_mode=True,
-            target_group_size=2,
-            prefix="mygroup",
-            start=True,
-        )
-        averager2 = hivemind.averaging.DecentralizedAverager(
-            [x.clone() for x in tensors2],
-            dht=dht_instances[1],
-            compression_type=compression_type_pair,
-            target_group_size=2,
-            prefix="mygroup",
-            start=True,
-        )
-
-        for future in averager1.step(wait=False), averager2.step(wait=False):
-            future.result()
-
-        with averager1.get_tensors() as averaged_tensors:
-            results[compression_type_pair] = averaged_tensors
-
-        for instance in [averager1, averager2] + dht_instances:
-            instance.shutdown()
-
-    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
-    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
-    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
-    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
-
-    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
-    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
-    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
-    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
-
-    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
-    for i in range(2):
-        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
-        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
-
-
 def compute_mean_std(averagers, unbiased=True):
     results = []
     for averager in averagers:

+ 213 - 0
tests/test_compression.py

@@ -0,0 +1,213 @@
+import multiprocessing as mp
+from ctypes import c_int32
+
+import pytest
+import torch
+import torch.nn as nn
+
+import hivemind
+from hivemind.compression import (
+    CompressionBase,
+    CompressionInfo,
+    Float16Compression,
+    NoCompression,
+    PerTensorCompression,
+    RoleAdaptiveCompression,
+    SizeAdaptiveCompression,
+    Uniform8BitQuantization,
+    deserialize_torch_tensor,
+    serialize_torch_tensor,
+)
+from hivemind.compression.adaptive import AdaptiveCompressionBase
+from hivemind.proto.runtime_pb2 import CompressionType
+
+from test_utils.dht_swarms import launch_dht_instances
+
+
+@pytest.mark.forked
+def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
+    torch.manual_seed(0)
+    X = torch.randn(*size)
+    assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
+    assert error.square().mean() < alpha
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
+    assert error.square().mean() < alpha
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
+    assert error.square().mean() < beta
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
+    assert error.square().mean() < beta
+
+    zeros = torch.zeros(5, 5)
+    for compression_type in CompressionType.values():
+        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
+
+
+@pytest.mark.forked
+def test_serialize_tensor():
+    def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
+        serialized_tensor = serialize_torch_tensor(tensor, compression)
+        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
+        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
+        restored = hivemind.combine_from_streaming(chunks)
+        assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)
+
+    tensor = torch.randn(512, 12288)
+    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
+        _check(tensor, CompressionType.NONE, chunk_size=chunk_size)
+
+    _check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
+    _check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
+    _check(torch.tensor(1.0), CompressionType.NONE)
+    _check(torch.tensor(1.0), CompressionType.FLOAT16)
+
+
+@pytest.mark.forked
+def test_allreduce_compression():
+    """this test ensures that compression works correctly when multiple tensors have different compression types"""
+
+    tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
+    tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
+    results = {}
+
+    FLOAT16, UINT8 = Float16Compression(), Uniform8BitQuantization()
+
+    for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        dht_instances = launch_dht_instances(2)
+        averager1 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors1],
+            dht=dht_instances[0],
+            compression=PerTensorCompression(compression_type_pair),
+            client_mode=True,
+            target_group_size=2,
+            prefix="mygroup",
+            start=True,
+        )
+        averager2 = hivemind.averaging.DecentralizedAverager(
+            [x.clone() for x in tensors2],
+            dht=dht_instances[1],
+            compression=PerTensorCompression(compression_type_pair),
+            target_group_size=2,
+            prefix="mygroup",
+            start=True,
+        )
+
+        for future in averager1.step(wait=False), averager2.step(wait=False):
+            future.result()
+
+        with averager1.get_tensors() as averaged_tensors:
+            results[compression_type_pair] = averaged_tensors
+
+        for instance in [averager1, averager2] + dht_instances:
+            instance.shutdown()
+
+    assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
+    assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
+    assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
+    assert torch.allclose(results[FLOAT16, UINT8][0], results[FLOAT16, FLOAT16][0])
+
+    assert not torch.allclose(results[UINT8, FLOAT16][1], results[UINT8, UINT8][1])
+    assert not torch.allclose(results[UINT8, FLOAT16][0], results[FLOAT16, FLOAT16][0])
+    assert not torch.allclose(results[UINT8, UINT8][0], results[FLOAT16, UINT8][0])
+    assert not torch.allclose(results[FLOAT16, UINT8][1], results[FLOAT16, FLOAT16][1])
+
+    reference = [(tensors1[i] + tensors2[i]) / 2 for i in range(len(tensors1))]
+    for i in range(2):
+        assert 0 < torch.mean(torch.square(results[FLOAT16, FLOAT16][i] - reference[i])).item() <= 1e-5
+        assert 1e-5 < torch.mean(torch.square(results[UINT8, UINT8][i] - reference[i])).item() <= 1e-2
+
+
+class TrackedCompression(AdaptiveCompressionBase):
+    def __init__(self, compression: CompressionBase):
+        self.compression = compression
+        self.mp_counter, self.mp_part_size = mp.Value(c_int32, 0), mp.Value(c_int32, 0)
+        super().__init__()
+
+    def choose_compression(self, info: CompressionInfo) -> CompressionBase:
+        return self.compression
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False):
+        self.mp_counter.value += 1
+        if info.part_size is not None:
+            self.mp_part_size.value = max(self.mp_part_size.value, info.part_size)
+        return self.compression.compress(tensor, info=info, allow_inplace=allow_inplace)
+
+
+def make_params():
+    return [
+        nn.Parameter(x)
+        for x in (
+            torch.randn([]),
+            torch.randn(1),
+            torch.randn(100),
+            torch.randn(1_000),
+            torch.randn(5_000),
+            torch.randn(10_000),
+        )
+    ]
+
+
+@pytest.mark.forked
+def test_adaptive_compression():
+    UINT8 = TrackedCompression(Uniform8BitQuantization())
+    FLOAT16 = TrackedCompression(Float16Compression())
+    FLOAT32 = TrackedCompression(NoCompression())
+    STATE_FP16 = TrackedCompression(Float16Compression())
+    STATE_FP32 = TrackedCompression(NoCompression())
+
+    averaging_compression_adaptive = RoleAdaptiveCompression(
+        parameter=FLOAT16,
+        gradient=SizeAdaptiveCompression(threshold=1_000, less=FLOAT16, greater_equal=UINT8),
+        optimizer=FLOAT32,
+        default=FLOAT32,
+    )
+
+    state_compression_adaptive = SizeAdaptiveCompression(
+        threshold=500,
+        less=STATE_FP32,
+        greater_equal=STATE_FP16,
+    )
+
+    averager1 = hivemind.TrainingAverager(
+        opt=torch.optim.Adam(make_params()),
+        average_parameters=True,
+        average_gradients=True,
+        average_opt_statistics=("exp_avg",),
+        compression=averaging_compression_adaptive,
+        state_compression=state_compression_adaptive,
+        prefix="test_avgr",
+        target_group_size=2,
+        part_size_bytes=5_000,
+        start=True,
+        dht=hivemind.DHT(start=True),
+    )
+
+    averager2 = hivemind.TrainingAverager(
+        opt=torch.optim.Adam(make_params()),
+        average_parameters=True,
+        average_gradients=True,
+        average_opt_statistics=("exp_avg",),
+        compression=averaging_compression_adaptive,
+        state_compression=state_compression_adaptive,
+        prefix="test_avgr",
+        target_group_size=2,
+        part_size_bytes=5_000,
+        start=True,
+        dht=hivemind.DHT(initial_peers=averager1.dht.get_visible_maddrs(), start=True),
+    )
+
+    futures = [averager1.step(wait=False), averager2.step(wait=False)]
+
+    for future in futures:
+        future.result()
+
+    assert UINT8.mp_counter.value == 4  # half gradients: 3 tensors, 1 is split
+    assert UINT8.mp_part_size.value == 5_000  # single byte tensors
+    assert FLOAT16.mp_counter.value == 13  # parameters and half gradients
+    assert FLOAT16.mp_part_size.value == 2_500  # two-byte tensors
+    assert FLOAT32.mp_counter.value == 16  # statistics
+    assert FLOAT32.mp_part_size.value == 1250  # four-byte tensors
+
+    averager1.load_state_from_peers()
+    assert STATE_FP16.mp_counter.value == STATE_FP32.mp_counter.value == 9
+    assert STATE_FP16.mp_part_size.value == STATE_FP32.mp_part_size.value == 0  # not partitioned

+ 34 - 64
tests/test_util_modules.py

@@ -9,6 +9,7 @@ import pytest
 import torch
 
 import hivemind
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
@@ -17,14 +18,14 @@ from hivemind.utils.asyncio import (
     achain,
     aenumerate,
     afirst,
-    aiter,
+    aiter_with_timeout,
     amap_in_executor,
     anext,
+    as_aiter,
     asingle,
     azip,
     cancel_and_wait,
 )
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 
 
@@ -322,24 +323,6 @@ def test_many_futures():
     p.join()
 
 
-def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
-    torch.manual_seed(0)
-    X = torch.randn(*size)
-    assert torch.allclose(deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.NONE)), X)
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
-    assert error.square().mean() < alpha
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
-    assert error.square().mean() < alpha
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
-    assert error.square().mean() < beta
-    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
-    assert error.square().mean() < beta
-
-    zeros = torch.zeros(5, 5)
-    for compression_type in CompressionType.values():
-        assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
-
-
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_channel_cache():
@@ -384,38 +367,6 @@ async def test_channel_cache():
             assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
 
 
-def test_serialize_tensor():
-    tensor = torch.randn(512, 12288)
-
-    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
-    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10 ** 9]:
-        chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
-        assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-        restored = hivemind.combine_from_streaming(chunks)
-        assert torch.allclose(deserialize_torch_tensor(restored), tensor)
-
-    chunk_size = 30 * 1024
-    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16)
-    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
-    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-    restored = hivemind.combine_from_streaming(chunks)
-    assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2)
-
-    tensor = torch.randint(0, 100, (512, 1, 1))
-    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
-    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
-    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
-    restored = hivemind.combine_from_streaming(chunks)
-    assert torch.allclose(deserialize_torch_tensor(restored), tensor)
-
-    scalar = torch.tensor(1.0)
-    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE)
-    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
-
-    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16)
-    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
-
-
 def test_serialize_tuple():
     test_pairs = (
         ((1, 2, 3), [1, 2, 3]),
@@ -478,20 +429,23 @@ def test_generic_data_classes():
 
 @pytest.mark.asyncio
 async def test_asyncio_utils():
-    res = [i async for i, item in aenumerate(aiter("a", "b", "c"))]
+    res = [i async for i, item in aenumerate(as_aiter("a", "b", "c"))]
     assert res == list(range(len(res)))
 
     num_steps = 0
-    async for elem in amap_in_executor(lambda x: x ** 2, aiter(*range(100)), max_prefetch=5):
+    async for elem in amap_in_executor(lambda x: x ** 2, as_aiter(*range(100)), max_prefetch=5):
         assert elem == num_steps ** 2
         num_steps += 1
     assert num_steps == 100
 
-    ours = [elem async for elem in amap_in_executor(max, aiter(*range(7)), aiter(*range(-50, 50, 10)), max_prefetch=1)]
+    ours = [
+        elem
+        async for elem in amap_in_executor(max, as_aiter(*range(7)), as_aiter(*range(-50, 50, 10)), max_prefetch=1)
+    ]
     ref = list(map(max, range(7), range(-50, 50, 10)))
     assert ours == ref
 
-    ours = [row async for row in azip(aiter("a", "b", "c"), aiter(1, 2, 3))]
+    ours = [row async for row in azip(as_aiter("a", "b", "c"), as_aiter(1, 2, 3))]
     ref = list(zip(["a", "b", "c"], [1, 2, 3]))
     assert ours == ref
 
@@ -507,18 +461,34 @@ async def test_asyncio_utils():
     with pytest.raises(StopAsyncIteration):
         await anext(iterator)
 
-    assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
+    assert [item async for item in achain(_aiterate(), as_aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
 
-    assert await asingle(aiter(1)) == 1
+    assert await asingle(as_aiter(1)) == 1
     with pytest.raises(ValueError):
-        await asingle(aiter())
+        await asingle(as_aiter())
     with pytest.raises(ValueError):
-        await asingle(aiter(1, 2, 3))
+        await asingle(as_aiter(1, 2, 3))
+
+    assert await afirst(as_aiter(1)) == 1
+    assert await afirst(as_aiter()) is None
+    assert await afirst(as_aiter(), -1) == -1
+    assert await afirst(as_aiter(1, 2, 3)) == 1
+
+    async def iterate_with_delays(delays):
+        for i, delay in enumerate(delays):
+            await asyncio.sleep(delay)
+            yield i
+
+    async for _ in aiter_with_timeout(iterate_with_delays([0.1] * 5), timeout=0.2):
+        pass
+
+    sleepy_aiter = iterate_with_delays([0.1, 0.1, 0.3, 0.1, 0.1])
+    num_steps = 0
+    with pytest.raises(asyncio.TimeoutError):
+        async for _ in aiter_with_timeout(sleepy_aiter, timeout=0.2):
+            num_steps += 1
 
-    assert await afirst(aiter(1)) == 1
-    assert await afirst(aiter()) is None
-    assert await afirst(aiter(), -1) == -1
-    assert await afirst(aiter(1, 2, 3)) == 1
+    assert num_steps == 2
 
 
 @pytest.mark.asyncio